diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py
index 3a6b3e4c5d11f1cddaac0e1d18b372ef45f14584..6e7874e94a82c5d77740e219261d14ab8f33b4de 100644
--- a/MetaAugment/autoaugment_learners/aa_learner.py
+++ b/MetaAugment/autoaugment_learners/aa_learner.py
@@ -9,6 +9,8 @@ from MetaAugment.autoaugment_learners.autoaugment import AutoAugment
 import torchvision.transforms as transforms
 
 from pprint import pprint
+import matplotlib.pyplot as plt
+
 
 # We will use this augmentation_space temporarily. Later on we will need to 
 # make sure we are able to add other image functions if the users want.
@@ -59,7 +61,7 @@ class aa_learner:
         self.history = []
 
 
-    def translate_operation_tensor(self, operation_tensor, argmax=False):
+    def translate_operation_tensor(self, operation_tensor, return_log_prob=False, argmax=False):
         '''
         takes in a tensor representing an operation and returns an actual operation which
         is in the form of:
@@ -76,10 +78,28 @@ class aa_learner:
                                 - If self.discrete_p_m is False, we expect to take in a tensor with
                                 dimension (self.fun_num + 1 + 1)
 
+            return_log_prob (boolesn): 
+                                When this is on, we return which indices (of fun, prob, mag) were
+                                chosen (either randomly or deterministically, depending on argmax).
+                                This is used, for example, in the gru_learner to calculate the
+                                probability of the actions were chosen, which is then logged, then
+                                differentiated.
+
             argmax (boolean): 
                             Whether we are taking the argmax of the softmaxed tensors. 
                             If this is False, we treat the softmaxed outputs as multinomial pdf's.
+
+        Returns:
+            operation (list of tuples):
+                                An operation in the format that can be directly put into an
+                                AutoAugment object.
+            log_prob
+                                
         '''
+        if (not self.discrete_p_m) and return_log_prob:
+            raise ValueError("You are not supposed to use return_log_prob=True when the agent's \
+                            self.discrete_p_m is False!")
+
         # make sure shape is correct
         assert operation_tensor.shape==(self.op_tensor_length, ), operation_tensor.shape
 
@@ -92,53 +112,66 @@ class aa_learner:
             assert prob_t.shape==(self.p_bins,), f'{prob_t.shape} != {self.p_bins}'
             assert mag_t.shape==(self.m_bins,), f'{mag_t.shape} != {self.m_bins}'
 
+
             if argmax==True:
-                fun = torch.argmax(fun_t).item()
-                prob = torch.argmax(prob_t).item() # 0 <= p <= 10
+                fun_idx = torch.argmax(fun_t).item()
+                prob_idx = torch.argmax(prob_t).item() # 0 <= p <= 10
                 mag = torch.argmax(mag_t).item() # 0 <= m <= 9
             elif argmax==False:
                 # we need these to add up to 1 to be valid pdf's of multinomials
                 assert torch.sum(fun_t).isclose(torch.ones(1)), torch.sum(fun_t)
                 assert torch.sum(prob_t).isclose(torch.ones(1)), torch.sum(prob_t)
                 assert torch.sum(mag_t).isclose(torch.ones(1)), torch.sum(mag_t)
-                fun = torch.multinomial(fun_t, 1).item() # 0 <= fun <= self.fun_num-1
-                prob = torch.multinomial(prob_t, 1).item() # 0 <= p <= 10
+
+                fun_idx = torch.multinomial(fun_t, 1).item() # 0 <= fun <= self.fun_num-1
+                prob_idx = torch.multinomial(prob_t, 1).item() # 0 <= p <= 10
                 mag = torch.multinomial(mag_t, 1).item() # 0 <= m <= 9
 
-            function = augmentation_space[fun][0]
-            prob = prob/10
+            function = augmentation_space[fun_idx][0]
+            prob = prob_idx/10
+
+            indices = (fun_idx, prob_idx, mag)
+
+            # log probability is the sum of the log of the softmax values of the indices 
+            # (of fun_t, prob_t, mag_t) that we have chosen
+            log_prob = torch.log(fun_t[fun_idx]) + torch.log(prob_t[prob_idx]) + torch.log(mag_t[mag])
 
 
         # if probability and magnitude are represented as continuous variables
         else:
             fun_t, prob, mag = operation_tensor.split([self.fun_num, 1, 1])
+            prob = prob.item()
             # 0 =< prob =< 1
+            mag = mag.item()
             # 0 =< mag =< 9
 
             # make sure the shape is correct
             assert fun_t.shape==(self.fun_num,), f'{fun_t.shape} != {self.fun_num}'
             
             if argmax==True:
-                fun = torch.argmax(fun_t)
+                fun_idx = torch.argmax(fun_t)
             elif argmax==False:
                 assert torch.sum(fun_t).isclose(torch.ones(1))
-                fun = torch.multinomial(fun_t, 1)
+                fun_idx = torch.multinomial(fun_t, 1).item()
+            prob = round(prob, 1) # round to nearest first decimal digit
+            mag = round(mag) # round to nearest integer
             
-        function = augmentation_space[fun][0]
-        prob = round(prob, 1) # round to nearest first decimal digit
-        mag = round(mag) # round to nearest integer
+        function = augmentation_space[fun_idx][0]
 
         assert 0 <= prob <= 1
         assert 0 <= mag <= self.m_bins-1
         
         # if the image function does not require a magnitude, we set the magnitude to None
-        if augmentation_space[fun][1] == True: # if the image function has a magnitude
-            return (function, prob, mag)
+        if augmentation_space[fun_idx][1] == True: # if the image function has a magnitude
+            operation = (function, prob, mag)
         else:
-            return (function, prob, None)
-            
-
-
+            operation =  (function, prob, None)
+        
+        if return_log_prob:
+            return operation, log_prob
+        else:
+            return operation
+        
 
     def generate_new_policy(self):
         '''
@@ -176,7 +209,8 @@ class aa_learner:
             self.history.append((policy, reward))
     
 
-    def test_autoaugment_policy(self, policy, child_network, train_dataset, test_dataset, toy_flag):
+    def test_autoaugment_policy(self, policy, child_network, train_dataset, test_dataset, 
+                                toy_flag, logging=False):
         '''
         Given a policy (using AutoAugment paper terminology), we train a child network
         using the policy and return the accuracy (how good the policy is for the dataset and 
@@ -198,7 +232,7 @@ class aa_learner:
         # create Dataloader objects out of the Dataset objects
         train_loader, test_loader = create_toy(train_dataset,
                                                 test_dataset,
-                                                batch_size=32,
+                                                batch_size=64,
                                                 n_samples=0.01,
                                                 seed=100)
 
@@ -206,9 +240,12 @@ class aa_learner:
         accuracy = train_child_network(child_network, 
                                     train_loader, 
                                     test_loader, 
-                                    sgd = optim.SGD(child_network.parameters(), lr=1e-1),
+                                    sgd = optim.SGD(child_network.parameters(), lr=3e-1),
+                                    # sgd = optim.Adadelta(child_network.parameters(), lr=1e-2),
                                     cost = nn.CrossEntropyLoss(),
-                                    max_epochs = 100, 
-                                    early_stop_num = 15, 
-                                    logging = False)
+                                    max_epochs = 3000000, 
+                                    early_stop_num = 120, 
+                                    logging = logging)
+        
+        # if logging is true, 'accuracy' is actually a tuple: (accuracy, accuracy_log)
         return accuracy
\ No newline at end of file
diff --git a/MetaAugment/autoaugment_learners/gru_learner.py b/MetaAugment/autoaugment_learners/gru_learner.py
index f003f1f13b06de1c8f7b26de98f52bfffc17b635..377064a2c107c573bf8ae9a89630b22d8ee51d6c 100644
--- a/MetaAugment/autoaugment_learners/gru_learner.py
+++ b/MetaAugment/autoaugment_learners/gru_learner.py
@@ -5,6 +5,7 @@ from MetaAugment.autoaugment_learners.aa_learner import aa_learner
 from MetaAugment.controller_networks.rnn_controller import RNNModel
 
 from pprint import pprint
+import pickle
 
 
 
@@ -36,15 +37,18 @@ class gru_learner(aa_learner):
     # and
     # http://arxiv.org/abs/1611.01578
 
-    def __init__(self, sp_num=5, fun_num=14, p_bins=11, m_bins=10, discrete_p_m=True):
+    def __init__(self, sp_num=5, fun_num=14, p_bins=11, m_bins=10, discrete_p_m=True, alpha=0.2):
         '''
         Args:
             spdim: number of subpolicies per policy
             fun_num: number of image functions in our search space
             p_bins: number of bins we divide the interval [0,1] for probabilities
             m_bins: number of bins we divide the magnitude space
+
+            alpha: Exploration parameter. The lower this value, the more exploration.
         '''
         super().__init__(sp_num, fun_num, p_bins, m_bins, discrete_p_m=True)
+        self.alpha = alpha
 
         self.rnn_output_size = fun_num+p_bins+m_bins
         self.controller = RNNModel(mode='GRU', output_size=self.rnn_output_size, 
@@ -66,8 +70,10 @@ class gru_learner(aa_learner):
             (("ShearY", 0.5, 8), ("Invert", 0.7, None)),
             ]
         '''
+        log_prob = 0
+
         # we need a random input to put in
-        random_input = torch.rand(self.rnn_output_size, requires_grad=False)
+        random_input = torch.zeros(self.rnn_output_size, requires_grad=False)
 
         # 2*self.sp_num because we need 2 operations for every subpolicy
         vectors = self.controller(input=random_input, time_steps=2*self.sp_num)
@@ -76,15 +82,13 @@ class gru_learner(aa_learner):
         # of each timestep
         softmaxed_vectors = []
         for vector in vectors:
-            print(vector)
             fun_t, prob_t, mag_t = vector.split([self.fun_num, self.p_bins, self.m_bins])
-            fun_t = self.softmax(fun_t)
-            prob_t = self.softmax(prob_t)
-            mag_t = self.softmax(mag_t)
+            fun_t = self.softmax(fun_t * self.alpha)
+            prob_t = self.softmax(prob_t * self.alpha)
+            mag_t = self.softmax(mag_t * self.alpha)
             softmaxed_vector = torch.cat((fun_t, prob_t, mag_t))
             softmaxed_vectors.append(softmaxed_vector)
             
-        print(softmaxed_vectors)
         new_policy = []
 
         for subpolicy_idx in range(self.sp_num):
@@ -94,16 +98,16 @@ class gru_learner(aa_learner):
             op2 = softmaxed_vectors[2*subpolicy_idx+1]
 
             # translate both vectors
-            op1 = self.translate_operation_tensor(op1)
-            op2 = self.translate_operation_tensor(op2)
+            op1, log_prob1 = self.translate_operation_tensor(op1, return_log_prob=True)
+            op2, log_prob2 = self.translate_operation_tensor(op2, return_log_prob=True)
             
-            print('new subpol:', (op1, op2))
             new_policy.append((op1,op2))
+            log_prob += (log_prob1+log_prob2)
         
-        return new_policy
+        return new_policy, log_prob
 
 
-    def learn(self, train_dataset, test_dataset, child_network_architecture, toy_flag):
+    def learn(self, train_dataset, test_dataset, child_network_architecture, toy_flag, m=8):
         '''
         Does the loop which is seen in Figure 1 in the AutoAugment paper.
         In other words, repeat:
@@ -111,16 +115,52 @@ class gru_learner(aa_learner):
             2. <see how good that policy is>
             3. <save how good the policy is in a list/dictionary>
         '''
-        # test out 15 random policies
-        for _ in range(15):
-            policy = self.generate_new_policy()
+        # optimizer for training the GRU controller
+        cont_optim = torch.optim.SGD(self.controller.parameters(), lr=1e-2)
+
+        m = 8 # minibatch size
+        b = 0.88 # b is the running exponential mean of the rewards, used for training stability
+               # (see section 3.2 of https://arxiv.org/abs/1611.01578)
+
+        for _ in range(1000):
+            cont_optim.zero_grad()
+
+            # obj(objective) is $ \sum_{k=1}^m (reward_k-b) \sum_{t=1}^T log(P(a_t|a_{(t-1):1};\theta_c))$,
+            # which is used in PPO
+            obj = 0
+
+            # sum up the rewards within a minibatch in order to update the running mean, 'b'
+            mb_rewards_sum = 0
+
+            for k in range(m):
+                # log_prob is $\sum_{t=1}^T log(P(a_t|a_{(t-1):1};\theta_c))$, used in PPO
+                policy, log_prob = self.generate_new_policy()
 
-            pprint(policy)
-            child_network = child_network_architecture()
-            reward = self.test_autoaugment_policy(policy, child_network, train_dataset,
-                                                test_dataset, toy_flag)
+                pprint(policy)
+                child_network = child_network_architecture()
+                reward = self.test_autoaugment_policy(policy, child_network, train_dataset,
+                                                    test_dataset, toy_flag)
+                mb_rewards_sum += reward
+
+                # log
+                self.history.append((policy, reward))
+
+                # gradient accumulation
+                obj += (reward-b)*log_prob
+            
+            # update running mean of rewards
+            b = 0.7*b + 0.3*(mb_rewards_sum/m)
+
+            (-obj).backward() # We put a minus because we want to maximize the objective, not 
+                              # minimize it.
+            cont_optim.step()
+
+            # save the history every 1 epochs as a pickle
+            if _%1==1:
+                with open('gru_logs.pkl', 'wb') as file:
+                    pickle.dump(self.history, file)
+            
 
-            self.history.append((policy, reward))
 
 
 if __name__=='__main__':
diff --git a/MetaAugment/autoaugment_learners/randomsearch_learner.py b/MetaAugment/autoaugment_learners/randomsearch_learner.py
index e82f6aba18b14941f3066632f0343dd1df49f285..a5e971c13ef4ef490ed7c5b413949ff00b6e7c00 100644
--- a/MetaAugment/autoaugment_learners/randomsearch_learner.py
+++ b/MetaAugment/autoaugment_learners/randomsearch_learner.py
@@ -5,6 +5,8 @@ import MetaAugment.child_networks as cn
 from MetaAugment.autoaugment_learners.aa_learner import aa_learner
 
 from pprint import pprint
+import matplotlib.pyplot as plt
+import pickle
 
 
 
@@ -84,7 +86,7 @@ class randomsearch_learner(aa_learner):
         fun_p_m[random_fun] = 1
 
         fun_p_m[-2] = np.random.uniform() # 0<prob<1
-        fun_p_m[-1] = np.random.uniform() * (self.m_bins-1) # 0<mag<9
+        fun_p_m[-1] = np.random.uniform() * (self.m_bins-0.0000001) - 0.4999999 # -0.5<mag<9.5
         
         return fun_p_m
 
@@ -129,7 +131,7 @@ class randomsearch_learner(aa_learner):
             3. <save how good the policy is in a list/dictionary>
         '''
         # test out 15 random policies
-        for _ in range(15):
+        for _ in range(1500):
             policy = self.generate_new_policy()
 
             pprint(policy)
@@ -139,9 +141,41 @@ class randomsearch_learner(aa_learner):
 
             self.history.append((policy, reward))
 
+            # save the history every 10 epochs as a pickle
+            if _%10==1:
+                with open('randomsearch_logs.pkl', 'wb') as file:
+                    pickle.dump(self.history, file)
+    
 
-if __name__=='__main__':
+    def demo_plot(self, train_dataset, test_dataset, child_network_architecture, toy_flag, n=50):
+        '''
+        I made this to plot a couple of accuracy graphs to help manually tune my gradient 
+        optimizer hyperparameters.
+        '''
+        acc_lists = []
+
+        # This is dummy code
+        # test out 15 random policies
+        for _ in range(n):
+            policy = self.generate_new_policy()
+
+            pprint(policy)
+            child_network = child_network_architecture()
+            reward, acc_list = self.test_autoaugment_policy(policy, child_network, train_dataset,
+                                                test_dataset, toy_flag, logging=True)
+
+            self.history.append((policy, reward))
+            acc_lists.append(acc_list)
 
+        for acc_list in acc_lists:
+            plt.plot(acc_list)
+        plt.title('I ran 50 random policies to see if there is any sign of \
+                    catastrophic failure during training')
+        plt.show()
+        plt.savefig('random_policies')
+
+
+if __name__=='__main__':
     # We can initialize the train_dataset with its transform as None.
     # Later on, we will change this object's transform attribute to the policy
     # that we want to test
@@ -154,7 +188,7 @@ if __name__=='__main__':
                             train=False, download=True, transform=torchvision.transforms.ToTensor())
     child_network = cn.lenet
 
-    
-    rs_learner = randomsearch_learner(discrete_p_m=False)
+    rs_learner = randomsearch_learner(discrete_p_m=True)
     rs_learner.learn(train_dataset, test_dataset, child_network, toy_flag=True)
+    # rs_learner.demo_plot(train_dataset, test_dataset, child_network, toy_flag=True)
     pprint(rs_learner.history)
\ No newline at end of file
diff --git a/MetaAugment/main.py b/MetaAugment/main.py
index 5b0e04e47202272e25df81dabf84a39ed7050a1a..b39b4a21658c63d03e4dd6b1d251d4587546e633 100644
--- a/MetaAugment/main.py
+++ b/MetaAugment/main.py
@@ -1,6 +1,5 @@
 import numpy as np
 import torch
-torch.manual_seed(0)
 import torch.nn as nn
 import torch.optim as optim
 import torchvision
diff --git a/randomsearch_logs.pkl b/randomsearch_logs.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..b475be1198d4b25e8aa6f715e1f31d6945c210ff
Binary files /dev/null and b/randomsearch_logs.pkl differ