diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py
index 60d463c273c59f80753a65926a1ea6134e0c5ddb..8d1e643057732918cd694fae975bed2323418c47 100644
--- a/MetaAugment/autoaugment_learners/aa_learner.py
+++ b/MetaAugment/autoaugment_learners/aa_learner.py
@@ -50,6 +50,8 @@ class aa_learner:
         self.p_bins = p_bins
         self.m_bins = m_bins
 
+        self.op_tensor_length = fun_num+p_bins+m_bins if discrete_p_m else fun_num+2
+
         # should we repre
         self.discrete_p_m = discrete_p_m
 
@@ -78,11 +80,12 @@ class aa_learner:
                             Whether we are taking the argmax of the softmaxed tensors. 
                             If this is False, we treat the softmaxed outputs as multinomial pdf's.
         '''
+        # make sure shape is correct
+        assert operation_tensor.shape==(self.op_tensor_length, ), operation_tensor.shape
+
         # if probability and magnitude are represented as discrete variables
         if self.discrete_p_m:
-            fun_t = operation_tensor[ : self.fun_num]
-            prob_t = operation_tensor[self.fun_num : self.fun_num+self.p_bins]
-            mag_t = operation_tensor[-self.m_bins : ]
+            fun_t, prob_t, mag_t = operation_tensor.split([self.fun_num, self.p_bins, self.m_bins])
 
             # make sure they are of right size
             assert fun_t.shape==(self.fun_num,), f'{fun_t.shape} != {self.fun_num}'
@@ -95,9 +98,9 @@ class aa_learner:
                 mag = torch.argmax(mag_t) # 0 <= m <= 9
             elif argmax==False:
                 # we need these to add up to 1 to be valid pdf's of multinomials
-                assert torch.sum(fun_t)==1
-                assert torch.sum(prob_t)==1
-                assert torch.sum(mag_t)==1
+                assert torch.sum(fun_t).isclose(torch.ones(1)), torch.sum(fun_t)
+                assert torch.sum(prob_t).isclose(torch.ones(1)), torch.sum(prob_t)
+                assert torch.sum(mag_t).isclose(torch.ones(1)), torch.sum(mag_t)
                 fun = torch.multinomial(fun_t, 1) # 0 <= fun <= self.fun_num-1
                 prob = torch.multinomial(prob_t, 1) # 0 <= p <= 10
                 mag = torch.multinomial(mag_t, 1) # 0 <= m <= 9
@@ -108,7 +111,7 @@ class aa_learner:
 
         # if probability and magnitude are represented as continuous variables
         else:
-            fun_t = operation_tensor[:self.fun_num]
+            fun_t, p, m = operation_tensor.split([self.fun_num, 1, 1])
             p = operation_tensor[-2].item() # 0 < p < 1
             m = operation_tensor[-1].item() # 0 < m < 9
 
@@ -118,7 +121,7 @@ class aa_learner:
             if argmax==True:
                 fun = torch.argmax(fun_t)
             elif argmax==False:
-                assert torch.sum(fun_t)==1
+                assert torch.sum(fun_t).isclose(torch.ones(1))
                 fun = torch.multinomial(fun_t, 1)
             
             function = augmentation_space[fun][0]
diff --git a/MetaAugment/autoaugment_learners/gru_learner.py b/MetaAugment/autoaugment_learners/gru_learner.py
index 709fdabc5e3878727c6be77ce4279113cbc1afbc..f003f1f13b06de1c8f7b26de98f52bfffc17b635 100644
--- a/MetaAugment/autoaugment_learners/gru_learner.py
+++ b/MetaAugment/autoaugment_learners/gru_learner.py
@@ -36,7 +36,7 @@ class gru_learner(aa_learner):
     # and
     # http://arxiv.org/abs/1611.01578
 
-    def __init__(self, sp_num=5, fun_num=14, p_bins=11, m_bins=10, discrete_p_m=False):
+    def __init__(self, sp_num=5, fun_num=14, p_bins=11, m_bins=10, discrete_p_m=True):
         '''
         Args:
             spdim: number of subpolicies per policy
@@ -44,11 +44,12 @@ class gru_learner(aa_learner):
             p_bins: number of bins we divide the interval [0,1] for probabilities
             m_bins: number of bins we divide the magnitude space
         '''
-        super().__init__(sp_num, fun_num, p_bins, m_bins, discrete_p_m)
+        super().__init__(sp_num, fun_num, p_bins, m_bins, discrete_p_m=True)
 
-        # input_size of the RNNModel can be chosen arbitrarily as we don't put any inputs in it.
-        self.controller = RNNModel(mode='GRU', input_size=1, hidden_size=40, num_layers=1,
-                         bias=True, output_size=fun_num+p_bins+m_bins)
+        self.rnn_output_size = fun_num+p_bins+m_bins
+        self.controller = RNNModel(mode='GRU', output_size=self.rnn_output_size, 
+                                    num_layers=1, bias=True)
+        self.softmax = torch.nn.Softmax(dim=0)
 
 
     def generate_new_policy(self):
@@ -65,7 +66,41 @@ class gru_learner(aa_learner):
             (("ShearY", 0.5, 8), ("Invert", 0.7, None)),
             ]
         '''
-        new_policy = self.controller(input=torch.rand(1))
+        # we need a random input to put in
+        random_input = torch.rand(self.rnn_output_size, requires_grad=False)
+
+        # 2*self.sp_num because we need 2 operations for every subpolicy
+        vectors = self.controller(input=random_input, time_steps=2*self.sp_num)
+
+        # softmax the funcion vector, probability vector, and magnitude vector
+        # of each timestep
+        softmaxed_vectors = []
+        for vector in vectors:
+            print(vector)
+            fun_t, prob_t, mag_t = vector.split([self.fun_num, self.p_bins, self.m_bins])
+            fun_t = self.softmax(fun_t)
+            prob_t = self.softmax(prob_t)
+            mag_t = self.softmax(mag_t)
+            softmaxed_vector = torch.cat((fun_t, prob_t, mag_t))
+            softmaxed_vectors.append(softmaxed_vector)
+            
+        print(softmaxed_vectors)
+        new_policy = []
+
+        for subpolicy_idx in range(self.sp_num):
+            # the vector corresponding to the first operation of this subpolicy
+            op1 = softmaxed_vectors[2*subpolicy_idx]
+            # the vector corresponding to the second operation of this subpolicy
+            op2 = softmaxed_vectors[2*subpolicy_idx+1]
+
+            # translate both vectors
+            op1 = self.translate_operation_tensor(op1)
+            op2 = self.translate_operation_tensor(op2)
+            
+            print('new subpol:', (op1, op2))
+            new_policy.append((op1,op2))
+        
+        return new_policy
 
 
     def learn(self, train_dataset, test_dataset, child_network_architecture, toy_flag):
@@ -93,6 +128,10 @@ if __name__=='__main__':
     # We can initialize the train_dataset with its transform as None.
     # Later on, we will change this object's transform attribute to the policy
     # that we want to test
+    import torchvision.datasets as datasets
+    import torchvision
+    torch.manual_seed(0)
+
     train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=True, 
                                 transform=None)
     test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=True,
@@ -101,7 +140,6 @@ if __name__=='__main__':
 
     
     learner = gru_learner(discrete_p_m=False)
-    print(learner.generate_new_policy())
-    breakpoint()
+    newpol = learner.generate_new_policy()
     learner.learn(train_dataset, test_dataset, child_network, toy_flag=True)
     pprint(learner.history)
\ No newline at end of file
diff --git a/MetaAugment/autoaugment_learners/randomsearch_learner.py b/MetaAugment/autoaugment_learners/randomsearch_learner.py
index 02798927072fa7cd3d1a39959f229f1a57d41f00..e82f6aba18b14941f3066632f0343dd1df49f285 100644
--- a/MetaAugment/autoaugment_learners/randomsearch_learner.py
+++ b/MetaAugment/autoaugment_learners/randomsearch_learner.py
@@ -145,6 +145,9 @@ if __name__=='__main__':
     # We can initialize the train_dataset with its transform as None.
     # Later on, we will change this object's transform attribute to the policy
     # that we want to test
+    import torchvision.datasets as datasets
+    import torchvision
+    
     train_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/train',
                                     train=True, download=True, transform=None)
     test_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/test', 
diff --git a/MetaAugment/controller_networks/rnn_controller.py b/MetaAugment/controller_networks/rnn_controller.py
index 1e228fc183f12504173bbfc202a11d3ffeb054ed..12680eae88cbda7f93949f30ffd619ec65f46069 100644
--- a/MetaAugment/controller_networks/rnn_controller.py
+++ b/MetaAugment/controller_networks/rnn_controller.py
@@ -61,7 +61,7 @@ class GRUCell(nn.Module):
 
     def forward(self, input, hx=None):
         if hx is None:
-            hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
+            hx = input.new_zeros(self.hidden_size, requires_grad=False)
 
         z, r = torch.chunk(self.x2h(input) + self.h2h(hx), 2, -1)
         z = torch.sigmoid(z)
@@ -73,11 +73,11 @@ class GRUCell(nn.Module):
 
 
 class RNNModel(nn.Module):
-    def __init__(self, mode, input_size, hidden_size, num_layers, bias, output_size):
+    def __init__(self, mode, output_size, num_layers, bias):
         super(RNNModel, self).__init__()
         self.mode = mode
-        self.input_size = input_size
-        self.hidden_size = hidden_size
+        self.input_size = output_size
+        self.hidden_size = output_size
         self.num_layers = num_layers
         self.bias = bias
         self.output_size = output_size
@@ -113,17 +113,27 @@ class RNNModel(nn.Module):
         self.fc = nn.Linear(self.hidden_size, self.output_size)
 
         
-    def forward(self, input, hx=None):
+    def forward(self, input, time_steps=10, hx=None):
+        # The 'input' is the input x into the first timestep
+        # I think this should be a random vector
+        assert input.shape == (self.output_size, )
 
         outs = []
         h0 = [None] * self.num_layers if hx is None else list(hx)
     
-        X = list(input.permute(1, 0, 2))
-        for j, l in enumerate(self.rnn_cell_list):
-            hx = h0[j]
-            for i in range(input.shape[1]):
-                hx = l(X[i], hx)
-                X[i] = hx if self.mode != 'LSTM' else hx[0]
+
+        X = [None] * time_steps
+        X[0] = input # first input is 'input'
+        for layer_idx, layer_cell in enumerate(self.rnn_cell_list):
+            hx = h0[layer_idx]
+            for i in range(time_steps):
+                hx = layer_cell(X[i], hx)
+                
+                # we feed in this timestep's output into the next timestep's input
+                # except if we are at the last timestep
+                if i != time_steps-1:
+                    X[i+1] = hx if self.mode == 'GRU' else hx[0]
+                
         outs = X
     
 
@@ -191,7 +201,8 @@ class BidirRecurrentModel(nn.Module):
         
         
     def forward(self, input, hx=None):
-        
+        assert NotImplementedError('right now this forward function is written for classification. \
+                                You should modify it for our purpose, like the RNNModel was.')
         outs = []
         outs_rev = []