diff --git a/MetaAugment/CP2_Max.py b/MetaAugment/CP2_Max.py
index c85e9fa95087839d259a18123c2da4d8bd77087f..04c8ffcfe871f3fdc5c570c522e92e81e52228aa 100644
--- a/MetaAugment/CP2_Max.py
+++ b/MetaAugment/CP2_Max.py
@@ -46,24 +46,6 @@ class Learner(nn.Module):
         self.p_bins = p_bins 
         self.m_bins = m_bins 
 
-        self.augmentation_space = [
-            # (function_name, do_we_need_to_specify_magnitude)
-            ("ShearX", True),
-            ("ShearY", True),
-            ("TranslateX", True),
-            ("TranslateY", True),
-            ("Rotate", True),
-            ("Brightness", True),
-            ("Color", True),
-            ("Contrast", True),
-            ("Sharpness", True),
-            ("Posterize", True),
-            ("Solarize", True),
-            ("AutoContrast", False),
-            ("Equalize", False),
-            ("Invert", False),
-        ]
-
         super().__init__()
         self.conv1 = nn.Conv2d(1, 6, 5)
         self.relu1 = nn.ReLU()
@@ -95,24 +77,6 @@ class Learner(nn.Module):
 
         return y
 
-    def get_idx(self, x):
-        section = self.fun_num + self.p_bins + self.m_bins
-        y = self.forward(x)
-        full_policy = []
-        for pol in range(5 * 2):
-            int_pol = []
-            idx_ret = torch.argmax(y[:, (pol * section):(pol*section) + self.fun_num].mean(dim = 0))
-
-            trans, need_mag = self.augmentation_space[idx_ret]
-
-            p_ret = 0.1 * torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0))
-            mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0)) if need_mag else 0
-            int_pol.append((trans, p_ret, mag))
-            if pol % 2 != 0:
-                full_policy.append(tuple(int_pol))
-
-        return full_policy
-
 
 class LeNet(nn.Module):
     def __init__(self):
@@ -204,7 +168,7 @@ train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=600
 
 class Evolutionary_learner():
 
-    def __init__(self, network, num_solutions = 30, num_generations = 10, num_parents_mating = 15, train_loader = None, sec_model = None, p_bins = 11, mag_bins = 10, fun_num = 14):
+    def __init__(self, network, num_solutions = 30, num_generations = 10, num_parents_mating = 15, train_loader = None, sec_model = None, p_bins = 11, mag_bins = 10, fun_num = 14, augmentation_space = None):
         self.meta_rl_agent = Learner(fun_num, p_bins=11, m_bins=10)
         self.torch_ga = torchga.TorchGA(model=network, num_solutions=num_solutions)
         self.num_generations = num_generations
@@ -215,6 +179,7 @@ class Evolutionary_learner():
         self.p_bins = p_bins 
         self.mag_bins = mag_bins
         self.fun_num = fun_num
+        self.augmentation_space = augmentation_space
 
         assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!'
 
@@ -222,6 +187,9 @@ class Evolutionary_learner():
     
 
     def generate_policy(self, sp_num, ps, mags):
+        """
+        
+        """
         policies = []
         for subpol in range(sp_num):
             sub = []
@@ -235,7 +203,33 @@ class Evolutionary_learner():
         return policies
 
 
+    def get_full_policy(self, x):
+        """
+        Generates the full policy (5 x 2 subpolicies)
+        """
+        section = self.meta_rl_agent.fun_num + self.meta_rl_agent.p_bins + self.meta_rl_agent.m_bins
+        y = self.meta_rl_agent.forward(x)
+        full_policy = []
+        for pol in range(5):
+            int_pol = []
+            for _ in range(2):
+                idx_ret = torch.argmax(y[:, (pol * section):(pol*section) + self.fun_num].mean(dim = 0))
+
+                trans, need_mag = self.augmentation_space[idx_ret]
+
+                p_ret = 0.1 * torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0))
+                mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0)) if need_mag else 0
+                int_pol.append((trans, p_ret, mag))
+
+            full_policy.append(tuple(int_pol))
+
+        return full_policy
+
+
     def run_instance(self, return_weights = False):
+        """
+        Runs the GA instance and returns the model weights as a dictionary
+        """
         self.ga_instance.run()
         solution, solution_fitness, solution_idx = self.ga_instance.best_solution()
         if return_weights:
@@ -245,6 +239,9 @@ class Evolutionary_learner():
 
 
     def new_model(self):
+        """
+        Simple function to create a copy of the secondary model (used for classification)
+        """
         copy_model = copy.deepcopy(self.sec_model)
         return copy_model
 
@@ -259,7 +256,7 @@ class Evolutionary_learner():
                                                             weights_vector=solution)
             self.meta_rl_agent.load_state_dict(model_weights_dict)
             for idx, (test_x, label_x) in enumerate(train_loader):
-                full_policy = self.meta_rl_agent.get_idx(test_x)
+                full_policy = self.meta_rl_agent.get_full_policy(test_x)
             cop_mod = self.new_model()
             fit_val = train_model(full_policy, cop_mod)
             cop_mod = 0