diff --git a/MetaAugment/CP2_Max.py b/MetaAugment/CP2_Max.py
index 792e81e1f85932408755840fbcbc09612137d39e..a13fd207b58df0e70bdd5543a9c98bd55831a2b6 100644
--- a/MetaAugment/CP2_Max.py
+++ b/MetaAugment/CP2_Max.py
@@ -94,75 +94,56 @@ class Learner(nn.Module):
         return y
 
 
+# class LeNet(nn.Module):
+#     def __init__(self):
+#         super().__init__()
+#         self.conv1 = nn.Conv2d(1, 6, 5)
+#         self.relu1 = nn.ReLU()
+#         self.pool1 = nn.MaxPool2d(2)
+#         self.conv2 = nn.Conv2d(6, 16, 5)
+#         self.relu2 = nn.ReLU()
+#         self.pool2 = nn.MaxPool2d(2)
+#         self.fc1 = nn.Linear(256, 120)
+#         self.relu3 = nn.ReLU()
+#         self.fc2 = nn.Linear(120, 84)
+#         self.relu4 = nn.ReLU()
+#         self.fc3 = nn.Linear(84, 10)
+#         self.relu5 = nn.ReLU()
+
+#     def forward(self, x):
+#         y = self.conv1(x)
+#         y = self.relu1(y)
+#         y = self.pool1(y)
+#         y = self.conv2(y)
+#         y = self.relu2(y)
+#         y = self.pool2(y)
+#         y = y.view(y.shape[0], -1)
+#         y = self.fc1(y)
+#         y = self.relu3(y)
+#         y = self.fc2(y)
+#         y = self.relu4(y)
+#         y = self.fc3(y)
+#         return y
+
+
 class LeNet(nn.Module):
     def __init__(self):
         super().__init__()
-        self.conv1 = nn.Conv2d(1, 6, 5)
+        self.fc1 = nn.Linear(784, 2048)
         self.relu1 = nn.ReLU()
-        self.pool1 = nn.MaxPool2d(2)
-        self.conv2 = nn.Conv2d(6, 16, 5)
+        self.fc2 = nn.Linear(2048, 10)
         self.relu2 = nn.ReLU()
-        self.pool2 = nn.MaxPool2d(2)
-        self.fc1 = nn.Linear(256, 120)
-        self.relu3 = nn.ReLU()
-        self.fc2 = nn.Linear(120, 84)
-        self.relu4 = nn.ReLU()
-        self.fc3 = nn.Linear(84, 10)
-        self.relu5 = nn.ReLU()
 
     def forward(self, x):
-        y = self.conv1(x)
+        x = x.reshape((-1, 784))
+        y = self.fc1(x)
         y = self.relu1(y)
-        y = self.pool1(y)
-        y = self.conv2(y)
-        y = self.relu2(y)
-        y = self.pool2(y)
-        y = y.view(y.shape[0], -1)
-        y = self.fc1(y)
-        y = self.relu3(y)
         y = self.fc2(y)
-        y = self.relu4(y)
-        y = self.fc3(y)
+        y = self.relu2(y)
         return y
 
 
 
-# code from https://github.com/ChawDoe/LeNet5-MNIST-PyTorch/blob/master/train.py
-# def train_model(full_policy, child_network):
-#     """
-#     Takes in the specific transformation index and probability 
-#     """
-
-#     # transformation = generate_policy(5, ps, mags)
-
-#     train_transform = transforms.Compose([
-#                                             full_policy,
-#                                             transforms.ToTensor()
-#                                         ])
-
-#     batch_size = 32
-#     n_samples = 0.005
-
-#     train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, transform=train_transform)
-#     test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False, transform=torchvision.transforms.ToTensor())
-
-#     train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, 0.01)
-
-
-#     sgd = optim.SGD(child_network.parameters(), lr=1e-1)
-#     cost = nn.CrossEntropyLoss()
-#     epoch = 20
-
-
-#     best_acc = train_child_network(child_network, train_loader, test_loader,
-#                                      sgd, cost, max_epochs=100, print_every_epoch=False)
-
-#     return best_acc
-
-
-
-
-
 # ORGANISING DATA
 
 # transforms = ['RandomResizedCrop', 'RandomHorizontalFlip', 'RandomVerticalCrop', 'RandomRotation']
@@ -183,6 +164,7 @@ train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=600
 
 
 
+
 class Evolutionary_learner():
 
     def __init__(self, network, num_solutions = 10, num_generations = 5, num_parents_mating = 5, train_loader = None, child_network = None, p_bins = 11, mag_bins = 10, sub_num_pol=5, fun_num = 14, augmentation_space = None, train_dataset = None, test_dataset = None):
@@ -204,28 +186,23 @@ class Evolutionary_learner():
         assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!'
 
         self.set_up_instance()
-    
-
-    def generate_policy(self, sp_num, ps, mags):
-        """
-        
-        """
-        policies = []
-        for subpol in range(sp_num):
-            sub = []
-            for idx in range(2):
-                transformation = augmentation_space[(2*subpol) + idx]
-                p = ps[(2*subpol) + idx]
-                mag = mags[(2*subpol) + idx]
-                sub.append((transformation, p, mag))
-            policies.append(tuple(sub))
-        
-        return policies
 
 
     def get_full_policy(self, x):
         """
-        Generates the full policy (5 x 2 subpolicies)
+        Generates the full policy (self.num_sub_pol subpolicies). Network architecture requires
+        output size 5 * 2 * (self.fun_num + self.p_bins + self.mag_bins)
+
+        Parameters 
+        -----------
+        x -> PyTorch tensor
+            Input data for network 
+
+        Returns
+        ----------
+        full_policy -> [((String, float, float), (String, float, float)), ...)
+            Full policy consisting of tuples of subpolicies. Each subpolicy consisting of
+            two transformations, with a probability and magnitude float for each
         """
         section = self.auto_aug_agent.fun_num + self.auto_aug_agent.p_bins + self.auto_aug_agent.m_bins
         y = self.auto_aug_agent.forward(x)
@@ -244,11 +221,26 @@ class Evolutionary_learner():
             full_policy.append(tuple(int_pol))
 
         return full_policy
-# 
+
     
     def get_policy_cov(self, x, alpha = 0.5):
         """
-        Need p_bins = 1, num_sub_pol = 1, mag_bins = 1
+        Selects policy using population and covariance matrices. For this method 
+        we require p_bins = 1, num_sub_pol = 1, mag_bins = 1. 
+
+        Parameters
+        ------------
+        x -> PyTorch Tensor
+            Input data for the AutoAugment network 
+
+        alpha -> Float
+            Proportion for covariance and population matrices 
+
+        Returns
+        -----------
+        Subpolicy -> [(String, float, float), (String, float, float)]
+            Subpolicy consisting of two tuples of policies, each with a string associated 
+            to a transformation, a float for a probability, and a float for a magnittude
         """
         section = self.auto_aug_agent.fun_num + self.auto_aug_agent.p_bins + self.auto_aug_agent.m_bins
 
@@ -284,7 +276,7 @@ class Evolutionary_learner():
             mag1 = None
         if not self.augmentation_space[max_idx[1]][1]:
             mag2 = None
-   
+    
         for idx in range(y.shape[0]):
             if (torch.argmax(y_1[idx]) == max_idx[0]) and (torch.argmax(y_2[idx]) == max_idx[1]):
                 prob1 += torch.sigmoid(y[idx, self.auto_aug_agent.fun_num]).item()
@@ -313,6 +305,23 @@ class Evolutionary_learner():
     def run_instance(self, return_weights = False):
         """
         Runs the GA instance and returns the model weights as a dictionary
+
+        Parameters
+        ------------
+        return_weights -> Bool
+            Determines if the weight of the GA network should be returned 
+        
+        Returns
+        ------------
+        If return_weights:
+            Network weights -> Dictionary
+        
+        Else:
+            Solution -> Best GA instance solution
+
+            Solution fitness -> Float
+
+            Solution_idx -> Int
         """
         self.ga_instance.run()
         solution, solution_fitness, solution_idx = self.ga_instance.best_solution()
@@ -331,12 +340,25 @@ class Evolutionary_learner():
 
 
     def set_up_instance(self):
+        """
+        Initialises GA instance, as well as fitness and on_generation functions
+        
+        """
 
         def fitness_func(solution, sol_idx):
             """
-            Defines fitness function (accuracy of the model)
+            Defines the fitness function for the parent selection
+
+            Parameters
+            --------------
+            solution -> GA solution instance (parsed automatically)
+
+            sol_idx -> GA solution index (parsed automatically)
+
+            Returns 
+            --------------
+            fit_val -> float            
             """
-            print("FITNESS HERE")
 
             model_weights_dict = torchga.model_weights_as_dict(model=self.auto_aug_agent,
                                                             weights_vector=solution)
@@ -345,18 +367,23 @@ class Evolutionary_learner():
 
             for idx, (test_x, label_x) in enumerate(train_loader):
                 full_policy = self.get_policy_cov(test_x)
-            print("FULL POLICY: ", full_policy)
-
 
-            fit_val = (test_autoaugment_policy(full_policy, self.train_dataset, self.test_dataset)[0]) #+ test_autoaugment_policy(full_policy, self.train_dataset, self.test_dataset)[0]) / 2
-
-            print("DONE FITNESS")
+            fit_val = ((test_autoaugment_policy(full_policy, self.train_dataset, self.test_dataset)[0])/
+                        + test_autoaugment_policy(full_policy, self.train_dataset, self.test_dataset)[0]) / 2
 
             return fit_val
 
         def on_generation(ga_instance):
             """
-            Just prints stuff while running
+            Prints information of generational fitness
+
+            Parameters 
+            -------------
+            ga_instance -> GA instance
+
+            Returns
+            -------------
+            None
             """
             print("Generation = {generation}".format(generation=ga_instance.generations_completed))
             print("Fitness    = {fitness}".format(fitness=ga_instance.best_solution()[1]))
@@ -373,13 +400,6 @@ class Evolutionary_learner():
 
 
 
-
-
-
-
-
-
-
 # HEREHEREHERE0
 
 def create_toy(train_dataset, test_dataset, batch_size, n_samples, seed=100):
@@ -407,8 +427,8 @@ def create_toy(train_dataset, test_dataset, batch_size, n_samples, seed=100):
 
 
 def train_child_network(child_network, train_loader, test_loader, sgd,
-                         cost, max_epochs=2000, early_stop_num = 5, logging=False,
-                         print_every_epoch=True):
+                            cost, max_epochs=2000, early_stop_num = 5, logging=False,
+                            print_every_epoch=True):
     if torch.cuda.is_available():
         device = torch.device('cuda')
     else:
@@ -451,12 +471,10 @@ def train_child_network(child_network, train_loader, test_loader, sgd,
 
                 predict_y = child_network(test_x.float()).detach()
                 predict_ys = torch.argmax(predict_y, axis=-1)
-
-                # label_np = test_label.numpy()
-
+    
                 _ = predict_ys == test_label
                 correct += torch.sum(_, axis=-1)
-                # correct += torch.sum(_.numpy(), axis=-1)
+
                 _sum += _.shape[0]
         
         # update best validation accuracy if it was higher, otherwise increase early stop count
@@ -511,19 +529,19 @@ __all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWid
 
 
 def _apply_op(img: Tensor, op_name: str, magnitude: float,
-              interpolation: InterpolationMode, fill: Optional[List[float]]):
+                interpolation: InterpolationMode, fill: Optional[List[float]]):
     if op_name == "ShearX":
         img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(magnitude), 0.0],
-                       interpolation=interpolation, fill=fill)
+                        interpolation=interpolation, fill=fill)
     elif op_name == "ShearY":
         img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)],
-                       interpolation=interpolation, fill=fill)
+                        interpolation=interpolation, fill=fill)
     elif op_name == "TranslateX":
         img = F.affine(img, angle=0.0, translate=[int(magnitude), 0], scale=1.0,
-                       interpolation=interpolation, shear=[0.0, 0.0], fill=fill)
+                        interpolation=interpolation, shear=[0.0, 0.0], fill=fill)
     elif op_name == "TranslateY":
         img = F.affine(img, angle=0.0, translate=[0, int(magnitude)], scale=1.0,
-                       interpolation=interpolation, shear=[0.0, 0.0], fill=fill)
+                        interpolation=interpolation, shear=[0.0, 0.0], fill=fill)
     elif op_name == "Rotate":
         img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill)
     elif op_name == "Brightness":
@@ -728,18 +746,6 @@ class AutoAugment(torch.nn.Module):
                 fill = [float(f) for f in fill]
 
         transform_id, probs, signs = self.get_params(len(self.subpolicies))
-        # print("transform_id, probs, signs : ", transform_id, probs, signs )
-
-        # for i, (op_name, p, magnitude_id) in enumerate(self.subpolicies[transform_id]):
-        # for i, (op_name, p, magnitude_id) in enumerate(self.subpolicies):
-        #     print("op_name, p, magnitude_id: ", op_name, p, magnitude_id)
-        #     if probs[i] <= p:
-        #         op_meta = self._augmentation_space(10, F.get_image_size(img))
-        #         magnitudes, signed = op_meta[op_name]
-        #         magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0
-        #         if signed and signs[i] == 0:
-        #             magnitude *= -1.0
-        #         img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
 
         for i, (op_name, p, magnitude) in enumerate(self.subpolicies):
             img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
@@ -771,8 +777,8 @@ class RandAugment(torch.nn.Module):
         """
 
     def __init__(self, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 31,
-                 interpolation: InterpolationMode = InterpolationMode.NEAREST,
-                 fill: Optional[List[float]] = None) -> None:
+                    interpolation: InterpolationMode = InterpolationMode.NEAREST,
+                    fill: Optional[List[float]] = None) -> None:
         super().__init__()
         self.num_ops = num_ops
         self.magnitude = magnitude
@@ -853,7 +859,7 @@ class TrivialAugmentWide(torch.nn.Module):
         """
 
     def __init__(self, num_magnitude_bins: int = 31, interpolation: InterpolationMode = InterpolationMode.NEAREST,
-                 fill: Optional[List[float]] = None) -> None:
+                    fill: Optional[List[float]] = None) -> None:
         super().__init__()
         self.num_magnitude_bins = num_magnitude_bins
         self.interpolation = interpolation
@@ -938,4 +944,4 @@ print(f"Fitness value of the best solution = {solution_fitness}")
 print(f"Index of the best solution : {solution_idx}")
 # Fetch the parameters of the best solution.
 best_solution_weights = torchga.model_weights_as_dict(model=ev_learner.auto_aug_agent,
-                                                      weights_vector=solution)
\ No newline at end of file
+                                                        weights_vector=solution)
diff --git a/MetaAugment/GA_results.png b/MetaAugment/GA_results.png
deleted file mode 100644
index 62449415b64500804927328ca677c4c023085436..0000000000000000000000000000000000000000
Binary files a/MetaAugment/GA_results.png and /dev/null differ
diff --git a/MetaAugment/METALEANER.py b/MetaAugment/METALEANER.py
deleted file mode 100644
index c94246d6898ccf2d316c1dae7644513bf113149e..0000000000000000000000000000000000000000
--- a/MetaAugment/METALEANER.py
+++ /dev/null
@@ -1,7 +0,0 @@
-
-
-
-# Neural network 
-# Input the dataset (same batch size, have to check if the input sizes are correc i.e. 28x28)
-# Output the hyperprameters --> weights of network, kernel size, number of layers, number of kernels
-# 
\ No newline at end of file
diff --git a/MetaAugment/UCB1_JC.ipynb b/MetaAugment/UCB1_JC.ipynb
index 6d872fd7b0ec12573b775a7dc16f7b0cb3b56254..aed9567093f27cbe3a5937da7bb7057b663cb605 100644
--- a/MetaAugment/UCB1_JC.ipynb
+++ b/MetaAugment/UCB1_JC.ipynb
@@ -1,24 +1,12 @@
 {
-  "nbformat": 4,
-  "nbformat_minor": 0,
-  "metadata": {
-    "colab": {
-      "name": "UCB1.ipynb",
-      "provenance": [],
-      "collapsed_sections": []
-    },
-    "kernelspec": {
-      "name": "python3",
-      "display_name": "Python 3"
-    },
-    "language_info": {
-      "name": "python"
-    },
-    "accelerator": "GPU"
-  },
   "cells": [
     {
       "cell_type": "code",
+      "execution_count": 1,
+      "metadata": {
+        "id": "U_ZJ2LqDiu_v"
+      },
+      "outputs": [],
       "source": [
         "import numpy as np\n",
         "import torch\n",
@@ -33,15 +21,15 @@
         "from matplotlib import pyplot as plt\n",
         "from numpy import save, load\n",
         "from tqdm import trange"
-      ],
-      "metadata": {
-        "id": "U_ZJ2LqDiu_v"
-      },
-      "execution_count": 1,
-      "outputs": []
+      ]
     },
     {
       "cell_type": "code",
+      "execution_count": 2,
+      "metadata": {
+        "id": "4ksS_duLFADW"
+      },
+      "outputs": [],
       "source": [
         "\"\"\"Define internal NN module that trains on the dataset\"\"\"\n",
         "class LeNet(nn.Module):\n",
@@ -75,15 +63,15 @@
         "        y = self.fc3(y)\n",
         "        y = self.relu5(y)\n",
         "        return y"
-      ],
-      "metadata": {
-        "id": "4ksS_duLFADW"
-      },
-      "execution_count": 2,
-      "outputs": []
+      ]
     },
     {
       "cell_type": "code",
+      "execution_count": 3,
+      "metadata": {
+        "id": "LckxnUXGfxjW"
+      },
+      "outputs": [],
       "source": [
         "\"\"\"Define internal NN module that trains on the dataset\"\"\"\n",
         "class EasyNet(nn.Module):\n",
@@ -101,15 +89,15 @@
         "        y = self.fc2(y)\n",
         "        y = self.relu2(y)\n",
         "        return y"
-      ],
-      "metadata": {
-        "id": "LckxnUXGfxjW"
-      },
-      "execution_count": 3,
-      "outputs": []
+      ]
     },
     {
       "cell_type": "code",
+      "execution_count": 4,
+      "metadata": {
+        "id": "enaD2xbw5hew"
+      },
+      "outputs": [],
       "source": [
         "\"\"\"Define internal NN module that trains on the dataset\"\"\"\n",
         "class SimpleNet(nn.Module):\n",
@@ -123,15 +111,15 @@
         "        y = self.fc1(y)\n",
         "        y = self.relu1(y)\n",
         "        return y"
-      ],
-      "metadata": {
-        "id": "enaD2xbw5hew"
-      },
-      "execution_count": 4,
-      "outputs": []
+      ]
     },
     {
       "cell_type": "code",
+      "execution_count": 5,
+      "metadata": {
+        "id": "xujQtvVWBgMH"
+      },
+      "outputs": [],
       "source": [
         "\"\"\"Make toy dataset\"\"\"\n",
         "\n",
@@ -154,15 +142,15 @@
         "    test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size)\n",
         "\n",
         "    return train_loader, test_loader"
-      ],
-      "metadata": {
-        "id": "xujQtvVWBgMH"
-      },
-      "execution_count": 5,
-      "outputs": []
+      ]
     },
     {
       "cell_type": "code",
+      "execution_count": 6,
+      "metadata": {
+        "id": "Iql-c88jGGWy"
+      },
+      "outputs": [],
       "source": [
         "\"\"\"Randomly generate 10 policies\"\"\"\n",
         "\"\"\"Each policy has 5 sub-policies\"\"\"\n",
@@ -193,15 +181,15 @@
         "                    policies[policy, sub_policy, transformation + 4] = np.random.randint(5,15)/10\n",
         "\n",
         "    return policies"
-      ],
-      "metadata": {
-        "id": "Iql-c88jGGWy"
-      },
-      "execution_count": 6,
-      "outputs": []
+      ]
     },
     {
       "cell_type": "code",
+      "execution_count": 7,
+      "metadata": {
+        "id": "QE2VWI8o731X"
+      },
+      "outputs": [],
       "source": [
         "\"\"\"Pick policy and sub-policy\"\"\"\n",
         "\"\"\"Each row of data should have a different sub-policy but for now, this will do\"\"\"\n",
@@ -238,12 +226,7 @@
         "            scale = policies[policy, sub_policy][5]\n",
         "\n",
         "    return degrees, shear, scale"
-      ],
-      "metadata": {
-        "id": "QE2VWI8o731X"
-      },
-      "execution_count": 7,
-      "outputs": []
+      ]
     },
     {
       "cell_type": "code",
@@ -392,29 +375,7 @@
     },
     {
       "cell_type": "code",
-      "source": [
-        "batch_size = 32       # size of batch the inner NN is trained with\n",
-        "learning_rate = 1e-1  # fix learning rate\n",
-        "ds = \"MNIST\"          # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)\n",
-        "toy_size = 0.02       # total propeortion of training and test set we use\n",
-        "max_epochs = 100      # max number of epochs that is run if early stopping is not hit\n",
-        "early_stop_num = 10   # max number of worse validation scores before early stopping is triggered\n",
-        "num_policies = 5      # fix number of policies\n",
-        "num_sub_policies = 5  # fix number of sub-policies in a policy\n",
-        "iterations = 100      # total iterations, should be more than the number of policies\n",
-        "IsLeNet = \"SimpleNet\" # using LeNet or EasyNet or SimpleNet\n",
-        "\n",
-        "# generate random policies at start\n",
-        "policies = generate_policies(num_policies, num_sub_policies)\n",
-        "\n",
-        "q_values, best_q_values = run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet)\n",
-        "\n",
-        "plt.plot(best_q_values)\n",
-        "\n",
-        "best_q_values = np.array(best_q_values)\n",
-        "save('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), best_q_values)\n",
-        "#best_q_values = load('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), allow_pickle=True)"
-      ],
+      "execution_count": 9,
       "metadata": {
         "colab": {
           "base_uri": "https://localhost:8080/",
@@ -423,168 +384,207 @@
         "id": "doHUtJ_tEiA6",
         "outputId": "3a7becf3-7b5d-4403-84d3-96e51bac8bf5"
       },
-      "execution_count": 9,
       "outputs": [
         {
-          "output_type": "stream",
           "name": "stderr",
+          "output_type": "stream",
           "text": [
             " 10%|â–ˆ         | 10/100 [01:09<09:26,  6.29s/it]"
           ]
         },
         {
-          "output_type": "stream",
           "name": "stdout",
+          "output_type": "stream",
           "text": [
             "Iteration: 10,\tQ-Values: [0.8, 0.71, 0.79, 0.86, 0.76], Best Policy: 0.86\n"
           ]
         },
         {
-          "output_type": "stream",
           "name": "stderr",
+          "output_type": "stream",
           "text": [
             " 20%|██        | 20/100 [02:18<09:03,  6.80s/it]"
           ]
         },
         {
-          "output_type": "stream",
           "name": "stdout",
+          "output_type": "stream",
           "text": [
             "Iteration: 20,\tQ-Values: [0.77, 0.75, 0.81, 0.86, 0.78], Best Policy: 0.86\n"
           ]
         },
         {
-          "output_type": "stream",
           "name": "stderr",
+          "output_type": "stream",
           "text": [
             " 30%|███       | 30/100 [03:24<06:50,  5.87s/it]"
           ]
         },
         {
-          "output_type": "stream",
           "name": "stdout",
+          "output_type": "stream",
           "text": [
             "Iteration: 30,\tQ-Values: [0.81, 0.71, 0.79, 0.8, 0.78], Best Policy: 0.81\n"
           ]
         },
         {
-          "output_type": "stream",
           "name": "stderr",
+          "output_type": "stream",
           "text": [
             " 40%|████      | 40/100 [04:34<06:14,  6.23s/it]"
           ]
         },
         {
-          "output_type": "stream",
           "name": "stdout",
+          "output_type": "stream",
           "text": [
             "Iteration: 40,\tQ-Values: [0.8, 0.7, 0.76, 0.8, 0.78], Best Policy: 0.8\n"
           ]
         },
         {
-          "output_type": "stream",
           "name": "stderr",
+          "output_type": "stream",
           "text": [
             " 50%|█████     | 50/100 [05:49<06:04,  7.28s/it]"
           ]
         },
         {
-          "output_type": "stream",
           "name": "stdout",
+          "output_type": "stream",
           "text": [
             "Iteration: 50,\tQ-Values: [0.79, 0.72, 0.76, 0.81, 0.74], Best Policy: 0.81\n"
           ]
         },
         {
-          "output_type": "stream",
           "name": "stderr",
+          "output_type": "stream",
           "text": [
             " 60%|██████    | 60/100 [06:55<04:32,  6.82s/it]"
           ]
         },
         {
-          "output_type": "stream",
           "name": "stdout",
+          "output_type": "stream",
           "text": [
             "Iteration: 60,\tQ-Values: [0.79, 0.72, 0.77, 0.81, 0.76], Best Policy: 0.81\n"
           ]
         },
         {
-          "output_type": "stream",
           "name": "stderr",
+          "output_type": "stream",
           "text": [
             " 70%|███████   | 70/100 [08:29<04:16,  8.53s/it]"
           ]
         },
         {
-          "output_type": "stream",
           "name": "stdout",
+          "output_type": "stream",
           "text": [
             "Iteration: 70,\tQ-Values: [0.78, 0.7, 0.78, 0.8, 0.76], Best Policy: 0.8\n"
           ]
         },
         {
-          "output_type": "stream",
           "name": "stderr",
+          "output_type": "stream",
           "text": [
             " 80%|████████  | 80/100 [09:38<02:05,  6.27s/it]"
           ]
         },
         {
-          "output_type": "stream",
           "name": "stdout",
+          "output_type": "stream",
           "text": [
             "Iteration: 80,\tQ-Values: [0.79, 0.72, 0.78, 0.79, 0.77], Best Policy: 0.79\n"
           ]
         },
         {
-          "output_type": "stream",
           "name": "stderr",
+          "output_type": "stream",
           "text": [
             " 90%|█████████ | 90/100 [10:41<01:04,  6.47s/it]"
           ]
         },
         {
-          "output_type": "stream",
           "name": "stdout",
+          "output_type": "stream",
           "text": [
             "Iteration: 90,\tQ-Values: [0.79, 0.71, 0.78, 0.79, 0.77], Best Policy: 0.79\n"
           ]
         },
         {
-          "output_type": "stream",
           "name": "stderr",
+          "output_type": "stream",
           "text": [
             "100%|██████████| 100/100 [11:51<00:00,  7.11s/it]"
           ]
         },
         {
-          "output_type": "stream",
           "name": "stdout",
+          "output_type": "stream",
           "text": [
             "Iteration: 100,\tQ-Values: [0.79, 0.72, 0.79, 0.79, 0.78], Best Policy: 0.79\n"
           ]
         },
         {
-          "output_type": "stream",
           "name": "stderr",
+          "output_type": "stream",
           "text": [
             "\n"
           ]
         },
         {
-          "output_type": "display_data",
           "data": {
+            "image/png": "",
             "text/plain": [
               "<Figure size 432x288 with 1 Axes>"
-            ],
-            "image/png": "\n"
+            ]
           },
           "metadata": {
             "needs_background": "light"
-          }
+          },
+          "output_type": "display_data"
         }
+      ],
+      "source": [
+        "batch_size = 32       # size of batch the inner NN is trained with\n",
+        "learning_rate = 1e-1  # fix learning rate\n",
+        "ds = \"MNIST\"          # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)\n",
+        "toy_size = 0.02       # total propeortion of training and test set we use\n",
+        "max_epochs = 100      # max number of epochs that is run if early stopping is not hit\n",
+        "early_stop_num = 10   # max number of worse validation scores before early stopping is triggered\n",
+        "num_policies = 5      # fix number of policies\n",
+        "num_sub_policies = 5  # fix number of sub-policies in a policy\n",
+        "iterations = 100      # total iterations, should be more than the number of policies\n",
+        "IsLeNet = \"SimpleNet\" # using LeNet or EasyNet or SimpleNet\n",
+        "\n",
+        "# generate random policies at start\n",
+        "policies = generate_policies(num_policies, num_sub_policies)\n",
+        "\n",
+        "q_values, best_q_values = run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet)\n",
+        "\n",
+        "plt.plot(best_q_values)\n",
+        "\n",
+        "best_q_values = np.array(best_q_values)\n",
+        "save('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), best_q_values)\n",
+        "#best_q_values = load('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), allow_pickle=True)"
       ]
     }
-  ]
-}
\ No newline at end of file
+  ],
+  "metadata": {
+    "accelerator": "GPU",
+    "colab": {
+      "collapsed_sections": [],
+      "name": "UCB1.ipynb",
+      "provenance": []
+    },
+    "kernelspec": {
+      "display_name": "Python 3",
+      "name": "python3"
+    },
+    "language_info": {
+      "name": "python"
+    }
+  },
+  "nbformat": 4,
+  "nbformat_minor": 0
+}
diff --git a/MetaAugment/genetic_learner_results.py b/MetaAugment/genetic_learner_results.py
deleted file mode 100644
index 35d9de8df2e17748b34e6879d4a3ae75dca9d9fb..0000000000000000000000000000000000000000
--- a/MetaAugment/genetic_learner_results.py
+++ /dev/null
@@ -1,109 +0,0 @@
-import matplotlib.pyplot as plt
-import numpy as np
-
-
-# Fixed seed (same as benchmark)
-
-# Looking at last generation can make out general trends of which transformations lead to the largest accuracies
-
-
-gen_1_acc = [0.1998, 0.1405, 0.1678, 0.9690, 0.9672, 0.9540, 0.9047, 0.9730, 0.2060, 0.9260, 0.8035, 0.9715, 0.9737, 0.14, 0.9645]
-
-gen_2_acc = [0.9218, 0.9753, 0.9758, 0.1088, 0.9710, 0.1655, 0.9735, 0.9655, 0.9740, 0.9377]
-
-gen_3_acc = [0.1445, 0.9740, 0.9643, 0.9750, 0.9492, 0.9693, 0.1262, 0.9660, 0.9760, 0.9697]
-
-gen_4_acc = [0.9697, 0.1238, 0.9613, 0.9737, 0.9603, 0.8620, 0.9712, 0.9617, 0.9737, 0.1855]
-
-gen_5_acc = [0.6445, 0.9705, 0.9668, 0.9765, 0.1142, 0.9780, 0.9700, 0.2120, 0.9555, 0.9732]
-
-gen_6_acc = [0.9710, 0.9665, 0.2077, 0.9535, 0.9765, 0.9712, 0.9697, 0.2145, 0.9523, 0.9718, 0.9718, 0.9718, 0.2180, 0.9622, 0.9785]
-
-gen_acc = [gen_1_acc, gen_2_acc, gen_3_acc, gen_4_acc, gen_5_acc, gen_6_acc]
-
-gen_acc_means = []
-gen_acc_stds = []
-
-for val in gen_acc:
-    gen_acc_means.append(np.mean(val))
-    gen_acc_stds.append(np.std(val))
-
-
-
-# Vary seed
-
-gen_1_vary = [0.1998, 0.9707, 0.9715, 0.9657, 0.8347, 0.9655, 0.1870, 0.0983, 0.3750, 0.9765, 0.9712, 0.9705, 0.9635, 0.9718, 0.1170]
-
-gen_2_vary = [0.9758, 0.9607, 0.9597, 0.9753, 0.1165, 0.1503, 0.9747, 0.1725, 0.9645, 0.2290]
-
-gen_3_vary = [0.1357, 0.9725, 0.1708, 0.9607, 0.2132, 0.9730, 0.9743, 0.9690, 0.0850, 0.9755]
-
-gen_4_vary = [0.9722, 0.9760, 0.9697, 0.1155, 0.9715, 0.9688, 0.1785, 0.9745, 0.2362, 0.9765]
-
-gen_5_vary = [0.9705, 0.2280, 0.9745, 0.1875, 0.9735, 0.9735, 0.9720, 0.9678, 0.9770, 0.1155]
-
-gen_6_vary = [0.9685, 0.9730, 0.9735, 0.9760, 0.1495, 0.9707, 0.9700, 0.9747, 0.9750, 0.1155, 0.9732, 0.9745, 0.9758, 0.9768, 0.1155]
-
-gen_vary = [gen_1_vary, gen_2_vary, gen_3_vary, gen_4_vary, gen_5_vary, gen_6_vary]
-
-gen_vary_means = []
-gen_vary_stds = []
-
-for val in gen_vary:
-    gen_vary_means.append(np.mean(val))
-    gen_vary_stds.append(np.std(val))
-
-
-
-
-
-# Multiple runs 
-
-gen_1_mult = [0.1762, 0.9575, 0.1200, 0.9660, 0.9650, 0.9570, 0.9745, 0.9700, 0.15, 0.23, 0.16, 0.186, 0.9640, 0.9650]
-
-gen_2_mult = [0.17, 0.1515, 0.1700, 0.9625, 0.9630, 0.9732, 0.9680, 0.9633, 0.9530, 0.9640]
-
-gen_3_mult = [0.9750, 0.9720, 0.9655, 0.9530, 0.9623, 0.9730, 0.9748, 0.9625, 0.9716, 0.9672]
-
-gen_4_mult = [0.9724, 0.9755, 0.9657, 0.9718, 0.9690, 0.9735, 0.9715, 0.9300, 0.9725, 0.9695]
-
-gen_5_mult = [0.9560, 0.9750, 0.8750, 0.9717, 0.9731, 0.9741, 0.9747, 0.9726, 0.9729, 0.9727]
-
-gen_6_mult = [0.9730, 0.9740, 0.9715, 0.9755, 0.9761, 0.9700, 0.9755, 0.9750, 0.9726, 0.9748, 0.9705, 0.9745, 0.9752, 0.9740, 0.9744]
-
-
-
-gen_mult = [gen_1_mult, gen_2_mult, gen_3_mult,  gen_4_mult, gen_5_mult, gen_6_mult]
-
-gen_mult_means = []
-gen_mult_stds = []
-
-for val in gen_mult:
-    gen_mult_means.append(np.mean(val))
-    gen_mult_stds.append(np.std(val))
-
-num_gen = [i for i in range(len(gen_mult))]
-
-
-# Baseline
-baseline = [0.7990 for i in range(len(gen_mult))]
-
-
-
-# plt.errorbar(num_gen, gen_acc_means, yerr=gen_acc_stds, linestyle = 'dotted', label = 'Fixed seed GA')
-# plt.errorbar(num_gen, gen_vary_means, linestyle = 'dotted', yerr=gen_vary_stds, label = 'Varying seed GA')
-# plt.errorbar(num_gen, gen_mult_means, linestyle = 'dotted', yerr=gen_mult_stds, label = 'Varying seed GA 2')
-
-plt.plot(num_gen, gen_acc_means, linestyle = 'dotted', label = 'Fixed seed GA')
-plt.plot(num_gen, gen_vary_means, linestyle = 'dotted',  label = 'Varying seed GA')
-plt.plot(num_gen, gen_mult_means, linestyle = 'dotted', label = 'Varying seed GA 2')
-
-plt.plot(num_gen, baseline, label = 'Fixed seed baseline')
-
-
-plt.xlabel('Generation', fontsize = 16)
-plt.ylabel('Validation Accuracy', fontsize = 16)
-
-plt.legend()
-
-plt.savefig('GA_results.png')
\ No newline at end of file
diff --git a/app.py b/app.py
index e0f2a3ca2891816f160d1d7bdab85e14bcbf659a..3de6b5b0c5285e24d774b12606894c2a4e6f0f88 100644
--- a/app.py
+++ b/app.py
@@ -1,9 +1,49 @@
-from flask import Flask
+# from flask import Flask
+# from auto_augmentation import create_app
+# import os
+
+# app = create_app()
+# port = int(os.environ.get("PORT", 5000))
+
+# if __name__ == '__main__':
+#     app.run(host='0.0.0.0',port=port)
+
+
+from flask import Flask, flash, request, redirect, url_for
+from werkzeug.utils import secure_filename
 from auto_augmentation import create_app
 import os
-
 app = create_app()
 port = int(os.environ.get("PORT", 5000))
 
+
+UPLOAD_FOLDER = '/datasets'
+
+app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
+ALLOWED_EXTENSIONS = {'pdf', 'py'}
+
+def allowed_file(filename):
+    return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
+
+@app.route('/user_input', methods = ['GET', 'POST'])
+def upload_file():
+    print("HELLoasdjsadojsadojsaodjsaoij")
+    if request.method == 'POST':
+        if 'file' not in request.files:
+            flash('No file part')
+            return redirect(request.url)
+        file = request.files['file']
+        if file.filename == '':
+            flash('No selected file')
+            return redirect(request.url)
+        if file and allowed_file(file.filename):
+            filename = secure_filename(file.filename)
+            file.save(os.path.join(app.config['UPLOAD_FOLDER'], filename))
+            return redirect(url_for('uploaded_file', filename=filename))
+    return '''
+    
+    '''
+
+
 if __name__ == '__main__':
-    app.run(host='0.0.0.0',port=port)
\ No newline at end of file
+    app.run(debug=True)
\ No newline at end of file
diff --git a/auto_augmentation/progress.py b/auto_augmentation/progress.py
index 03d33fadf5894115ddbf2c20ccc10f5f2ccc214f..77845a0260cf6c2494e8111f69fce2fdbf3124a8 100644
--- a/auto_augmentation/progress.py
+++ b/auto_augmentation/progress.py
@@ -1,5 +1,6 @@
 from flask import Blueprint, request, render_template, flash, send_file
 import subprocess
+import os
 
 import numpy as np
 import torch
@@ -15,7 +16,10 @@ from numpy import save, load
 from tqdm import trange
 torch.manual_seed(0)
 # import agents and its functions
-from MetaAugment import UCB1_JC  
+
+from MetaAugment import UCB1_JC_py as UCB1_JC
+
+
 
 bp = Blueprint("progress", __name__)
 
diff --git a/auto_augmentation/templates/home.html b/auto_augmentation/templates/home.html
index 7e4fc0670a7f7cddb9ed59c80586505a0160b70a..c40f23f825448e20ef0bf2c30dbb2377a264345c 100644
--- a/auto_augmentation/templates/home.html
+++ b/auto_augmentation/templates/home.html
@@ -59,77 +59,86 @@
   <h3>Advanced Search</h3>
   <!-- action(data augmentation) space -->
   Which data augmentation method you would like exclude? <br>
-  <input type="radio" id="ShearX"
+  <input type="checkbox" id="ShearX"
     name="action_space" value="ShearX">
   <label for="ShearX">ShearX</label><br>
 
-  <input type="radio" id="ShearY"
+  <input type="checkbox" id="ShearY"
     name="action_space" value="ShearY">
   <label for="ShearY">ShearY</label><br>
 
-  <input type="radio" id="TranslateX"
+  <input type="checkbox" id="TranslateX"
     name="action_space" value="TranslateX">
   <label for="TranslateX">TranslateX</label><br>
 
-  <input type="radio" id="TranslateY"
+  <input type="checkbox" id="TranslateY"
   name="action_space" value="TranslateY">
   <label for="TranslateY">TranslateY</label><br>
 
-  <input type="radio" id="Rotate"
+  <input type="checkbox" id="Rotate"
     name="Rotate" value="Rotate">
   <label for="Rotate">Rotate</label><br>
 
-  <input type="radio" id="Brightness"
+  <input type="checkbox" id="Brightness"
   name="action_space" value="Brightness">
   <label for="Brightness">Brightness</label><br>
 
-  <input type="radio" id="Color"
+  <input type="checkbox" id="Color"
   name="action_space" value="Color">
   <label for="Color">Color</label><br>
 
-  <input type="radio" id="Contrast"
+  <input type="checkbox" id="Contrast"
   name="action_space" value="Contrast">
   <label for="Contrast">Contrast</label><br>
 
-  <input type="radio" id="Sharpness"
+  <input type="checkbox" id="Sharpness"
   name="action_space" value="Sharpness">
   <label for="Sharpness">Sharpness</label><br>
 
-  <input type="radio" id="Posterize"
+  <input type="checkbox" id="Posterize"
   name="action_space" value="Posterize">
   <label for="Posterize">Posterize</label><br>
 
-  <input type="radio" id="Solarize"
+  <input type="checkbox" id="Solarize"
   name="action_space" value="Solarize">
   <label for="Solarize">Solarize</label><br>
 
-  <input type="radio" id="AutoContrast"
+  <input type="checkbox" id="AutoContrast"
   name="action_space" value="AutoContrast">
   <label for="AutoContrast">AutoContrast</label><br>
 
-  <input type="radio" id="Equalize"
+  <input type="checkbox" id="Equalize"
   name="action_space" value="Equalize">
   <label for="Equalize">Equalize</label><br>
 
-  <input type="radio" id="Invert"
+  <input type="checkbox" id="Invert"
   name="action_space" value="Invert">
   <label for="Invert">Invert</label><br><br><br>
 
 
+  <!-- <div id="exclude_augments" class="dropdown-check-list" tabindex="100">
+    <span class="anchor">Select data augmentation method(s) to exclude:</span>
+    <ul class="items">
+      <input type="checkbox" />Translate 
+      <input type="checkbox" />Rotate
+      <input type="checkbox" />AutoContrast 
+      <input type="checkbox" />Equalize 
+      <br>
+      <input type="checkbox" />Solarize
+      <input type="checkbox" />Posterize 
+      <input type="checkbox" />Contrast
+      <input type="checkbox" />Brightness
+    </ul>
+  </div> -->
+
+  <div id="exclude_augments">
+    <span class="anchor">Hyperparameter (Learning Rate):</span>
+    <ul class="items">
+      Automatic: <input type="checkbox" /> <div></div>
+      Manual: <input type="number" />
+    </ul>
+  </div>
 
-  <!-- action space -->
-  <!-- <label for="data_aug_method">Which data augmentation method you would like exclude?</label>
-    <select id="data_aug_method" name="data_aug_method">
-        <option value="Translate">Translate</option>
-        <option value="Rotate">Rotate</option>
-        <option value="AutoContrast">AutoContrast</option>
-        <option value="Equalize">Equalize</option>
-        <option value="Solarize">Solarize</option>
-        <option value="Posterize">Posterize</option>
-        <option value="Contrast">Contrast</option>
-        <option value="Brightness">Brightness</option>
-
-    </select><br><br> -->
   
 
   <input type="submit">