diff --git a/MetaAugment/autoaugment_learners/autoaugment.py b/MetaAugment/autoaugment_learners/autoaugment.py index 65dcad7dd1c2a74550b7c82722db85d7dd6659d6..3b2459c3e6af2b32c7953935a78c29fed79fe96e 100644 --- a/MetaAugment/autoaugment_learners/autoaugment.py +++ b/MetaAugment/autoaugment_learners/autoaugment.py @@ -423,9 +423,8 @@ if __name__=='__main__': # rid of the bug. from torchvision.transforms import functional as F, InterpolationMode - batch_size = 32 - n_samples = 0.005 - cost = nn.CrossEntropyLoss() + + subpolicies1 = [ (("Invert", 0.8, None), ("Contrast", 0.2, 6)), @@ -445,32 +444,42 @@ if __name__=='__main__': (("Rotate", 0.5, 3), ("TranslateX", 0.5, 5)) ] - def test_autoaugment_policy(subpolicies): - aa_transform = AutoAugment() - aa_transform.subpolicies = subpolicies + + train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, + transform=None) + test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False, + transform=torchvision.transforms.ToTensor()) + + + + def test_autoaugment_policy(subpolicies, train_dataset, test_dataset): + + aa_transform = AutoAugment() + aa_transform.subpolicies = subpolicies1 train_transform = transforms.Compose([ aa_transform, transforms.ToTensor() ]) - - train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, - transform=train_transform) - test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False, - transform=torchvision.transforms.ToTensor()) + train_dataset.transform = train_transform # create toy dataset from above uploaded data - train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, 0.01) + train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size=32, n_samples=0.1) child_network = cn.lenet() sgd = optim.SGD(child_network.parameters(), lr=1e-1) + cost = nn.CrossEntropyLoss() + + best_acc, acc_log = train_child_network(child_network, train_loader, test_loader, + sgd, cost, max_epochs=100, logging=True) - best_acc, acc_log = train_child_network(child_network, train_loader, test_loader, sgd, cost, max_epochs=100) return best_acc, acc_log - _, acc_log1 = test_autoaugment_policy(subpolicies1) - _, acc_log2 = test_autoaugment_policy(subpolicies2) + + _, acc_log1 = test_autoaugment_policy(subpolicies1, train_dataset, test_dataset) + _, acc_log2 = test_autoaugment_policy(subpolicies2, train_dataset, test_dataset) + plt.plot(acc_log1, label='subpolicies1') plt.plot(acc_log2, label='subpolicies2') plt.xlabel('epochs')