Skip to content
Snippets Groups Projects
Commit 53d0d751 authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

Merge branch 'master' of gitlab.doc.ic.ac.uk:yw21218/metarl

parents 9605f53b 1de0f23d
No related branches found
No related tags found
No related merge requests found
Pipeline #272322 passed
...@@ -403,4 +403,26 @@ class aa_learner: ...@@ -403,4 +403,26 @@ class aa_learner:
self.num_pols_tested += 1 self.num_pols_tested += 1
self.history.append((policy,accuracy)) self.history.append((policy,accuracy))
return accuracy return accuracy
\ No newline at end of file
def get_mega_policy(self, number_policies):
"""
Produces a mega policy, based on the n best subpolicies (evo learner)/policies
(other learners)
Args:
number_policies -> int: Number of (sub)policies to be included in the mega
policy
Returns:
megapolicy -> [subpolicy, subpolicy, ...]
"""
inter_pol = sorted(self.history, key=lambda x: x[1], reverse = True)[:number_policies]
megapol = []
for pol in inter_pol:
megapol += pol[0]
return megapol
...@@ -221,21 +221,21 @@ class evo_learner(aa_learner): ...@@ -221,21 +221,21 @@ class evo_learner(aa_learner):
self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size) self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size)
for idx, (test_x, label_x) in enumerate(self.train_loader): for idx, (test_x, label_x) in enumerate(self.train_loader):
full_policy = self._get_single_policy_cov(test_x) sub_pol = self._get_single_policy_cov(test_x)
while self._in_pol_dict(full_policy): while self._in_pol_dict(sub_pol):
full_policy = self._get_single_policy_cov(test_x)[0] sub_pol = self._get_single_policy_cov(test_x)[0]
fit_val = self._test_autoaugment_policy(full_policy,child_network_architecture,train_dataset,test_dataset) fit_val = self._test_autoaugment_policy(sub_pol,child_network_architecture,train_dataset,test_dataset)
self.running_policy.append((full_policy, fit_val))
self.running_policy.append((sub_pol, fit_val))
if len(self.running_policy) > self.sp_num: if len(self.running_policy) > self.sp_num:
self.running_policy = sorted(self.running_policy, key=lambda x: x[1], reverse=True) self.running_policy = sorted(self.running_policy, key=lambda x: x[1], reverse=True)
self.running_policy = self.running_policy[:self.sp_num] self.running_policy = self.running_policy[:self.sp_num]
print("appended policy: ", self.running_policy)
if len(self.history_best) == 0: if len(self.history_best) == 0:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment