From 75083fc3c3053db7e06a73d0c5963a974cbaef37 Mon Sep 17 00:00:00 2001 From: Max Ramsay King <maxramsayking@gmail.com> Date: Wed, 27 Apr 2022 15:15:23 +0100 Subject: [PATCH] pol_dict bug --- MetaAugment/autoaugment_learners/aa_learner.py | 9 +++++---- MetaAugment/autoaugment_learners/rand_augment_learner.py | 8 ++++++++ 2 files changed, 13 insertions(+), 4 deletions(-) create mode 100644 MetaAugment/autoaugment_learners/rand_augment_learner.py diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py index a562a3dc..00d38aab 100644 --- a/MetaAugment/autoaugment_learners/aa_learner.py +++ b/MetaAugment/autoaugment_learners/aa_learner.py @@ -394,15 +394,16 @@ class aa_learner: first_trans, first_prob, first_mag = subpol[0] second_trans, second_prob, second_mag = subpol[1] components = (first_prob, first_mag, second_prob, second_mag) - if second_trans in pol_dict[first_trans]: - pol_dict[first_trans][second_trans].append(components) + if first_trans in pol_dict: + if second_trans in pol_dict[first_trans]: + pol_dict[first_trans][second_trans].append(components) + else: + pol_dict[first_trans]= {second_trans: [components]} else: pol_dict[first_trans]= {second_trans: [components]} self.policy_record[curr_pol] = (pol_dict, accuracy) self.num_pols_tested += 1 - - return accuracy diff --git a/MetaAugment/autoaugment_learners/rand_augment_learner.py b/MetaAugment/autoaugment_learners/rand_augment_learner.py new file mode 100644 index 00000000..b6974bef --- /dev/null +++ b/MetaAugment/autoaugment_learners/rand_augment_learner.py @@ -0,0 +1,8 @@ +import torch +import numpy as np +from MetaAugment.autoaugment_learners.randomsearch_learner import randomsearch_learner + +class rand_augment_learner(randomsearch_learner): + + def __init__(self): + pass \ No newline at end of file -- GitLab