From f2e6f1592ca683f6aae7bb30b03478fba18088a1 Mon Sep 17 00:00:00 2001 From: Max Ramsay King <maxramsayking@gmail.com> Date: Wed, 27 Apr 2022 11:47:19 +0100 Subject: [PATCH] added the policy-accuracy record --- .../autoaugment_learners/aa_learner.py | 20 +++++++++++++++++++ .../autoaugment_learners/evo_learner.py | 11 +--------- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py index 48c05b95..cc452550 100644 --- a/MetaAugment/autoaugment_learners/aa_learner.py +++ b/MetaAugment/autoaugment_learners/aa_learner.py @@ -99,6 +99,8 @@ class aa_learner: self.fun_num = len(self.augmentation_space) self.op_tensor_length = self.fun_num + p_bins + m_bins if discrete_p_m else self.fun_num +2 + self.num_pols_tested = 0 + self.policy_record = {} def _translate_operation_tensor(self, operation_tensor, return_log_prob=False, argmax=False): @@ -329,6 +331,8 @@ class aa_learner: accuracy (float): best accuracy reached in any """ + + # we create an instance of the child network that we're going # to train. The method of creation depends on the type of # input we got for child_network_architecture @@ -378,8 +382,24 @@ class aa_learner: early_stop_num = self.early_stop_num, logging = logging, print_every_epoch=print_every_epoch) + + curr_pol = f'pol{self.num_pols_tested}' + pol_dict = {} + for subpol in policy: + subpol = subpol[0] + first_trans, first_prob, first_mag = subpol[0] + second_trans, second_prob, second_mag = subpol[1] + components = (first_prob, first_mag, second_prob, second_mag) + if second_trans in pol_dict[first_trans]: + pol_dict[first_trans][second_trans].append(components) + else: + pol_dict[first_trans]= {second_trans: [components]} + self.policy_record[curr_pol] = (pol_dict, accuracy) + + self.num_pols_tested += 1 # if logging is true, 'accuracy' is actually a tuple: (accuracy, accuracy_log) + return accuracy diff --git a/MetaAugment/autoaugment_learners/evo_learner.py b/MetaAugment/autoaugment_learners/evo_learner.py index 6bf682c1..c3dd315d 100644 --- a/MetaAugment/autoaugment_learners/evo_learner.py +++ b/MetaAugment/autoaugment_learners/evo_learner.py @@ -244,18 +244,9 @@ class evo_learner(aa_learner): if new_set == test_pol: return True self.policy_dict[trans1][trans2].append(new_set) - return False else: self.policy_dict[trans1] = {trans2: [new_set]} - if trans2 in self.policy_dict: - if trans1 in self.policy_dict[trans2]: - for test_pol in self.policy_dict[trans2][trans1]: - if new_set == test_pol: - return True - self.policy_dict[trans2][trans1].append(new_set) - return False - else: - self.policy_dict[trans2] = {trans1: [new_set]} + return False def set_up_instance(self, train_dataset, test_dataset, child_network_architecture): -- GitLab