From 5488b9220bf5d95105937bbc3ee41c56f0b0e18a Mon Sep 17 00:00:00 2001 From: Sun Jin Kim <sk2521@ic.ac.uk> Date: Sat, 19 Feb 2022 17:59:45 +0000 Subject: [PATCH] Change Demo policies in autoaugment.py. Now you can see policies2 does way better than policies 1! --- .../autoaugment_learners/autoaugment.py | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/MetaAugment/autoaugment_learners/autoaugment.py b/MetaAugment/autoaugment_learners/autoaugment.py index 977946bc..f9ba4900 100644 --- a/MetaAugment/autoaugment_learners/autoaugment.py +++ b/MetaAugment/autoaugment_learners/autoaugment.py @@ -409,6 +409,12 @@ if __name__=='__main__': from MetaAugment.main import * import MetaAugment.child_networks as cn import torchvision.transforms as transforms + + # If you get rid of this nextimport, the whole thing doesn't work... By the way this import also + # exists on the top of this document. + # I think this is because "import torchvision.transforms as transforms" overrides the import at + # the top of this file and does some funny stuff... Anyways we need to call this import again to get + # rid of the bug. from torchvision.transforms import functional as F, InterpolationMode batch_size = 32 @@ -419,16 +425,18 @@ if __name__=='__main__': (("Invert", 0.8, None), ("Contrast", 0.2, 6)), (("Rotate", 0.7, 2), ("Invert", 0.8, None)), (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)), - (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)), + (("ShearY", 0.5, 8), ("Invert", 0.7, None)), (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)) ] + # The one that i hand crafted. You'll see that this one usually reaches a much + # higher poerformance policies2 = [ - (("Color", 0.9, 9), ("Equalize", 0.6, None)), - (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)), - (("Brightness", 0.1, 3), ("Color", 0.7, 0)), - (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)), - (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)) + (("ShearY", 0.8, 4), ("Rotate", 0.5, 6)), + (("TranslateY", 0.7, 4), ("TranslateX", 0.8, 6)), + (("Rotate", 0.5, 3), ("ShearY", 0.8, 5)), + (("ShearX", 0.5, 6), ("TranslateY", 0.7, 3)), + (("Rotate", 0.5, 3), ("TranslateX", 0.5, 5)) ] def test_autoaugment_policy(policies): @@ -453,6 +461,8 @@ if __name__=='__main__': sgd = optim.SGD(child_network.parameters(), lr=1e-1) best_acc = train_child_network(child_network, train_loader, test_loader, sgd, cost, max_epochs=100) + + train_dataset test_autoaugment_policy(policies1) -- GitLab