From c6136f11b5871ebaf274795bceeaa85ba4649166 Mon Sep 17 00:00:00 2001 From: Sun Jin Kim <sk2521@ic.ac.uk> Date: Mon, 25 Apr 2022 17:23:37 +0100 Subject: [PATCH] finish removing fun_num from aa_learners also update tests accordingly --- MetaAugment/autoaugment_learners/evo_learner.py | 1 - MetaAugment/autoaugment_learners/gru_learner.py | 9 ++++----- MetaAugment/autoaugment_learners/randomsearch_learner.py | 8 +++----- MetaAugment/autoaugment_learners/ucb_learner.py | 2 -- test/MetaAugment/test_aa_learner.py | 9 +++------ test/MetaAugment/test_gru_learner.py | 2 -- test/MetaAugment/test_randomsearch_learner.py | 3 +-- 7 files changed, 11 insertions(+), 23 deletions(-) diff --git a/MetaAugment/autoaugment_learners/evo_learner.py b/MetaAugment/autoaugment_learners/evo_learner.py index 8e1d5bc1..e9a65865 100644 --- a/MetaAugment/autoaugment_learners/evo_learner.py +++ b/MetaAugment/autoaugment_learners/evo_learner.py @@ -6,7 +6,6 @@ import pygad import pygad.torchga as torchga import copy import torch -from MetaAugment.controller_networks.evo_controller import Evo_learner from MetaAugment.autoaugment_learners.aa_learner import aa_learner, augmentation_space import MetaAugment.child_networks as cn diff --git a/MetaAugment/autoaugment_learners/gru_learner.py b/MetaAugment/autoaugment_learners/gru_learner.py index c06edec3..db8205d5 100644 --- a/MetaAugment/autoaugment_learners/gru_learner.py +++ b/MetaAugment/autoaugment_learners/gru_learner.py @@ -47,7 +47,6 @@ class gru_learner(aa_learner): def __init__(self, # parameters that define the search space sp_num=5, - fun_num=14, p_bins=11, m_bins=10, discrete_p_m=False, @@ -78,10 +77,10 @@ class gru_learner(aa_learner): print('Warning: Incompatible discrete_p_m=True input into gru_learner. \ discrete_p_m=False will be used') - super().__init__(sp_num, - fun_num, - p_bins, - m_bins, + super().__init__( + sp_num=sp_num, + p_bins=p_bins, + m_bins=m_bins, discrete_p_m=True, batch_size=batch_size, toy_flag=toy_flag, diff --git a/MetaAugment/autoaugment_learners/randomsearch_learner.py b/MetaAugment/autoaugment_learners/randomsearch_learner.py index 6541cd3f..09f6626f 100644 --- a/MetaAugment/autoaugment_learners/randomsearch_learner.py +++ b/MetaAugment/autoaugment_learners/randomsearch_learner.py @@ -38,7 +38,6 @@ class randomsearch_learner(aa_learner): def __init__(self, # parameters that define the search space sp_num=5, - fun_num=14, p_bins=11, m_bins=10, discrete_p_m=True, @@ -51,10 +50,9 @@ class randomsearch_learner(aa_learner): early_stop_num=30, ): - super().__init__(sp_num, - fun_num, - p_bins, - m_bins, + super().__init__(sp_num=sp_num, + p_bins=p_bins, + m_bins=m_bins, discrete_p_m=discrete_p_m, batch_size=batch_size, toy_flag=toy_flag, diff --git a/MetaAugment/autoaugment_learners/ucb_learner.py b/MetaAugment/autoaugment_learners/ucb_learner.py index e22f32ff..dc82c2ee 100644 --- a/MetaAugment/autoaugment_learners/ucb_learner.py +++ b/MetaAugment/autoaugment_learners/ucb_learner.py @@ -20,7 +20,6 @@ class ucb_learner(randomsearch_learner): def __init__(self, # parameters that define the search space sp_num=5, - fun_num=14, p_bins=11, m_bins=10, discrete_p_m=True, @@ -36,7 +35,6 @@ class ucb_learner(randomsearch_learner): ): super().__init__(sp_num=sp_num, - fun_num=14, p_bins=p_bins, m_bins=m_bins, discrete_p_m=discrete_p_m, diff --git a/test/MetaAugment/test_aa_learner.py b/test/MetaAugment/test_aa_learner.py index 3e280870..29af4f6d 100644 --- a/test/MetaAugment/test_aa_learner.py +++ b/test/MetaAugment/test_aa_learner.py @@ -25,13 +25,12 @@ def test_translate_operation_tensor(): softmax = torch.nn.Softmax(dim=0) - fun_num = random.randint(1, 14) + fun_num=14 p_bins = random.randint(2, 15) m_bins = random.randint(2, 15) - + agent = aal.aa_learner( sp_num=5, - fun_num=fun_num, p_bins=p_bins, m_bins=m_bins, discrete_p_m=True @@ -54,13 +53,12 @@ def test_translate_operation_tensor(): for i in range(2000): - fun_num = random.randint(1, 14) + fun_num = 14 p_bins = random.randint(1, 15) m_bins = random.randint(1, 15) agent = aal.aa_learner( sp_num=5, - fun_num=fun_num, p_bins=p_bins, m_bins=m_bins, discrete_p_m=False @@ -81,7 +79,6 @@ def test_translate_operation_tensor(): def test_test_autoaugment_policy(): agent = aal.aa_learner( sp_num=5, - fun_num=14, p_bins=11, m_bins=10, discrete_p_m=True, diff --git a/test/MetaAugment/test_gru_learner.py b/test/MetaAugment/test_gru_learner.py index 6ad8204f..b5c695cf 100644 --- a/test/MetaAugment/test_gru_learner.py +++ b/test/MetaAugment/test_gru_learner.py @@ -14,13 +14,11 @@ def test_generate_new_policy(): """ for _ in range(40): sp_num = random.randint(1,20) - fun_num = random.randint(1, 14) p_bins = random.randint(2, 15) m_bins = random.randint(2, 15) agent = aal.gru_learner( sp_num=sp_num, - fun_num=fun_num, p_bins=p_bins, m_bins=m_bins, cont_mb_size=2 diff --git a/test/MetaAugment/test_randomsearch_learner.py b/test/MetaAugment/test_randomsearch_learner.py index 5b67d98e..29cd812b 100644 --- a/test/MetaAugment/test_randomsearch_learner.py +++ b/test/MetaAugment/test_randomsearch_learner.py @@ -16,13 +16,12 @@ def test_generate_new_policy(): def my_test(discrete_p_m): for _ in range(40): sp_num = random.randint(1,20) - fun_num = random.randint(1, 14) + p_bins = random.randint(2, 15) m_bins = random.randint(2, 15) agent = aal.randomsearch_learner( sp_num=sp_num, - fun_num=fun_num, p_bins=p_bins, m_bins=m_bins, discrete_p_m=discrete_p_m -- GitLab