diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py
index 4b7cd86ab04508bf77f69a92e5615f9faaeebc03..60d463c273c59f80753a65926a1ea6134e0c5ddb 100644
--- a/MetaAugment/autoaugment_learners/aa_learner.py
+++ b/MetaAugment/autoaugment_learners/aa_learner.py
@@ -1,14 +1,12 @@
-# The parent class for all other autoaugment learners``
+# The parent class for all other autoaugment learners
 
 import torch
-import numpy as np
-from MetaAugment.main import *
-import MetaAugment.child_networks as cn
-import torchvision.transforms as transforms
-from MetaAugment.autoaugment_learners.autoaugment import *
+import torch.nn as nn
+import torch.optim as optim
+from MetaAugment.main import train_child_network, create_toy
+from MetaAugment.autoaugment_learners.autoaugment import AutoAugment
 
-import torchvision.transforms.autoaugment as torchaa
-from torchvision.transforms import functional as F, InterpolationMode
+import torchvision.transforms as transforms
 
 from pprint import pprint
 
diff --git a/MetaAugment/autoaugment_learners/gru_learner.py b/MetaAugment/autoaugment_learners/gru_learner.py
index 58970e68668f1243486072c3a9be882a544e7e71..709fdabc5e3878727c6be77ce4279113cbc1afbc 100644
--- a/MetaAugment/autoaugment_learners/gru_learner.py
+++ b/MetaAugment/autoaugment_learners/gru_learner.py
@@ -1,13 +1,8 @@
 import torch
-import numpy as np
-import torchvision.transforms as transforms
-import torchvision.transforms.autoaugment as torchaa
-from torchvision.transforms import functional as F, InterpolationMode
 
-from MetaAugment.main import *
 import MetaAugment.child_networks as cn
-from MetaAugment.autoaugment_learners.autoaugment import *
-from MetaAugment.autoaugment_learners.aa_learner import *
+from MetaAugment.autoaugment_learners.aa_learner import aa_learner
+from MetaAugment.controller_networks.rnn_controller import RNNModel
 
 from pprint import pprint
 
@@ -33,6 +28,7 @@ augmentation_space = [
             ("Invert", False),
         ]
 
+
 class gru_learner(aa_learner):
     # Uses a GRU controller which is updated via Proximal Polixy Optimization
     # It is the same model use in
@@ -50,8 +46,9 @@ class gru_learner(aa_learner):
         '''
         super().__init__(sp_num, fun_num, p_bins, m_bins, discrete_p_m)
 
-        # TODO: We should probably use a different way to store results than self.history
-        self.history = []
+        # input_size of the RNNModel can be chosen arbitrarily as we don't put any inputs in it.
+        self.controller = RNNModel(mode='GRU', input_size=1, hidden_size=40, num_layers=1,
+                         bias=True, output_size=fun_num+p_bins+m_bins)
 
 
     def generate_new_policy(self):
@@ -68,25 +65,7 @@ class gru_learner(aa_learner):
             (("ShearY", 0.5, 8), ("Invert", 0.7, None)),
             ]
         '''
-        new_policy = []
-        
-        for _ in range(self.sp_num): # generate sp_num subpolicies for each policy
-            ops = []
-            # generate 2 operations for each subpolicy
-            for i in range(2):
-                # if our agent uses discrete representations of probability and magnitude
-                if self.discrete_p_m:
-                    new_op = self.generate_new_discrete_operation()
-                else:
-                    new_op = self.generate_new_continuous_operation()
-                new_op = self.translate_operation_tensor(new_op)
-                ops.append(new_op)
-
-            new_subpolicy = tuple(ops)
-
-            new_policy.append(new_subpolicy)
-
-        return new_policy
+        new_policy = self.controller(input=torch.rand(1))
 
 
     def learn(self, train_dataset, test_dataset, child_network_architecture, toy_flag):
@@ -114,13 +93,15 @@ if __name__=='__main__':
     # We can initialize the train_dataset with its transform as None.
     # Later on, we will change this object's transform attribute to the policy
     # that we want to test
-    train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, 
+    train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=True, 
                                 transform=None)
-    test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False,
+    test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=True,
                                 transform=torchvision.transforms.ToTensor())
     child_network = cn.lenet
 
     
-    rs_learner = randomsearch_learner(discrete_p_m=False)
-    rs_learner.learn(train_dataset, test_dataset, child_network, toy_flag=True)
-    pprint(rs_learner.history)
\ No newline at end of file
+    learner = gru_learner(discrete_p_m=False)
+    print(learner.generate_new_policy())
+    breakpoint()
+    learner.learn(train_dataset, test_dataset, child_network, toy_flag=True)
+    pprint(learner.history)
\ No newline at end of file
diff --git a/MetaAugment/autoaugment_learners/randomsearch_learner.py b/MetaAugment/autoaugment_learners/randomsearch_learner.py
index 3657224f7bb538c0c436c3f12ff451ce1ccc2c1b..02798927072fa7cd3d1a39959f229f1a57d41f00 100644
--- a/MetaAugment/autoaugment_learners/randomsearch_learner.py
+++ b/MetaAugment/autoaugment_learners/randomsearch_learner.py
@@ -1,13 +1,8 @@
 import torch
 import numpy as np
-import torchvision.transforms as transforms
-import torchvision.transforms.autoaugment as torchaa
-from torchvision.transforms import functional as F, InterpolationMode
 
-from MetaAugment.main import *
 import MetaAugment.child_networks as cn
-from MetaAugment.autoaugment_learners.autoaugment import *
-from MetaAugment.autoaugment_learners.aa_learner import *
+from MetaAugment.autoaugment_learners.aa_learner import aa_learner
 
 from pprint import pprint
 
@@ -43,9 +38,7 @@ class randomsearch_learner(aa_learner):
             m_bins: number of bins we divide the magnitude space
         '''
         super().__init__(sp_num, fun_num, p_bins, m_bins, discrete_p_m)
-
-        # TODO: We should probably use a different way to store results than self.history
-        self.history = []
+        
 
     def generate_new_discrete_operation(self):
         '''
diff --git a/MetaAugment/main.py b/MetaAugment/main.py
index 0fd76bcf189a297f1d8decd88f39851f4ce3433c..5b0e04e47202272e25df81dabf84a39ed7050a1a 100644
--- a/MetaAugment/main.py
+++ b/MetaAugment/main.py
@@ -2,11 +2,9 @@ import numpy as np
 import torch
 torch.manual_seed(0)
 import torch.nn as nn
-import torch.nn.functional as F
 import torch.optim as optim
 import torchvision
 import torchvision.datasets as datasets
-import torchvision.transforms.autoaugment as autoaugment
 #import MetaAugment.AutoAugmentDemo.ops as ops # 
 
 # code from https://github.com/ChawDoe/LeNet5-MNIST-PyTorch/blob/master/train.py