diff --git a/MetaAugment/CP2_Max.py b/MetaAugment/CP2_Max.py index df325504b77ea57327dcb82054fef873a74b6ccd..1ed7b470596259dc4a4bfa3f6cee006b68174e66 100644 --- a/MetaAugment/CP2_Max.py +++ b/MetaAugment/CP2_Max.py @@ -90,8 +90,8 @@ def train_model(transform_idx, p): batch_size = 32 n_samples = 0.005 - train_dataset = datasets.MNIST(root='./MetaAugment/train', train=True, download=False, transform=transform_train) - test_dataset = datasets.MNIST(root='./MetaAugment/test', train=False, download=False, transform=torchvision.transforms.ToTensor()) + train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, transform=transform_train) + test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False, transform=torchvision.transforms.ToTensor()) # create toy dataset from above uploaded data train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, 0.01) @@ -142,8 +142,8 @@ def callback_generation(ga_instance): # ORGANISING DATA # transforms = ['RandomResizedCrop', 'RandomHorizontalFlip', 'RandomVerticalCrop', 'RandomRotation'] -train_dataset = datasets.MNIST(root='./MetaAugment/train', train=True, download=True, transform=torchvision.transforms.ToTensor()) -test_dataset = datasets.MNIST(root='./MetaAugment/test', train=False, download=True, transform=torchvision.transforms.ToTensor()) +train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=True, transform=torchvision.transforms.ToTensor()) +test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=True, transform=torchvision.transforms.ToTensor()) n_samples = 0.02 # shuffle and take first n_samples %age of training dataset shuffled_train_dataset = torch.utils.data.Subset(train_dataset, torch.randperm(len(train_dataset)).tolist()) diff --git a/MetaAugment/UCB1_JC.ipynb b/MetaAugment/UCB1_JC.ipynb index ac88adec8cb81cd9e91a93b6a2d38ee87faa4f74..d3bbda39b16470afb555c1b859813f83ae82b6ab 100644 --- a/MetaAugment/UCB1_JC.ipynb +++ b/MetaAugment/UCB1_JC.ipynb @@ -144,8 +144,8 @@ " torchvision.transforms.ToTensor()])\n", "\n", " # open data and apply these transformations\n", - " train_dataset = datasets.MNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)\n", - " test_dataset = datasets.MNIST(root='./MetaAugment/test', train=False, download=True, transform=torchvision.transforms.ToTensor())\n", + " train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=True, transform=transform)\n", + " test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=True, transform=torchvision.transforms.ToTensor())\n", "\n", " # create toy dataset from above uploaded data\n", " train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size)\n", diff --git a/MetaAugment/__pycache__/main.cpython-38.pyc b/MetaAugment/__pycache__/main.cpython-38.pyc index 213ba2380b9f4137ea09f7c7fbea31a8aca33345..73b244741637ddc13240976d7760041aa3ac49ea 100644 Binary files a/MetaAugment/__pycache__/main.cpython-38.pyc and b/MetaAugment/__pycache__/main.cpython-38.pyc differ diff --git a/MetaAugment/main.py b/MetaAugment/main.py index c68f4b1a4509d27cf8b2a328d4352b6bde896b6c..1f2de939f8eee1364cde23bd4d5c0b0daa3ff99e 100644 --- a/MetaAugment/main.py +++ b/MetaAugment/main.py @@ -82,7 +82,7 @@ class AA_Learner: def __init__(self, controller): self.controller = controller - def learn(self, dataset, child_network, toy_flag): + def learn(self, train_dataset, test_dataset, child_network, toy_flag): ''' Deos what is seen in Figure 1 in the AutoAugment paper. @@ -94,9 +94,10 @@ class AA_Learner: while not good_policy_found: policy = self.controller.pop_policy() - train_loader, test_loader = prepare_dataset(dataset, policy, toy_flag) + train_loader, test_loader = create_toy(train_dataset, test_dataset, + batch_size=32, n_samples=0.005) - reward = train_model(child_network, train_loader, test_loader, sgd, cost, epoch) + reward = train_child_network(child_network, train_loader, test_loader, sgd, cost, epoch) self.controller.update(reward, policy)