diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py
index 3fb9da5f2c6782d98224784f97cef7a6a05fa5a0..4b7cd86ab04508bf77f69a92e5615f9faaeebc03 100644
--- a/MetaAugment/autoaugment_learners/aa_learner.py
+++ b/MetaAugment/autoaugment_learners/aa_learner.py
@@ -59,7 +59,7 @@ class aa_learner:
         self.history = []
 
 
-    def translate_operation_tensor(self, operation_tensor):
+    def translate_operation_tensor(self, operation_tensor, argmax=False):
         '''
         takes in a tensor representing an operation and returns an actual operation which
         is in the form of:
@@ -69,12 +69,16 @@ class aa_learner:
 
         Args:
             operation_tensor (tensor): 
-                                - If discrete_p_m is True, we expect to take in a tensor with
+                                We expect this tensor to already have been softmaxed.
+                                Furthermore,
+                                - If self.discrete_p_m is True, we expect to take in a tensor with
                                 dimension (self.fun_num + self.p_bins + self.m_bins)
-                                - If discrete_p_m is False, we expect to take in a tensor with
+                                - If self.discrete_p_m is False, we expect to take in a tensor with
                                 dimension (self.fun_num + 1 + 1)
-            continuous_p_m (boolean): whether the operation_tensor has continuous representations
-                                    of probability and magnitude
+
+            argmax (boolean): 
+                            Whether we are taking the argmax of the softmaxed tensors. 
+                            If this is False, we treat the softmaxed outputs as multinomial pdf's.
         '''
         # if probability and magnitude are represented as discrete variables
         if self.discrete_p_m:
@@ -82,9 +86,23 @@ class aa_learner:
             prob_t = operation_tensor[self.fun_num : self.fun_num+self.p_bins]
             mag_t = operation_tensor[-self.m_bins : ]
 
-            fun = torch.argmax(fun_t)
-            prob = torch.argmax(prob_t) # 0 <= p <= 10
-            mag = torch.argmax(mag_t) # 0 <= m <= 9
+            # make sure they are of right size
+            assert fun_t.shape==(self.fun_num,), f'{fun_t.shape} != {self.fun_num}'
+            assert prob_t.shape==(self.p_bins,), f'{prob_t.shape} != {self.p_bins}'
+            assert mag_t.shape==(self.m_bins,), f'{mag_t.shape} != {self.m_bins}'
+
+            if argmax==True:
+                fun = torch.argmax(fun_t)
+                prob = torch.argmax(prob_t) # 0 <= p <= 10
+                mag = torch.argmax(mag_t) # 0 <= m <= 9
+            elif argmax==False:
+                # we need these to add up to 1 to be valid pdf's of multinomials
+                assert torch.sum(fun_t)==1
+                assert torch.sum(prob_t)==1
+                assert torch.sum(mag_t)==1
+                fun = torch.multinomial(fun_t, 1) # 0 <= fun <= self.fun_num-1
+                prob = torch.multinomial(prob_t, 1) # 0 <= p <= 10
+                mag = torch.multinomial(mag_t, 1) # 0 <= m <= 9
 
             function = augmentation_space[fun][0]
             prob = prob/10
@@ -96,17 +114,31 @@ class aa_learner:
             p = operation_tensor[-2].item() # 0 < p < 1
             m = operation_tensor[-1].item() # 0 < m < 9
 
-            fun = torch.argmax(fun_t)
-
+            # make sure the shape is correct
+            assert fun_t.shape==(self.fun_num,), f'{fun_t.shape} != {self.fun_num}'
+            
+            if argmax==True:
+                fun = torch.argmax(fun_t)
+            elif argmax==False:
+                assert torch.sum(fun_t)==1
+                fun = torch.multinomial(fun_t, 1)
+            
             function = augmentation_space[fun][0]
             prob = round(p, 1) # round to nearest first decimal digit
             mag = round(m) # round to nearest integer
+            # If argmax is False, we treat operation_tensor as a concatenation of three
+            # multinomial pdf's.
 
+        assert 0 <= prob <= 1
+        assert 0 <= mag <= self.m_bins-1
+        
         # if the image function does not require a magnitude, we set the magnitude to None
-        if augmentation_space[fun][0] == True: # if the image function has a magnitude
+        if augmentation_space[fun][1] == True: # if the image function has a magnitude
             return (function, prob, mag)
         else:
             return (function, prob, None)
+            
+
 
 
     def generate_new_policy(self):
diff --git a/MetaAugment/autoaugment_learners/randomsearch_learner.py b/MetaAugment/autoaugment_learners/randomsearch_learner.py
index da76cf0b16bdd829e19fbbd649b37d5577df586a..3657224f7bb538c0c436c3f12ff451ce1ccc2c1b 100644
--- a/MetaAugment/autoaugment_learners/randomsearch_learner.py
+++ b/MetaAugment/autoaugment_learners/randomsearch_learner.py
@@ -63,11 +63,11 @@ class randomsearch_learner(aa_learner):
         random_mag = np.random.randint(0, self.m_bins)
         
         fun_t= torch.zeros(self.fun_num)
-        fun_t[random_fun] = 1
+        fun_t[random_fun] = 1.0
         prob_t = torch.zeros(self.p_bins)
-        prob_t[random_prob] = 1
+        prob_t[random_prob] = 1.0
         mag_t = torch.zeros(self.m_bins)
-        mag_t[random_mag] = 1
+        mag_t[random_mag] = 1.0
 
         return torch.cat([fun_t, prob_t, mag_t])
 
@@ -152,10 +152,10 @@ if __name__=='__main__':
     # We can initialize the train_dataset with its transform as None.
     # Later on, we will change this object's transform attribute to the policy
     # that we want to test
-    train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, 
-                                transform=None)
-    test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False,
-                                transform=torchvision.transforms.ToTensor())
+    train_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/train',
+                                    train=True, download=True, transform=None)
+    test_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/test', 
+                            train=False, download=True, transform=torchvision.transforms.ToTensor())
     child_network = cn.lenet