diff --git a/.gitignore b/.gitignore index 74ee9c993d456549a615caccad39b8077158c767..0b5dd1e1e07b97f6795a39919b831d8b50b54a3e 100644 --- a/.gitignore +++ b/.gitignore @@ -152,3 +152,4 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +MetaAugment/__pycache__/main.cpython-38.pyc diff --git a/MetaAugment/autoaugment_learners/autoaugment.py b/MetaAugment/autoaugment_learners/autoaugment.py index 3b2459c3e6af2b32c7953935a78c29fed79fe96e..e578f315aca01677cb27793995ad8988be5351ff 100644 --- a/MetaAugment/autoaugment_learners/autoaugment.py +++ b/MetaAugment/autoaugment_learners/autoaugment.py @@ -456,7 +456,7 @@ if __name__=='__main__': def test_autoaugment_policy(subpolicies, train_dataset, test_dataset): aa_transform = AutoAugment() - aa_transform.subpolicies = subpolicies1 + aa_transform.subpolicies = subpolicies train_transform = transforms.Compose([ aa_transform, transforms.ToTensor() diff --git a/MetaAugment/autoaugment_learners/randomsearch_learner.py b/MetaAugment/autoaugment_learners/randomsearch_learner.py index 77a495e5c81a50b2734155c016f9a61a1f432306..7a888f3d365a67b963632705ae457ba4ec1d71b8 100644 --- a/MetaAugment/autoaugment_learners/randomsearch_learner.py +++ b/MetaAugment/autoaugment_learners/randomsearch_learner.py @@ -6,7 +6,24 @@ from MetaAugment.autoaugment_learners.autoaugment import * import torchvision.transforms.autoaugment as torchaa from torchvision.transforms import functional as F, InterpolationMode - +policies1 = [ + (("Invert", 0.8, None), ("Contrast", 0.2, 6)), + (("Rotate", 0.7, 2), ("Invert", 0.8, None)), + (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)), + (("ShearY", 0.5, 8), ("Invert", 0.7, None)), + (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)) + ] + +# The one that i hand crafted. You'll see that this one usually reaches a much +# higher poerformance +policies2 = [ + (("ShearY", 0.8, 4), ("Rotate", 0.5, 6)), + (("TranslateY", 0.7, 4), ("TranslateX", 0.8, 6)), + (("Rotate", 0.5, 3), ("ShearY", 0.8, 5)), + (("ShearX", 0.5, 6), ("TranslateY", 0.7, 3)), + (("Rotate", 0.5, 3), ("TranslateX", 0.5, 5)) + ] + class randomsearch_learner: def __init__(self): pass @@ -32,59 +49,33 @@ class randomsearch_learner: return good_policy - def test_autoaugment_policy(policies): - aa_transform = AutoAugment() - aa_transform.policies = policies + def test_autoaugment_policy(policy): + aa_transform = AutoAugment() + aa_transform.policies = policy train_transform = transforms.Compose([ aa_transform, transforms.ToTensor() ]) - train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, - transform=train_transform) - test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False, - transform=torchvision.transforms.ToTensor()) - - # create toy dataset from above uploaded data - train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, 0.01) - - child_network = cn.lenet() - sgd = optim.SGD(child_network.parameters(), lr=1e-1) - - best_acc = train_child_network(child_network, train_loader, test_loader, sgd, cost, max_epochs=100) - - train_dataset - +if __name__=='__main__': + train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, + transform=train_transform) + test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False, + transform=torchvision.transforms.ToTensor()) + train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size=32, n_samples=0.01) + child_network = cn.lenet() + sgd = optim.SGD(child_network.parameters(), lr=1e-1) -if __name__=='__main__': - - batch_size = 32 - n_samples = 0.005 cost = nn.CrossEntropyLoss() + best_acc, acc_log = train_child_network(child_network, train_loader, test_loader, + sgd, cost, max_epochs=100, logging=True) - policies1 = [ - (("Invert", 0.8, None), ("Contrast", 0.2, 6)), - (("Rotate", 0.7, 2), ("Invert", 0.8, None)), - (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)), - (("ShearY", 0.5, 8), ("Invert", 0.7, None)), - (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)) - ] - - # The one that i hand crafted. You'll see that this one usually reaches a much - # higher poerformance - policies2 = [ - (("ShearY", 0.8, 4), ("Rotate", 0.5, 6)), - (("TranslateY", 0.7, 4), ("TranslateX", 0.8, 6)), - (("Rotate", 0.5, 3), ("ShearY", 0.8, 5)), - (("ShearX", 0.5, 6), ("TranslateY", 0.7, 3)), - (("Rotate", 0.5, 3), ("TranslateX", 0.5, 5)) - ] - learner = RandomSearch_Learner() + learner = randomsearch_learner()