Skip to content
Snippets Groups Projects
Commit 023eff7d authored by tbuckworth's avatar tbuckworth
Browse files

created solution_test.py

parent b0c27334
No related branches found
No related tags found
No related merge requests found
......@@ -10,7 +10,7 @@ from task import Task
class TaskTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.task_dict = load_task("data/training/d4f3cd78.json")
cls.task_dict = load_task()#"data/training/d4f3cd78.json")
cls.task = Task(cls.task_dict)
# Example of accessing input/output grids for the first example
cls.output_grid = cls.task.train_examples[0].output_grid.grid
......@@ -29,13 +29,13 @@ class TaskTester(unittest.TestCase):
self.assertTrue((out_grid == self.output_grid).all())
def test_task_class(self):
solution = "solution_2"
solution = "solution_1"
res = self.task.try_solution(solution)
self.assertTrue(res)
def test_task_empty(self):
res = self.task.try_solution("empty_solution")
self.assertFalse(res)
# def test_task_empty(self):
# res = self.task.try_solution("empty_solution")
# self.assertFalse(res)
if __name__ == '__main__':
unittest.main()
......@@ -8,9 +8,10 @@ import matplotlib.pyplot as plt
import subprocess
tasks = {
"lines": "0b148d64.json",
"grids": "90f3ed37.json",
# "lines": "0b148d64.json",
"lines": "0a938d79.json",
"pour": "d4f3cd78.json",
"grids": "90f3ed37.json",
"cross": "e21d9049.json",
"stripes": "f8c80d96.json"
}
......
all_rows(Rs):-
setof(R, C^Colour^input_colour(R,C,Colour), Rs).
all_cols(Cs):-
setof(C, R^Colour^input_colour(R,C,Colour), Cs).
row(R):-
all_rows(Rs),
member(R,Rs).
......
import copy
import unittest
import numpy as np
from main import load_task, FOL2grid, FOL2prolog, prolog2FOL_array, tasks
from task import Task
class TaskTester(unittest.TestCase):
# @classmethod
# def setUpClass(cls):
# cls.task_dict = load_task("data/training/d4f3cd78.json")
# cls.task = Task(cls.task_dict)
# # Example of accessing input/output grids for the first example
# cls.output_grid = cls.task.train_examples[0].output_grid.grid
# cls.out_preds = cls.task.train_examples[0].output_grid.preds
def try_solution(self, task_file, solution):
task_dict = load_task(f"data/training/{task_file}")
task = Task(task_dict)
res = task.try_solution(solution)
self.assertTrue(res)
def test_solution_1(self):
self.try_solution(tasks["lines"], 'solution_1')
def test_solution_2(self):
self.try_solution(tasks["pour"], 'solution_2')
if __name__ == '__main__':
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment