diff --git a/MetaAugment/autoaugment_learners/__init__.py b/MetaAugment/autoaugment_learners/__init__.py
index e844c5e0ae157061ceb54251b1fd3ecdbdea77c8..700f73591c6f0309de84d82cc3609de7c54a396e 100644
--- a/MetaAugment/autoaugment_learners/__init__.py
+++ b/MetaAugment/autoaugment_learners/__init__.py
@@ -1,4 +1,5 @@
 from .aa_learner import *
 from .randomsearch_learner import *
 from .gru_learner import *
-from .evo_learner import *
\ No newline at end of file
+from .evo_learner import *
+from .ucb_learner import *
\ No newline at end of file
diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py
index 48d4f051410ce67e1593167c239284061e48953b..561222a5fac35d5348f99da6a9fba31657afe133 100644
--- a/MetaAugment/autoaugment_learners/aa_learner.py
+++ b/MetaAugment/autoaugment_learners/aa_learner.py
@@ -1,4 +1,3 @@
-from numpy import isin
 import torch
 import torch.nn as nn
 import torch.optim as optim
@@ -7,31 +6,10 @@ from MetaAugment.autoaugment_learners.autoaugment import AutoAugment
 
 import torchvision.transforms as transforms
 
-from pprint import pprint
-import matplotlib.pyplot as plt
 import copy
 import types
 
 
-# We will use this augmentation_space temporarily. Later on we will need to 
-# make sure we are able to add other image functions if the users want.
-augmentation_space = [
-            # (function_name, do_we_need_to_specify_magnitude)
-            ("ShearX", True),
-            ("ShearY", True),
-            ("TranslateX", True),
-            ("TranslateY", True),
-            ("Rotate", True),
-            ("Brightness", True),
-            ("Color", True),
-            ("Contrast", True),
-            ("Sharpness", True),
-            ("Posterize", True),
-            ("Solarize", True),
-            ("AutoContrast", False),
-            ("Equalize", False),
-            ("Invert", False),
-        ]
 
 
 class aa_learner:
@@ -46,17 +24,16 @@ class aa_learner:
     def __init__(self, 
                 # parameters that define the search space
                 sp_num=5,
-                fun_num=14,
                 p_bins=11,
                 m_bins=10,
                 discrete_p_m=False,
                 # hyperparameters for when training the child_network
                 batch_size=32,
-                toy_flag=False,
-                toy_size=0.1,
+                toy_size=1,
                 learning_rate=1e-1,
                 max_epochs=float('inf'),
                 early_stop_num=20,
+                exclude_method = [],
                 ):
         """
         Args:
@@ -74,7 +51,6 @@ class aa_learner:
                             algorithm, etc.). Defaults to False
             
             batch_size (int, optional): child_network training parameter. Defaults to 32.
-            toy_flag (bool, optional): child_network training parameter. Defaults to False.
             toy_size (int, optional): child_network training parameter. ratio of original
                                 dataset used in toy dataset. Defaults to 0.1.
             learning_rate (float, optional): child_network training parameter. Defaults to 1e-2.
@@ -84,15 +60,12 @@ class aa_learner:
         """
         # related to defining the search space
         self.sp_num = sp_num
-        self.fun_num = fun_num
         self.p_bins = p_bins
         self.m_bins = m_bins
         self.discrete_p_m = discrete_p_m
-        self.op_tensor_length = fun_num+p_bins+m_bins if discrete_p_m else fun_num+2
 
         # related to training of the child_network
         self.batch_size = batch_size
-        self.toy_flag = toy_flag
         self.toy_size = toy_size
         self.learning_rate = learning_rate
 
@@ -102,6 +75,31 @@ class aa_learner:
         # TODO: We should probably use a different way to store results than self.history
         self.history = []
 
+        # this is the full augmentation space. We take out some image functions
+        # if the user specifies so in the exclude_method parameter
+        augmentation_space = [
+            # (function_name, do_we_need_to_specify_magnitude)
+            ("ShearX", True),
+            ("ShearY", True),
+            ("TranslateX", True),
+            ("TranslateY", True),
+            ("Rotate", True),
+            ("Brightness", True),
+            ("Color", True),
+            ("Contrast", True),
+            ("Sharpness", True),
+            ("Posterize", True),
+            ("Solarize", True),
+            ("AutoContrast", False),
+            ("Equalize", False),
+            ("Invert", False),
+        ]
+        self.exclude_method = exclude_method
+        self.augmentation_space = [x for x in augmentation_space if x[0] not in exclude_method]
+
+        self.fun_num = len(self.augmentation_space)
+        self.op_tensor_length = self.fun_num + p_bins + m_bins if discrete_p_m else self.fun_num +2
+
 
     def translate_operation_tensor(self, operation_tensor, return_log_prob=False, argmax=False):
         """
@@ -176,7 +174,7 @@ class aa_learner:
                 prob_idx = torch.multinomial(prob_t, 1).item() # 0 <= p <= 10
                 mag = torch.multinomial(mag_t, 1).item() # 0 <= m <= 9
 
-            function = augmentation_space[fun_idx][0]
+            function = self.augmentation_space[fun_idx][0]
             prob = prob_idx/(self.p_bins-1)
 
             indices = (fun_idx, prob_idx, mag)
@@ -205,13 +203,13 @@ class aa_learner:
             prob = round(prob, 1) # round to nearest first decimal digit
             mag = round(mag) # round to nearest integer
             
-        function = augmentation_space[fun_idx][0]
+        function = self.augmentation_space[fun_idx][0]
 
         assert 0 <= prob <= 1, prob
         assert 0 <= mag <= self.m_bins-1, (mag, self.m_bins)
         
         # if the image function does not require a magnitude, we set the magnitude to None
-        if augmentation_space[fun_idx][1] == True: # if the image function has a magnitude
+        if self.augmentation_space[fun_idx][1] == True: # if the image function has a magnitude
             operation = (function, prob, mag)
         else:
             operation =  (function, prob, None)
@@ -298,7 +296,7 @@ class aa_learner:
                 reward = self.test_autoaugment_policy(policy,
                                         child_network_architecture,
                                         train_dataset,
-                                        test_dataset, toy_flag)
+                                        test_dataset)
 
                 self.history.append((policy, reward))
         """
@@ -309,7 +307,8 @@ class aa_learner:
                                 child_network_architecture,
                                 train_dataset,
                                 test_dataset,
-                                logging=False):
+                                logging=False,
+                                print_every_epoch=True):
         """
         Given a policy (using AutoAugment paper terminology), we train a child network
         using the policy and return the accuracy (how good the policy is for the dataset and 
@@ -324,8 +323,6 @@ class aa_learner:
                                 of it.
             train_dataset (torchvision.dataset.vision.VisionDataset)
             test_dataset (torchvision.dataset.vision.VisionDataset)
-            toy_flag (boolean): Whether we want to obtain a toy version of 
-                            train_dataset and test_dataset and use those.
             logging (boolean): Whether we want to save logs
         
         Returns:
@@ -359,17 +356,11 @@ class aa_learner:
         train_dataset.transform = train_transform
 
         # create Dataloader objects out of the Dataset objects
-        if self.toy_flag:
-            train_loader, test_loader = create_toy(train_dataset,
-                                                test_dataset,
-                                                batch_size=self.batch_size,
-                                                n_samples=self.toy_size,
-                                                seed=100)
-        else:
-            train_loader = torch.utils.data.DataLoader(train_dataset, 
-                                                batch_size=self.batch_size)
-            test_loader = torch.utils.data.DataLoader(test_dataset, 
-                                                batch_size=self.batch_size)
+        train_loader, test_loader = create_toy(train_dataset,
+                                            test_dataset,
+                                            batch_size=self.batch_size,
+                                            n_samples=self.toy_size,
+                                            seed=100)
         
         # train the child network with the dataloaders equipped with our specific policy
         accuracy = train_child_network(child_network, 
@@ -384,45 +375,44 @@ class aa_learner:
                                     max_epochs = self.max_epochs, 
                                     early_stop_num = self.early_stop_num, 
                                     logging = logging,
-                                    print_every_epoch=True)
+                                    print_every_epoch=print_every_epoch)
         
         # if logging is true, 'accuracy' is actually a tuple: (accuracy, accuracy_log)
         return accuracy
     
 
-    def demo_plot(self, train_dataset, test_dataset, child_network_architecture, n=5):
-        """
-        I made this to plot a couple of accuracy graphs to help manually tune my gradient 
-        optimizer hyperparameters.
+    # def demo_plot(self, train_dataset, test_dataset, child_network_architecture, n=5):
+    #     """
+    #     I made this to plot a couple of accuracy graphs to help manually tune my gradient 
+    #     optimizer hyperparameters.
 
-        Saves a plot of `n` training accuracy graphs overlapped.
-        """
+    #     Saves a plot of `n` training accuracy graphs overlapped.
+    #     """
         
-        acc_lists = []
-
-        # This is dummy code
-        # test out `n` random policies
-        for _ in range(n):
-            policy = self.generate_new_policy()
-
-            pprint(policy)
-            reward, acc_list = self.test_autoaugment_policy(policy,
-                                                child_network_architecture,
-                                                train_dataset,
-                                                test_dataset,
-                                                toy_flag=self.toy_flag,
-                                                logging=True)
-
-            self.history.append((policy, reward))
-            acc_lists.append(acc_list)
-
-        for acc_list in acc_lists:
-            plt.plot(acc_list)
-        plt.title('I ran 5 random policies to see if there is any sign of \
-                    catastrophic failure during training. If there are \
-                    any lines which reach significantly lower (>10%) \
-                    accuracies, you might want to tune the hyperparameters')
-        plt.xlabel('epoch')
-        plt.ylabel('accuracy')
-        plt.show()
-        plt.savefig('training_graphs_without_policies')
\ No newline at end of file
+    #     acc_lists = []
+
+    #     # This is dummy code
+    #     # test out `n` random policies
+    #     for _ in range(n):
+    #         policy = self.generate_new_policy()
+
+    #         pprint(policy)
+    #         reward, acc_list = self.test_autoaugment_policy(policy,
+    #                                             child_network_architecture,
+    #                                             train_dataset,
+    #                                             test_dataset,
+    #                                             logging=True)
+
+    #         self.history.append((policy, reward))
+    #         acc_lists.append(acc_list)
+
+    #     for acc_list in acc_lists:
+    #         plt.plot(acc_list)
+    #     plt.title('I ran 5 random policies to see if there is any sign of \
+    #                 catastrophic failure during training. If there are \
+    #                 any lines which reach significantly lower (>10%) \
+    #                 accuracies, you might want to tune the hyperparameters')
+    #     plt.xlabel('epoch')
+    #     plt.ylabel('accuracy')
+    #     plt.show()
+    #     plt.savefig('training_graphs_without_policies')
\ No newline at end of file
diff --git a/MetaAugment/autoaugment_learners/autoaugment.py b/MetaAugment/autoaugment_learners/autoaugment.py
index 8e10c74547f7230a0eeecf11356804413721f7c1..5a8ecbcf6f0b8c6212a8c034a70d61476f4870f6 100644
--- a/MetaAugment/autoaugment_learners/autoaugment.py
+++ b/MetaAugment/autoaugment_learners/autoaugment.py
@@ -238,6 +238,8 @@ class AutoAugment(torch.nn.Module):
             if probs[i] <= p:
                 op_meta = self._augmentation_space(10, F.get_image_size(img))
                 magnitudes, signed = op_meta[op_name]
+                print("magnitude_id: ", magnitude_id)
+                print("magnitudes[magnitude_id]: ", magnitudes[magnitude_id])
                 magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0
                 if signed and signs[i] == 0:
                     magnitude *= -1.0
diff --git a/MetaAugment/autoaugment_learners/evo_learner.py b/MetaAugment/autoaugment_learners/evo_learner.py
index 1ff576b28a4d39f367550bb2fa15168e0a9b99c8..34cc2d44555423475914a1ba2528cfddb71aad57 100644
--- a/MetaAugment/autoaugment_learners/evo_learner.py
+++ b/MetaAugment/autoaugment_learners/evo_learner.py
@@ -1,4 +1,3 @@
-from cgi import test
 import torch
 torch.manual_seed(0)
 import torch.nn as nn
@@ -6,64 +5,59 @@ import pygad
 import pygad.torchga as torchga
 import copy
 import torch
-from MetaAugment.controller_networks.evo_controller import evo_controller
 
-from MetaAugment.autoaugment_learners.aa_learner import aa_learner, augmentation_space
-import MetaAugment.child_networks as cn
+from MetaAugment.autoaugment_learners.aa_learner import aa_learner
 
 
-class evo_learner():
+class evo_learner(aa_learner):
 
     def __init__(self, 
-                sp_num=1,
-                num_solutions = 10, 
-                num_parents_mating = 5,
-                learning_rate = 1e-1, 
+                # search space settings
+                sp_num=5,
+                p_bins=10, 
+                m_bins=10, 
+                discrete_p_m=False,
+                exclude_method=[],
+                # child network settings
+                learning_rate=1e-1, 
                 max_epochs=float('inf'),
                 early_stop_num=20,
-                train_loader = None, 
-                child_network = None, 
-                p_bins = 1, 
-                m_bins = 1, 
-                discrete_p_m=False,
                 batch_size=8,
-                toy_flag=False,
-                toy_size=0.1,
-                fun_num = 14,
-                exclude_method=[],
+                toy_size=1,
+                # evolutionary learner specific settings
+                num_solutions=5,
+                num_parents_mating=3,
+                controller=None
                 ):
 
-        super().__init__(sp_num, 
-            fun_num, 
-            p_bins, 
-            m_bins, 
-            discrete_p_m=discrete_p_m, 
-            batch_size=batch_size, 
-            toy_flag=toy_flag, 
-            toy_size=toy_size, 
-            learning_rate=learning_rate,
-            max_epochs=max_epochs,
-            early_stop_num=early_stop_num,)
+        super().__init__(
+                    sp_num=sp_num, 
+                    p_bins=p_bins, 
+                    m_bins=m_bins, 
+                    discrete_p_m=discrete_p_m, 
+                    batch_size=batch_size, 
+                    toy_size=toy_size, 
+                    learning_rate=learning_rate,
+                    max_epochs=max_epochs,
+                    early_stop_num=early_stop_num,
+                    exclude_method=exclude_method
+                    )
 
         self.num_solutions = num_solutions
-        self.auto_aug_agent = evo_controller(fun_num=fun_num, p_bins=p_bins, m_bins=m_bins, sub_num_pol=sp_num)
-        self.torch_ga = torchga.TorchGA(model=self.auto_aug_agent, num_solutions=num_solutions)
+        self.controller = controller
+        self.torch_ga = torchga.TorchGA(model=self.controller, num_solutions=num_solutions)
         self.num_parents_mating = num_parents_mating
         self.initial_population = self.torch_ga.population_weights
-        self.train_loader = train_loader
-        self.child_network = child_network
         self.p_bins = p_bins 
         self.sub_num_pol = sp_num
         self.m_bins = m_bins
-        self.fun_num = fun_num
-        self.augmentation_space = [x for x in augmentation_space if x[0] not in exclude_method]
-
+        self.policy_dict = {}
+        self.policy_result = []
 
 
         assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!'
 
 
-
     def get_full_policy(self, x):
         """
         Generates the full policy (self.num_sub_pol subpolicies). Network architecture requires
@@ -80,8 +74,8 @@ class evo_learner():
             Full policy consisting of tuples of subpolicies. Each subpolicy consisting of
             two transformations, with a probability and magnitude float for each
         """
-        section = self.auto_aug_agent.fun_num + self.auto_aug_agent.p_bins + self.auto_aug_agent.m_bins
-        y = self.auto_aug_agent.forward(x)
+        section = self.fun_num + self.p_bins + self.m_bins
+        y = self.controller.forward(x)
         full_policy = []
         for pol in range(self.sub_num_pol):
             int_pol = []
@@ -90,8 +84,22 @@ class evo_learner():
 
                 trans, need_mag = self.augmentation_space[idx_ret]
 
-                p_ret = (1/(self.p_bins-1)) * torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0))
-                mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0)) if need_mag else None
+                if self.p_bins == 1:
+                    p_ret = min(1, max(0, (y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0).item())))
+                    # p_ret = torch.sigmoid(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0))
+                else:
+                    p_ret = torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0).item()) * 0.1
+
+
+                if need_mag:
+                    # print("original mag", y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0))
+                    if self.m_bins == 1:
+                        mag = min(9, max(0, (y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0).item())))
+                    else:
+                        mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0).item())
+                    mag = int(mag)
+                else:
+                    mag = None
                 int_pol.append((trans, p_ret, mag))
 
             full_policy.append(tuple(int_pol))
@@ -118,18 +126,18 @@ class evo_learner():
             Subpolicy consisting of two tuples of policies, each with a string associated 
             to a transformation, a float for a probability, and a float for a magnittude
         """
-        section = self.auto_aug_agent.fun_num + self.auto_aug_agent.p_bins + self.auto_aug_agent.m_bins
+        section = self.fun_num + self.p_bins + self.m_bins
 
-        y = self.auto_aug_agent.forward(x)
+        y = self.controller.forward(x)
 
-        y_1 = torch.softmax(y[:,:self.auto_aug_agent.fun_num], dim = 1) 
-        y[:,:self.auto_aug_agent.fun_num] = y_1
-        y_2 = torch.softmax(y[:,section:section+self.auto_aug_agent.fun_num], dim = 1)
-        y[:,section:section+self.auto_aug_agent.fun_num] = y_2
+        y_1 = torch.softmax(y[:,:self.fun_num], dim = 1) 
+        y[:,:self.fun_num] = y_1
+        y_2 = torch.softmax(y[:,section:section+self.fun_num], dim = 1)
+        y[:,section:section+self.fun_num] = y_2
         concat = torch.cat((y_1, y_2), dim = 1)
 
         cov_mat = torch.cov(concat.T)
-        cov_mat = cov_mat[:self.auto_aug_agent.fun_num, self.auto_aug_agent.fun_num:]
+        cov_mat = cov_mat[:self.fun_num, self.fun_num:]
         shape_store = cov_mat.shape
 
         counter, prob1, prob2, mag1, mag2 = (0, 0, 0, 0, 0)
@@ -155,26 +163,29 @@ class evo_learner():
     
         for idx in range(y.shape[0]):
             if (torch.argmax(y_1[idx]) == max_idx[0]) and (torch.argmax(y_2[idx]) == max_idx[1]):
-                prob1 += torch.sigmoid(y[idx, self.auto_aug_agent.fun_num]).item()
-                prob2 += torch.sigmoid(y[idx, section+self.auto_aug_agent.fun_num]).item()
+                prob1 += torch.sigmoid(y[idx, self.fun_num]).item()
+                prob2 += torch.sigmoid(y[idx, section+self.fun_num]).item()
                 if mag1 is not None:
-                    mag1 += min(max(0, (y[idx, self.auto_aug_agent.fun_num+1]).item()), 8)
+                    # mag1 += min(max(0, (y[idx, self.auto_aug_agent.fun_num+1]).item()), 8)
+                    mag1 += 10 * torch.sigmoid(y[idx, self.fun_num+1]).item()
                 if mag2 is not None:
-                    mag2 += min(max(0, y[idx, section+self.auto_aug_agent.fun_num+1].item()), 8)
+                    # mag2 += min(max(0, y[idx, section+self.auto_aug_agent.fun_num+1].item()), 8)
+                    mag2 += 10 * torch.sigmoid(y[idx, self.fun_num+1]).item()
+
                 counter += 1
 
-        prob1 = prob1/counter if counter != 0 else 0
-        prob2 = prob2/counter if counter != 0 else 0
+        prob1 = round(prob1/counter, 1) if counter != 0 else 0
+        prob2 = round(prob2/counter, 1) if counter != 0 else 0
         if mag1 is not None:
-            mag1 = mag1/counter 
+            mag1 = int(mag1/counter)
         if mag2 is not None:
-            mag2 = mag2/counter    
+            mag2 = int(mag2/counter)  
 
         
-        return [(self.augmentation_space[max_idx[0]][0], prob1, mag1), (self.augmentation_space[max_idx[1]][0], prob2, mag2)]
+        return [((self.augmentation_space[max_idx[0]][0], prob1, mag1), (self.augmentation_space[max_idx[1]][0], prob2, mag2))]
 
 
-    def learn(self, iterations = 15, return_weights = False):
+    def learn(self, train_dataset, test_dataset, child_network_architecture, iterations = 15, return_weights = False):
         """
         Runs the GA instance and returns the model weights as a dictionary
 
@@ -196,24 +207,52 @@ class evo_learner():
             Solution_idx -> Int
         """
         self.num_generations = iterations
-        self.history_best = [0 for i in range(iterations)]
-        self.history_avg = [0 for i in range(iterations)]
+        self.history_best = [0 for i in range(iterations+1)]
+        print("itations: ", iterations)
+
+        self.history_avg = [0 for i in range(iterations+1)]
         self.gen_count = 0
         self.best_model = 0
 
-        self.set_up_instance()
+        self.set_up_instance(train_dataset, test_dataset, child_network_architecture)
+        print("train_dataset: ", train_dataset)
 
         self.ga_instance.run()
-        self.history_avg = self.history_avg / self.num_solutions
+        self.history_avg = [x / self.num_solutions for x in self.history_avg]
+        print("-----------------------------------------------------------------------------------------------------")
 
         solution, solution_fitness, solution_idx = self.ga_instance.best_solution()
         if return_weights:
-            return torchga.model_weights_as_dict(model=self.auto_aug_agent, weights_vector=solution)
+            return torchga.model_weights_as_dict(model=self.controller, weights_vector=solution)
         else:
             return solution, solution_fitness, solution_idx
 
 
-    def set_up_instance(self, train_dataset, test_dataset):
+    def in_pol_dict(self, new_policy):
+        new_policy = new_policy[0]
+        trans1, trans2 = new_policy[0][0], new_policy[1][0]
+        new_set = {new_policy[0][1], new_policy[0][2], new_policy[1][1], new_policy[1][2]}
+        if trans1 in self.policy_dict:
+            if trans2 in self.policy_dict[trans1]:
+                for test_pol in self.policy_dict[trans1][trans2]:
+                    if new_set == test_pol:
+                        return True
+                self.policy_dict[trans1][trans2].append(new_set)
+                return False 
+            else:
+                self.policy_dict[trans1][trans2] = [new_set]
+        if trans2 in self.policy_dict:
+            if trans1 in self.policy_dict[trans2]:
+                for test_pol in self.policy_dict[trans2][trans1]:
+                    if new_set == test_pol:
+                        return True
+                self.policy_dict[trans2][trans1].append(new_set)
+                return False 
+            else:
+                self.policy_dict[trans2][trans1] = [new_set]
+
+
+    def set_up_instance(self, train_dataset, test_dataset, child_network_architecture):
         """
         Initialises GA instance, as well as fitness and on_generation functions
         
@@ -234,24 +273,36 @@ class evo_learner():
             fit_val -> float            
             """
 
-            model_weights_dict = torchga.model_weights_as_dict(model=self.auto_aug_agent,
+            model_weights_dict = torchga.model_weights_as_dict(model=self.controller,
                                                             weights_vector=solution)
 
-            self.auto_aug_agent.load_state_dict(model_weights_dict)
+            self.controller.load_state_dict(model_weights_dict)
             self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size)
 
             for idx, (test_x, label_x) in enumerate(self.train_loader):
-                if self.sp_num == 1:
-                    full_policy = self.get_single_policy_cov(test_x)
-                else:                    
-                    full_policy = self.get_full_policy(test_x)
+                # if self.sp_num == 1:
+                full_policy = self.get_single_policy_cov(test_x)
 
-# Checkpoint -> save learner as a pickle 
 
-            fit_val = ((self.test_autoaugment_policy(full_policy, train_dataset, test_dataset)[0]) /
-                        + self.test_autoaugment_policy(full_policy, train_dataset, test_dataset)[0]) / 2
+                # else:                      
+                # full_policy = self.get_full_policy(test_x)
+                while self.in_pol_dict(full_policy):
+                    full_policy = self.get_single_policy_cov(test_x)[0]
+
+
+            fit_val = self.test_autoaugment_policy(full_policy,child_network_architecture,train_dataset,test_dataset) #) /
+                      #  + self.test_autoaugment_policy(full_policy, train_dataset, test_dataset)) / 2
+
+            self.policy_result.append([full_policy, fit_val])
+
+            if len(self.policy_result) > self.sp_num:
+                self.policy_result = sorted(self.policy_result, key=lambda x: x[1], reverse=True)
+                self.policy_result = self.policy_result[:self.sp_num]
+                print("Appended policy: ", self.policy_result)
+
 
             if fit_val > self.history_best[self.gen_count]:
+                print("Best policy: ", full_policy)
                 self.history_best[self.gen_count] = fit_val 
                 self.best_model = model_weights_dict
             
@@ -285,6 +336,3 @@ class evo_learner():
             fitness_func=fitness_func,
             on_generation = on_generation)
 
-
-
-
diff --git a/MetaAugment/autoaugment_learners/gru_learner.py b/MetaAugment/autoaugment_learners/gru_learner.py
index c06edec316eed6982272abc685d6e02735e92adf..5c15a4a41b086982aa543bda451c89bfa7eecba9 100644
--- a/MetaAugment/autoaugment_learners/gru_learner.py
+++ b/MetaAugment/autoaugment_learners/gru_learner.py
@@ -9,25 +9,6 @@ import pickle
 
 
 
-# We will use this augmentation_space temporarily. Later on we will need to 
-# make sure we are able to add other image functions if the users want.
-augmentation_space = [
-            # (function_name, do_we_need_to_specify_magnitude)
-            ("ShearX", True),
-            ("ShearY", True),
-            ("TranslateX", True),
-            ("TranslateY", True),
-            ("Rotate", True),
-            ("Brightness", True),
-            ("Color", True),
-            ("Contrast", True),
-            ("Sharpness", True),
-            ("Posterize", True),
-            ("Solarize", True),
-            ("AutoContrast", False),
-            ("Equalize", False),
-            ("Invert", False),
-        ]
 
 
 class gru_learner(aa_learner):
@@ -47,14 +28,13 @@ class gru_learner(aa_learner):
     def __init__(self,
                 # parameters that define the search space
                 sp_num=5,
-                fun_num=14,
                 p_bins=11,
                 m_bins=10,
                 discrete_p_m=False,
+                exclude_method=[],
                 # hyperparameters for when training the child_network
                 batch_size=8,
-                toy_flag=False,
-                toy_size=0.1,
+                toy_size=1,
                 learning_rate=1e-1,
                 max_epochs=float('inf'),
                 early_stop_num=20,
@@ -78,17 +58,18 @@ class gru_learner(aa_learner):
             print('Warning: Incompatible discrete_p_m=True input into gru_learner. \
                 discrete_p_m=False will be used')
         
-        super().__init__(sp_num, 
-                fun_num, 
-                p_bins, 
-                m_bins, 
+        super().__init__(
+                sp_num=sp_num, 
+                p_bins=p_bins, 
+                m_bins=m_bins, 
                 discrete_p_m=True, 
                 batch_size=batch_size, 
-                toy_flag=toy_flag, 
                 toy_size=toy_size, 
                 learning_rate=learning_rate,
                 max_epochs=max_epochs,
-                early_stop_num=early_stop_num,)
+                early_stop_num=early_stop_num,
+                exclude_method=exclude_method,
+                )
 
         # GRU-specific attributes that aren't in general aa_learner's
         self.alpha = alpha
@@ -245,7 +226,6 @@ if __name__=='__main__':
 
     agent = gru_learner(
                         sp_num=7,
-                        toy_flag=True,
                         toy_size=0.01,
                         batch_size=32,
                         learning_rate=0.1,
diff --git a/MetaAugment/autoaugment_learners/randomsearch_learner.py b/MetaAugment/autoaugment_learners/randomsearch_learner.py
index 6541cd3f54980254d0001c969bf2eb90d57b0ad2..2c35fb80ab15f7b2c51dfdcfbfcff942a6a70032 100644
--- a/MetaAugment/autoaugment_learners/randomsearch_learner.py
+++ b/MetaAugment/autoaugment_learners/randomsearch_learner.py
@@ -10,25 +10,7 @@ import pickle
 
 
 
-# We will use this augmentation_space temporarily. Later on we will need to 
-# make sure we are able to add other image functions if the users want.
-augmentation_space = [
-            # (function_name, do_we_need_to_specify_magnitude)
-            ("ShearX", True),
-            ("ShearY", True),
-            ("TranslateX", True),
-            ("TranslateY", True),
-            ("Rotate", True),
-            ("Brightness", True),
-            ("Color", True),
-            ("Contrast", True),
-            ("Sharpness", True),
-            ("Posterize", True),
-            ("Solarize", True),
-            ("AutoContrast", False),
-            ("Equalize", False),
-            ("Invert", False),
-        ]
+
 
 class randomsearch_learner(aa_learner):
     """
@@ -38,30 +20,30 @@ class randomsearch_learner(aa_learner):
     def __init__(self,
                 # parameters that define the search space
                 sp_num=5,
-                fun_num=14,
                 p_bins=11,
                 m_bins=10,
                 discrete_p_m=True,
+                exclude_method=[],
                 # hyperparameters for when training the child_network
                 batch_size=8,
-                toy_flag=False,
-                toy_size=0.1,
+                toy_size=1,
                 learning_rate=1e-1,
                 max_epochs=float('inf'),
                 early_stop_num=30,
                 ):
         
-        super().__init__(sp_num, 
-                fun_num, 
-                p_bins, 
-                m_bins, 
-                discrete_p_m=discrete_p_m,
-                batch_size=batch_size,
-                toy_flag=toy_flag,
-                toy_size=toy_size,
-                learning_rate=learning_rate,
-                max_epochs=max_epochs,
-                early_stop_num=early_stop_num,)
+        super().__init__(
+                    sp_num=sp_num, 
+                    p_bins=p_bins, 
+                    m_bins=m_bins, 
+                    discrete_p_m=discrete_p_m,
+                    batch_size=batch_size,
+                    toy_size=toy_size,
+                    learning_rate=learning_rate,
+                    max_epochs=max_epochs,
+                    early_stop_num=early_stop_num,
+                    exclude_method=exclude_method
+                    )
         
 
     def generate_new_discrete_operation(self):
@@ -187,7 +169,6 @@ if __name__=='__main__':
 
     agent = randomsearch_learner(
                                 sp_num=7,
-                                toy_flag=True,
                                 toy_size=0.01,
                                 batch_size=4,
                                 learning_rate=0.05,
diff --git a/MetaAugment/autoaugment_learners/ucb_learner.py b/MetaAugment/autoaugment_learners/ucb_learner.py
index 41b8977156e9148965b0ffa6c00fe4d0a4a2595d..fdf735bea1916897ee5e28a27bd67c10b5751581 100644
--- a/MetaAugment/autoaugment_learners/ucb_learner.py
+++ b/MetaAugment/autoaugment_learners/ucb_learner.py
@@ -1,224 +1,158 @@
-#!/usr/bin/env python
-# coding: utf-8
-
-# In[1]:
-
-
 import numpy as np
-from sklearn.covariance import log_likelihood
-import torch
-torch.manual_seed(0)
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.optim as optim
-import torch.utils.data as data_utils
-import torchvision
-import torchvision.datasets as datasets
-import pickle
-
-from matplotlib import pyplot as plt
-from numpy import save, load
+
 from tqdm import trange
 
 from ..child_networks import *
-from ..main import create_toy, train_child_network
-
-
-# In[6]:
-
-
-"""Randomly generate 10 policies"""
-"""Each policy has 5 sub-policies"""
-"""For each sub-policy, pick 2 transformations, 2 probabilities and 2 magnitudes"""
-
-def generate_policies(num_policies, num_sub_policies):
-    
-    policies = np.zeros([num_policies,num_sub_policies,6])
+from .randomsearch_learner import randomsearch_learner
+
+
+class ucb_learner(randomsearch_learner):
+    """
+    Tests randomly sampled policies from the search space specified by the AutoAugment
+    paper. Acts as a baseline for other aa_learner's.
+    """
+    def __init__(self,
+                # parameters that define the search space
+                sp_num=5,
+                p_bins=11,
+                m_bins=10,
+                discrete_p_m=True,
+                exclude_method=[],
+                # hyperparameters for when training the child_network
+                batch_size=8,
+                toy_size=1,
+                learning_rate=1e-1,
+                max_epochs=float('inf'),
+                early_stop_num=30,
+                # ucb_learner specific hyperparameter
+                num_policies=100
+                ):
+        
+        super().__init__(
+                        sp_num=sp_num, 
+                        p_bins=p_bins, 
+                        m_bins=m_bins, 
+                        discrete_p_m=discrete_p_m,
+                        batch_size=batch_size,
+                        toy_size=toy_size,
+                        learning_rate=learning_rate,
+                        max_epochs=max_epochs,
+                        early_stop_num=early_stop_num,
+                        exclude_method=exclude_method,
+                        )
+        
 
-    # Policies array will be 10x5x6
-    for policy in range(num_policies):
-        for sub_policy in range(num_sub_policies):
-            # pick two sub_policy transformations (0=rotate, 1=shear, 2=scale)
-            policies[policy, sub_policy, 0] = np.random.randint(0,3)
-            policies[policy, sub_policy, 1] = np.random.randint(0,3)
-            while policies[policy, sub_policy, 0] == policies[policy, sub_policy, 1]:
-                policies[policy, sub_policy, 1] = np.random.randint(0,3)
-
-            # pick probabilities
-            policies[policy, sub_policy, 2] = np.random.randint(0,11) / 10
-            policies[policy, sub_policy, 3] = np.random.randint(0,11) / 10
-
-            # pick magnitudes
-            for transformation in range(2):
-                if policies[policy, sub_policy, transformation] <= 1:
-                    policies[policy, sub_policy, transformation + 4] = np.random.randint(-4,5)*5
-                elif policies[policy, sub_policy, transformation] == 2:
-                    policies[policy, sub_policy, transformation + 4] = np.random.randint(5,15)/10
-
-    return policies
-
-
-# In[7]:
-
-
-"""Pick policy and sub-policy"""
-"""Each row of data should have a different sub-policy but for now, this will do"""
-
-def sample_sub_policy(policies, policy, num_sub_policies):
-    sub_policy = np.random.randint(0,num_sub_policies)
-
-    degrees = 0
-    shear = 0
-    scale = 1
-
-    # check for rotations
-    if policies[policy, sub_policy][0] == 0:
-        if np.random.uniform() < policies[policy, sub_policy][2]:
-            degrees = policies[policy, sub_policy][4]
-    elif policies[policy, sub_policy][1] == 0:
-        if np.random.uniform() < policies[policy, sub_policy][3]:
-            degrees = policies[policy, sub_policy][5]
-
-    # check for shears
-    if policies[policy, sub_policy][0] == 1:
-        if np.random.uniform() < policies[policy, sub_policy][2]:
-            shear = policies[policy, sub_policy][4]
-    elif policies[policy, sub_policy][1] == 1:
-        if np.random.uniform() < policies[policy, sub_policy][3]:
-            shear = policies[policy, sub_policy][5]
-
-    # check for scales
-    if policies[policy, sub_policy][0] == 2:
-        if np.random.uniform() < policies[policy, sub_policy][2]:
-            scale = policies[policy, sub_policy][4]
-    elif policies[policy, sub_policy][1] == 2:
-        if np.random.uniform() < policies[policy, sub_policy][3]:
-            scale = policies[policy, sub_policy][5]
-
-    return degrees, shear, scale
-
-
-# In[8]:
-
-
-"""Sample policy, open and apply above transformations"""
-def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation, iterations, IsLeNet, ds_name=None):
-
-    # get number of policies and sub-policies
-    num_policies = len(policies)
-    num_sub_policies = len(policies[0])
-
-    #Initialize vector weights, counts and regret
-    q_values = [0]*num_policies
-    cnts = [0]*num_policies
-    q_plus_cnt = [0]*num_policies
-    total_count = 0
-
-    best_q_values = []
-
-    for policy in trange(iterations):
-
-        # get the action to try (either initially in order or using best q_plus_cnt value)
-        if policy >= num_policies:
-            this_policy = np.argmax(q_plus_cnt)
-        else:
-            this_policy = policy
-
-        # get info of transformation for this sub-policy
-        degrees, shear, scale = sample_sub_policy(policies, this_policy, num_sub_policies)
-
-        # create transformations using above info
-        transform = torchvision.transforms.Compose(
-            [torchvision.transforms.RandomAffine(degrees=(degrees,degrees), shear=(shear,shear), scale=(scale,scale)),
-            torchvision.transforms.CenterCrop(28), # <--- need to remove after finishing testing
-            torchvision.transforms.ToTensor()])
-
-        # open data and apply these transformations
-        if ds == "MNIST":
-            train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=True, transform=transform)
-            test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=True, transform=transform)
-        elif ds == "KMNIST":
-            train_dataset = datasets.KMNIST(root='./datasets/kmnist/train', train=True, download=True, transform=transform)
-            test_dataset = datasets.KMNIST(root='./datasets/kmnist/test', train=False, download=True, transform=transform)
-        elif ds == "FashionMNIST":
-            train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train', train=True, download=True, transform=transform)
-            test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test', train=False, download=True, transform=transform)
-        elif ds == "CIFAR10":
-            train_dataset = datasets.CIFAR10(root='./datasets/cifar10/train', train=True, download=True, transform=transform)
-            test_dataset = datasets.CIFAR10(root='./datasets/cifar10/test', train=False, download=True, transform=transform)
-        elif ds == "CIFAR100":
-            train_dataset = datasets.CIFAR100(root='./datasets/cifar100/train', train=True, download=True, transform=transform)
-            test_dataset = datasets.CIFAR100(root='./datasets/cifar100/test', train=False, download=True, transform=transform)
-        elif ds == 'Other':
-            dataset = datasets.ImageFolder('./datasets/upload_dataset/'+ ds_name, transform=transform)
-            len_train = int(0.8*len(dataset))
-            train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len_train, len(dataset)-len_train])
-
-        # check sizes of images
-        img_height = len(train_dataset[0][0][0])
-        img_width = len(train_dataset[0][0][0][0])
-        img_channels = len(train_dataset[0][0])
-
-
-        # check output labels
-        if ds == 'Other':
-            num_labels = len(dataset.class_to_idx)
-        elif ds == "CIFAR10" or ds == "CIFAR100":
-            num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1)
-        else:
-            num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1).item()
-
-        # create toy dataset from above uploaded data
-        train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size)
-
-        # create model
-        if torch.cuda.is_available():
-            device='cuda'
-        else:
-            device='cpu'
         
-        if IsLeNet == "LeNet":
-            model = LeNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)
-        elif IsLeNet == "EasyNet":
-            model = EasyNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)
-        elif IsLeNet == 'SimpleNet':
-            model = SimpleNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)
-        else:
-            model = pickle.load(open(f'datasets/childnetwork', "rb"))
 
-        sgd = optim.SGD(model.parameters(), lr=learning_rate)
-        cost = nn.CrossEntropyLoss()
+        # attributes used in the UCB1 algorithm
+        self.num_policies = num_policies
+
+        self.policies = [self.generate_new_policy() for _ in range(num_policies)]
+
+        self.avg_accs = [None]*self.num_policies
+        self.best_avg_accs = []
+
+        self.cnts = [0]*self.num_policies
+        self.q_plus_cnt = [0]*self.num_policies
+        self.total_count = 0
+
+
+
+
+    def make_more_policies(self, n):
+        """generates n more random policies and adds it to self.policies
+
+        Args:
+            n (int): how many more policies to we want to randomly generate
+                    and add to our list of policies
+        """
+
+        self.policies += [self.generate_new_policy() for _ in range(n)]
+
+        # all the below need to be lengthened to store information for the 
+        # new policies
+        self.avg_accs += [None for _ in range(n)]
+        self.cnts += [0 for _ in range(n)]
+        self.q_plus_cnt += [None for _ in range(n)]
+        self.num_policies += n
+
+
+
+    def learn(self, 
+            train_dataset, 
+            test_dataset, 
+            child_network_architecture, 
+            iterations=15,
+            print_every_epoch=False):
+        """continue the UCB algorithm for `iterations` number of turns
+
+        """
+
+        for this_iter in trange(iterations):
+
+            # choose which policy we want to test
+            if None in self.avg_accs:
+                # if there is a policy we haven't tested yet, we 
+                # test that one
+                this_policy_idx = self.avg_accs.index(None)
+                this_policy = self.policies[this_policy_idx]
+                acc = self.test_autoaugment_policy(
+                                this_policy,
+                                child_network_architecture,
+                                train_dataset,
+                                test_dataset,
+                                logging=False,
+                                print_every_epoch=print_every_epoch
+                                )
+                # update q_values (average accuracy)
+                self.avg_accs[this_policy_idx] = acc
+            else:
+                # if we have tested all policies before, we test the
+                # one with the best q_plus_cnt value
+                this_policy_idx = np.argmax(self.q_plus_cnt)
+                this_policy = self.policies[this_policy_idx]
+                acc = self.test_autoaugment_policy(
+                                this_policy,
+                                child_network_architecture,
+                                train_dataset,
+                                test_dataset,
+                                logging=False,
+                                print_every_epoch=print_every_epoch
+                                )
+                # update q_values (average accuracy)
+                self.avg_accs[this_policy_idx] = (self.avg_accs[this_policy_idx]*self.cnts[this_policy_idx] + acc) / (self.cnts[this_policy_idx] + 1)
+    
+            # logging the best avg acc up to now
+            best_avg_acc = max([x for x in self.avg_accs if x is not None])
+            self.best_avg_accs.append(best_avg_acc)
 
-        best_acc = train_child_network(model, train_loader, test_loader, sgd,
-                         cost, max_epochs, early_stop_num, early_stop_flag,
-			 average_validation, logging=False, print_every_epoch=False)
+            # print progress for user
+            if (this_iter+1) % 5 == 0:
+                print("Iteration: {},\tQ-Values: {}, Best this_iter: {}".format(
+                                this_iter+1, 
+                                list(np.around(np.array(self.avg_accs),2)), 
+                                max(list(np.around(np.array(self.avg_accs),2)))
+                                )
+                    )
 
-        # update q_values
-        if policy < num_policies:
-            q_values[this_policy] += best_acc
-        else:
-            q_values[this_policy] = (q_values[this_policy]*cnts[this_policy] + best_acc) / (cnts[this_policy] + 1)
+            # update counts
+            self.cnts[this_policy_idx] += 1
+            self.total_count += 1
 
-        best_q_value = max(q_values)
-        best_q_values.append(best_q_value)
+            # update q_plus_cnt values every turn after the initial sweep through
+            for i in range(self.num_policies):
+                if self.avg_accs[i] is not None:
+                    self.q_plus_cnt[i] = self.avg_accs[i] + np.sqrt(2*np.log(self.total_count)/self.cnts[i])
+            
+            print(self.cnts)
 
-        if (policy+1) % 5 == 0:
-            print("Iteration: {},\tQ-Values: {}, Best Policy: {}".format(policy+1, list(np.around(np.array(q_values),2)), max(list(np.around(np.array(q_values),2)))))
+            
 
-        # update counts
-        cnts[this_policy] += 1
-        total_count += 1
 
-        # update q_plus_cnt values every turn after the initial sweep through
-        if policy >= num_policies - 1:
-            for i in range(num_policies):
-                q_plus_cnt[i] = q_values[i] + np.sqrt(2*np.log(total_count)/cnts[i])
+       
 
-        # yield q_values, best_q_values
-    return q_values, best_q_values
 
 
-# # In[9]:
 
 if __name__=='__main__':
     batch_size = 32       # size of batch the inner NN is trained with
@@ -230,18 +164,6 @@ if __name__=='__main__':
     early_stop_flag = True        # implement early stopping or not
     average_validation = [15,25]  # if not implementing early stopping, what epochs are we averaging over
     num_policies = 5      # fix number of policies
-    num_sub_policies = 5  # fix number of sub-policies in a policy
+    sp_num = 5  # fix number of sub-policies in a policy
     iterations = 100      # total iterations, should be more than the number of policies
-    IsLeNet = "SimpleNet" # using LeNet or EasyNet or SimpleNet
-
-    # generate random policies at start
-    policies = generate_policies(num_policies, num_sub_policies)
-
-    q_values, best_q_values = run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation, iterations, IsLeNet)
-
-    plt.plot(best_q_values)
-
-    best_q_values = np.array(best_q_values)
-    save('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), best_q_values)
-    #best_q_values = load('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), allow_pickle=True)
-
+    IsLeNet = "SimpleNet" # using LeNet or EasyNet or SimpleNet
\ No newline at end of file
diff --git a/MetaAugment/main.py b/MetaAugment/main.py
index af1f311d9b266da4572fbda550a2978342c0aad8..51c9070a8ea9dfd86bd4df76eaabee71b50c2fca 100644
--- a/MetaAugment/main.py
+++ b/MetaAugment/main.py
@@ -11,6 +11,12 @@ import torchvision.datasets as datasets
 
 
 def create_toy(train_dataset, test_dataset, batch_size, n_samples, seed=100):
+    if n_samples==1:
+        # push into DataLoader
+        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
+        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)
+        return train_loader, test_loader
+
     # shuffle and take first n_samples %age of training dataset
     shuffle_order_train = np.random.RandomState(seed=seed).permutation(len(train_dataset))
     shuffled_train_dataset = torch.utils.data.Subset(train_dataset, shuffle_order_train)
diff --git a/backend_react/.flaskenv b/backend_react/.flaskenv
index 5aabce39e19b5f99fbfc93caae30a17a1933c54a..89ab8b9bcc57b84b7f1b783246f01f27a176b3b7 100644
--- a/backend_react/.flaskenv
+++ b/backend_react/.flaskenv
@@ -1,2 +1,3 @@
 FLASK_APP=react_app.py
-FLASK_ENV=development
\ No newline at end of file
+FLASK_ENV=development
+FLASK_DEBUG=1 flask run --no-reload
\ No newline at end of file
diff --git a/backend_react/child_networks/CIFAR100_v2.txt b/backend_react/child_networks/CIFAR100_v2.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d9a62099b8c0bcedfd1a6a3c8b60e238dde0a40e
--- /dev/null
+++ b/backend_react/child_networks/CIFAR100_v2.txt
@@ -0,0 +1 @@
+0.19934545454545455,0.19519090909090908,0.19935454545454545,0.19381818181818183,0.18769999999999998,0.19858181818181822,0.19459090909090906,0.18030000000000002,0.17654545454545453,0.2042909090909091
\ No newline at end of file
diff --git a/backend_react/child_networks/place_holder b/backend_react/child_networks/place_holder
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/backend_react/policy.txt b/backend_react/policy.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ba1f82e0664d8cca0732c7db753ac93c122986de
--- /dev/null
+++ b/backend_react/policy.txt
@@ -0,0 +1 @@
+0.7018545454545454,0.6530636363636364,0.6565090909090909,0.7029727272727273,0.6615000000000001,0.6610181818181818,0.6333545454545454,0.6617909090909091,0.6584636363636364,0.6933909090909091
\ No newline at end of file
diff --git a/backend_react/react_app.py b/backend_react/react_app.py
index 21f5e8a2a9ae99d4b058931f2e912aa116d7ee0b..96da69473fac69057f23cb7e4c1d8df01866e67d 100644
--- a/backend_react/react_app.py
+++ b/backend_react/react_app.py
@@ -1,40 +1,21 @@
 from dataclasses import dataclass
-from flask import Flask, request, current_app, render_template
+from flask import Flask, request, current_app, send_file
 # from flask_cors import CORS
-import subprocess
 import os
 import zipfile
 
-import numpy as np
 import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.optim as optim
-import torch.utils.data as data_utils
-import torchvision
-import torchvision.datasets as datasets
-
-from matplotlib import pyplot as plt
 from numpy import save, load
-from tqdm import trange
-torch.manual_seed(0)
+import temp_util.wapp_util as wapp_util
+import time
 
 import os
 import sys
 sys.path.insert(0, os.path.abspath('..'))
+torch.manual_seed(0)
 
-# # import agents and its functions
-from MetaAugment.autoaugment_learners import ucb_learner as UCB1_JC
-from MetaAugment.autoaugment_learners import evo_learner
-import MetaAugment.controller_networks as cn
-import MetaAugment.autoaugment_learners as aal
 print('@@@ import successful')
 
-# import agents and its functions
-# from ..MetaAugment import UCB1_JC_py as UCB1_JC
-# from ..MetaAugment import Evo_learner as Evo
-# print('@@@ import successful')
-
 app = Flask(__name__)
 
 
@@ -46,199 +27,139 @@ def get_form_data():
     # form_data = request.files['ds_upload'] 
     # print('@@@ form_data', form_data) 
  
-    # form_data = request.form.get('test') 
-    # print('@@@ this is form data', request.get_data())
+    form_data = request.form
+    print('@@@ this is form data', form_data)
 
     # required input
-    # ds = form_data['select_dataset'] # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)
-    # IsLeNet = form_data["select_network"]   # using LeNet or EasyNet or SimpleNet ->> default 
-    # auto_aug_learner = form_data["select_learner"] # augmentation methods to be excluded
-
-    # print('@@@ required user input:', 'ds', ds, 'IsLeNet:', IsLeNet, 'auto_aug_leanrer:',auto_aug_learner)
-    # # advanced input
-    # if 'batch_size' in form_data.keys(): 
-    #     batch_size = form_data['batch_size']       # size of batch the inner NN is trained with
-    # else: 
-    #     batch_size = 1 # this is for demonstration purposes
-    # if 'learning_rate' in form_data.keys(): 
-    #     learning_rate =  form_data['learning_rate']  # fix learning rate
-    # else: 
-    #     learning_rate = 10-1
-    # if 'toy_size' in form_data.keys(): 
-    #     toy_size = form_data['toy_size']      # total propeortion of training and test set we use
-    # else: 
-    #     toy_size = 1 # this is for demonstration purposes
-    # if 'iterations' in form_data.keys(): 
-    #     iterations = form_data['iterations']      # total iterations, should be more than the number of policies
-    # else: 
-    #     iterations = 10
-    # exclude_method = form_data['select_action']
-    # num_funcs = 14 - len(exclude_method)
-    # print('@@@ advanced search: batch_size:', batch_size, 'learning_rate:', learning_rate, 'toy_size:', toy_size, 'iterations:', iterations, 'exclude_method', exclude_method, 'num_funcs', num_funcs)
+    ds = form_data['select_dataset'] # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)
+    IsLeNet = form_data["select_network"]   # using LeNet or EasyNet or SimpleNet ->> default 
+    auto_aug_learner = form_data["select_learner"] # augmentation methods to be excluded
+
+    print('@@@ required user input:', 'ds', ds, 'IsLeNet:', IsLeNet, 'auto_aug_leanrer:',auto_aug_learner)
+    # advanced input
+    if form_data['batch_size'] != 'undefined': 
+        batch_size = form_data['batch_size']       # size of batch the inner NN is trained with
+    else: 
+        batch_size = 1 # this is for demonstration purposes
+    if form_data['learning_rate'] != 'undefined': 
+        learning_rate =  form_data['learning_rate']  # fix learning rate
+    else: 
+        learning_rate = 10-1
+    if form_data['toy_size'] != 'undefined': 
+        toy_size = form_data['toy_size']      # total propeortion of training and test set we use
+    else: 
+        toy_size = 1 # this is for demonstration purposes
+    if form_data['iterations'] != 'undefined': 
+        iterations = form_data['iterations']      # total iterations, should be more than the number of policies
+    else: 
+        iterations = 10
+    exclude_method = form_data['select_action']
+    print('@@@ advanced search: batch_size:', batch_size, 'learning_rate:', learning_rate, 'toy_size:', toy_size, 'iterations:', iterations, 'exclude_method', exclude_method)
     
 
-    # # default values 
-    # max_epochs = 10      # max number of epochs that is run if early stopping is not hit
-    # early_stop_num = 10   # max number of worse validation scores before early stopping is triggered
-    # num_policies = 5      # fix number of policies
-    # num_sub_policies = 5  # fix number of sub-policies in a policy
+    # default values 
+    max_epochs = 10      # max number of epochs that is run if early stopping is not hit
+    early_stop_num = 10   # max number of worse validation scores before early stopping is triggered
+    num_policies = 5      # fix number of policies
+    num_sub_policies = 5  # fix number of sub-policies in a policy
     
     
-    # # if user upload datasets and networks, save them in the database
-    # if ds == 'Other':
-    #     ds_folder = request.files['ds_upload'] 
-    #     print('!!!ds_folder', ds_folder)
-    #     ds_name_zip = ds_folder.filename
-    #     ds_name = ds_name_zip.split('.')[0]
-    #     ds_folder.save('./datasets/'+ ds_name_zip)
-    #     with zipfile.ZipFile('./datasets/'+ ds_name_zip, 'r') as zip_ref:
-    #         zip_ref.extractall('./datasets/upload_dataset/')
-    #     if not current_app.debug:
-    #         os.remove(f'./datasets/{ds_name_zip}')
-    # else: 
-    #     ds_name = None
-
-    # # test if uploaded dataset meets the criteria 
-    # for (dirpath, dirnames, filenames) in os.walk(f'./datasets/upload_dataset/{ds_name}/'):
-    #     for dirname in dirnames:
-    #         if dirname[0:6] != 'class_':
-    #             return None # neet to change render to a 'failed dataset webpage'
-
-    # # save the user uploaded network
-    # if IsLeNet == 'Other':
-    #     childnetwork = request.files['network_upload']
-    #     childnetwork.save('./child_networks/'+childnetwork.filename)
-    #     network_name = childnetwork.filename
+    # if user upload datasets and networks, save them in the database
+    if ds == 'Other':
+        ds_folder = request.files['ds_upload'] 
+        print('!!!ds_folder', ds_folder)
+        ds_name_zip = ds_folder.filename
+        ds_name = ds_name_zip.split('.')[0]
+        ds_folder.save('./datasets/'+ ds_name_zip)
+        with zipfile.ZipFile('./datasets/'+ ds_name_zip, 'r') as zip_ref:
+            zip_ref.extractall('./datasets/upload_dataset/')
+        if not current_app.debug:
+            os.remove(f'./datasets/{ds_name_zip}')
+    else: 
+        ds_name_zip = None
+        ds_name = None
+
+    # test if uploaded dataset meets the criteria 
+    for (dirpath, dirnames, filenames) in os.walk(f'./datasets/upload_dataset/{ds_name}/'):
+        for dirname in dirnames:
+            if dirname[0:6] != 'class_':
+                return None # neet to change render to a 'failed dataset webpage'
+
+    # save the user uploaded network
+    if IsLeNet == 'Other':
+        childnetwork = request.files['network_upload']
+        childnetwork.save('./child_networks/'+childnetwork.filename)
+        network_name = childnetwork.filename
+    else: 
+        network_name = None
 
     
-    # # generate random policies at start
-    # current_app.config['AAL'] = auto_aug_learner
-    # current_app.config['NP'] = num_policies
-    # current_app.config['NSP'] = num_sub_policies
-    # current_app.config['BS'] = batch_size
-    # current_app.config['LR'] = learning_rate
-    # current_app.config['TS'] = toy_size
-    # current_app.config['ME'] = max_epochs
-    # current_app.config['ESN'] = early_stop_num
-    # current_app.config['IT'] = iterations
-    # current_app.config['ISLENET'] = IsLeNet
-    # current_app.config['DSN'] = ds_name
-    # current_app.config['ds'] = ds
+    print("@@@ user input has all stored in the app")
 
-    
-    # print("@@@ user input has all stored in the app")
+    data = {'ds': ds, 'ds_name': ds_name_zip, 'IsLeNet': IsLeNet, 'network_name': network_name,
+            'auto_aug_learner':auto_aug_learner, 'batch_size': batch_size, 'learning_rate': learning_rate, 
+            'toy_size':toy_size, 'iterations':iterations, 'exclude_method': exclude_method, }
 
-    # data = {'ds': ds, 'ds_name': ds_name, 'IsLeNet': IsLeNet, 'ds_folder.filename': ds_name,
-    #         'auto_aug_learner':auto_aug_learner, 'batch_size': batch_size, 'learning_rate': learning_rate, 
-    #         'toy_size':toy_size, 'iterations':iterations, }
+    current_app.config['data'] = data
     
-    # print('@@@ all data sent', data)
-    return {'data': 'show training data'}
+    print('@@@ all data sent', current_app.config['data'])
+
+    # try this if you want it might work, it might not
+    # wapp_util.parse_users_learner_spec(
+    #                         num_policies,
+    #                         num_sub_policies,
+    #                         early_stop_num,
+    #                         max_epochs,
+    #                         **data,
+    #                         )
+
+    return {'data': 'all stored'}
+
+
+
 
+# ========================================================================
 @app.route('/confirm', methods=['POST', 'GET'])
 def confirm():
-    print('inside confirm')
-
-    # aa learner
-    auto_aug_learner = current_app.config.get('AAL')
-
-    # search space & problem setting
-    ds = current_app.config.get('ds')
-    ds_name = current_app.config.get('DSN')
-    exclude_method = current_app.config.get('exc_meth')
-    num_policies = current_app.config.get('NP')
-    num_sub_policies = current_app.config.get('NSP')
-    num_funcs = current_app.config.get('NUMFUN')
-    toy_size = current_app.config.get('TS')
-
-    # child network
-    IsLeNet = current_app.config.get('ISLENET')
-
-    # child network training hyperparameters
-    batch_size = current_app.config.get('BS')
-    early_stop_num = current_app.config.get('ESN')
-    iterations = current_app.config.get('IT')
-    learning_rate = current_app.config.get('LR')
-    max_epochs = current_app.config.get('ME')
-
-    data = {'ds': ds, 'ds_name': ds_name, 'IsLeNet': IsLeNet, 'ds_folder.filename': ds_name,
-            'auto_aug_learner':auto_aug_learner, 'batch_size': batch_size, 'learning_rate': learning_rate, 
-            'toy_size':toy_size, 'iterations':iterations, }
-    return {'batch_size': '12'}
+    print('inside confirm page')
+    data = current_app.config['data']
+    return data
+
+
+
 
 # ========================================================================
 @app.route('/training', methods=['POST', 'GET'])
 def training():
 
-    # aa learner
-    auto_aug_learner = current_app.config.get('AAL')
-
-    # search space & problem setting
-    ds = current_app.config.get('ds')
-    ds_name = current_app.config.get('DSN')
-    exclude_method = current_app.config.get('exc_meth')
-    num_funcs = current_app.config.get('NUMFUN')
-    num_policies = current_app.config.get('NP')
-    num_sub_policies = current_app.config.get('NSP')
-    toy_size = current_app.config.get('TS')
-    
-    # child network
-    IsLeNet = current_app.config.get('ISLENET')
-
-    # child network training hyperparameters
-    batch_size = current_app.config.get('BS')
-    early_stop_num = current_app.config.get('ESN')
-    iterations = current_app.config.get('IT')
-    learning_rate = current_app.config.get('LR')
-    max_epochs = current_app.config.get('ME')
-
-
-    if auto_aug_learner == 'UCB':
-        policies = UCB1_JC.generate_policies(num_policies, num_sub_policies)
-        q_values, best_q_values = UCB1_JC.run_UCB1(
-                                                policies,
-                                                batch_size, 
-                                                learning_rate, 
-                                                ds, 
-                                                toy_size, 
-                                                max_epochs, 
-                                                early_stop_num, 
-                                                iterations, 
-                                                IsLeNet, 
-                                                ds_name
-                                                )     
-        best_q_values = np.array(best_q_values)
-
-    elif auto_aug_learner == 'Evolutionary Learner':
-
-        network = cn.evo_controller.evo_controller(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1)
-        child_network = aal.evo.LeNet()
-        learner = aal.evo.evo_learner(
-                                    network=network, 
-                                    fun_num=num_funcs, 
-                                    p_bins=1, 
-                                    mag_bins=1, 
-                                    sub_num_pol=1, 
-                                    ds = ds, 
-                                    ds_name=ds_name, 
-                                    exclude_method=exclude_method, 
-                                    child_network=child_network
-                                    )
-
-        learner.run_instance()
-    elif auto_aug_learner == 'Random Searcher':
-        pass 
-    elif auto_aug_learner == 'Genetic Learner':
-        pass
-
-    return {'status': 'training'}
+    # default values 
+    max_epochs = 10      # max number of epochs that is run if early stopping is not hit
+    early_stop_num = 10   # max number of worse validation scores before early stopping is triggered
+    num_policies = 5      # fix number of policies
+    num_sub_policies = 5  # fix number of sub-policies in a policy
+    data = current_app.config.get('data')
+
+    # fake training
+    print('pretend it is training')
+    time.sleep(3)
+    print('epoch: 1')
+    time.sleep(3)
+    print('epoch: 2')
+    time.sleep(3) 
+    print('epoch: 3')
+    print('it has finished training')
 
+    return {'status': 'Training is done!'}
 
 
 # ========================================================================
-@app.route('/results')
+@app.route('/result')
 def show_result():
-    return {'status': 'results'}
+    file_path = "./policy.txt"
+    f = open(file_path, "r")
+    return send_file(file_path, as_attachment=True)
+
+
 
 @app.route('/api')
 def index():
@@ -246,4 +167,4 @@ def index():
 
 
 if __name__ == '__main__':
-    app.run(debug=True)
\ No newline at end of file
+    app.run(debug=False, use_reloader=False)
\ No newline at end of file
diff --git a/benchmark/scripts/04_22_ci_gru.py b/benchmark/scripts/04_22_ci_gru.py
index 194a5f235d18ee6bc7fb1c3c2ef56adff4104cfa..155b7c92aa1953a57c66c377b80d6379da66cbb5 100644
--- a/benchmark/scripts/04_22_ci_gru.py
+++ b/benchmark/scripts/04_22_ci_gru.py
@@ -12,8 +12,6 @@ from .util_04_22 import *
 config = {
         'sp_num' : 3,
         'learning_rate' : 1e-1,
-        'toy_flag' : False,
-#         'toy_flag' : True,
 #         'toy_size' : 0.001,
         'batch_size' : 32,
         'max_epochs' : 100,
diff --git a/benchmark/scripts/04_22_ci_rs.py b/benchmark/scripts/04_22_ci_rs.py
index c6dd5f4df9b9fb1cae36d1bb6abd783b604efbae..e1279b1e18038884ca75c5502effe7ff2c7966b1 100644
--- a/benchmark/scripts/04_22_ci_rs.py
+++ b/benchmark/scripts/04_22_ci_rs.py
@@ -12,8 +12,6 @@ from .util_04_22 import *
 config = {
         'sp_num' : 3,
         'learning_rate' : 1e-1,
-        'toy_flag' : False,
-#         'toy_flag' : True,
 #         'toy_size' : 0.001,
         'batch_size' : 32,
         'max_epochs' : 100,
diff --git a/benchmark/scripts/04_22_fm_gru.py b/benchmark/scripts/04_22_fm_gru.py
index 799e439ef22f51cef57b42e807905648800a4710..807d0177ce45845c986a72750dfadd647283ab26 100644
--- a/benchmark/scripts/04_22_fm_gru.py
+++ b/benchmark/scripts/04_22_fm_gru.py
@@ -12,8 +12,6 @@ from .util_04_22 import *
 config = {
         'sp_num' : 3,
         'learning_rate' : 1e-1,
-        'toy_flag' : False,
-#         'toy_flag' : True,
 #         'toy_size' : 0.001,
         'batch_size' : 32,
         'max_epochs' : 100,
diff --git a/benchmark/scripts/04_22_fm_rs.py b/benchmark/scripts/04_22_fm_rs.py
index 6b983284e789873af8ef85f1f07b9ddecd880186..dfe7195831460119a0ddd65225f4b9b4ccd04d63 100644
--- a/benchmark/scripts/04_22_fm_rs.py
+++ b/benchmark/scripts/04_22_fm_rs.py
@@ -12,8 +12,6 @@ from .util_04_22 import *
 config = {
         'sp_num' : 3,
         'learning_rate' : 1e-1,
-        'toy_flag' : False,
-#         'toy_flag' : True,
 #         'toy_size' : 0.001,
         'batch_size' : 32,
         'max_epochs' : 100,
diff --git a/docs/source/usage/tutorial_for_team.rst b/docs/source/usage/tutorial_for_team.rst
index d4cebf46a184c71d8335544fd63bf1ce1275fdf9..1c81cd7cc6133d627c19f6a58c1db8f0971b75d5 100644
--- a/docs/source/usage/tutorial_for_team.rst
+++ b/docs/source/usage/tutorial_for_team.rst
@@ -57,7 +57,6 @@ can use any other learner in place of random search learner as well)
     # aa_agent = aal.ac_learner()
     aa_agent = aal.randomsearch_learner(
                                     sp_num=7,
-                                    toy_flag=True,
                                     toy_size=0.01,
                                     batch_size=4,
                                     learning_rate=0.05,
diff --git a/flask_mvp/app.py b/flask_mvp/app.py
index 5e39517f6ae17dc93910e02f960e7aac0074dd7d..8f71616620872dc04fd66be39753bb45a74ae2e4 100644
--- a/flask_mvp/app.py
+++ b/flask_mvp/app.py
@@ -3,7 +3,8 @@
 #     app.run(host='0.0.0.0',port=port)
 
 from numpy import broadcast
-from auto_augmentation import home, progress,result, training
+from auto_augmentation import home, progress,result
+from flask_mvp.auto_augmentation import training
 from flask_socketio import SocketIO,  send
 
 from flask import Flask, flash, request, redirect, url_for
diff --git a/flask_mvp/auto_augmentation/__init__.py b/flask_mvp/auto_augmentation/__init__.py
index 0899be3d1b979ffc3f6e5a123cdb848470b29feb..72634111728ad96c69734e682ef37bae7c112a75 100644
--- a/flask_mvp/auto_augmentation/__init__.py
+++ b/flask_mvp/auto_augmentation/__init__.py
@@ -2,7 +2,8 @@ import os
 
 from flask import Flask, render_template, request, flash
 
-from auto_augmentation import home, progress,result, training
+from auto_augmentation import home, progress,result
+from flask_mvp.auto_augmentation import training
 
 def create_app(test_config=None):
     # create and configure the app
diff --git a/flask_mvp/auto_augmentation/progress.py b/flask_mvp/auto_augmentation/progress.py
index 4c3e96b28ca42e47eada9913c5008bafb90f5ddb..a0645a0105fb70d48088d9a752c53d341c2650e4 100644
--- a/flask_mvp/auto_augmentation/progress.py
+++ b/flask_mvp/auto_augmentation/progress.py
@@ -1,32 +1,12 @@
 from flask import Blueprint, request, render_template, flash, send_file, current_app, g, session
-import subprocess
 import os
 import zipfile
 
-import numpy as np
 import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.optim as optim
-import torch.utils.data as data_utils
-import torchvision
-import torchvision.datasets as datasets
-
-from matplotlib import pyplot as plt
-from numpy import save, load
-from tqdm import trange
 torch.manual_seed(0)
-# import agents and its functions
 
-from MetaAugment.autoaugment_learners import ucb_learner
-# hi
-from MetaAugment import Evo_learner as Evo
-
-import MetaAugment.autoaugment_learners as aal
-from MetaAugment.main import create_toy
-import MetaAugment.child_networks as cn
-import pickle
 
+import temp_util.wapp_util as wapp_util
 
 bp = Blueprint("progress", __name__)
 
@@ -92,100 +72,20 @@ def response():
         
 
 
-        if auto_aug_learner == 'UCB':
-            policies = ucb_learner.generate_policies(num_policies, num_sub_policies)
-            q_values, best_q_values = ucb_learner.run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet, ds_name)
-        elif auto_aug_learner == 'Evolutionary Learner':
-            learner = Evo.Evolutionary_learner(fun_num=num_funcs, p_bins=1, mag_bins=1, sub_num_pol=1, ds_name=ds_name, exclude_method=exclude_method)
-            learner.run_instance()
-        elif auto_aug_learner == 'Random Searcher':
-            # As opposed to when ucb==True, `ds` and `IsLenet` are processed outside of the agent
-            # This system makes more sense for the user who is not using the webapp and is instead
-            # using the library within their code
-            download = True
-            if ds == "MNIST":
-                train_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/train', train=True, download=download)
-                test_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/test', train=False,
-                                                download=download, transform=torchvision.transforms.ToTensor())
-            elif ds == "KMNIST":
-                train_dataset = datasets.KMNIST(root='./MetaAugment/datasets/kmnist/train', train=True, download=download)
-                test_dataset = datasets.KMNIST(root='./MetaAugment/datasets/kmnist/test', train=False,
-                                                download=download, transform=torchvision.transforms.ToTensor())
-            elif ds == "FashionMNIST":
-                train_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=download)
-                test_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/test', train=False,
-                                                download=download, transform=torchvision.transforms.ToTensor())
-            elif ds == "CIFAR10":
-                train_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/train', train=True, download=download)
-                test_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/test', train=False,
-                                                download=download, transform=torchvision.transforms.ToTensor())
-            elif ds == "CIFAR100":
-                train_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/train', train=True, download=download)
-                test_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/test', train=False,
-                                                download=download, transform=torchvision.transforms.ToTensor())
-            elif ds == 'Other':
-                dataset = datasets.ImageFolder('./MetaAugment/datasets/'+ ds_name)
-                len_train = int(0.8*len(dataset))
-                train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len_train, len(dataset)-len_train])
-
-            # check sizes of images
-            img_height = len(train_dataset[0][0][0])
-            img_width = len(train_dataset[0][0][0][0])
-            img_channels = len(train_dataset[0][0])
-            # check output labels
-            if ds == 'Other':
-                num_labels = len(dataset.class_to_idx)
-            elif ds == "CIFAR10" or ds == "CIFAR100":
-                num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1)
-            else:
-                num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1).item()
-            # create toy dataset from above uploaded data
-            train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size)
-            # create model
-            if IsLeNet == "LeNet":
-                model = cn.LeNet(img_height, img_width, num_labels, img_channels)
-            elif IsLeNet == "EasyNet":
-                model = cn.EasyNet(img_height, img_width, num_labels, img_channels)
-            elif IsLeNet == 'SimpleNet':
-                model = cn.SimpleNet(img_height, img_width, num_labels, img_channels)
-            else:
-                model = pickle.load(open(f'datasets/childnetwork', "rb"))
-
-            # use an aa_learner. in this case, a rs learner
-            agent = aal.randomsearch_learner(batch_size=batch_size,
-                                            toy_flag=True,
-                                            learning_rate=learning_rate,
-                                            toy_size=toy_size,
-                                            max_epochs=max_epochs,
-                                            early_stop_num=early_stop_num,
-                                            )
-            agent.learn(train_dataset,
-                        test_dataset,
-                        child_network_architecture=model,
-                        iterations=iterations)
-        elif auto_aug_learner == 'Genetic Learner':
-            pass
-
-        plt.figure()
-        plt.plot(q_values)
-
-
-        # if auto_aug_learner == 'UCB':
-        #     policies = ucb_learner.generate_policies(num_policies, num_sub_policies)
-        #     q_values, best_q_values = ucb_learner.run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet, ds_name)     
-        #     # plt.figure()
-        #     # plt.plot(q_values)
-        #     best_q_values = np.array(best_q_values)
-
-        # elif auto_aug_learner == 'Evolutionary Learner':
-        #     network = Evo.Learner(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1)
-        #     child_network = Evo.LeNet()
-        #     learner = Evo.Evolutionary_learner(network=network, fun_num=num_funcs, p_bins=1, mag_bins=1, sub_num_pol=1, ds = ds, ds_name=ds_name, exclude_method=exclude_method, child_network=child_network)
-        #     learner.run_instance()
-        # elif auto_aug_learner == 'Random Searcher':
-        #     pass 
-        # elif auto_aug_learner == 'Genetic Learner':
-        #     pass
+        learner = wapp_util.parse_users_learner_spec(auto_aug_learner, 
+                                                    ds, 
+                                                    exclude_method, 
+                                                    num_funcs, 
+                                                    num_policies, 
+                                                    num_sub_policies, 
+                                                    toy_size, 
+                                                    IsLeNet, 
+                                                    batch_size, 
+                                                    early_stop_num, 
+                                                    iterations, 
+                                                    learning_rate, 
+                                                    max_epochs, 
+                                                    ds_name)
 
 
     current_app.config['AAL'] = auto_aug_learner
diff --git a/flask_mvp/auto_augmentation/training.py b/flask_mvp/auto_augmentation/training.py
index 5e695b58a2994efb1bdc89bb363b3eddf643d9dc..be5c7254ca211e14096f077509d56e0d41a8eceb 100644
--- a/flask_mvp/auto_augmentation/training.py
+++ b/flask_mvp/auto_augmentation/training.py
@@ -1,28 +1,11 @@
 from flask import Blueprint, request, render_template, flash, send_file, current_app
-import subprocess
 import os
-import zipfile
 
-import numpy as np
 import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.optim as optim
-import torch.utils.data as data_utils
-import torchvision
-import torchvision.datasets as datasets
-
-from matplotlib import pyplot as plt
-from numpy import save, load
-from tqdm import trange
 torch.manual_seed(0)
-# import agents and its functions
-
-import MetaAugment.autoaugment_learners as aal
-import MetaAugment.controller_networks as cont_n
-import MetaAugment.child_networks as cn
 
 
+import temp_util.wapp_util as wapp_util
 
 bp = Blueprint("training", __name__)
 
@@ -56,41 +39,22 @@ def response():
     max_epochs = current_app.config.get('ME')
 
 
-    if auto_aug_learner == 'UCB':
-        policies = aal.ucb_learner.generate_policies(num_policies, num_sub_policies)
-        q_values, best_q_values = aal.ucb_learner.run_UCB1(
-                                                policies, 
-                                                batch_size, 
-                                                learning_rate, 
-                                                ds, 
-                                                toy_size, 
-                                                max_epochs, 
-                                                early_stop_num, 
-                                                iterations, 
-                                                IsLeNet, 
-                                                ds_name
-                                                )     
-        best_q_values = np.array(best_q_values)
-
-    elif auto_aug_learner == 'Evolutionary Learner':
-        network = cont_n.evo_controller(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1)
-        child_network = cn.LeNet()
-        learner = aal.evo_learner(
-                                network=network, 
-                                fun_num=num_funcs, 
-                                p_bins=1, 
-                                mag_bins=1, 
-                                sub_num_pol=1, 
-                                ds = ds, 
-                                ds_name=ds_name, 
-                                exclude_method=exclude_method, 
-                                child_network=child_network
-                                )
-        learner.run_instance()
-    elif auto_aug_learner == 'Random Searcher':
-        pass 
-    elif auto_aug_learner == 'Genetic Learner':
-        pass
+    wapp_util.parse_users_learner_spec(
+            auto_aug_learner, 
+            ds, 
+            ds_name, 
+            exclude_method, 
+            num_funcs, 
+            num_policies, 
+            num_sub_policies, 
+            toy_size, 
+            IsLeNet, 
+            batch_size, 
+            early_stop_num, 
+            iterations, 
+            learning_rate, 
+            max_epochs
+            )
 
     return render_template("progress.html", auto_aug_learner=auto_aug_learner)
 
diff --git a/package.json b/package.json
index f98529296d0ba29029442b36dc7eecf618f84ec5..ada7219e60444673f911e7644b5eb6df70b1bb8e 100644
--- a/package.json
+++ b/package.json
@@ -13,6 +13,7 @@
     "@testing-library/react": "^13.1.1",
     "@testing-library/user-event": "^13.5.0",
     "axios": "^0.26.1",
+    "js-file-download": "^0.4.12",
     "react": "^18.0.0",
     "react-dom": "^18.0.0",
     "react-hook-form": "^7.30.0",
diff --git a/src/App.js b/src/App.js
index 669aafc632d912cbf870ba3aabd229cab13dcdaa..15ae090797af34edd330230308dae2b6cc7acee0 100644
--- a/src/App.js
+++ b/src/App.js
@@ -37,7 +37,8 @@ function App() {
         <Routes>
           <Route exact path="/" element={<Home/>}/>
           <Route exact path="/confirm" element={<Confirm/>}/>
-          {/* <Route exact path="/Progress" element={<Training/>}/> */}
+          <Route exact path="/progress" element={<Progress/>}/>
+          <Route exact path="/result" element={<Result/>}/>
         </Routes>
       </BrowserRouter>
     </div>
diff --git a/src/pages/Confirm.js b/src/pages/Confirm.js
index e293133a61438cde948b463ad8a4d3e581873b27..98ba0433a02e0df5fe16a5c8468525d811e16257 100644
--- a/src/pages/Confirm.js
+++ b/src/pages/Confirm.js
@@ -1,25 +1,27 @@
 import React, { useState, useEffect } from "react";
-import { Grid, List, ListItem, Avatar, ListItemAvatar, ListItemText, Card, CardContent, Typography, Button, TextField } from '@mui/material';
+import { Grid, ListItem, ListItemAvatar, ListItemText, Card, CardContent, Typography, Button } from '@mui/material';
 import CheckCircleOutlineRoundedIcon from '@mui/icons-material/CheckCircleOutlineRounded';
 import TuneRoundedIcon from '@mui/icons-material/TuneRounded';
+import {useNavigate, Route} from "react-router-dom";
 
 export default function Confirm() {
-    const [batchSize, setBatchSize] = useState(0)
-//     // const [myData, setMyData] = useState([{}])
+    const [myData, setMyData] = useState([])
+    const [dataset, setDataset] = useState()
+    const [network, setNetwork] = useState()
+
   useEffect(() => {
     const res = fetch('/confirm').then(
       response => response.json()
-      ).then(data => setBatchSize(data.batch_size));
-
-    console.log("batchsize", batchSize)
-    // setBatchSize(res.batch_size)
-
-    // .then(data => {console.log('training', data); 
-    //     })
+      ).then(data => {setMyData(data);
+        if (data.ds == 'Other'){setDataset(data.ds_name)} else {setDataset(data.ds)};
+        if (data.IsLeNet == 'Other'){setNetwork(data.network_name)} else {setNetwork(data.IsLeNet)};
+    });
   }, []);
 
-
-
+  let navigate = useNavigate();
+  const onSubmit = async () => {
+    navigate('/progress', {replace:true});
+  };
 
     return (
         <div className="App" style={{padding:"60px"}}>
@@ -39,7 +41,7 @@ export default function Confirm() {
                                 <ListItemAvatar>
                                     <TuneRoundedIcon color="primary" fontSize='large'/>
                                 </ListItemAvatar>
-                                <ListItemText primary="Batch size" secondary={batchSize} />
+                                <ListItemText primary="Batch size" secondary={myData.batch_size} />
                             </ListItem>
                         </Grid>
                         <Grid xs={12} sm={6} item > 
@@ -47,7 +49,7 @@ export default function Confirm() {
                                 <ListItemAvatar>
                                     <CheckCircleOutlineRoundedIcon color="primary" fontSize='large'/>
                                 </ListItemAvatar>
-                                <ListItemText primary="Dataset selection" secondary="[Dataset]" />
+                                <ListItemText primary="Dataset selection" secondary={dataset} />
                             </ListItem>
                         </Grid>
                         <Grid xs={12} sm={6} item> 
@@ -55,7 +57,7 @@ export default function Confirm() {
                                 <ListItemAvatar>
                                     <TuneRoundedIcon color="primary" fontSize='large'/>
                                 </ListItemAvatar>
-                                <ListItemText primary="Learning rate" secondary="[Learning rate]" />
+                                <ListItemText primary="Learning rate" secondary={myData.learning_rate} />
                             </ListItem>
                         </Grid>
                         <Grid xs={12} sm={6} item> 
@@ -63,7 +65,7 @@ export default function Confirm() {
                                 <ListItemAvatar>
                                     <CheckCircleOutlineRoundedIcon color="primary" fontSize='large'/>
                                 </ListItemAvatar>
-                                <ListItemText primary="Network selection" secondary="[Network selection]" />
+                                <ListItemText primary="Network selection" secondary={network} />
                             </ListItem>
                         </Grid>
                         <Grid xs={12} sm={6} item> 
@@ -71,7 +73,7 @@ export default function Confirm() {
                                 <ListItemAvatar>
                                     <TuneRoundedIcon color="primary" fontSize='large'/>
                                 </ListItemAvatar>
-                                <ListItemText primary="Dataset Proportion" secondary="[Dataset Proportion]" />
+                                <ListItemText primary="Dataset Proportion" secondary={myData.toy_size} />
                             </ListItem>
                         </Grid>
                         <Grid xs={12} sm={6} item> 
@@ -79,7 +81,7 @@ export default function Confirm() {
                                 <ListItemAvatar>
                                     <CheckCircleOutlineRoundedIcon color="primary" fontSize='large'/>
                                 </ListItemAvatar>
-                                <ListItemText primary="Auto-augment learner selection" secondary="[Auto-augment learner selection]" />
+                                <ListItemText primary="Auto-augment learner selection" secondary={myData.auto_aug_learner} />
                             </ListItem>
                         </Grid>
                         <Grid xs={12} sm={6} item> 
@@ -87,7 +89,7 @@ export default function Confirm() {
                                 <ListItemAvatar>
                                     <TuneRoundedIcon color="primary" fontSize='large'/>
                                 </ListItemAvatar>
-                                <ListItemText primary="Iterations" secondary="[Iterations]" />
+                                <ListItemText primary="Iterations" secondary={myData.iterations} />
                             </ListItem>
                         </Grid>
                         </Grid>
@@ -98,6 +100,7 @@ export default function Confirm() {
                             variant="contained"
                             color='success'
                             size='large'
+                            onClick={onSubmit}
                         >
                             Confirm
                         </Button>
diff --git a/src/pages/Home.js b/src/pages/Home.js
index a3057f585732b495bbd13972e39d25157eccbe67..d314782650c161458f1295804107cfd289b77fb8 100644
--- a/src/pages/Home.js
+++ b/src/pages/Home.js
@@ -6,16 +6,11 @@ import SendIcon from '@mui/icons-material/Send';
 import { CardActions, Collapse, IconButton } from "@mui/material";
 import ExpandMoreIcon from '@mui/icons-material/ExpandMore';
 import { styled } from '@mui/material/styles';
-
-// import {
-//     BrowserRouter as Router,
-//     Switch,
-//     Route,
-//     Redirect,
-//   } from "react-router-dom";
-// import Confirm from './pages/Confirm'
 import {useNavigate, Route} from "react-router-dom";
 
+
+
+
 const ExpandMore = styled((props) => {
     const { expand, ...other } = props;
     return <IconButton {...other} />;
@@ -47,8 +42,16 @@ export default function Home() {
 
         formData.append("ds_upload", data.ds_upload[0]);
         formData.append("network_upload", data.network_upload[0]);
-        formData.append("test", 'see');
+        formData.append("batch_size", data.batch_size)
+        formData.append("toy_size", data.toy_size)
+        formData.append("iterations", data.iterations)
+        formData.append("learning_rate", data.learning_rate)
+        formData.append("select_action", data.select_action)
+        formData.append("select_dataset", data.select_dataset)
+        formData.append("select_learner", data.select_learner)
+        formData.append("select_network", data.select_network)
 
+        console.log('>>> this is the user input in formData')
         for (var key of formData.entries()) {
             console.log(key[0] + ', ' + key[1])}
         
@@ -57,7 +60,6 @@ export default function Home() {
         method: 'POST',
         body: formData
         }).then((response) => response.json());
-        console.log('check if it is here')
         
         navigate('/confirm', {replace:true});
         // 
@@ -83,21 +85,6 @@ export default function Home() {
     // console.log('errors', errors); 
     // console.log('handleSubmit', handleSubmit)
 
-
-    // handling learner selection
-    const handleLearnerSelect = (value) => {
-        const isPresent = selectLearner.indexOf(value);
-        if (isPresent !== -1) {
-        const remaining = selectLearner.filter((item) => item !== value);
-        setSelectLearner(remaining);
-        } else {
-        setSelectLearner((prevItems) => [...prevItems, value]);
-        }
-    };
-
-    useEffect(() => {
-        setValue('select_learner', selectLearner); 
-      }, [selectLearner]);
     
     // handling action selection
     const handleActionSelect = (value) => {
@@ -250,28 +237,25 @@ export default function Home() {
                                 <FormLabel id="select_learner" align="left" variant="h6">
                                     Please select the auto-augment learners you'd like to use (multiple learners can be selected)
                                 </FormLabel>
-                                <div>
-                                    {['UCB learner', 'Evolutionary learner', 'Random Searcher', 'GRU Learner'].map((option) => {
-                                    return (
-                                        <FormControlLabel
-                                        control={
-                                            <Controller
-                                            name='select_learner'
-                                            render={({}) => {
-                                                return (
-                                                <Checkbox
-                                                    checked={selectLearner.includes(option)}
-                                                    onChange={() => handleLearnerSelect(option)}/> );
-                                            }}
-                                            control={control}
-                                            rules={{ required: true }}
-                                            />}
-                                        label={option}
-                                        key={option}
-                                        />
-                                    );
-                                    })}
-                                </div>
+                                <Controller 
+                                        name='select_learner'
+                                        control={control}
+                                        rules={{ required: true }}
+                                        render={({field: { onChange, value }}) => (
+                                    <RadioGroup
+                                        row
+                                        aria-labelledby="select_learner"
+                                        name="select_learner"
+                                        align="centre"
+                                        value={value ?? ""} 
+                                        onChange={onChange}
+                                        >
+                                        <FormControlLabel value="UCB learner" control={<Radio />} label="UCB learner" />
+                                        <FormControlLabel value="Evolutionary learner" control={<Radio />} label="Evolutionary learner" />
+                                        <FormControlLabel value="Random Searcher" control={<Radio />} label="Random Searcher" />
+                                        <FormControlLabel value="GRU Learner" control={<Radio />} label="GRU Learner" /> 
+                                    </RadioGroup> )}
+                                />
                                 {errors.select_learner && errors.select_learner.type === "required" && 
                                     <Alert severity="error">
                                         <AlertTitle>This field is required</AlertTitle>
@@ -314,16 +298,16 @@ export default function Home() {
                             </Typography>
                             <Grid container spacing={1} style={{maxWidth:800, padding:"10px 10px"}}>
                                 <Grid xs={12} sm={6} item>
-                                    <TextField type="number" {...register("batch_size", {valueAsNumber: true})} name="batch_size" placeholder="Batch Size" label="Batch Size" variant="outlined" fullWidth />
+                                    <TextField type="number" InputProps={{ inputProps: { min: 0} }} {...register("batch_size")} name="batch_size" placeholder="Batch Size" label="Batch Size" variant="outlined" fullWidth />
                                 </Grid>
                                 <Grid xs={12} sm={6} item>
-                                    <TextField type="number" {...register("learning_rate", {valueAsNumber: true})} name="learning_rate" placeholder="Learning Rate" label="Learning Rate" variant="outlined" fullWidth />
+                                    <TextField type="number" inputProps={{step: "0.000000001",min: 0}} {...register("learning_rate")} name="learning_rate" placeholder="Learning Rate" label="Learning Rate" variant="outlined" fullWidth />
                                 </Grid>
                                 <Grid xs={12} sm={6} item>
-                                    <TextField type="number" {...register("iterations", {valueAsNumber: true})} name="iterations" placeholder="Number of Iterations" label="Iterations" variant="outlined" fullWidth />
+                                    <TextField type="number" InputProps={{ inputProps: { min: 0} }} {...register("iterations")} name="iterations" placeholder="Number of Iterations" label="Iterations" variant="outlined" fullWidth />
                                 </Grid>
                                 <Grid xs={12} sm={6} item>
-                                    <TextField type="number" {...register("toy_size", {valueAsNumber: true})} name="toy_size" placeholder="Dataset Proportion" label="Dataset Proportion" variant="outlined" fullWidth />
+                                    <TextField type="number" inputProps={{step: "0.01", min: 0}} {...register("toy_size")} name="toy_size" placeholder="Dataset Proportion" label="Dataset Proportion" variant="outlined" fullWidth />
                                 </Grid>
                                 <FormLabel variant="h8" align='centre'>
                                     * Dataset Proportion defines the percentage of original dataset our auto-augment learner will use to find the 
diff --git a/src/pages/Progress.js b/src/pages/Progress.js
index ec771deea82520c31c6dc3df20cf2b36806e20ec..2dee4fb594b7771a0d3e468065f50f0005bd376c 100644
--- a/src/pages/Progress.js
+++ b/src/pages/Progress.js
@@ -2,8 +2,25 @@ import React, { useState, useEffect } from "react";
 import { Grid, LinearProgress, Card, CardContent, Typography, Button, TextField } from '@mui/material';
 import CheckCircleOutlineRoundedIcon from '@mui/icons-material/CheckCircleOutlineRounded';
 import TuneRoundedIcon from '@mui/icons-material/TuneRounded';
+import {useNavigate, Route} from "react-router-dom";
+
+
 
 export default function Training() {
+    let navigate = useNavigate();
+
+    const [status, setStatus] = useState("Training");
+    useEffect(() => {
+        const res = fetch('/training'
+        ).then(response => response.json()
+        ).then(data => {setStatus(data.status); console.log(data.status)});
+
+        
+        }, []);
+
+    const onSubmit = async () => {
+        navigate('/result', {replace:true});
+    }
 
     return (
         <div className="App" style={{padding:"60px"}}>
@@ -12,17 +29,38 @@ export default function Training() {
             </Typography>
             <Card style={{ maxWidth: 900, padding: "10px 5px", margin: "0 auto" }}>
                 <CardContent>
-                    <Grid style={{padding:"50px"}}>
-                    <Typography gutterBottom variant="subtitle1" align="center" >
-                        Our auto-augment agents are working hard to generate your data augmentation policy ...
+                    <Grid style={{padding:"30px"}}>
+                    <Typography gutterBottom variant="h6" align="center" >
+                        Our auto-augment learners are working hard to generate your data augmentation policy ...
                     </Typography>
-                    <Grid style={{padding:"60px"}}>
-                        <LinearProgress color="primary"/>
-                        <LinearProgress color="primary" />
-                        <LinearProgress color="primary" />
-                        <LinearProgress color="primary" />
                     </Grid>
+
+                    {status==="Training" &&
+                        <Grid style={{padding:"60px"}}>
+                            <LinearProgress color="primary"/>
+                            <LinearProgress color="primary" />
+                            <LinearProgress color="primary" />
+                            <LinearProgress color="primary" />
+                        </Grid>
+                    }
+
+                    <Grid style={{padding:"50px"}}>
+                    <Typography variant='h6'>
+                        Current status: {status}
+                    </Typography>
                     </Grid>
+                    
+                    {status==="Training is done!" &&
+                        <Button
+                                type="submit"
+                                variant="contained"
+                                color='primary'
+                                size='large'
+                                onClick={onSubmit}
+                            >
+                                Show Results
+                        </Button>
+                    }
                 </CardContent>
             </Card>
                 
diff --git a/src/pages/Result.js b/src/pages/Result.js
index e431e1762a4863c23750f783cdf4eaea0a273e1c..70dec30c7784651f839bdf6e1b56f04c72946f8c 100644
--- a/src/pages/Result.js
+++ b/src/pages/Result.js
@@ -1,9 +1,22 @@
 import React, { useState, useEffect } from "react";
 import { Grid, List, ListItem, Avatar, ListItemAvatar, ListItemText, Card, CardContent, Typography, Button, CardMedia } from '@mui/material';
 import output from './pytest.png'
+import {useNavigate, Route} from "react-router-dom";
+import axios from 'axios'
+import fileDownload from 'js-file-download'
 
 export default function Result() {
 
+    const handleClick = () => {
+        axios.get('/result', {
+            responseType: 'blob',
+          })
+        .then((res) => {
+          fileDownload(res.data, 'policy.txt');
+          console.log(res.data)
+        })
+      }
+    
     return (
         <div className="App" style={{padding:"60px"}}>
             <Typography gutterBottom variant="h3" align="center" >
@@ -33,6 +46,7 @@ export default function Result() {
                             variant="contained"
                             color='primary'
                             size='large'
+                            onClick={() => handleClick('https://avatars.githubusercontent.com/u/9919?s=280&v=4', 'sample')}
                         >
                             Download
                     </Button>
diff --git a/temp_util/parse_ds_cn_arch.py b/temp_util/parse_ds_cn_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..76d33c195ca53a6e05d9d4d144788cd2f691cfc9
--- /dev/null
+++ b/temp_util/parse_ds_cn_arch.py
@@ -0,0 +1,64 @@
+from ..MetaAugment.child_networks import *
+from ..MetaAugment.main import create_toy, train_child_network
+import torch
+import torchvision
+import torchvision.datasets as datasets
+import pickle
+
+def parse_ds_cn_arch(ds, ds_name, IsLeNet):
+    if ds == "MNIST":
+        train_dataset = datasets.MNIST(root='./datasets/mnist/train', 
+                        train=True, download=True, transform=None)
+        test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, 
+                        download=True, transform=torchvision.transforms.ToTensor())
+    elif ds == "KMNIST":
+        train_dataset = datasets.KMNIST(root='./datasets/kmnist/train', 
+                        train=True, download=True, transform=None)
+        test_dataset = datasets.KMNIST(root='./datasets/kmnist/test', train=False, 
+                        download=True, transform=torchvision.transforms.ToTensor())
+    elif ds == "FashionMNIST":
+        train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train', 
+                        train=True, download=True, transform=None)
+        test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test', train=False, 
+                        download=True, transform=torchvision.transforms.ToTensor())
+    elif ds == "CIFAR10":
+        train_dataset = datasets.CIFAR10(root='./datasets/cifar10/train', 
+                        train=True, download=True, transform=None)
+        test_dataset = datasets.CIFAR10(root='./datasets/cifar10/test', train=False, 
+                        download=True, transform=torchvision.transforms.ToTensor())
+    elif ds == "CIFAR100":
+        train_dataset = datasets.CIFAR100(root='./datasets/cifar100/train', 
+                        train=True, download=True, transform=None)
+        test_dataset = datasets.CIFAR100(root='./datasets/cifar100/test', train=False, 
+                        download=True, transform=torchvision.transforms.ToTensor())
+    elif ds == 'Other':
+        dataset = datasets.ImageFolder('./datasets/upload_dataset/'+ ds_name, transform=None)
+        len_train = int(0.8*len(dataset))
+        train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len_train, len(dataset)-len_train])
+
+        # check sizes of images
+    img_height = len(train_dataset[0][0][0])
+    img_width = len(train_dataset[0][0][0][0])
+    img_channels = len(train_dataset[0][0])
+
+
+        # check output labels
+    if ds == 'Other':
+        num_labels = len(dataset.class_to_idx)
+    elif ds == "CIFAR10" or ds == "CIFAR100":
+        num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1)
+    else:
+        num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1).item()
+
+
+        
+    if IsLeNet == "LeNet":
+        child_architecture = LeNet(img_height, img_width, num_labels, img_channels)
+    elif IsLeNet == "EasyNet":
+        child_architecture = EasyNet(img_height, img_width, num_labels, img_channels)
+    elif IsLeNet == 'SimpleNet':
+        child_architecture = SimpleNet(img_height, img_width, num_labels, img_channels)
+    else:
+        child_architecture = pickle.load(open(f'datasets/childnetwork', "rb"))
+
+    return train_dataset, test_dataset, child_architecture 
\ No newline at end of file
diff --git a/temp_util/wapp_util.py b/temp_util/wapp_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..cde572f7f1e0ba1590fb5685b7212f4d0b3b173a
--- /dev/null
+++ b/temp_util/wapp_util.py
@@ -0,0 +1,114 @@
+"""
+CONTAINS THE FUNTIONS THAT THE WEBAPP CAN USE TO INTERACT WITH
+THE LIBRARY
+"""
+
+import numpy as np
+import torch
+import torchvision
+import torchvision.datasets as datasets
+
+# # import agents and its functions
+import MetaAugment.autoaugment_learners as aal
+import MetaAugment.controller_networks as cont_n
+import MetaAugment.child_networks as cn
+from MetaAugment.main import create_toy
+
+import pickle
+from pprint import pprint
+from .parse_ds_cn_arch import parse_ds_cn_arch
+def parse_users_learner_spec(
+            # things we need to feed into string parser
+            ds, 
+            ds_name, 
+            IsLeNet, 
+            # aalearner type
+            auto_aug_learner, 
+            # search space settings
+            exclude_method, 
+            num_funcs, 
+            num_policies, 
+            num_sub_policies, 
+            # child network settings
+            toy_size, 
+            batch_size, 
+            early_stop_num, 
+            iterations, 
+            learning_rate, 
+            max_epochs
+            ):
+    train_dataset, test_dataset, child_archi = parse_ds_cn_arch(
+                                                    ds, 
+                                                    ds_name, 
+                                                    IsLeNet
+                                                    )
+    """
+    The website receives user inputs on what they want the aa_learner
+    to be. We take those hyperparameters and return an aa_learner
+
+    """
+    if auto_aug_learner == 'UCB':
+        learner = aal.ucb_learner(
+                        # parameters that define the search space
+                        sp_num=num_sub_policies,
+                        p_bins=11,
+                        m_bins=10,
+                        discrete_p_m=True,
+                        # hyperparameters for when training the child_network
+                        batch_size=batch_size,
+                        toy_size=toy_size,
+                        learning_rate=learning_rate,
+                        max_epochs=max_epochs,
+                        early_stop_num=early_stop_num,
+                        # ucb_learner specific hyperparameter
+                        num_policies=num_policies
+                        )
+        pprint(learner.policies)
+        
+        learner.learn(
+            train_dataset=train_dataset,
+            test_dataset=test_dataset,
+            child_network_architecture=child_archi,
+            iterations=5
+            )
+    elif auto_aug_learner == 'Evolutionary Learner':
+        network = cont_n.evo_controller(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1)
+        child_network = cn.LeNet()
+        learner = aal.evo_learner(
+                                network=network, 
+                                fun_num=num_funcs, 
+                                p_bins=1, 
+                                mag_bins=1, 
+                                sub_num_pol=1, 
+                                ds = ds, 
+                                ds_name=ds_name, 
+                                exclude_method=exclude_method, 
+                                child_network=child_network
+                                )
+        learner.run_instance()
+    elif auto_aug_learner == 'Random Searcher':
+        agent = aal.randomsearch_learner(
+                                        sp_num=num_sub_policies,
+                                        batch_size=batch_size,
+                                        learning_rate=learning_rate,
+                                        toy_size=toy_size,
+                                        max_epochs=max_epochs,
+                                        early_stop_num=early_stop_num,
+                                        )
+        agent.learn(train_dataset,
+                    test_dataset,
+                    child_network_architecture=child_archi,
+                    iterations=iterations)
+    elif auto_aug_learner == 'GRU Learner':
+        agent = aal.gru_learner(
+                                sp_num=num_sub_policies,
+                                batch_size=batch_size,
+                                learning_rate=learning_rate,
+                                toy_size=toy_size,
+                                max_epochs=max_epochs,
+                                early_stop_num=early_stop_num,
+                                )
+        agent.learn(train_dataset,
+                    test_dataset,
+                    child_network_architecture=child_archi,
+                    iterations=iterations)
\ No newline at end of file
diff --git a/test/MetaAugment/test_aa_learner.py b/test/MetaAugment/test_aa_learner.py
index 3e2808702a04746e625acd5b463cfe01f56687bd..b1524988939e9adac0b45285921b8d058f087887 100644
--- a/test/MetaAugment/test_aa_learner.py
+++ b/test/MetaAugment/test_aa_learner.py
@@ -25,13 +25,12 @@ def test_translate_operation_tensor():
 
         softmax = torch.nn.Softmax(dim=0)
 
-        fun_num = random.randint(1, 14)
+        fun_num=14
         p_bins = random.randint(2, 15)
         m_bins = random.randint(2, 15)
-
+        
         agent = aal.aa_learner(
                 sp_num=5,
-                fun_num=fun_num,
                 p_bins=p_bins,
                 m_bins=m_bins,
                 discrete_p_m=True
@@ -54,13 +53,12 @@ def test_translate_operation_tensor():
     for i in range(2000):
         
 
-        fun_num = random.randint(1, 14)
+        fun_num = 14
         p_bins = random.randint(1, 15)
         m_bins = random.randint(1, 15)
 
         agent = aal.aa_learner(
                 sp_num=5,
-                fun_num=fun_num,
                 p_bins=p_bins,
                 m_bins=m_bins,
                 discrete_p_m=False
@@ -81,11 +79,9 @@ def test_translate_operation_tensor():
 def test_test_autoaugment_policy():
     agent = aal.aa_learner(
                 sp_num=5,
-                fun_num=14,
                 p_bins=11,
                 m_bins=10,
                 discrete_p_m=True,
-                toy_flag=True,
                 toy_size=0.004,
                 max_epochs=20,
                 early_stop_num=10
diff --git a/test/MetaAugment/test_gru_learner.py b/test/MetaAugment/test_gru_learner.py
index 6ad8204f9b8473482f00d5c5d6a9d1e391cf9e0b..cd52b0e95f710c8cddf8a9afdbe67a86acb8fb07 100644
--- a/test/MetaAugment/test_gru_learner.py
+++ b/test/MetaAugment/test_gru_learner.py
@@ -14,13 +14,11 @@ def test_generate_new_policy():
     """
     for _ in range(40):
         sp_num = random.randint(1,20)
-        fun_num = random.randint(1, 14)
         p_bins = random.randint(2, 15)
         m_bins = random.randint(2, 15)
 
         agent = aal.gru_learner(
             sp_num=sp_num,
-            fun_num=fun_num,
             p_bins=p_bins,
             m_bins=m_bins,
             cont_mb_size=2
@@ -44,7 +42,6 @@ def test_learn():
 
     agent = aal.gru_learner(
                         sp_num=7,
-                        toy_flag=True,
                         toy_size=0.001,
                         batch_size=32,
                         learning_rate=0.05,
diff --git a/test/MetaAugment/test_randomsearch_learner.py b/test/MetaAugment/test_randomsearch_learner.py
index 5b67d98e1f2e40d56b3aac2445f041f1372bbe9f..61e9f9cd8c86be854dcd8d42cc9ceddde8ada3bc 100644
--- a/test/MetaAugment/test_randomsearch_learner.py
+++ b/test/MetaAugment/test_randomsearch_learner.py
@@ -16,13 +16,12 @@ def test_generate_new_policy():
     def my_test(discrete_p_m):
         for _ in range(40):
             sp_num = random.randint(1,20)
-            fun_num = random.randint(1, 14)
+            
             p_bins = random.randint(2, 15)
             m_bins = random.randint(2, 15)
 
             agent = aal.randomsearch_learner(
                 sp_num=sp_num,
-                fun_num=fun_num,
                 p_bins=p_bins,
                 m_bins=m_bins,
                 discrete_p_m=discrete_p_m
@@ -52,7 +51,6 @@ def test_learn():
 
     agent = aal.randomsearch_learner(
                         sp_num=7,
-                        toy_flag=True,
                         toy_size=0.001,
                         batch_size=32,
                         learning_rate=0.05,
diff --git a/test/MetaAugment/test_ucb_learner.py b/test/MetaAugment/test_ucb_learner.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f37f3e5506ef4a09dec2dce7ab44dc06bdd7f16
--- /dev/null
+++ b/test/MetaAugment/test_ucb_learner.py
@@ -0,0 +1,56 @@
+import MetaAugment.autoaugment_learners as aal
+import MetaAugment.child_networks as cn
+import torchvision
+import torchvision.datasets as datasets
+from pprint import pprint
+
+def test_ucb_learner():
+    child_network_architecture = cn.SimpleNet
+    train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train',
+                            train=True, download=True, transform=None)
+    test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test', 
+                            train=False, download=True,
+                            transform=torchvision.transforms.ToTensor())
+
+
+    learner = aal.ucb_learner(
+        # parameters that define the search space
+                sp_num=5,
+                p_bins=11,
+                m_bins=10,
+                discrete_p_m=True,
+                # hyperparameters for when training the child_network
+                batch_size=8,
+                toy_size=0.001,
+                learning_rate=1e-1,
+                max_epochs=float('inf'),
+                early_stop_num=30,
+                # ucb_learner specific hyperparameter
+                num_policies=3
+    )
+    pprint(learner.policies)
+    assert len(learner.policies)==len(learner.avg_accs), \
+                (len(learner.policies), (len(learner.avg_accs)))
+
+    # learn on the 3 policies we generated
+    learner.learn(
+        train_dataset=train_dataset,
+        test_dataset=test_dataset,
+        child_network_architecture=child_network_architecture,
+        iterations=5
+        )
+    
+    # let's say we want to explore more policies:
+    # we generate more new policies
+    learner.make_more_policies(n=4)
+
+    # and let's explore how good those are as well
+    learner.learn(
+        train_dataset=train_dataset,
+        test_dataset=test_dataset,
+        child_network_architecture=child_network_architecture,
+        iterations=7
+        )
+
+if __name__=="__main__":
+    test_ucb_learner()