diff --git a/MetaAugment/autoaugment_learners/autoaugment.py b/MetaAugment/autoaugment_learners/autoaugment.py
index 65dcad7dd1c2a74550b7c82722db85d7dd6659d6..3b2459c3e6af2b32c7953935a78c29fed79fe96e 100644
--- a/MetaAugment/autoaugment_learners/autoaugment.py
+++ b/MetaAugment/autoaugment_learners/autoaugment.py
@@ -423,9 +423,8 @@ if __name__=='__main__':
     # rid of the bug.
     from torchvision.transforms import functional as F, InterpolationMode
 
-    batch_size = 32
-    n_samples = 0.005
-    cost = nn.CrossEntropyLoss()
+
+
 
     subpolicies1 = [
             (("Invert", 0.8, None), ("Contrast", 0.2, 6)),
@@ -445,32 +444,42 @@ if __name__=='__main__':
             (("Rotate", 0.5, 3), ("TranslateX", 0.5, 5))
             ]
 
-    def test_autoaugment_policy(subpolicies):
-        aa_transform = AutoAugment()
-        aa_transform.subpolicies = subpolicies
 
+
+    train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, 
+                                transform=None)
+    test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False,
+                                transform=torchvision.transforms.ToTensor())
+
+
+
+    def test_autoaugment_policy(subpolicies, train_dataset, test_dataset):
+
+        aa_transform = AutoAugment()
+        aa_transform.subpolicies = subpolicies1
         train_transform = transforms.Compose([
                                                 aa_transform,
                                                 transforms.ToTensor()
                                             ])
 
-
-        train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, 
-                                    transform=train_transform)
-        test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False,
-                                    transform=torchvision.transforms.ToTensor())
+        train_dataset.transform = train_transform
 
         # create toy dataset from above uploaded data
-        train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, 0.01)
+        train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size=32, n_samples=0.1)
 
         child_network = cn.lenet()
         sgd = optim.SGD(child_network.parameters(), lr=1e-1)
+        cost = nn.CrossEntropyLoss()
+
+        best_acc, acc_log = train_child_network(child_network, train_loader, test_loader,
+                                                    sgd, cost, max_epochs=100, logging=True)
 
-        best_acc, acc_log = train_child_network(child_network, train_loader, test_loader, sgd, cost, max_epochs=100)
         return best_acc, acc_log
 
-    _, acc_log1 = test_autoaugment_policy(subpolicies1)
-    _, acc_log2 = test_autoaugment_policy(subpolicies2)
+
+    _, acc_log1 = test_autoaugment_policy(subpolicies1, train_dataset, test_dataset)
+    _, acc_log2 = test_autoaugment_policy(subpolicies2, train_dataset, test_dataset)
+
     plt.plot(acc_log1, label='subpolicies1')
     plt.plot(acc_log2, label='subpolicies2')
     plt.xlabel('epochs')