diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py
index 48d4f051410ce67e1593167c239284061e48953b..e4460cbfdca799022cd2f1d1ff950cd780355fa1 100644
--- a/MetaAugment/autoaugment_learners/aa_learner.py
+++ b/MetaAugment/autoaugment_learners/aa_learner.py
@@ -46,7 +46,6 @@ class aa_learner:
     def __init__(self, 
                 # parameters that define the search space
                 sp_num=5,
-                fun_num=14,
                 p_bins=11,
                 m_bins=10,
                 discrete_p_m=False,
@@ -57,6 +56,7 @@ class aa_learner:
                 learning_rate=1e-1,
                 max_epochs=float('inf'),
                 early_stop_num=20,
+                exclude_method = [],
                 ):
         """
         Args:
@@ -84,11 +84,9 @@ class aa_learner:
         """
         # related to defining the search space
         self.sp_num = sp_num
-        self.fun_num = fun_num
         self.p_bins = p_bins
         self.m_bins = m_bins
         self.discrete_p_m = discrete_p_m
-        self.op_tensor_length = fun_num+p_bins+m_bins if discrete_p_m else fun_num+2
 
         # related to training of the child_network
         self.batch_size = batch_size
@@ -101,6 +99,9 @@ class aa_learner:
 
         # TODO: We should probably use a different way to store results than self.history
         self.history = []
+        self.augmentation_space = [x for x in augmentation_space if x not in exclude_method]
+        self.fun_num = len(augmentation_space)
+        self.op_tensor_length = self.fun_num + p_bins + m_bins if discrete_p_m else self.fun_num +2
 
 
     def translate_operation_tensor(self, operation_tensor, return_log_prob=False, argmax=False):
@@ -309,7 +310,8 @@ class aa_learner:
                                 child_network_architecture,
                                 train_dataset,
                                 test_dataset,
-                                logging=False):
+                                logging=False,
+                                print_every_epoch=True):
         """
         Given a policy (using AutoAugment paper terminology), we train a child network
         using the policy and return the accuracy (how good the policy is for the dataset and 
@@ -384,7 +386,7 @@ class aa_learner:
                                     max_epochs = self.max_epochs, 
                                     early_stop_num = self.early_stop_num, 
                                     logging = logging,
-                                    print_every_epoch=True)
+                                    print_every_epoch=print_every_epoch)
         
         # if logging is true, 'accuracy' is actually a tuple: (accuracy, accuracy_log)
         return accuracy
diff --git a/MetaAugment/autoaugment_learners/autoaugment.py b/MetaAugment/autoaugment_learners/autoaugment.py
index 8e10c74547f7230a0eeecf11356804413721f7c1..5a8ecbcf6f0b8c6212a8c034a70d61476f4870f6 100644
--- a/MetaAugment/autoaugment_learners/autoaugment.py
+++ b/MetaAugment/autoaugment_learners/autoaugment.py
@@ -238,6 +238,8 @@ class AutoAugment(torch.nn.Module):
             if probs[i] <= p:
                 op_meta = self._augmentation_space(10, F.get_image_size(img))
                 magnitudes, signed = op_meta[op_name]
+                print("magnitude_id: ", magnitude_id)
+                print("magnitudes[magnitude_id]: ", magnitudes[magnitude_id])
                 magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0
                 if signed and signs[i] == 0:
                     magnitude *= -1.0
diff --git a/MetaAugment/autoaugment_learners/evo_learner.py b/MetaAugment/autoaugment_learners/evo_learner.py
index 18ecf751e614585c7db86902eb3cce927dd696f5..e9a65865c46b786005c01b4ff9d19d418baaa988 100644
--- a/MetaAugment/autoaugment_learners/evo_learner.py
+++ b/MetaAugment/autoaugment_learners/evo_learner.py
@@ -6,34 +6,31 @@ import pygad
 import pygad.torchga as torchga
 import copy
 import torch
-from MetaAugment.controller_networks.evo_controller import evo_controller
+
+from MetaAugment.autoaugment_learners.aa_learner import aa_learner, augmentation_space
 import MetaAugment.child_networks as cn
-from .aa_learner import aa_learner, augmentation_space
 
 
 class evo_learner(aa_learner):
 
     def __init__(self, 
                 sp_num=1,
-                num_solutions = 10, 
-                num_parents_mating = 5,
+                num_solutions = 5, 
+                num_parents_mating = 3,
                 learning_rate = 1e-1, 
                 max_epochs=float('inf'),
                 early_stop_num=20,
-                train_loader = None, 
-                child_network = None, 
                 p_bins = 1, 
                 m_bins = 1, 
                 discrete_p_m=False,
                 batch_size=8,
                 toy_flag=False,
                 toy_size=0.1,
-                fun_num = 14,
                 exclude_method=[],
+                controller = None
                 ):
 
         super().__init__(sp_num, 
-            fun_num, 
             p_bins, 
             m_bins, 
             discrete_p_m=discrete_p_m, 
@@ -42,27 +39,24 @@ class evo_learner(aa_learner):
             toy_size=toy_size, 
             learning_rate=learning_rate,
             max_epochs=max_epochs,
-            early_stop_num=early_stop_num,)
+            early_stop_num=early_stop_num,
+            exclude_method=exclude_method)
 
         self.num_solutions = num_solutions
-        self.auto_aug_agent = evo_controller(fun_num=fun_num, p_bins=p_bins, m_bins=m_bins, sub_num_pol=sp_num)
-        self.torch_ga = torchga.TorchGA(model=self.auto_aug_agent, num_solutions=num_solutions)
+        self.controller = controller
+        self.torch_ga = torchga.TorchGA(model=self.controller, num_solutions=num_solutions)
         self.num_parents_mating = num_parents_mating
         self.initial_population = self.torch_ga.population_weights
-        self.train_loader = train_loader
-        self.child_network = child_network
         self.p_bins = p_bins 
         self.sub_num_pol = sp_num
         self.m_bins = m_bins
-        self.fun_num = fun_num
-        self.augmentation_space = [x for x in augmentation_space if x[0] not in exclude_method]
-
+        self.policy_dict = {}
+        self.policy_result = []
 
 
         assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!'
 
 
-
     def get_full_policy(self, x):
         """
         Generates the full policy (self.num_sub_pol subpolicies). Network architecture requires
@@ -79,8 +73,8 @@ class evo_learner(aa_learner):
             Full policy consisting of tuples of subpolicies. Each subpolicy consisting of
             two transformations, with a probability and magnitude float for each
         """
-        section = self.auto_aug_agent.fun_num + self.auto_aug_agent.p_bins + self.auto_aug_agent.m_bins
-        y = self.auto_aug_agent.forward(x)
+        section = self.fun_num + self.p_bins + self.m_bins
+        y = self.controller.forward(x)
         full_policy = []
         for pol in range(self.sub_num_pol):
             int_pol = []
@@ -89,8 +83,22 @@ class evo_learner(aa_learner):
 
                 trans, need_mag = self.augmentation_space[idx_ret]
 
-                p_ret = (1/(self.p_bins-1)) * torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0))
-                mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0)) if need_mag else None
+                if self.p_bins == 1:
+                    p_ret = min(1, max(0, (y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0).item())))
+                    # p_ret = torch.sigmoid(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0))
+                else:
+                    p_ret = torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0).item()) * 0.1
+
+
+                if need_mag:
+                    # print("original mag", y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0))
+                    if self.m_bins == 1:
+                        mag = min(9, max(0, (y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0).item())))
+                    else:
+                        mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0).item())
+                    mag = int(mag)
+                else:
+                    mag = None
                 int_pol.append((trans, p_ret, mag))
 
             full_policy.append(tuple(int_pol))
@@ -117,18 +125,18 @@ class evo_learner(aa_learner):
             Subpolicy consisting of two tuples of policies, each with a string associated 
             to a transformation, a float for a probability, and a float for a magnittude
         """
-        section = self.auto_aug_agent.fun_num + self.auto_aug_agent.p_bins + self.auto_aug_agent.m_bins
+        section = self.fun_num + self.p_bins + self.m_bins
 
-        y = self.auto_aug_agent.forward(x)
+        y = self.controller.forward(x)
 
-        y_1 = torch.softmax(y[:,:self.auto_aug_agent.fun_num], dim = 1) 
-        y[:,:self.auto_aug_agent.fun_num] = y_1
-        y_2 = torch.softmax(y[:,section:section+self.auto_aug_agent.fun_num], dim = 1)
-        y[:,section:section+self.auto_aug_agent.fun_num] = y_2
+        y_1 = torch.softmax(y[:,:self.fun_num], dim = 1) 
+        y[:,:self.fun_num] = y_1
+        y_2 = torch.softmax(y[:,section:section+self.fun_num], dim = 1)
+        y[:,section:section+self.fun_num] = y_2
         concat = torch.cat((y_1, y_2), dim = 1)
 
         cov_mat = torch.cov(concat.T)
-        cov_mat = cov_mat[:self.auto_aug_agent.fun_num, self.auto_aug_agent.fun_num:]
+        cov_mat = cov_mat[:self.fun_num, self.fun_num:]
         shape_store = cov_mat.shape
 
         counter, prob1, prob2, mag1, mag2 = (0, 0, 0, 0, 0)
@@ -154,26 +162,29 @@ class evo_learner(aa_learner):
     
         for idx in range(y.shape[0]):
             if (torch.argmax(y_1[idx]) == max_idx[0]) and (torch.argmax(y_2[idx]) == max_idx[1]):
-                prob1 += torch.sigmoid(y[idx, self.auto_aug_agent.fun_num]).item()
-                prob2 += torch.sigmoid(y[idx, section+self.auto_aug_agent.fun_num]).item()
+                prob1 += torch.sigmoid(y[idx, self.fun_num]).item()
+                prob2 += torch.sigmoid(y[idx, section+self.fun_num]).item()
                 if mag1 is not None:
-                    mag1 += min(max(0, (y[idx, self.auto_aug_agent.fun_num+1]).item()), 8)
+                    # mag1 += min(max(0, (y[idx, self.auto_aug_agent.fun_num+1]).item()), 8)
+                    mag1 += 10 * torch.sigmoid(y[idx, self.fun_num+1]).item()
                 if mag2 is not None:
-                    mag2 += min(max(0, y[idx, section+self.auto_aug_agent.fun_num+1].item()), 8)
+                    # mag2 += min(max(0, y[idx, section+self.auto_aug_agent.fun_num+1].item()), 8)
+                    mag2 += 10 * torch.sigmoid(y[idx, self.fun_num+1]).item()
+
                 counter += 1
 
-        prob1 = prob1/counter if counter != 0 else 0
-        prob2 = prob2/counter if counter != 0 else 0
+        prob1 = round(prob1/counter, 1) if counter != 0 else 0
+        prob2 = round(prob2/counter, 1) if counter != 0 else 0
         if mag1 is not None:
-            mag1 = mag1/counter 
+            mag1 = int(mag1/counter)
         if mag2 is not None:
-            mag2 = mag2/counter    
+            mag2 = int(mag2/counter)  
 
         
-        return [(self.augmentation_space[max_idx[0]][0], prob1, mag1), (self.augmentation_space[max_idx[1]][0], prob2, mag2)]
+        return [((self.augmentation_space[max_idx[0]][0], prob1, mag1), (self.augmentation_space[max_idx[1]][0], prob2, mag2))]
 
 
-    def learn(self, iterations = 15, return_weights = False):
+    def learn(self, train_dataset, test_dataset, child_network_architecture, iterations = 15, return_weights = False):
         """
         Runs the GA instance and returns the model weights as a dictionary
 
@@ -195,24 +206,52 @@ class evo_learner(aa_learner):
             Solution_idx -> Int
         """
         self.num_generations = iterations
-        self.history_best = [0 for i in range(iterations)]
-        self.history_avg = [0 for i in range(iterations)]
+        self.history_best = [0 for i in range(iterations+1)]
+        print("itations: ", iterations)
+
+        self.history_avg = [0 for i in range(iterations+1)]
         self.gen_count = 0
         self.best_model = 0
 
-        self.set_up_instance()
+        self.set_up_instance(train_dataset, test_dataset, child_network_architecture)
+        print("train_dataset: ", train_dataset)
 
         self.ga_instance.run()
-        self.history_avg = self.history_avg / self.num_solutions
+        self.history_avg = [x / self.num_solutions for x in self.history_avg]
+        print("-----------------------------------------------------------------------------------------------------")
 
         solution, solution_fitness, solution_idx = self.ga_instance.best_solution()
         if return_weights:
-            return torchga.model_weights_as_dict(model=self.auto_aug_agent, weights_vector=solution)
+            return torchga.model_weights_as_dict(model=self.controller, weights_vector=solution)
         else:
             return solution, solution_fitness, solution_idx
 
 
-    def set_up_instance(self, train_dataset, test_dataset):
+    def in_pol_dict(self, new_policy):
+        new_policy = new_policy[0]
+        trans1, trans2 = new_policy[0][0], new_policy[1][0]
+        new_set = {new_policy[0][1], new_policy[0][2], new_policy[1][1], new_policy[1][2]}
+        if trans1 in self.policy_dict:
+            if trans2 in self.policy_dict[trans1]:
+                for test_pol in self.policy_dict[trans1][trans2]:
+                    if new_set == test_pol:
+                        return True
+                self.policy_dict[trans1][trans2].append(new_set)
+                return False 
+            else:
+                self.policy_dict[trans1][trans2] = [new_set]
+        if trans2 in self.policy_dict:
+            if trans1 in self.policy_dict[trans2]:
+                for test_pol in self.policy_dict[trans2][trans1]:
+                    if new_set == test_pol:
+                        return True
+                self.policy_dict[trans2][trans1].append(new_set)
+                return False 
+            else:
+                self.policy_dict[trans2][trans1] = [new_set]
+
+
+    def set_up_instance(self, train_dataset, test_dataset, child_network_architecture):
         """
         Initialises GA instance, as well as fitness and on_generation functions
         
@@ -233,24 +272,36 @@ class evo_learner(aa_learner):
             fit_val -> float            
             """
 
-            model_weights_dict = torchga.model_weights_as_dict(model=self.auto_aug_agent,
+            model_weights_dict = torchga.model_weights_as_dict(model=self.controller,
                                                             weights_vector=solution)
 
-            self.auto_aug_agent.load_state_dict(model_weights_dict)
+            self.controller.load_state_dict(model_weights_dict)
             self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size)
 
             for idx, (test_x, label_x) in enumerate(self.train_loader):
-                if self.sp_num == 1:
-                    full_policy = self.get_single_policy_cov(test_x)
-                else:                    
-                    full_policy = self.get_full_policy(test_x)
+                # if self.sp_num == 1:
+                full_policy = self.get_single_policy_cov(test_x)
+
+
+                # else:                      
+                # full_policy = self.get_full_policy(test_x)
+                while self.in_pol_dict(full_policy):
+                    full_policy = self.get_single_policy_cov(test_x)[0]
 
-# Checkpoint -> save learner as a pickle 
 
-            fit_val = ((self.test_autoaugment_policy(full_policy, train_dataset, test_dataset)[0]) /
-                        + self.test_autoaugment_policy(full_policy, train_dataset, test_dataset)[0]) / 2
+            fit_val = self.test_autoaugment_policy(full_policy,child_network_architecture,train_dataset,test_dataset) #) /
+                      #  + self.test_autoaugment_policy(full_policy, train_dataset, test_dataset)) / 2
+
+            self.policy_result.append([full_policy, fit_val])
+
+            if len(self.policy_result) > self.sp_num:
+                self.policy_result = sorted(self.policy_result, key=lambda x: x[1], reverse=True)
+                self.policy_result = self.policy_result[:self.sp_num]
+                print("Appended policy: ", self.policy_result)
+
 
             if fit_val > self.history_best[self.gen_count]:
+                print("Best policy: ", full_policy)
                 self.history_best[self.gen_count] = fit_val 
                 self.best_model = model_weights_dict
             
@@ -284,6 +335,3 @@ class evo_learner(aa_learner):
             fitness_func=fitness_func,
             on_generation = on_generation)
 
-
-
-
diff --git a/MetaAugment/autoaugment_learners/gru_learner.py b/MetaAugment/autoaugment_learners/gru_learner.py
index c06edec316eed6982272abc685d6e02735e92adf..db8205d5f335f056f82b0e40557a73031ad72b1a 100644
--- a/MetaAugment/autoaugment_learners/gru_learner.py
+++ b/MetaAugment/autoaugment_learners/gru_learner.py
@@ -47,7 +47,6 @@ class gru_learner(aa_learner):
     def __init__(self,
                 # parameters that define the search space
                 sp_num=5,
-                fun_num=14,
                 p_bins=11,
                 m_bins=10,
                 discrete_p_m=False,
@@ -78,10 +77,10 @@ class gru_learner(aa_learner):
             print('Warning: Incompatible discrete_p_m=True input into gru_learner. \
                 discrete_p_m=False will be used')
         
-        super().__init__(sp_num, 
-                fun_num, 
-                p_bins, 
-                m_bins, 
+        super().__init__(
+                sp_num=sp_num, 
+                p_bins=p_bins, 
+                m_bins=m_bins, 
                 discrete_p_m=True, 
                 batch_size=batch_size, 
                 toy_flag=toy_flag, 
diff --git a/MetaAugment/autoaugment_learners/randomsearch_learner.py b/MetaAugment/autoaugment_learners/randomsearch_learner.py
index 6541cd3f54980254d0001c969bf2eb90d57b0ad2..09f6626f8a42a35e5006c79188fef3d2947c6418 100644
--- a/MetaAugment/autoaugment_learners/randomsearch_learner.py
+++ b/MetaAugment/autoaugment_learners/randomsearch_learner.py
@@ -38,7 +38,6 @@ class randomsearch_learner(aa_learner):
     def __init__(self,
                 # parameters that define the search space
                 sp_num=5,
-                fun_num=14,
                 p_bins=11,
                 m_bins=10,
                 discrete_p_m=True,
@@ -51,10 +50,9 @@ class randomsearch_learner(aa_learner):
                 early_stop_num=30,
                 ):
         
-        super().__init__(sp_num, 
-                fun_num, 
-                p_bins, 
-                m_bins, 
+        super().__init__(sp_num=sp_num, 
+                p_bins=p_bins, 
+                m_bins=m_bins, 
                 discrete_p_m=discrete_p_m,
                 batch_size=batch_size,
                 toy_flag=toy_flag,
diff --git a/MetaAugment/autoaugment_learners/ucb_learner.py b/MetaAugment/autoaugment_learners/ucb_learner.py
index 1a4ddf3a0d7d218ac768645d64154f90cd07d134..dc82c2ee75d22dd503f46212dd7251c79bb271db 100644
--- a/MetaAugment/autoaugment_learners/ucb_learner.py
+++ b/MetaAugment/autoaugment_learners/ucb_learner.py
@@ -1,9 +1,3 @@
-#!/usr/bin/env python
-# coding: utf-8
-
-# In[1]:
-
-
 import numpy as np
 import torch
 import torch.nn as nn
@@ -26,7 +20,6 @@ class ucb_learner(randomsearch_learner):
     def __init__(self,
                 # parameters that define the search space
                 sp_num=5,
-                fun_num=14,
                 p_bins=11,
                 m_bins=10,
                 discrete_p_m=True,
@@ -42,7 +35,6 @@ class ucb_learner(randomsearch_learner):
                 ):
         
         super().__init__(sp_num=sp_num, 
-                        fun_num=14,
                         p_bins=p_bins, 
                         m_bins=m_bins, 
                         discrete_p_m=discrete_p_m,
@@ -53,23 +45,24 @@ class ucb_learner(randomsearch_learner):
                         max_epochs=max_epochs,
                         early_stop_num=early_stop_num,)
         
-        self.num_policies = num_policies
 
-        # When this learner is initialized we generate `num_policies` number
-        # of random policies. 
-        # generate_new_policy is inherited from the randomsearch_learner class
-        self.policies = []
-        self.make_more_policies()
+        
 
         # attributes used in the UCB1 algorithm
-        self.q_values = [0]*self.num_policies
-        self.best_q_values = []
+        self.num_policies = num_policies
+
+        self.policies = [self.generate_new_policy() for _ in range(num_policies)]
+
+        self.avg_accs = [None]*self.num_policies
+        self.best_avg_accs = []
+
         self.cnts = [0]*self.num_policies
         self.q_plus_cnt = [0]*self.num_policies
         self.total_count = 0
 
 
 
+
     def make_more_policies(self, n):
         """generates n more random policies and adds it to self.policies
 
@@ -78,50 +71,71 @@ class ucb_learner(randomsearch_learner):
                     and add to our list of policies
         """
 
-        self.policies.append([self.generate_new_policy() for _ in n])
+        self.policies += [self.generate_new_policy() for _ in range(n)]
+
+        # all the below need to be lengthened to store information for the 
+        # new policies
+        self.avg_accs += [None for _ in range(n)]
+        self.cnts += [0 for _ in range(n)]
+        self.q_plus_cnt += [None for _ in range(n)]
+        self.num_policies += n
+
 
 
     def learn(self, 
             train_dataset, 
             test_dataset, 
             child_network_architecture, 
-            iterations=15):
+            iterations=15,
+            print_every_epoch=False):
+        """continue the UCB algorithm for `iterations` number of turns
 
+        """
 
         for this_iter in trange(iterations):
 
-            # get the action to try (either initially in order or using best q_plus_cnt value)
-            # TODO: change this if statemetn
-            if this_iter >= self.num_policies:
-                this_policy_idx = np.argmax(self.q_plus_cnt)
+            # choose which policy we want to test
+            if None in self.avg_accs:
+                # if there is a policy we haven't tested yet, we 
+                # test that one
+                this_policy_idx = self.avg_accs.index(None)
                 this_policy = self.policies[this_policy_idx]
-            else:
-                this_policy = this_iter
-
-
-            best_acc = self.test_autoaugment_policy(
+                acc = self.test_autoaugment_policy(
                                 this_policy,
                                 child_network_architecture,
                                 train_dataset,
                                 test_dataset,
-                                logging=False
+                                logging=False,
+                                print_every_epoch=print_every_epoch
                                 )
-
-            # update q_values
-            # TODO: change this if statemetn
-            if this_iter < self.num_policies:
-                self.q_values[this_policy_idx] += best_acc
+                # update q_values (average accuracy)
+                self.avg_accs[this_policy_idx] = acc
             else:
-                self.q_values[this_policy_idx] = (self.q_values[this_policy_idx]*self.cnts[this_policy_idx] + best_acc) / (self.cnts[this_policy_idx] + 1)
-
-            best_q_value = max(self.q_values)
-            self.best_q_values.append(best_q_value)
-
+                # if we have tested all policies before, we test the
+                # one with the best q_plus_cnt value
+                this_policy_idx = np.argmax(self.q_plus_cnt)
+                this_policy = self.policies[this_policy_idx]
+                acc = self.test_autoaugment_policy(
+                                this_policy,
+                                child_network_architecture,
+                                train_dataset,
+                                test_dataset,
+                                logging=False,
+                                print_every_epoch=print_every_epoch
+                                )
+                # update q_values (average accuracy)
+                self.avg_accs[this_policy_idx] = (self.avg_accs[this_policy_idx]*self.cnts[this_policy_idx] + acc) / (self.cnts[this_policy_idx] + 1)
+    
+            # logging the best avg acc up to now
+            best_avg_acc = max([x for x in self.avg_accs if x is not None])
+            self.best_avg_accs.append(best_avg_acc)
+
+            # print progress for user
             if (this_iter+1) % 5 == 0:
                 print("Iteration: {},\tQ-Values: {}, Best this_iter: {}".format(
                                 this_iter+1, 
-                                list(np.around(np.array(self.q_values),2)), 
-                                max(list(np.around(np.array(self.q_values),2)))
+                                list(np.around(np.array(self.avg_accs),2)), 
+                                max(list(np.around(np.array(self.avg_accs),2)))
                                 )
                     )
 
@@ -130,10 +144,11 @@ class ucb_learner(randomsearch_learner):
             self.total_count += 1
 
             # update q_plus_cnt values every turn after the initial sweep through
-            # TODO: change this if statemetn
-            if this_iter >= self.num_policies - 1:
-                for i in range(self.num_policies):
-                    self.q_plus_cnt[i] = self.q_values[i] + np.sqrt(2*np.log(self.total_count)/self.cnts[i])
+            for i in range(self.num_policies):
+                if self.avg_accs[i] is not None:
+                    self.q_plus_cnt[i] = self.avg_accs[i] + np.sqrt(2*np.log(self.total_count)/self.cnts[i])
+            
+            print(self.cnts)
 
             
 
diff --git a/temp_util/wapp_util.py b/temp_util/wapp_util.py
index 78be118ae9f3143d907cb8b0940bc6283a3e82ac..e48d1c31c44d2e5d6af548eeb15b957094abac17 100644
--- a/temp_util/wapp_util.py
+++ b/temp_util/wapp_util.py
@@ -17,13 +17,16 @@ from MetaAugment.main import create_toy
 import pickle
 
 def parse_users_learner_spec(
+            # aalearner type
             auto_aug_learner, 
+            # search space settings
             ds, 
             ds_name, 
             exclude_method, 
             num_funcs, 
             num_policies, 
             num_sub_policies, 
+            # child network settings
             toy_size, 
             IsLeNet, 
             batch_size, 
diff --git a/test/MetaAugment/test_aa_learner.py b/test/MetaAugment/test_aa_learner.py
index 3e2808702a04746e625acd5b463cfe01f56687bd..29af4f6da149a9619bafe30ba03cabe6b77064a7 100644
--- a/test/MetaAugment/test_aa_learner.py
+++ b/test/MetaAugment/test_aa_learner.py
@@ -25,13 +25,12 @@ def test_translate_operation_tensor():
 
         softmax = torch.nn.Softmax(dim=0)
 
-        fun_num = random.randint(1, 14)
+        fun_num=14
         p_bins = random.randint(2, 15)
         m_bins = random.randint(2, 15)
-
+        
         agent = aal.aa_learner(
                 sp_num=5,
-                fun_num=fun_num,
                 p_bins=p_bins,
                 m_bins=m_bins,
                 discrete_p_m=True
@@ -54,13 +53,12 @@ def test_translate_operation_tensor():
     for i in range(2000):
         
 
-        fun_num = random.randint(1, 14)
+        fun_num = 14
         p_bins = random.randint(1, 15)
         m_bins = random.randint(1, 15)
 
         agent = aal.aa_learner(
                 sp_num=5,
-                fun_num=fun_num,
                 p_bins=p_bins,
                 m_bins=m_bins,
                 discrete_p_m=False
@@ -81,7 +79,6 @@ def test_translate_operation_tensor():
 def test_test_autoaugment_policy():
     agent = aal.aa_learner(
                 sp_num=5,
-                fun_num=14,
                 p_bins=11,
                 m_bins=10,
                 discrete_p_m=True,
diff --git a/test/MetaAugment/test_gru_learner.py b/test/MetaAugment/test_gru_learner.py
index 6ad8204f9b8473482f00d5c5d6a9d1e391cf9e0b..b5c695cfdf2d988408d70d1379af4fbf7738ae15 100644
--- a/test/MetaAugment/test_gru_learner.py
+++ b/test/MetaAugment/test_gru_learner.py
@@ -14,13 +14,11 @@ def test_generate_new_policy():
     """
     for _ in range(40):
         sp_num = random.randint(1,20)
-        fun_num = random.randint(1, 14)
         p_bins = random.randint(2, 15)
         m_bins = random.randint(2, 15)
 
         agent = aal.gru_learner(
             sp_num=sp_num,
-            fun_num=fun_num,
             p_bins=p_bins,
             m_bins=m_bins,
             cont_mb_size=2
diff --git a/test/MetaAugment/test_randomsearch_learner.py b/test/MetaAugment/test_randomsearch_learner.py
index 5b67d98e1f2e40d56b3aac2445f041f1372bbe9f..29cd812b1d428441d405556b53db3e65e2ab7bc6 100644
--- a/test/MetaAugment/test_randomsearch_learner.py
+++ b/test/MetaAugment/test_randomsearch_learner.py
@@ -16,13 +16,12 @@ def test_generate_new_policy():
     def my_test(discrete_p_m):
         for _ in range(40):
             sp_num = random.randint(1,20)
-            fun_num = random.randint(1, 14)
+            
             p_bins = random.randint(2, 15)
             m_bins = random.randint(2, 15)
 
             agent = aal.randomsearch_learner(
                 sp_num=sp_num,
-                fun_num=fun_num,
                 p_bins=p_bins,
                 m_bins=m_bins,
                 discrete_p_m=discrete_p_m
diff --git a/test/MetaAugment/test_ucb_learner.py b/test/MetaAugment/test_ucb_learner.py
index 564ac80dff999467f6bf91fbc4c55019e0b86d98..7c6635ffe467e9a6cd4beb3b596380c76446b750 100644
--- a/test/MetaAugment/test_ucb_learner.py
+++ b/test/MetaAugment/test_ucb_learner.py
@@ -1,7 +1,18 @@
 import MetaAugment.autoaugment_learners as aal
-
+import MetaAugment.child_networks as cn
+import torchvision
+import torchvision.datasets as datasets
+from pprint import pprint
 
 def test_ucb_learner():
+    child_network_architecture = cn.SimpleNet
+    train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train',
+                            train=True, download=True, transform=None)
+    test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test', 
+                            train=False, download=True,
+                            transform=torchvision.transforms.ToTensor())
+
+
     learner = aal.ucb_learner(
         # parameters that define the search space
                 sp_num=5,
@@ -10,15 +21,37 @@ def test_ucb_learner():
                 discrete_p_m=True,
                 # hyperparameters for when training the child_network
                 batch_size=8,
-                toy_flag=False,
-                toy_size=0.1,
+                toy_flag=True,
+                toy_size=0.001,
                 learning_rate=1e-1,
                 max_epochs=float('inf'),
                 early_stop_num=30,
                 # ucb_learner specific hyperparameter
-                num_policies=100
+                num_policies=3
     )
-    print(learner.policies)
+    pprint(learner.policies)
+    assert len(learner.policies)==len(learner.avg_accs), \
+                (len(learner.policies), (len(learner.avg_accs)))
+
+    # learn on the 3 policies we generated
+    learner.learn(
+        train_dataset=train_dataset,
+        test_dataset=test_dataset,
+        child_network_architecture=child_network_architecture,
+        iterations=5
+        )
+    
+    # let's say we want to explore more policies:
+    # we generate more new policies
+    learner.make_more_policies(n=4)
+
+    # and let's explore how good those are as well
+    learner.learn(
+        train_dataset=train_dataset,
+        test_dataset=test_dataset,
+        child_network_architecture=child_network_architecture,
+        iterations=7
+        )
 
 if __name__=="__main__":
-    test_ucb_learner()
\ No newline at end of file
+    test_ucb_learner()