diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py index 00d38aab9b292b1a3e88aae9d76310fbc3fdc370..8378299133a122f970939df031d430f97445f4db 100644 --- a/MetaAugment/autoaugment_learners/aa_learner.py +++ b/MetaAugment/autoaugment_learners/aa_learner.py @@ -407,38 +407,23 @@ class aa_learner: return accuracy - # def demo_plot(self, train_dataset, test_dataset, child_network_architecture, n=5): - # """ - # I made this to plot a couple of accuracy graphs to help manually tune my gradient - # optimizer hyperparameters. + def get_mega_policy(self, number_policies): + """ + Produces a mega policy, based on the n best subpolicies (evo learner)/policies + (other learners) - # Saves a plot of `n` training accuracy graphs overlapped. - # """ - # acc_lists = [] - - # # This is dummy code - # # test out `n` random policies - # for _ in range(n): - # policy = self._generate_new_policy() - - # pprint(policy) - # reward, acc_list = self._test_autoaugment_policy(policy, - # child_network_architecture, - # train_dataset, - # test_dataset, - # logging=True) - - # self.history.append((policy, reward)) - # acc_lists.append(acc_list) - - # for acc_list in acc_lists: - # plt.plot(acc_list) - # plt.title('I ran 5 random policies to see if there is any sign of \ - # catastrophic failure during training. If there are \ - # any lines which reach significantly lower (>10%) \ - # accuracies, you might want to tune the hyperparameters') - # plt.xlabel('epoch') - # plt.ylabel('accuracy') - # plt.show() - # plt.savefig('training_graphs_without_policies') \ No newline at end of file + Args: + number_policies -> int: Number of (sub)policies to be included in the mega + policy + + Returns: + megapolicy -> [subpolicy, subpolicy, ...] + """ + inter_pol = sorted(self.history, key=lambda x: x[1], reverse = True)[:number_policies] + + megapol = [] + for pol in inter_pol: + megapol += pol[0] + + return megapol