Skip to content
Snippets Groups Projects
Commit 5488b922 authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

Change Demo policies in autoaugment.py. Now you can see policies2 does way better than policies 1!

parent c2395ee1
No related branches found
No related tags found
No related merge requests found
......@@ -409,6 +409,12 @@ if __name__=='__main__':
from MetaAugment.main import *
import MetaAugment.child_networks as cn
import torchvision.transforms as transforms
# If you get rid of this nextimport, the whole thing doesn't work... By the way this import also
# exists on the top of this document.
# I think this is because "import torchvision.transforms as transforms" overrides the import at
# the top of this file and does some funny stuff... Anyways we need to call this import again to get
# rid of the bug.
from torchvision.transforms import functional as F, InterpolationMode
batch_size = 32
......@@ -419,16 +425,18 @@ if __name__=='__main__':
(("Invert", 0.8, None), ("Contrast", 0.2, 6)),
(("Rotate", 0.7, 2), ("Invert", 0.8, None)),
(("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
(("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
(("ShearY", 0.5, 8), ("Invert", 0.7, None)),
(("AutoContrast", 0.5, None), ("Equalize", 0.9, None))
]
# The one that i hand crafted. You'll see that this one usually reaches a much
# higher poerformance
policies2 = [
(("Color", 0.9, 9), ("Equalize", 0.6, None)),
(("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
(("Brightness", 0.1, 3), ("Color", 0.7, 0)),
(("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
(("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9))
(("ShearY", 0.8, 4), ("Rotate", 0.5, 6)),
(("TranslateY", 0.7, 4), ("TranslateX", 0.8, 6)),
(("Rotate", 0.5, 3), ("ShearY", 0.8, 5)),
(("ShearX", 0.5, 6), ("TranslateY", 0.7, 3)),
(("Rotate", 0.5, 3), ("TranslateX", 0.5, 5))
]
def test_autoaugment_policy(policies):
......@@ -453,6 +461,8 @@ if __name__=='__main__':
sgd = optim.SGD(child_network.parameters(), lr=1e-1)
best_acc = train_child_network(child_network, train_loader, test_loader, sgd, cost, max_epochs=100)
train_dataset
test_autoaugment_policy(policies1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment