From 75083fc3c3053db7e06a73d0c5963a974cbaef37 Mon Sep 17 00:00:00 2001
From: Max Ramsay King <maxramsayking@gmail.com>
Date: Wed, 27 Apr 2022 15:15:23 +0100
Subject: [PATCH] pol_dict bug

---
 MetaAugment/autoaugment_learners/aa_learner.py           | 9 +++++----
 MetaAugment/autoaugment_learners/rand_augment_learner.py | 8 ++++++++
 2 files changed, 13 insertions(+), 4 deletions(-)
 create mode 100644 MetaAugment/autoaugment_learners/rand_augment_learner.py

diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py
index a562a3dc..00d38aab 100644
--- a/MetaAugment/autoaugment_learners/aa_learner.py
+++ b/MetaAugment/autoaugment_learners/aa_learner.py
@@ -394,15 +394,16 @@ class aa_learner:
             first_trans, first_prob, first_mag = subpol[0]
             second_trans, second_prob, second_mag = subpol[1]
             components = (first_prob, first_mag, second_prob, second_mag)
-            if second_trans in pol_dict[first_trans]:
-                pol_dict[first_trans][second_trans].append(components)
+            if first_trans in pol_dict:
+                if second_trans in pol_dict[first_trans]:
+                    pol_dict[first_trans][second_trans].append(components)
+                else:
+                    pol_dict[first_trans]= {second_trans: [components]}
             else:
                 pol_dict[first_trans]= {second_trans: [components]}
         self.policy_record[curr_pol] = (pol_dict, accuracy)
 
         self.num_pols_tested += 1
-        
-
         return accuracy
     
 
diff --git a/MetaAugment/autoaugment_learners/rand_augment_learner.py b/MetaAugment/autoaugment_learners/rand_augment_learner.py
new file mode 100644
index 00000000..b6974bef
--- /dev/null
+++ b/MetaAugment/autoaugment_learners/rand_augment_learner.py
@@ -0,0 +1,8 @@
+import torch 
+import numpy as np
+from MetaAugment.autoaugment_learners.randomsearch_learner import randomsearch_learner
+
+class rand_augment_learner(randomsearch_learner):
+
+    def __init__(self):
+        pass
\ No newline at end of file
-- 
GitLab