Skip to content
Snippets Groups Projects
Commit c16fe2bc authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

work on randomsearch_learner

parent cef7662c
No related branches found
No related tags found
No related merge requests found
......@@ -152,3 +152,4 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
MetaAugment/__pycache__/main.cpython-38.pyc
......@@ -456,7 +456,7 @@ if __name__=='__main__':
def test_autoaugment_policy(subpolicies, train_dataset, test_dataset):
aa_transform = AutoAugment()
aa_transform.subpolicies = subpolicies1
aa_transform.subpolicies = subpolicies
train_transform = transforms.Compose([
aa_transform,
transforms.ToTensor()
......
......@@ -6,7 +6,24 @@ from MetaAugment.autoaugment_learners.autoaugment import *
import torchvision.transforms.autoaugment as torchaa
from torchvision.transforms import functional as F, InterpolationMode
policies1 = [
(("Invert", 0.8, None), ("Contrast", 0.2, 6)),
(("Rotate", 0.7, 2), ("Invert", 0.8, None)),
(("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
(("ShearY", 0.5, 8), ("Invert", 0.7, None)),
(("AutoContrast", 0.5, None), ("Equalize", 0.9, None))
]
# The one that i hand crafted. You'll see that this one usually reaches a much
# higher poerformance
policies2 = [
(("ShearY", 0.8, 4), ("Rotate", 0.5, 6)),
(("TranslateY", 0.7, 4), ("TranslateX", 0.8, 6)),
(("Rotate", 0.5, 3), ("ShearY", 0.8, 5)),
(("ShearX", 0.5, 6), ("TranslateY", 0.7, 3)),
(("Rotate", 0.5, 3), ("TranslateX", 0.5, 5))
]
class randomsearch_learner:
def __init__(self):
pass
......@@ -32,59 +49,33 @@ class randomsearch_learner:
return good_policy
def test_autoaugment_policy(policies):
aa_transform = AutoAugment()
aa_transform.policies = policies
def test_autoaugment_policy(policy):
aa_transform = AutoAugment()
aa_transform.policies = policy
train_transform = transforms.Compose([
aa_transform,
transforms.ToTensor()
])
train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False,
transform=train_transform)
test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False,
transform=torchvision.transforms.ToTensor())
# create toy dataset from above uploaded data
train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, 0.01)
child_network = cn.lenet()
sgd = optim.SGD(child_network.parameters(), lr=1e-1)
best_acc = train_child_network(child_network, train_loader, test_loader, sgd, cost, max_epochs=100)
train_dataset
if __name__=='__main__':
train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False,
transform=train_transform)
test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False,
transform=torchvision.transforms.ToTensor())
train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size=32, n_samples=0.01)
child_network = cn.lenet()
sgd = optim.SGD(child_network.parameters(), lr=1e-1)
if __name__=='__main__':
batch_size = 32
n_samples = 0.005
cost = nn.CrossEntropyLoss()
best_acc, acc_log = train_child_network(child_network, train_loader, test_loader,
sgd, cost, max_epochs=100, logging=True)
policies1 = [
(("Invert", 0.8, None), ("Contrast", 0.2, 6)),
(("Rotate", 0.7, 2), ("Invert", 0.8, None)),
(("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
(("ShearY", 0.5, 8), ("Invert", 0.7, None)),
(("AutoContrast", 0.5, None), ("Equalize", 0.9, None))
]
# The one that i hand crafted. You'll see that this one usually reaches a much
# higher poerformance
policies2 = [
(("ShearY", 0.8, 4), ("Rotate", 0.5, 6)),
(("TranslateY", 0.7, 4), ("TranslateX", 0.8, 6)),
(("Rotate", 0.5, 3), ("ShearY", 0.8, 5)),
(("ShearX", 0.5, 6), ("TranslateY", 0.7, 3)),
(("Rotate", 0.5, 3), ("TranslateX", 0.5, 5))
]
learner = RandomSearch_Learner()
learner = randomsearch_learner()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment