diff --git a/MetaAugment/CP2_Max.py b/MetaAugment/CP2_Max.py index 1ed7b470596259dc4a4bfa3f6cee006b68174e66..005676f234792310fa0a11a6a170cd1822fb6e2a 100644 --- a/MetaAugment/CP2_Max.py +++ b/MetaAugment/CP2_Max.py @@ -96,11 +96,11 @@ def train_model(transform_idx, p): # create toy dataset from above uploaded data train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, 0.01) - train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size) - test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size) + # train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size) + # test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size) - print("Size of training dataset:\t", len(reduced_train_dataset)) - print("Size of testing dataset:\t", len(reduced_test_dataset), "\n") + # print("Size of training dataset:\t", len(reduced_train_dataset)) + # print("Size of testing dataset:\t", len(reduced_test_dataset), "\n") child_network = child_networks.lenet() sgd = optim.SGD(child_network.parameters(), lr=1e-1) diff --git a/MetaAugment/__pycache__/main.cpython-38.pyc b/MetaAugment/__pycache__/main.cpython-38.pyc index 41db5a07126aab1fac85821e7ede3feb4b8c8846..847abde12f173e9cdfacf2d0252f00f7a12f2e4a 100644 Binary files a/MetaAugment/__pycache__/main.cpython-38.pyc and b/MetaAugment/__pycache__/main.cpython-38.pyc differ diff --git a/MetaAugment/main.py b/MetaAugment/main.py index 1f2de939f8eee1364cde23bd4d5c0b0daa3ff99e..f599a88a82b9fb60e37812f854c61c8ac1e7ef8d 100644 --- a/MetaAugment/main.py +++ b/MetaAugment/main.py @@ -77,28 +77,24 @@ def train_child_network(child_network, train_loader, test_loader, sgd, return best_acc -# This is sort of how our AA_Learner class should look like: -class AA_Learner: - def __init__(self, controller): - self.controller = controller - def learn(self, train_dataset, test_dataset, child_network, toy_flag): - ''' - Deos what is seen in Figure 1 in the AutoAugment paper. +if __name__=='__main__': + import MetaAugment.child_networks as cn - 'res' stands for resolution of the discretisation of the search space. It could be - a tuple, with first entry regarding probability, second regarding magnitude - ''' - good_policy_found = False + batch_size = 32 + n_samples = 0.005 - while not good_policy_found: - policy = self.controller.pop_policy() + train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, + transform=torchvision.transforms.ToTensor()) + test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False, + transform=torchvision.transforms.ToTensor()) - train_loader, test_loader = create_toy(train_dataset, test_dataset, - batch_size=32, n_samples=0.005) + # create toy dataset from above uploaded data + train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, 0.01) - reward = train_child_network(child_network, train_loader, test_loader, sgd, cost, epoch) + child_network = cn.lenet() + sgd = optim.SGD(child_network.parameters(), lr=1e-1) + cost = nn.CrossEntropyLoss() + epoch = 20 - self.controller.update(reward, policy) - - return good_policy \ No newline at end of file + best_acc = train_child_network(child_network, train_loader, test_loader, sgd, cost, max_epochs=100) \ No newline at end of file