diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py
index 0eb38d59c1e3156dbc7a620ee1958fb7d1d032bb..e4460cbfdca799022cd2f1d1ff950cd780355fa1 100644
--- a/MetaAugment/autoaugment_learners/aa_learner.py
+++ b/MetaAugment/autoaugment_learners/aa_learner.py
@@ -46,7 +46,6 @@ class aa_learner:
     def __init__(self, 
                 # parameters that define the search space
                 sp_num=5,
-                fun_num=14,
                 p_bins=11,
                 m_bins=10,
                 discrete_p_m=False,
@@ -57,6 +56,7 @@ class aa_learner:
                 learning_rate=1e-1,
                 max_epochs=float('inf'),
                 early_stop_num=20,
+                exclude_method = [],
                 ):
         """
         Args:
@@ -84,11 +84,9 @@ class aa_learner:
         """
         # related to defining the search space
         self.sp_num = sp_num
-        self.fun_num = fun_num
         self.p_bins = p_bins
         self.m_bins = m_bins
         self.discrete_p_m = discrete_p_m
-        self.op_tensor_length = fun_num+p_bins+m_bins if discrete_p_m else fun_num+2
 
         # related to training of the child_network
         self.batch_size = batch_size
@@ -101,6 +99,9 @@ class aa_learner:
 
         # TODO: We should probably use a different way to store results than self.history
         self.history = []
+        self.augmentation_space = [x for x in augmentation_space if x not in exclude_method]
+        self.fun_num = len(augmentation_space)
+        self.op_tensor_length = self.fun_num + p_bins + m_bins if discrete_p_m else self.fun_num +2
 
 
     def translate_operation_tensor(self, operation_tensor, return_log_prob=False, argmax=False):
diff --git a/MetaAugment/autoaugment_learners/autoaugment.py b/MetaAugment/autoaugment_learners/autoaugment.py
index 8e10c74547f7230a0eeecf11356804413721f7c1..5a8ecbcf6f0b8c6212a8c034a70d61476f4870f6 100644
--- a/MetaAugment/autoaugment_learners/autoaugment.py
+++ b/MetaAugment/autoaugment_learners/autoaugment.py
@@ -238,6 +238,8 @@ class AutoAugment(torch.nn.Module):
             if probs[i] <= p:
                 op_meta = self._augmentation_space(10, F.get_image_size(img))
                 magnitudes, signed = op_meta[op_name]
+                print("magnitude_id: ", magnitude_id)
+                print("magnitudes[magnitude_id]: ", magnitudes[magnitude_id])
                 magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0
                 if signed and signs[i] == 0:
                     magnitude *= -1.0
diff --git a/MetaAugment/autoaugment_learners/evo_learner.py b/MetaAugment/autoaugment_learners/evo_learner.py
index 18ecf751e614585c7db86902eb3cce927dd696f5..8e1d5bc198548c3e24bb3d2bd5ac2d1f39650923 100644
--- a/MetaAugment/autoaugment_learners/evo_learner.py
+++ b/MetaAugment/autoaugment_learners/evo_learner.py
@@ -6,34 +6,32 @@ import pygad
 import pygad.torchga as torchga
 import copy
 import torch
-from MetaAugment.controller_networks.evo_controller import evo_controller
+from MetaAugment.controller_networks.evo_controller import Evo_learner
+
+from MetaAugment.autoaugment_learners.aa_learner import aa_learner, augmentation_space
 import MetaAugment.child_networks as cn
-from .aa_learner import aa_learner, augmentation_space
 
 
 class evo_learner(aa_learner):
 
     def __init__(self, 
                 sp_num=1,
-                num_solutions = 10, 
-                num_parents_mating = 5,
+                num_solutions = 5, 
+                num_parents_mating = 3,
                 learning_rate = 1e-1, 
                 max_epochs=float('inf'),
                 early_stop_num=20,
-                train_loader = None, 
-                child_network = None, 
                 p_bins = 1, 
                 m_bins = 1, 
                 discrete_p_m=False,
                 batch_size=8,
                 toy_flag=False,
                 toy_size=0.1,
-                fun_num = 14,
                 exclude_method=[],
+                controller = None
                 ):
 
         super().__init__(sp_num, 
-            fun_num, 
             p_bins, 
             m_bins, 
             discrete_p_m=discrete_p_m, 
@@ -42,27 +40,24 @@ class evo_learner(aa_learner):
             toy_size=toy_size, 
             learning_rate=learning_rate,
             max_epochs=max_epochs,
-            early_stop_num=early_stop_num,)
+            early_stop_num=early_stop_num,
+            exclude_method=exclude_method)
 
         self.num_solutions = num_solutions
-        self.auto_aug_agent = evo_controller(fun_num=fun_num, p_bins=p_bins, m_bins=m_bins, sub_num_pol=sp_num)
-        self.torch_ga = torchga.TorchGA(model=self.auto_aug_agent, num_solutions=num_solutions)
+        self.controller = controller
+        self.torch_ga = torchga.TorchGA(model=self.controller, num_solutions=num_solutions)
         self.num_parents_mating = num_parents_mating
         self.initial_population = self.torch_ga.population_weights
-        self.train_loader = train_loader
-        self.child_network = child_network
         self.p_bins = p_bins 
         self.sub_num_pol = sp_num
         self.m_bins = m_bins
-        self.fun_num = fun_num
-        self.augmentation_space = [x for x in augmentation_space if x[0] not in exclude_method]
-
+        self.policy_dict = {}
+        self.policy_result = []
 
 
         assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!'
 
 
-
     def get_full_policy(self, x):
         """
         Generates the full policy (self.num_sub_pol subpolicies). Network architecture requires
@@ -79,8 +74,8 @@ class evo_learner(aa_learner):
             Full policy consisting of tuples of subpolicies. Each subpolicy consisting of
             two transformations, with a probability and magnitude float for each
         """
-        section = self.auto_aug_agent.fun_num + self.auto_aug_agent.p_bins + self.auto_aug_agent.m_bins
-        y = self.auto_aug_agent.forward(x)
+        section = self.fun_num + self.p_bins + self.m_bins
+        y = self.controller.forward(x)
         full_policy = []
         for pol in range(self.sub_num_pol):
             int_pol = []
@@ -89,8 +84,22 @@ class evo_learner(aa_learner):
 
                 trans, need_mag = self.augmentation_space[idx_ret]
 
-                p_ret = (1/(self.p_bins-1)) * torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0))
-                mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0)) if need_mag else None
+                if self.p_bins == 1:
+                    p_ret = min(1, max(0, (y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0).item())))
+                    # p_ret = torch.sigmoid(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0))
+                else:
+                    p_ret = torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0).item()) * 0.1
+
+
+                if need_mag:
+                    # print("original mag", y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0))
+                    if self.m_bins == 1:
+                        mag = min(9, max(0, (y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0).item())))
+                    else:
+                        mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0).item())
+                    mag = int(mag)
+                else:
+                    mag = None
                 int_pol.append((trans, p_ret, mag))
 
             full_policy.append(tuple(int_pol))
@@ -117,18 +126,18 @@ class evo_learner(aa_learner):
             Subpolicy consisting of two tuples of policies, each with a string associated 
             to a transformation, a float for a probability, and a float for a magnittude
         """
-        section = self.auto_aug_agent.fun_num + self.auto_aug_agent.p_bins + self.auto_aug_agent.m_bins
+        section = self.fun_num + self.p_bins + self.m_bins
 
-        y = self.auto_aug_agent.forward(x)
+        y = self.controller.forward(x)
 
-        y_1 = torch.softmax(y[:,:self.auto_aug_agent.fun_num], dim = 1) 
-        y[:,:self.auto_aug_agent.fun_num] = y_1
-        y_2 = torch.softmax(y[:,section:section+self.auto_aug_agent.fun_num], dim = 1)
-        y[:,section:section+self.auto_aug_agent.fun_num] = y_2
+        y_1 = torch.softmax(y[:,:self.fun_num], dim = 1) 
+        y[:,:self.fun_num] = y_1
+        y_2 = torch.softmax(y[:,section:section+self.fun_num], dim = 1)
+        y[:,section:section+self.fun_num] = y_2
         concat = torch.cat((y_1, y_2), dim = 1)
 
         cov_mat = torch.cov(concat.T)
-        cov_mat = cov_mat[:self.auto_aug_agent.fun_num, self.auto_aug_agent.fun_num:]
+        cov_mat = cov_mat[:self.fun_num, self.fun_num:]
         shape_store = cov_mat.shape
 
         counter, prob1, prob2, mag1, mag2 = (0, 0, 0, 0, 0)
@@ -154,26 +163,29 @@ class evo_learner(aa_learner):
     
         for idx in range(y.shape[0]):
             if (torch.argmax(y_1[idx]) == max_idx[0]) and (torch.argmax(y_2[idx]) == max_idx[1]):
-                prob1 += torch.sigmoid(y[idx, self.auto_aug_agent.fun_num]).item()
-                prob2 += torch.sigmoid(y[idx, section+self.auto_aug_agent.fun_num]).item()
+                prob1 += torch.sigmoid(y[idx, self.fun_num]).item()
+                prob2 += torch.sigmoid(y[idx, section+self.fun_num]).item()
                 if mag1 is not None:
-                    mag1 += min(max(0, (y[idx, self.auto_aug_agent.fun_num+1]).item()), 8)
+                    # mag1 += min(max(0, (y[idx, self.auto_aug_agent.fun_num+1]).item()), 8)
+                    mag1 += 10 * torch.sigmoid(y[idx, self.fun_num+1]).item()
                 if mag2 is not None:
-                    mag2 += min(max(0, y[idx, section+self.auto_aug_agent.fun_num+1].item()), 8)
+                    # mag2 += min(max(0, y[idx, section+self.auto_aug_agent.fun_num+1].item()), 8)
+                    mag2 += 10 * torch.sigmoid(y[idx, self.fun_num+1]).item()
+
                 counter += 1
 
-        prob1 = prob1/counter if counter != 0 else 0
-        prob2 = prob2/counter if counter != 0 else 0
+        prob1 = round(prob1/counter, 1) if counter != 0 else 0
+        prob2 = round(prob2/counter, 1) if counter != 0 else 0
         if mag1 is not None:
-            mag1 = mag1/counter 
+            mag1 = int(mag1/counter)
         if mag2 is not None:
-            mag2 = mag2/counter    
+            mag2 = int(mag2/counter)  
 
         
-        return [(self.augmentation_space[max_idx[0]][0], prob1, mag1), (self.augmentation_space[max_idx[1]][0], prob2, mag2)]
+        return [((self.augmentation_space[max_idx[0]][0], prob1, mag1), (self.augmentation_space[max_idx[1]][0], prob2, mag2))]
 
 
-    def learn(self, iterations = 15, return_weights = False):
+    def learn(self, train_dataset, test_dataset, child_network_architecture, iterations = 15, return_weights = False):
         """
         Runs the GA instance and returns the model weights as a dictionary
 
@@ -195,24 +207,52 @@ class evo_learner(aa_learner):
             Solution_idx -> Int
         """
         self.num_generations = iterations
-        self.history_best = [0 for i in range(iterations)]
-        self.history_avg = [0 for i in range(iterations)]
+        self.history_best = [0 for i in range(iterations+1)]
+        print("itations: ", iterations)
+
+        self.history_avg = [0 for i in range(iterations+1)]
         self.gen_count = 0
         self.best_model = 0
 
-        self.set_up_instance()
+        self.set_up_instance(train_dataset, test_dataset, child_network_architecture)
+        print("train_dataset: ", train_dataset)
 
         self.ga_instance.run()
-        self.history_avg = self.history_avg / self.num_solutions
+        self.history_avg = [x / self.num_solutions for x in self.history_avg]
+        print("-----------------------------------------------------------------------------------------------------")
 
         solution, solution_fitness, solution_idx = self.ga_instance.best_solution()
         if return_weights:
-            return torchga.model_weights_as_dict(model=self.auto_aug_agent, weights_vector=solution)
+            return torchga.model_weights_as_dict(model=self.controller, weights_vector=solution)
         else:
             return solution, solution_fitness, solution_idx
 
 
-    def set_up_instance(self, train_dataset, test_dataset):
+    def in_pol_dict(self, new_policy):
+        new_policy = new_policy[0]
+        trans1, trans2 = new_policy[0][0], new_policy[1][0]
+        new_set = {new_policy[0][1], new_policy[0][2], new_policy[1][1], new_policy[1][2]}
+        if trans1 in self.policy_dict:
+            if trans2 in self.policy_dict[trans1]:
+                for test_pol in self.policy_dict[trans1][trans2]:
+                    if new_set == test_pol:
+                        return True
+                self.policy_dict[trans1][trans2].append(new_set)
+                return False 
+            else:
+                self.policy_dict[trans1][trans2] = [new_set]
+        if trans2 in self.policy_dict:
+            if trans1 in self.policy_dict[trans2]:
+                for test_pol in self.policy_dict[trans2][trans1]:
+                    if new_set == test_pol:
+                        return True
+                self.policy_dict[trans2][trans1].append(new_set)
+                return False 
+            else:
+                self.policy_dict[trans2][trans1] = [new_set]
+
+
+    def set_up_instance(self, train_dataset, test_dataset, child_network_architecture):
         """
         Initialises GA instance, as well as fitness and on_generation functions
         
@@ -233,24 +273,36 @@ class evo_learner(aa_learner):
             fit_val -> float            
             """
 
-            model_weights_dict = torchga.model_weights_as_dict(model=self.auto_aug_agent,
+            model_weights_dict = torchga.model_weights_as_dict(model=self.controller,
                                                             weights_vector=solution)
 
-            self.auto_aug_agent.load_state_dict(model_weights_dict)
+            self.controller.load_state_dict(model_weights_dict)
             self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size)
 
             for idx, (test_x, label_x) in enumerate(self.train_loader):
-                if self.sp_num == 1:
-                    full_policy = self.get_single_policy_cov(test_x)
-                else:                    
-                    full_policy = self.get_full_policy(test_x)
+                # if self.sp_num == 1:
+                full_policy = self.get_single_policy_cov(test_x)
+
+
+                # else:                      
+                # full_policy = self.get_full_policy(test_x)
+                while self.in_pol_dict(full_policy):
+                    full_policy = self.get_single_policy_cov(test_x)[0]
 
-# Checkpoint -> save learner as a pickle 
 
-            fit_val = ((self.test_autoaugment_policy(full_policy, train_dataset, test_dataset)[0]) /
-                        + self.test_autoaugment_policy(full_policy, train_dataset, test_dataset)[0]) / 2
+            fit_val = self.test_autoaugment_policy(full_policy,child_network_architecture,train_dataset,test_dataset) #) /
+                      #  + self.test_autoaugment_policy(full_policy, train_dataset, test_dataset)) / 2
+
+            self.policy_result.append([full_policy, fit_val])
+
+            if len(self.policy_result) > self.sp_num:
+                self.policy_result = sorted(self.policy_result, key=lambda x: x[1], reverse=True)
+                self.policy_result = self.policy_result[:self.sp_num]
+                print("Appended policy: ", self.policy_result)
+
 
             if fit_val > self.history_best[self.gen_count]:
+                print("Best policy: ", full_policy)
                 self.history_best[self.gen_count] = fit_val 
                 self.best_model = model_weights_dict
             
@@ -284,6 +336,3 @@ class evo_learner(aa_learner):
             fitness_func=fitness_func,
             on_generation = on_generation)
 
-
-
-