Skip to content
Snippets Groups Projects
Commit 4a934128 authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

Using: 'train_dataset.transform = my_transform'

parent 879f574a
No related branches found
No related tags found
No related merge requests found
......@@ -423,9 +423,8 @@ if __name__=='__main__':
# rid of the bug.
from torchvision.transforms import functional as F, InterpolationMode
batch_size = 32
n_samples = 0.005
cost = nn.CrossEntropyLoss()
subpolicies1 = [
(("Invert", 0.8, None), ("Contrast", 0.2, 6)),
......@@ -445,32 +444,42 @@ if __name__=='__main__':
(("Rotate", 0.5, 3), ("TranslateX", 0.5, 5))
]
def test_autoaugment_policy(subpolicies):
aa_transform = AutoAugment()
aa_transform.subpolicies = subpolicies
train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False,
transform=None)
test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False,
transform=torchvision.transforms.ToTensor())
def test_autoaugment_policy(subpolicies, train_dataset, test_dataset):
aa_transform = AutoAugment()
aa_transform.subpolicies = subpolicies1
train_transform = transforms.Compose([
aa_transform,
transforms.ToTensor()
])
train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False,
transform=train_transform)
test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False,
transform=torchvision.transforms.ToTensor())
train_dataset.transform = train_transform
# create toy dataset from above uploaded data
train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, 0.01)
train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size=32, n_samples=0.1)
child_network = cn.lenet()
sgd = optim.SGD(child_network.parameters(), lr=1e-1)
cost = nn.CrossEntropyLoss()
best_acc, acc_log = train_child_network(child_network, train_loader, test_loader,
sgd, cost, max_epochs=100, logging=True)
best_acc, acc_log = train_child_network(child_network, train_loader, test_loader, sgd, cost, max_epochs=100)
return best_acc, acc_log
_, acc_log1 = test_autoaugment_policy(subpolicies1)
_, acc_log2 = test_autoaugment_policy(subpolicies2)
_, acc_log1 = test_autoaugment_policy(subpolicies1, train_dataset, test_dataset)
_, acc_log2 = test_autoaugment_policy(subpolicies2, train_dataset, test_dataset)
plt.plot(acc_log1, label='subpolicies1')
plt.plot(acc_log2, label='subpolicies2')
plt.xlabel('epochs')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment