diff --git a/MetaAugment/autoaugment_learners/evo_learner.py b/MetaAugment/autoaugment_learners/evo_learner.py index 8e1d5bc198548c3e24bb3d2bd5ac2d1f39650923..e9a65865c46b786005c01b4ff9d19d418baaa988 100644 --- a/MetaAugment/autoaugment_learners/evo_learner.py +++ b/MetaAugment/autoaugment_learners/evo_learner.py @@ -6,7 +6,6 @@ import pygad import pygad.torchga as torchga import copy import torch -from MetaAugment.controller_networks.evo_controller import Evo_learner from MetaAugment.autoaugment_learners.aa_learner import aa_learner, augmentation_space import MetaAugment.child_networks as cn diff --git a/MetaAugment/autoaugment_learners/gru_learner.py b/MetaAugment/autoaugment_learners/gru_learner.py index c06edec316eed6982272abc685d6e02735e92adf..db8205d5f335f056f82b0e40557a73031ad72b1a 100644 --- a/MetaAugment/autoaugment_learners/gru_learner.py +++ b/MetaAugment/autoaugment_learners/gru_learner.py @@ -47,7 +47,6 @@ class gru_learner(aa_learner): def __init__(self, # parameters that define the search space sp_num=5, - fun_num=14, p_bins=11, m_bins=10, discrete_p_m=False, @@ -78,10 +77,10 @@ class gru_learner(aa_learner): print('Warning: Incompatible discrete_p_m=True input into gru_learner. \ discrete_p_m=False will be used') - super().__init__(sp_num, - fun_num, - p_bins, - m_bins, + super().__init__( + sp_num=sp_num, + p_bins=p_bins, + m_bins=m_bins, discrete_p_m=True, batch_size=batch_size, toy_flag=toy_flag, diff --git a/MetaAugment/autoaugment_learners/randomsearch_learner.py b/MetaAugment/autoaugment_learners/randomsearch_learner.py index 6541cd3f54980254d0001c969bf2eb90d57b0ad2..09f6626f8a42a35e5006c79188fef3d2947c6418 100644 --- a/MetaAugment/autoaugment_learners/randomsearch_learner.py +++ b/MetaAugment/autoaugment_learners/randomsearch_learner.py @@ -38,7 +38,6 @@ class randomsearch_learner(aa_learner): def __init__(self, # parameters that define the search space sp_num=5, - fun_num=14, p_bins=11, m_bins=10, discrete_p_m=True, @@ -51,10 +50,9 @@ class randomsearch_learner(aa_learner): early_stop_num=30, ): - super().__init__(sp_num, - fun_num, - p_bins, - m_bins, + super().__init__(sp_num=sp_num, + p_bins=p_bins, + m_bins=m_bins, discrete_p_m=discrete_p_m, batch_size=batch_size, toy_flag=toy_flag, diff --git a/MetaAugment/autoaugment_learners/ucb_learner.py b/MetaAugment/autoaugment_learners/ucb_learner.py index e22f32ff11cce285d3baf4a680ad044afd43045d..dc82c2ee75d22dd503f46212dd7251c79bb271db 100644 --- a/MetaAugment/autoaugment_learners/ucb_learner.py +++ b/MetaAugment/autoaugment_learners/ucb_learner.py @@ -20,7 +20,6 @@ class ucb_learner(randomsearch_learner): def __init__(self, # parameters that define the search space sp_num=5, - fun_num=14, p_bins=11, m_bins=10, discrete_p_m=True, @@ -36,7 +35,6 @@ class ucb_learner(randomsearch_learner): ): super().__init__(sp_num=sp_num, - fun_num=14, p_bins=p_bins, m_bins=m_bins, discrete_p_m=discrete_p_m, diff --git a/test/MetaAugment/test_aa_learner.py b/test/MetaAugment/test_aa_learner.py index 3e2808702a04746e625acd5b463cfe01f56687bd..29af4f6da149a9619bafe30ba03cabe6b77064a7 100644 --- a/test/MetaAugment/test_aa_learner.py +++ b/test/MetaAugment/test_aa_learner.py @@ -25,13 +25,12 @@ def test_translate_operation_tensor(): softmax = torch.nn.Softmax(dim=0) - fun_num = random.randint(1, 14) + fun_num=14 p_bins = random.randint(2, 15) m_bins = random.randint(2, 15) - + agent = aal.aa_learner( sp_num=5, - fun_num=fun_num, p_bins=p_bins, m_bins=m_bins, discrete_p_m=True @@ -54,13 +53,12 @@ def test_translate_operation_tensor(): for i in range(2000): - fun_num = random.randint(1, 14) + fun_num = 14 p_bins = random.randint(1, 15) m_bins = random.randint(1, 15) agent = aal.aa_learner( sp_num=5, - fun_num=fun_num, p_bins=p_bins, m_bins=m_bins, discrete_p_m=False @@ -81,7 +79,6 @@ def test_translate_operation_tensor(): def test_test_autoaugment_policy(): agent = aal.aa_learner( sp_num=5, - fun_num=14, p_bins=11, m_bins=10, discrete_p_m=True, diff --git a/test/MetaAugment/test_gru_learner.py b/test/MetaAugment/test_gru_learner.py index 6ad8204f9b8473482f00d5c5d6a9d1e391cf9e0b..b5c695cfdf2d988408d70d1379af4fbf7738ae15 100644 --- a/test/MetaAugment/test_gru_learner.py +++ b/test/MetaAugment/test_gru_learner.py @@ -14,13 +14,11 @@ def test_generate_new_policy(): """ for _ in range(40): sp_num = random.randint(1,20) - fun_num = random.randint(1, 14) p_bins = random.randint(2, 15) m_bins = random.randint(2, 15) agent = aal.gru_learner( sp_num=sp_num, - fun_num=fun_num, p_bins=p_bins, m_bins=m_bins, cont_mb_size=2 diff --git a/test/MetaAugment/test_randomsearch_learner.py b/test/MetaAugment/test_randomsearch_learner.py index 5b67d98e1f2e40d56b3aac2445f041f1372bbe9f..29cd812b1d428441d405556b53db3e65e2ab7bc6 100644 --- a/test/MetaAugment/test_randomsearch_learner.py +++ b/test/MetaAugment/test_randomsearch_learner.py @@ -16,13 +16,12 @@ def test_generate_new_policy(): def my_test(discrete_p_m): for _ in range(40): sp_num = random.randint(1,20) - fun_num = random.randint(1, 14) + p_bins = random.randint(2, 15) m_bins = random.randint(2, 15) agent = aal.randomsearch_learner( sp_num=sp_num, - fun_num=fun_num, p_bins=p_bins, m_bins=m_bins, discrete_p_m=discrete_p_m