diff --git a/.gitignore b/.gitignore index 0ffad1fd82a5435de017aacbe88c1336683916f3..2c930906cc18ef7e51eceb7d1b0663058736ea02 100644 --- a/.gitignore +++ b/.gitignore @@ -202,8 +202,11 @@ celerybeat.pid # SageMath parsed files *.sage.py +# we don't want dataset directories **/test **/train +# but there is a unit test folder in main that we DO want to track +!test # user uplaod /react_backend/child_networks diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py index 561222a5fac35d5348f99da6a9fba31657afe133..756ae33e3f166017dcf7e14ea531555785c6b805 100644 --- a/MetaAugment/autoaugment_learners/aa_learner.py +++ b/MetaAugment/autoaugment_learners/aa_learner.py @@ -329,7 +329,9 @@ class aa_learner: accuracy (float): best accuracy reached in any """ - + # we create an instance of the child network that we're going + # to train. The method of creation depends on the type of + # input we got for child_network_architecture if isinstance(child_network_architecture, types.FunctionType): child_network = child_network_architecture() elif isinstance(child_network_architecture, type): diff --git a/test/MetaAugment/test_aa_learner.py b/test/MetaAugment/test_aa_learner.py index b1524988939e9adac0b45285921b8d058f087887..8960746c654e6be3f14652d017babcd172bf819c 100644 --- a/test/MetaAugment/test_aa_learner.py +++ b/test/MetaAugment/test_aa_learner.py @@ -116,4 +116,41 @@ def test_test_autoaugment_policy(): ) assert isinstance(acc, float) + + +def test_exclude_method(): + """ + we want to see if the exclude_methods + parameter is working properly in aa_learners + """ + + exclude_method = [ + 'ShearX', + 'Color', + 'Brightness', + 'Contrast' + ] + agent = aal.gru_learner( + exclude_method=exclude_method + ) + for _ in range(200): + new_pol, _ = agent.generate_new_policy() + print(new_pol) + for (op1, op2) in new_pol: + image_function_1 = op1[0] + image_function_2 = op2[0] + assert image_function_1 not in exclude_method + assert image_function_2 not in exclude_method + + agent = aal.randomsearch_learner( + exclude_method=exclude_method + ) + for _ in range(200): + new_pol, _ = agent.generate_new_policy() + print(new_pol) + for (op1, op2) in new_pol: + image_function_1 = op1[0] + image_function_2 = op2[0] + assert image_function_1 not in exclude_method + assert image_function_2 not in exclude_method \ No newline at end of file diff --git a/test/MetaAugment/test_evo_learner.py b/test/MetaAugment/test_evo_learner.py new file mode 100644 index 0000000000000000000000000000000000000000..b917fb3934555585c2b95a412580ab27f517c081 --- /dev/null +++ b/test/MetaAugment/test_evo_learner.py @@ -0,0 +1,44 @@ +import MetaAugment.autoaugment_learners as aal +import MetaAugment.child_networks as cn +import torchvision +import torchvision.datasets as datasets +from pprint import pprint + +def test_evo_learner(): + child_network_architecture = cn.SimpleNet + train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train', + train=True, download=True, transform=None) + test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test', + train=False, download=True, + transform=torchvision.transforms.ToTensor()) + + + learner = aal.evo_learner( + # parameters that define the search space + sp_num=5, + p_bins=11, + m_bins=10, + discrete_p_m=True, + exclude_method=['ShearX'], + # hyperparameters for when training the child_network + batch_size=8, + toy_size=0.0001, + learning_rate=1e-1, + max_epochs=float('inf'), + early_stop_num=30, + # evolutionary learner specific settings + num_solutions=3, + num_parents_mating=2, + ) + + # learn on the 3 policies we generated + learner.learn( + train_dataset=train_dataset, + test_dataset=test_dataset, + child_network_architecture=child_network_architecture, + iterations=2 + ) + + +if __name__=="__main__": + test_evo_learner()