From 5488b9220bf5d95105937bbc3ee41c56f0b0e18a Mon Sep 17 00:00:00 2001
From: Sun Jin Kim <sk2521@ic.ac.uk>
Date: Sat, 19 Feb 2022 17:59:45 +0000
Subject: [PATCH] Change Demo policies in autoaugment.py. Now you can see
 policies2 does way better than policies 1!

---
 .../autoaugment_learners/autoaugment.py       | 22 ++++++++++++++-----
 1 file changed, 16 insertions(+), 6 deletions(-)

diff --git a/MetaAugment/autoaugment_learners/autoaugment.py b/MetaAugment/autoaugment_learners/autoaugment.py
index 977946bc..f9ba4900 100644
--- a/MetaAugment/autoaugment_learners/autoaugment.py
+++ b/MetaAugment/autoaugment_learners/autoaugment.py
@@ -409,6 +409,12 @@ if __name__=='__main__':
     from MetaAugment.main import *
     import MetaAugment.child_networks as cn
     import torchvision.transforms as transforms
+
+    # If you get rid of this nextimport, the whole thing doesn't work... By the way this import also 
+    # exists on the top of this document.
+    # I think this is because "import torchvision.transforms as transforms" overrides the import at 
+    # the top of this file and does some funny stuff... Anyways we need to call this import again to get
+    # rid of the bug.
     from torchvision.transforms import functional as F, InterpolationMode
 
     batch_size = 32
@@ -419,16 +425,18 @@ if __name__=='__main__':
             (("Invert", 0.8, None), ("Contrast", 0.2, 6)),
             (("Rotate", 0.7, 2), ("Invert", 0.8, None)),
             (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
-            (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
+            (("ShearY", 0.5, 8), ("Invert", 0.7, None)),
             (("AutoContrast", 0.5, None), ("Equalize", 0.9, None))
             ]
 
+    # The one that i hand crafted. You'll see that this one usually reaches a much
+    # higher poerformance
     policies2 = [
-            (("Color", 0.9, 9), ("Equalize", 0.6, None)),
-            (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
-            (("Brightness", 0.1, 3), ("Color", 0.7, 0)),
-            (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
-            (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9))
+            (("ShearY", 0.8, 4), ("Rotate", 0.5, 6)),
+            (("TranslateY", 0.7, 4), ("TranslateX", 0.8, 6)),
+            (("Rotate", 0.5, 3), ("ShearY", 0.8, 5)),
+            (("ShearX", 0.5, 6), ("TranslateY", 0.7, 3)),
+            (("Rotate", 0.5, 3), ("TranslateX", 0.5, 5))
             ]
 
     def test_autoaugment_policy(policies):
@@ -453,6 +461,8 @@ if __name__=='__main__':
         sgd = optim.SGD(child_network.parameters(), lr=1e-1)
 
         best_acc = train_child_network(child_network, train_loader, test_loader, sgd, cost, max_epochs=100)
+
+        train_dataset
     
 
     test_autoaugment_policy(policies1)
-- 
GitLab