From 1de0f23de30ff045555f141b5221502286dddca4 Mon Sep 17 00:00:00 2001 From: Max Ramsay King <maxramsayking@gmail.com> Date: Wed, 27 Apr 2022 15:40:42 +0100 Subject: [PATCH] corrected megapolicy --- .../autoaugment_learners/aa_learner.py | 51 +++++++------------ 1 file changed, 18 insertions(+), 33 deletions(-) diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py index 00d38aab..83782991 100644 --- a/MetaAugment/autoaugment_learners/aa_learner.py +++ b/MetaAugment/autoaugment_learners/aa_learner.py @@ -407,38 +407,23 @@ class aa_learner: return accuracy - # def demo_plot(self, train_dataset, test_dataset, child_network_architecture, n=5): - # """ - # I made this to plot a couple of accuracy graphs to help manually tune my gradient - # optimizer hyperparameters. + def get_mega_policy(self, number_policies): + """ + Produces a mega policy, based on the n best subpolicies (evo learner)/policies + (other learners) - # Saves a plot of `n` training accuracy graphs overlapped. - # """ - # acc_lists = [] - - # # This is dummy code - # # test out `n` random policies - # for _ in range(n): - # policy = self._generate_new_policy() - - # pprint(policy) - # reward, acc_list = self._test_autoaugment_policy(policy, - # child_network_architecture, - # train_dataset, - # test_dataset, - # logging=True) - - # self.history.append((policy, reward)) - # acc_lists.append(acc_list) - - # for acc_list in acc_lists: - # plt.plot(acc_list) - # plt.title('I ran 5 random policies to see if there is any sign of \ - # catastrophic failure during training. If there are \ - # any lines which reach significantly lower (>10%) \ - # accuracies, you might want to tune the hyperparameters') - # plt.xlabel('epoch') - # plt.ylabel('accuracy') - # plt.show() - # plt.savefig('training_graphs_without_policies') \ No newline at end of file + Args: + number_policies -> int: Number of (sub)policies to be included in the mega + policy + + Returns: + megapolicy -> [subpolicy, subpolicy, ...] + """ + inter_pol = sorted(self.history, key=lambda x: x[1], reverse = True)[:number_policies] + + megapol = [] + for pol in inter_pol: + megapol += pol[0] + + return megapol -- GitLab