Skip to content
Snippets Groups Projects
Commit 9916ea79 authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

fix mega_pol

parent 12e06a51
No related branches found
No related tags found
No related merge requests found
Pipeline #272330 canceled
...@@ -406,7 +406,7 @@ class aa_learner: ...@@ -406,7 +406,7 @@ class aa_learner:
return accuracy return accuracy
def get_mega_policy(self, number_policies): def get_mega_policy(self, number_policies=5):
""" """
Produces a mega policy, based on the n best subpolicies (evo learner)/policies Produces a mega policy, based on the n best subpolicies (evo learner)/policies
(other learners) (other learners)
...@@ -419,6 +419,9 @@ class aa_learner: ...@@ -419,6 +419,9 @@ class aa_learner:
Returns: Returns:
megapolicy -> [subpolicy, subpolicy, ...] megapolicy -> [subpolicy, subpolicy, ...]
""" """
number_policies = max(number_policies, len(self.history))
inter_pol = sorted(self.history, key=lambda x: x[1], reverse = True)[:number_policies] inter_pol = sorted(self.history, key=lambda x: x[1], reverse = True)[:number_policies]
megapol = [] megapol = []
......
...@@ -147,7 +147,7 @@ class ucb_learner(randomsearch_learner): ...@@ -147,7 +147,7 @@ class ucb_learner(randomsearch_learner):
print(self.cnts) print(self.cnts)
def get_mega_policy(self, number_policies): def get_mega_policy(self, number_policies=5):
""" """
Produces a mega policy, based on the n best subpolicies (evo learner)/policies Produces a mega policy, based on the n best subpolicies (evo learner)/policies
(other learners) (other learners)
......
...@@ -82,7 +82,7 @@ def test__test_autoaugment_policy(): ...@@ -82,7 +82,7 @@ def test__test_autoaugment_policy():
p_bins=11, p_bins=11,
m_bins=10, m_bins=10,
discrete_p_m=True, discrete_p_m=True,
toy_size=0.004, toy_size=0.002,
max_epochs=20, max_epochs=20,
early_stop_num=10 early_stop_num=10
) )
...@@ -157,7 +157,15 @@ def test_exclude_method(): ...@@ -157,7 +157,15 @@ def test_exclude_method():
def test_get_mega_policy(): def test_get_mega_policy():
agent = aal.randomsearch_learner() agent = aal.randomsearch_learner(
sp_num=5,
p_bins=11,
m_bins=10,
discrete_p_m=True,
toy_size=0.002,
max_epochs=20,
early_stop_num=10
)
child_network_architecture = cn.SimpleNet child_network_architecture = cn.SimpleNet
train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train', train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train',
...@@ -167,7 +175,9 @@ def test_get_mega_policy(): ...@@ -167,7 +175,9 @@ def test_get_mega_policy():
transform=torchvision.transforms.ToTensor()) transform=torchvision.transforms.ToTensor())
agent.learn(train_dataset, test_dataset, child_network_architecture, 10) agent.learn(train_dataset, test_dataset, child_network_architecture, 10)
mega_pol = agent.get_mega_policy() mega_pol = agent.get_mega_policy(number_policies=30)
mega_pol = agent.get_mega_policy(number_policies=3)
mega_pol = agent.get_mega_policy(number_policies=1)
print("megapol: ", mega_pol) print("megapol: ", mega_pol)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment