diff --git a/MetaAugment/CP2_Max.py b/MetaAugment/CP2_Max.py
index df325504b77ea57327dcb82054fef873a74b6ccd..1ed7b470596259dc4a4bfa3f6cee006b68174e66 100644
--- a/MetaAugment/CP2_Max.py
+++ b/MetaAugment/CP2_Max.py
@@ -90,8 +90,8 @@ def train_model(transform_idx, p):
     batch_size = 32
     n_samples = 0.005
 
-    train_dataset = datasets.MNIST(root='./MetaAugment/train', train=True, download=False, transform=transform_train)
-    test_dataset = datasets.MNIST(root='./MetaAugment/test', train=False, download=False, transform=torchvision.transforms.ToTensor())
+    train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, transform=transform_train)
+    test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False, transform=torchvision.transforms.ToTensor())
 
     # create toy dataset from above uploaded data
     train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, 0.01)
@@ -142,8 +142,8 @@ def callback_generation(ga_instance):
 # ORGANISING DATA
 
 # transforms = ['RandomResizedCrop', 'RandomHorizontalFlip', 'RandomVerticalCrop', 'RandomRotation']
-train_dataset = datasets.MNIST(root='./MetaAugment/train', train=True, download=True, transform=torchvision.transforms.ToTensor())
-test_dataset = datasets.MNIST(root='./MetaAugment/test', train=False, download=True, transform=torchvision.transforms.ToTensor())
+train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=True, transform=torchvision.transforms.ToTensor())
+test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=True, transform=torchvision.transforms.ToTensor())
 n_samples = 0.02
 # shuffle and take first n_samples  %age of training dataset
 shuffled_train_dataset = torch.utils.data.Subset(train_dataset, torch.randperm(len(train_dataset)).tolist())
diff --git a/MetaAugment/UCB1_JC.ipynb b/MetaAugment/UCB1_JC.ipynb
index ac88adec8cb81cd9e91a93b6a2d38ee87faa4f74..d3bbda39b16470afb555c1b859813f83ae82b6ab 100644
--- a/MetaAugment/UCB1_JC.ipynb
+++ b/MetaAugment/UCB1_JC.ipynb
@@ -144,8 +144,8 @@
         "            torchvision.transforms.ToTensor()])\n",
         "\n",
         "        # open data and apply these transformations\n",
-        "        train_dataset = datasets.MNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)\n",
-        "        test_dataset = datasets.MNIST(root='./MetaAugment/test', train=False, download=True, transform=torchvision.transforms.ToTensor())\n",
+        "        train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=True, transform=transform)\n",
+        "        test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=True, transform=torchvision.transforms.ToTensor())\n",
         "\n",
         "        # create toy dataset from above uploaded data\n",
         "        train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size)\n",
diff --git a/MetaAugment/__pycache__/main.cpython-38.pyc b/MetaAugment/__pycache__/main.cpython-38.pyc
index 213ba2380b9f4137ea09f7c7fbea31a8aca33345..73b244741637ddc13240976d7760041aa3ac49ea 100644
Binary files a/MetaAugment/__pycache__/main.cpython-38.pyc and b/MetaAugment/__pycache__/main.cpython-38.pyc differ
diff --git a/MetaAugment/main.py b/MetaAugment/main.py
index c68f4b1a4509d27cf8b2a328d4352b6bde896b6c..1f2de939f8eee1364cde23bd4d5c0b0daa3ff99e 100644
--- a/MetaAugment/main.py
+++ b/MetaAugment/main.py
@@ -82,7 +82,7 @@ class AA_Learner:
     def __init__(self, controller):
         self.controller = controller
 
-    def learn(self, dataset, child_network, toy_flag):
+    def learn(self, train_dataset, test_dataset, child_network, toy_flag):
         '''
         Deos what is seen in Figure 1 in the AutoAugment paper.
 
@@ -94,9 +94,10 @@ class AA_Learner:
         while not good_policy_found:
             policy = self.controller.pop_policy()
 
-            train_loader, test_loader = prepare_dataset(dataset, policy, toy_flag)
+            train_loader, test_loader = create_toy(train_dataset, test_dataset,
+                                                    batch_size=32, n_samples=0.005)
 
-            reward = train_model(child_network, train_loader, test_loader, sgd, cost, epoch)
+            reward = train_child_network(child_network, train_loader, test_loader, sgd, cost, epoch)
 
             self.controller.update(reward, policy)