diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py index 950b7e11b8223b22208414594176b7a2911501dd..f67e99c78c6e9c963aadae9ba4b5e4ccf8961ea6 100644 --- a/MetaAugment/autoaugment_learners/aa_learner.py +++ b/MetaAugment/autoaugment_learners/aa_learner.py @@ -406,7 +406,7 @@ class aa_learner: return accuracy - def get_mega_policy(self, number_policies): + def get_mega_policy(self, number_policies=5): """ Produces a mega policy, based on the n best subpolicies (evo learner)/policies (other learners) @@ -419,6 +419,9 @@ class aa_learner: Returns: megapolicy -> [subpolicy, subpolicy, ...] """ + + number_policies = max(number_policies, len(self.history)) + inter_pol = sorted(self.history, key=lambda x: x[1], reverse = True)[:number_policies] megapol = [] diff --git a/MetaAugment/autoaugment_learners/ucb_learner.py b/MetaAugment/autoaugment_learners/ucb_learner.py index 053492b87626b13efe723e17b7cf63e3655c8967..5d9a32e7182851de1b03b670569894c722f4178a 100644 --- a/MetaAugment/autoaugment_learners/ucb_learner.py +++ b/MetaAugment/autoaugment_learners/ucb_learner.py @@ -147,7 +147,7 @@ class ucb_learner(randomsearch_learner): print(self.cnts) - def get_mega_policy(self, number_policies): + def get_mega_policy(self, number_policies=5): """ Produces a mega policy, based on the n best subpolicies (evo learner)/policies (other learners) diff --git a/test/MetaAugment/test_aa_learner.py b/test/MetaAugment/test_aa_learner.py index 96322b6d4ddf805a3182d8b3d9eedcc349245e95..cbb8d952b4f68899bf58557731f7bb7f023bfdd1 100644 --- a/test/MetaAugment/test_aa_learner.py +++ b/test/MetaAugment/test_aa_learner.py @@ -82,7 +82,7 @@ def test__test_autoaugment_policy(): p_bins=11, m_bins=10, discrete_p_m=True, - toy_size=0.004, + toy_size=0.002, max_epochs=20, early_stop_num=10 ) @@ -157,7 +157,15 @@ def test_exclude_method(): def test_get_mega_policy(): - agent = aal.randomsearch_learner() + agent = aal.randomsearch_learner( + sp_num=5, + p_bins=11, + m_bins=10, + discrete_p_m=True, + toy_size=0.002, + max_epochs=20, + early_stop_num=10 + ) child_network_architecture = cn.SimpleNet train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train', @@ -167,7 +175,9 @@ def test_get_mega_policy(): transform=torchvision.transforms.ToTensor()) agent.learn(train_dataset, test_dataset, child_network_architecture, 10) - mega_pol = agent.get_mega_policy() + mega_pol = agent.get_mega_policy(number_policies=30) + mega_pol = agent.get_mega_policy(number_policies=3) + mega_pol = agent.get_mega_policy(number_policies=1) print("megapol: ", mega_pol)