diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py index a562a3dcedab4dc7e6b86a921e87ef53a03f572b..00d38aab9b292b1a3e88aae9d76310fbc3fdc370 100644 --- a/MetaAugment/autoaugment_learners/aa_learner.py +++ b/MetaAugment/autoaugment_learners/aa_learner.py @@ -394,15 +394,16 @@ class aa_learner: first_trans, first_prob, first_mag = subpol[0] second_trans, second_prob, second_mag = subpol[1] components = (first_prob, first_mag, second_prob, second_mag) - if second_trans in pol_dict[first_trans]: - pol_dict[first_trans][second_trans].append(components) + if first_trans in pol_dict: + if second_trans in pol_dict[first_trans]: + pol_dict[first_trans][second_trans].append(components) + else: + pol_dict[first_trans]= {second_trans: [components]} else: pol_dict[first_trans]= {second_trans: [components]} self.policy_record[curr_pol] = (pol_dict, accuracy) self.num_pols_tested += 1 - - return accuracy diff --git a/MetaAugment/autoaugment_learners/rand_augment_learner.py b/MetaAugment/autoaugment_learners/rand_augment_learner.py new file mode 100644 index 0000000000000000000000000000000000000000..b6974befe7438f91db36c37cc1a6b5c15801813f --- /dev/null +++ b/MetaAugment/autoaugment_learners/rand_augment_learner.py @@ -0,0 +1,8 @@ +import torch +import numpy as np +from MetaAugment.autoaugment_learners.randomsearch_learner import randomsearch_learner + +class rand_augment_learner(randomsearch_learner): + + def __init__(self): + pass \ No newline at end of file