diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py index 9199601a2a84d93dbefb0443ff8c2a85b1d065ef..e5866049cd393dfecd5c220bd209da9e658fca4f 100644 --- a/MetaAugment/autoaugment_learners/aa_learner.py +++ b/MetaAugment/autoaugment_learners/aa_learner.py @@ -240,11 +240,11 @@ class aa_learner: accuracy = train_child_network(child_network, train_loader, test_loader, - sgd = optim.SGD(child_network.parameters(), lr=1e-1), + sgd = optim.SGD(child_network.parameters(), lr=3e-1), # sgd = optim.Adadelta(child_network.parameters(), lr=1e-2), cost = nn.CrossEntropyLoss(), max_epochs = 3000000, - early_stop_num = 10, + early_stop_num = 15, logging = logging, print_every_epoch=True) diff --git a/MetaAugment/autoaugment_learners/baseline.py b/MetaAugment/autoaugment_learners/baseline.py index 46249bb91c2e77605bbea71a8e690164c2771ea1..e33d4e1e33887bb81c6e7634697cc1a0e4840987 100644 --- a/MetaAugment/autoaugment_learners/baseline.py +++ b/MetaAugment/autoaugment_learners/baseline.py @@ -3,6 +3,7 @@ from pprint import pprint import torchvision.datasets as datasets import torchvision from MetaAugment.autoaugment_learners.aa_learner import aa_learner +import pickle train_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/train', train=True, download=True, transform=None) @@ -15,7 +16,15 @@ aalearner = aa_learner(discrete_p_m=True) # this policy is same as identity function, because probabaility and magnitude are both zero null_policy = [(("Contrast", 0.0, 0.0), ("Contrast", 0.0, 0.0))] -aalearner.test_autoaugment_policy(null_policy, child_network(), train_dataset, test_dataset, + +with open('bad_lenet_baseline.txt', 'w') as file: + file.write('') + +for _ in range(100): + acc = aalearner.test_autoaugment_policy(null_policy, child_network(), train_dataset, test_dataset, toy_flag=True, logging=False) + with open('bad_lenet_baseline.txt', 'a') as file: + file.write(str(acc)) + file.write('\n') pprint(aalearner.history) \ No newline at end of file diff --git a/MetaAugment/child_networks/bad_lenet.py b/MetaAugment/child_networks/bad_lenet.py index 30f0fb3212ba241869895219a2febe0ae96a7c2e..c85d432f6834df29d7a695b3540fcd8475309691 100644 --- a/MetaAugment/child_networks/bad_lenet.py +++ b/MetaAugment/child_networks/bad_lenet.py @@ -1,34 +1,34 @@ import torch.nn as nn -class Bad_LeNet(nn.Module): - # 1. I reduced the channel sizes of the convolutional layers - # 2. I reduced the number of fully ocnnected layers from 3 to 2 - # - # no. of weights: 25*2 + 25*2*4 + 16*4*10 = 250+640 = 890 - def __init__(self): - super().__init__() - self.conv1 = nn.Conv2d(1, 2, 5) - self.relu1 = nn.ReLU() - self.pool1 = nn.MaxPool2d(2) - self.conv2 = nn.Conv2d(2, 4, 5) - self.relu2 = nn.ReLU() - self.pool2 = nn.MaxPool2d(2) - self.fc1 = nn.Linear(16*4, 10) - self.relu3 = nn.ReLU() +# class Bad_LeNet(nn.Module): +# # 1. I reduced the channel sizes of the convolutional layers +# # 2. I reduced the number of fully ocnnected layers from 3 to 2 +# # +# # no. of weights: 25*2 + 25*2*4 + 16*4*10 = 250+640 = 890 +# def __init__(self): +# super().__init__() +# self.conv1 = nn.Conv2d(1, 2, 5) +# self.relu1 = nn.ReLU() +# self.pool1 = nn.MaxPool2d(2) +# self.conv2 = nn.Conv2d(2, 4, 5) +# self.relu2 = nn.ReLU() +# self.pool2 = nn.MaxPool2d(2) +# self.fc1 = nn.Linear(16*4, 10) +# self.relu3 = nn.ReLU() - def forward(self, x): - y = self.conv1(x) - y = self.relu1(y) - y = self.pool1(y) - y = self.conv2(y) - y = self.relu2(y) - y = self.pool2(y) - y = y.view(y.shape[0], -1) - y = self.fc1(y) - y = self.relu3(y) - return y +# def forward(self, x): +# y = self.conv1(x) +# y = self.relu1(y) +# y = self.pool1(y) +# y = self.conv2(y) +# y = self.relu2(y) +# y = self.pool2(y) +# y = y.view(y.shape[0], -1) +# y = self.fc1(y) +# y = self.relu3(y) +# return y class Bad_LeNet(nn.Module): # 1. I reduced the channel sizes of the convolutional layers diff --git a/bad_lenet_baseline.txt b/bad_lenet_baseline.txt new file mode 100644 index 0000000000000000000000000000000000000000..0ee426570305c6b7ab6eee3db0d80cd7b3d7d604 --- /dev/null +++ b/bad_lenet_baseline.txt @@ -0,0 +1,2 @@ +0.4399999976158142 +0.550000011920929