diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py
index 60d463c273c59f80753a65926a1ea6134e0c5ddb..6e7874e94a82c5d77740e219261d14ab8f33b4de 100644
--- a/MetaAugment/autoaugment_learners/aa_learner.py
+++ b/MetaAugment/autoaugment_learners/aa_learner.py
@@ -9,6 +9,8 @@ from MetaAugment.autoaugment_learners.autoaugment import AutoAugment
 import torchvision.transforms as transforms
 
 from pprint import pprint
+import matplotlib.pyplot as plt
+
 
 # We will use this augmentation_space temporarily. Later on we will need to 
 # make sure we are able to add other image functions if the users want.
@@ -50,6 +52,8 @@ class aa_learner:
         self.p_bins = p_bins
         self.m_bins = m_bins
 
+        self.op_tensor_length = fun_num+p_bins+m_bins if discrete_p_m else fun_num+2
+
         # should we repre
         self.discrete_p_m = discrete_p_m
 
@@ -57,7 +61,7 @@ class aa_learner:
         self.history = []
 
 
-    def translate_operation_tensor(self, operation_tensor, argmax=False):
+    def translate_operation_tensor(self, operation_tensor, return_log_prob=False, argmax=False):
         '''
         takes in a tensor representing an operation and returns an actual operation which
         is in the form of:
@@ -74,70 +78,100 @@ class aa_learner:
                                 - If self.discrete_p_m is False, we expect to take in a tensor with
                                 dimension (self.fun_num + 1 + 1)
 
+            return_log_prob (boolesn): 
+                                When this is on, we return which indices (of fun, prob, mag) were
+                                chosen (either randomly or deterministically, depending on argmax).
+                                This is used, for example, in the gru_learner to calculate the
+                                probability of the actions were chosen, which is then logged, then
+                                differentiated.
+
             argmax (boolean): 
                             Whether we are taking the argmax of the softmaxed tensors. 
                             If this is False, we treat the softmaxed outputs as multinomial pdf's.
+
+        Returns:
+            operation (list of tuples):
+                                An operation in the format that can be directly put into an
+                                AutoAugment object.
+            log_prob
+                                
         '''
+        if (not self.discrete_p_m) and return_log_prob:
+            raise ValueError("You are not supposed to use return_log_prob=True when the agent's \
+                            self.discrete_p_m is False!")
+
+        # make sure shape is correct
+        assert operation_tensor.shape==(self.op_tensor_length, ), operation_tensor.shape
+
         # if probability and magnitude are represented as discrete variables
         if self.discrete_p_m:
-            fun_t = operation_tensor[ : self.fun_num]
-            prob_t = operation_tensor[self.fun_num : self.fun_num+self.p_bins]
-            mag_t = operation_tensor[-self.m_bins : ]
+            fun_t, prob_t, mag_t = operation_tensor.split([self.fun_num, self.p_bins, self.m_bins])
 
             # make sure they are of right size
             assert fun_t.shape==(self.fun_num,), f'{fun_t.shape} != {self.fun_num}'
             assert prob_t.shape==(self.p_bins,), f'{prob_t.shape} != {self.p_bins}'
             assert mag_t.shape==(self.m_bins,), f'{mag_t.shape} != {self.m_bins}'
 
+
             if argmax==True:
-                fun = torch.argmax(fun_t)
-                prob = torch.argmax(prob_t) # 0 <= p <= 10
-                mag = torch.argmax(mag_t) # 0 <= m <= 9
+                fun_idx = torch.argmax(fun_t).item()
+                prob_idx = torch.argmax(prob_t).item() # 0 <= p <= 10
+                mag = torch.argmax(mag_t).item() # 0 <= m <= 9
             elif argmax==False:
                 # we need these to add up to 1 to be valid pdf's of multinomials
-                assert torch.sum(fun_t)==1
-                assert torch.sum(prob_t)==1
-                assert torch.sum(mag_t)==1
-                fun = torch.multinomial(fun_t, 1) # 0 <= fun <= self.fun_num-1
-                prob = torch.multinomial(prob_t, 1) # 0 <= p <= 10
-                mag = torch.multinomial(mag_t, 1) # 0 <= m <= 9
+                assert torch.sum(fun_t).isclose(torch.ones(1)), torch.sum(fun_t)
+                assert torch.sum(prob_t).isclose(torch.ones(1)), torch.sum(prob_t)
+                assert torch.sum(mag_t).isclose(torch.ones(1)), torch.sum(mag_t)
+
+                fun_idx = torch.multinomial(fun_t, 1).item() # 0 <= fun <= self.fun_num-1
+                prob_idx = torch.multinomial(prob_t, 1).item() # 0 <= p <= 10
+                mag = torch.multinomial(mag_t, 1).item() # 0 <= m <= 9
+
+            function = augmentation_space[fun_idx][0]
+            prob = prob_idx/10
 
-            function = augmentation_space[fun][0]
-            prob = prob/10
+            indices = (fun_idx, prob_idx, mag)
+
+            # log probability is the sum of the log of the softmax values of the indices 
+            # (of fun_t, prob_t, mag_t) that we have chosen
+            log_prob = torch.log(fun_t[fun_idx]) + torch.log(prob_t[prob_idx]) + torch.log(mag_t[mag])
 
 
         # if probability and magnitude are represented as continuous variables
         else:
-            fun_t = operation_tensor[:self.fun_num]
-            p = operation_tensor[-2].item() # 0 < p < 1
-            m = operation_tensor[-1].item() # 0 < m < 9
+            fun_t, prob, mag = operation_tensor.split([self.fun_num, 1, 1])
+            prob = prob.item()
+            # 0 =< prob =< 1
+            mag = mag.item()
+            # 0 =< mag =< 9
 
             # make sure the shape is correct
             assert fun_t.shape==(self.fun_num,), f'{fun_t.shape} != {self.fun_num}'
             
             if argmax==True:
-                fun = torch.argmax(fun_t)
+                fun_idx = torch.argmax(fun_t)
             elif argmax==False:
-                assert torch.sum(fun_t)==1
-                fun = torch.multinomial(fun_t, 1)
+                assert torch.sum(fun_t).isclose(torch.ones(1))
+                fun_idx = torch.multinomial(fun_t, 1).item()
+            prob = round(prob, 1) # round to nearest first decimal digit
+            mag = round(mag) # round to nearest integer
             
-            function = augmentation_space[fun][0]
-            prob = round(p, 1) # round to nearest first decimal digit
-            mag = round(m) # round to nearest integer
-            # If argmax is False, we treat operation_tensor as a concatenation of three
-            # multinomial pdf's.
+        function = augmentation_space[fun_idx][0]
 
         assert 0 <= prob <= 1
         assert 0 <= mag <= self.m_bins-1
         
         # if the image function does not require a magnitude, we set the magnitude to None
-        if augmentation_space[fun][1] == True: # if the image function has a magnitude
-            return (function, prob, mag)
+        if augmentation_space[fun_idx][1] == True: # if the image function has a magnitude
+            operation = (function, prob, mag)
         else:
-            return (function, prob, None)
-            
-
-
+            operation =  (function, prob, None)
+        
+        if return_log_prob:
+            return operation, log_prob
+        else:
+            return operation
+        
 
     def generate_new_policy(self):
         '''
@@ -175,7 +209,8 @@ class aa_learner:
             self.history.append((policy, reward))
     
 
-    def test_autoaugment_policy(self, policy, child_network, train_dataset, test_dataset, toy_flag):
+    def test_autoaugment_policy(self, policy, child_network, train_dataset, test_dataset, 
+                                toy_flag, logging=False):
         '''
         Given a policy (using AutoAugment paper terminology), we train a child network
         using the policy and return the accuracy (how good the policy is for the dataset and 
@@ -197,7 +232,7 @@ class aa_learner:
         # create Dataloader objects out of the Dataset objects
         train_loader, test_loader = create_toy(train_dataset,
                                                 test_dataset,
-                                                batch_size=32,
+                                                batch_size=64,
                                                 n_samples=0.01,
                                                 seed=100)
 
@@ -205,9 +240,12 @@ class aa_learner:
         accuracy = train_child_network(child_network, 
                                     train_loader, 
                                     test_loader, 
-                                    sgd = optim.SGD(child_network.parameters(), lr=1e-1),
+                                    sgd = optim.SGD(child_network.parameters(), lr=3e-1),
+                                    # sgd = optim.Adadelta(child_network.parameters(), lr=1e-2),
                                     cost = nn.CrossEntropyLoss(),
-                                    max_epochs = 100, 
-                                    early_stop_num = 15, 
-                                    logging = False)
+                                    max_epochs = 3000000, 
+                                    early_stop_num = 120, 
+                                    logging = logging)
+        
+        # if logging is true, 'accuracy' is actually a tuple: (accuracy, accuracy_log)
         return accuracy
\ No newline at end of file
diff --git a/MetaAugment/autoaugment_learners/gru_learner.py b/MetaAugment/autoaugment_learners/gru_learner.py
index 709fdabc5e3878727c6be77ce4279113cbc1afbc..377064a2c107c573bf8ae9a89630b22d8ee51d6c 100644
--- a/MetaAugment/autoaugment_learners/gru_learner.py
+++ b/MetaAugment/autoaugment_learners/gru_learner.py
@@ -5,6 +5,7 @@ from MetaAugment.autoaugment_learners.aa_learner import aa_learner
 from MetaAugment.controller_networks.rnn_controller import RNNModel
 
 from pprint import pprint
+import pickle
 
 
 
@@ -36,19 +37,23 @@ class gru_learner(aa_learner):
     # and
     # http://arxiv.org/abs/1611.01578
 
-    def __init__(self, sp_num=5, fun_num=14, p_bins=11, m_bins=10, discrete_p_m=False):
+    def __init__(self, sp_num=5, fun_num=14, p_bins=11, m_bins=10, discrete_p_m=True, alpha=0.2):
         '''
         Args:
             spdim: number of subpolicies per policy
             fun_num: number of image functions in our search space
             p_bins: number of bins we divide the interval [0,1] for probabilities
             m_bins: number of bins we divide the magnitude space
+
+            alpha: Exploration parameter. The lower this value, the more exploration.
         '''
-        super().__init__(sp_num, fun_num, p_bins, m_bins, discrete_p_m)
+        super().__init__(sp_num, fun_num, p_bins, m_bins, discrete_p_m=True)
+        self.alpha = alpha
 
-        # input_size of the RNNModel can be chosen arbitrarily as we don't put any inputs in it.
-        self.controller = RNNModel(mode='GRU', input_size=1, hidden_size=40, num_layers=1,
-                         bias=True, output_size=fun_num+p_bins+m_bins)
+        self.rnn_output_size = fun_num+p_bins+m_bins
+        self.controller = RNNModel(mode='GRU', output_size=self.rnn_output_size, 
+                                    num_layers=1, bias=True)
+        self.softmax = torch.nn.Softmax(dim=0)
 
 
     def generate_new_policy(self):
@@ -65,10 +70,44 @@ class gru_learner(aa_learner):
             (("ShearY", 0.5, 8), ("Invert", 0.7, None)),
             ]
         '''
-        new_policy = self.controller(input=torch.rand(1))
-
-
-    def learn(self, train_dataset, test_dataset, child_network_architecture, toy_flag):
+        log_prob = 0
+
+        # we need a random input to put in
+        random_input = torch.zeros(self.rnn_output_size, requires_grad=False)
+
+        # 2*self.sp_num because we need 2 operations for every subpolicy
+        vectors = self.controller(input=random_input, time_steps=2*self.sp_num)
+
+        # softmax the funcion vector, probability vector, and magnitude vector
+        # of each timestep
+        softmaxed_vectors = []
+        for vector in vectors:
+            fun_t, prob_t, mag_t = vector.split([self.fun_num, self.p_bins, self.m_bins])
+            fun_t = self.softmax(fun_t * self.alpha)
+            prob_t = self.softmax(prob_t * self.alpha)
+            mag_t = self.softmax(mag_t * self.alpha)
+            softmaxed_vector = torch.cat((fun_t, prob_t, mag_t))
+            softmaxed_vectors.append(softmaxed_vector)
+            
+        new_policy = []
+
+        for subpolicy_idx in range(self.sp_num):
+            # the vector corresponding to the first operation of this subpolicy
+            op1 = softmaxed_vectors[2*subpolicy_idx]
+            # the vector corresponding to the second operation of this subpolicy
+            op2 = softmaxed_vectors[2*subpolicy_idx+1]
+
+            # translate both vectors
+            op1, log_prob1 = self.translate_operation_tensor(op1, return_log_prob=True)
+            op2, log_prob2 = self.translate_operation_tensor(op2, return_log_prob=True)
+            
+            new_policy.append((op1,op2))
+            log_prob += (log_prob1+log_prob2)
+        
+        return new_policy, log_prob
+
+
+    def learn(self, train_dataset, test_dataset, child_network_architecture, toy_flag, m=8):
         '''
         Does the loop which is seen in Figure 1 in the AutoAugment paper.
         In other words, repeat:
@@ -76,16 +115,52 @@ class gru_learner(aa_learner):
             2. <see how good that policy is>
             3. <save how good the policy is in a list/dictionary>
         '''
-        # test out 15 random policies
-        for _ in range(15):
-            policy = self.generate_new_policy()
+        # optimizer for training the GRU controller
+        cont_optim = torch.optim.SGD(self.controller.parameters(), lr=1e-2)
+
+        m = 8 # minibatch size
+        b = 0.88 # b is the running exponential mean of the rewards, used for training stability
+               # (see section 3.2 of https://arxiv.org/abs/1611.01578)
+
+        for _ in range(1000):
+            cont_optim.zero_grad()
+
+            # obj(objective) is $ \sum_{k=1}^m (reward_k-b) \sum_{t=1}^T log(P(a_t|a_{(t-1):1};\theta_c))$,
+            # which is used in PPO
+            obj = 0
+
+            # sum up the rewards within a minibatch in order to update the running mean, 'b'
+            mb_rewards_sum = 0
 
-            pprint(policy)
-            child_network = child_network_architecture()
-            reward = self.test_autoaugment_policy(policy, child_network, train_dataset,
-                                                test_dataset, toy_flag)
+            for k in range(m):
+                # log_prob is $\sum_{t=1}^T log(P(a_t|a_{(t-1):1};\theta_c))$, used in PPO
+                policy, log_prob = self.generate_new_policy()
+
+                pprint(policy)
+                child_network = child_network_architecture()
+                reward = self.test_autoaugment_policy(policy, child_network, train_dataset,
+                                                    test_dataset, toy_flag)
+                mb_rewards_sum += reward
+
+                # log
+                self.history.append((policy, reward))
+
+                # gradient accumulation
+                obj += (reward-b)*log_prob
+            
+            # update running mean of rewards
+            b = 0.7*b + 0.3*(mb_rewards_sum/m)
+
+            (-obj).backward() # We put a minus because we want to maximize the objective, not 
+                              # minimize it.
+            cont_optim.step()
+
+            # save the history every 1 epochs as a pickle
+            if _%1==1:
+                with open('gru_logs.pkl', 'wb') as file:
+                    pickle.dump(self.history, file)
+            
 
-            self.history.append((policy, reward))
 
 
 if __name__=='__main__':
@@ -93,6 +168,10 @@ if __name__=='__main__':
     # We can initialize the train_dataset with its transform as None.
     # Later on, we will change this object's transform attribute to the policy
     # that we want to test
+    import torchvision.datasets as datasets
+    import torchvision
+    torch.manual_seed(0)
+
     train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=True, 
                                 transform=None)
     test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=True,
@@ -101,7 +180,6 @@ if __name__=='__main__':
 
     
     learner = gru_learner(discrete_p_m=False)
-    print(learner.generate_new_policy())
-    breakpoint()
+    newpol = learner.generate_new_policy()
     learner.learn(train_dataset, test_dataset, child_network, toy_flag=True)
     pprint(learner.history)
\ No newline at end of file
diff --git a/MetaAugment/autoaugment_learners/randomsearch_learner.py b/MetaAugment/autoaugment_learners/randomsearch_learner.py
index 02798927072fa7cd3d1a39959f229f1a57d41f00..a5e971c13ef4ef490ed7c5b413949ff00b6e7c00 100644
--- a/MetaAugment/autoaugment_learners/randomsearch_learner.py
+++ b/MetaAugment/autoaugment_learners/randomsearch_learner.py
@@ -5,6 +5,8 @@ import MetaAugment.child_networks as cn
 from MetaAugment.autoaugment_learners.aa_learner import aa_learner
 
 from pprint import pprint
+import matplotlib.pyplot as plt
+import pickle
 
 
 
@@ -84,7 +86,7 @@ class randomsearch_learner(aa_learner):
         fun_p_m[random_fun] = 1
 
         fun_p_m[-2] = np.random.uniform() # 0<prob<1
-        fun_p_m[-1] = np.random.uniform() * (self.m_bins-1) # 0<mag<9
+        fun_p_m[-1] = np.random.uniform() * (self.m_bins-0.0000001) - 0.4999999 # -0.5<mag<9.5
         
         return fun_p_m
 
@@ -129,7 +131,7 @@ class randomsearch_learner(aa_learner):
             3. <save how good the policy is in a list/dictionary>
         '''
         # test out 15 random policies
-        for _ in range(15):
+        for _ in range(1500):
             policy = self.generate_new_policy()
 
             pprint(policy)
@@ -139,19 +141,54 @@ class randomsearch_learner(aa_learner):
 
             self.history.append((policy, reward))
 
+            # save the history every 10 epochs as a pickle
+            if _%10==1:
+                with open('randomsearch_logs.pkl', 'wb') as file:
+                    pickle.dump(self.history, file)
+    
 
-if __name__=='__main__':
+    def demo_plot(self, train_dataset, test_dataset, child_network_architecture, toy_flag, n=50):
+        '''
+        I made this to plot a couple of accuracy graphs to help manually tune my gradient 
+        optimizer hyperparameters.
+        '''
+        acc_lists = []
 
+        # This is dummy code
+        # test out 15 random policies
+        for _ in range(n):
+            policy = self.generate_new_policy()
+
+            pprint(policy)
+            child_network = child_network_architecture()
+            reward, acc_list = self.test_autoaugment_policy(policy, child_network, train_dataset,
+                                                test_dataset, toy_flag, logging=True)
+
+            self.history.append((policy, reward))
+            acc_lists.append(acc_list)
+
+        for acc_list in acc_lists:
+            plt.plot(acc_list)
+        plt.title('I ran 50 random policies to see if there is any sign of \
+                    catastrophic failure during training')
+        plt.show()
+        plt.savefig('random_policies')
+
+
+if __name__=='__main__':
     # We can initialize the train_dataset with its transform as None.
     # Later on, we will change this object's transform attribute to the policy
     # that we want to test
+    import torchvision.datasets as datasets
+    import torchvision
+    
     train_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/train',
                                     train=True, download=True, transform=None)
     test_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/test', 
                             train=False, download=True, transform=torchvision.transforms.ToTensor())
     child_network = cn.lenet
 
-    
-    rs_learner = randomsearch_learner(discrete_p_m=False)
+    rs_learner = randomsearch_learner(discrete_p_m=True)
     rs_learner.learn(train_dataset, test_dataset, child_network, toy_flag=True)
+    # rs_learner.demo_plot(train_dataset, test_dataset, child_network, toy_flag=True)
     pprint(rs_learner.history)
\ No newline at end of file
diff --git a/MetaAugment/controller_networks/rnn_controller.py b/MetaAugment/controller_networks/rnn_controller.py
index 1e228fc183f12504173bbfc202a11d3ffeb054ed..12680eae88cbda7f93949f30ffd619ec65f46069 100644
--- a/MetaAugment/controller_networks/rnn_controller.py
+++ b/MetaAugment/controller_networks/rnn_controller.py
@@ -61,7 +61,7 @@ class GRUCell(nn.Module):
 
     def forward(self, input, hx=None):
         if hx is None:
-            hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
+            hx = input.new_zeros(self.hidden_size, requires_grad=False)
 
         z, r = torch.chunk(self.x2h(input) + self.h2h(hx), 2, -1)
         z = torch.sigmoid(z)
@@ -73,11 +73,11 @@ class GRUCell(nn.Module):
 
 
 class RNNModel(nn.Module):
-    def __init__(self, mode, input_size, hidden_size, num_layers, bias, output_size):
+    def __init__(self, mode, output_size, num_layers, bias):
         super(RNNModel, self).__init__()
         self.mode = mode
-        self.input_size = input_size
-        self.hidden_size = hidden_size
+        self.input_size = output_size
+        self.hidden_size = output_size
         self.num_layers = num_layers
         self.bias = bias
         self.output_size = output_size
@@ -113,17 +113,27 @@ class RNNModel(nn.Module):
         self.fc = nn.Linear(self.hidden_size, self.output_size)
 
         
-    def forward(self, input, hx=None):
+    def forward(self, input, time_steps=10, hx=None):
+        # The 'input' is the input x into the first timestep
+        # I think this should be a random vector
+        assert input.shape == (self.output_size, )
 
         outs = []
         h0 = [None] * self.num_layers if hx is None else list(hx)
     
-        X = list(input.permute(1, 0, 2))
-        for j, l in enumerate(self.rnn_cell_list):
-            hx = h0[j]
-            for i in range(input.shape[1]):
-                hx = l(X[i], hx)
-                X[i] = hx if self.mode != 'LSTM' else hx[0]
+
+        X = [None] * time_steps
+        X[0] = input # first input is 'input'
+        for layer_idx, layer_cell in enumerate(self.rnn_cell_list):
+            hx = h0[layer_idx]
+            for i in range(time_steps):
+                hx = layer_cell(X[i], hx)
+                
+                # we feed in this timestep's output into the next timestep's input
+                # except if we are at the last timestep
+                if i != time_steps-1:
+                    X[i+1] = hx if self.mode == 'GRU' else hx[0]
+                
         outs = X
     
 
@@ -191,7 +201,8 @@ class BidirRecurrentModel(nn.Module):
         
         
     def forward(self, input, hx=None):
-        
+        assert NotImplementedError('right now this forward function is written for classification. \
+                                You should modify it for our purpose, like the RNNModel was.')
         outs = []
         outs_rev = []
         
diff --git a/MetaAugment/main.py b/MetaAugment/main.py
index 5b0e04e47202272e25df81dabf84a39ed7050a1a..b39b4a21658c63d03e4dd6b1d251d4587546e633 100644
--- a/MetaAugment/main.py
+++ b/MetaAugment/main.py
@@ -1,6 +1,5 @@
 import numpy as np
 import torch
-torch.manual_seed(0)
 import torch.nn as nn
 import torch.optim as optim
 import torchvision
diff --git a/randomsearch_logs.pkl b/randomsearch_logs.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..b475be1198d4b25e8aa6f715e1f31d6945c210ff
Binary files /dev/null and b/randomsearch_logs.pkl differ