diff --git a/.DS_Store b/.DS_Store
index 87b56ad1c0caa0cd8b0aa4497cbd4d095b75bc27..720cf3ab50cbd4bb4f33acbbc3cb3516e7778732 100644
Binary files a/.DS_Store and b/.DS_Store differ
diff --git a/.flaskenv b/.flaskenv
deleted file mode 100644
index adfe7a48941f238a842cacea32c5d17cf4d9c60e..0000000000000000000000000000000000000000
--- a/.flaskenv
+++ /dev/null
@@ -1 +0,0 @@
-FLASK_APP=auto_augmentation
\ No newline at end of file
diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
deleted file mode 100644
index b125750ccc15f94f11db68afd39ebf0c9d898803..0000000000000000000000000000000000000000
--- a/.gitlab-ci.yml
+++ /dev/null
@@ -1,25 +0,0 @@
-build-job:
-  stage: build
-  script:
-    - echo "Hello, I'm Building"
-    - pip install pytest
-    - pip install flask
-    - pip install pandoc # for pdf making
-    - pip install weasyprint # for pdf making
-
-test-job:
-  stage: test
-  script:
-    - echo "Now I'm Testing!"
-    - python3 -m tests.test_query_processor
-
-deploy-job:
-  stage: deploy
-  script:
-    - echo "Now I'm Deploying to VM!"
-    - python3 -m venv venv
-    - . venv/bin/activate
-    - pip install -r requirements.txt
-    - flask run &
-    - echo "Now I'm Deploying to Heroku!"
-    - dpl --provider=heroku --app=metarl --api-key=5ccc3ae7-725e-4f9f-b441-0c9a28ebdc1b
diff --git a/0_01pkls/gru_learner.pkl b/0_01pkls/gru_learner.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..47c9510b1d190ea73626c71dd0ca280235e1cf91
Binary files /dev/null and b/0_01pkls/gru_learner.pkl differ
diff --git a/0_01pkls/gru_logs.pkl b/0_01pkls/gru_logs.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..ff7d5de2822d940b91c0572a96671cece7c9dae6
Binary files /dev/null and b/0_01pkls/gru_logs.pkl differ
diff --git a/0_01pkls/randomsearch_logs.pkl b/0_01pkls/randomsearch_logs.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..cf05e52402e9a36ddb91fbf4cd3b2d8c39aa27bd
Binary files /dev/null and b/0_01pkls/randomsearch_logs.pkl differ
diff --git a/Dockerfile b/Dockerfile
deleted file mode 100644
index 3ee025e601252084a1b6fcad8a6f72b1030a4d2a..0000000000000000000000000000000000000000
--- a/Dockerfile
+++ /dev/null
@@ -1,13 +0,0 @@
-FROM python:3
-
-RUN pip3 install virtualenv
-
-RUN python3 -m venv venvs
-
-COPY requirements.txt requirements.txt
-
-RUN pip3 install -r requirements.txt
-
-COPY . .
-
-CMD ["flask", "run"]
\ No newline at end of file
diff --git a/MetaAugment/CP2_Max.py b/MetaAugment/CP2_Max.py
index aacec6a801babddef3c86d37bac4d22a4c00f8e7..792e81e1f85932408755840fbcbc09612137d39e 100644
--- a/MetaAugment/CP2_Max.py
+++ b/MetaAugment/CP2_Max.py
@@ -1,3 +1,4 @@
+from cgi import test
 import numpy as np
 import torch
 torch.manual_seed(0)
@@ -13,19 +14,54 @@ import pygad
 import pygad.torchga as torchga
 import random
 import copy
+from torchvision.transforms import functional as F, InterpolationMode
+from typing import List, Tuple, Optional, Dict
+import heapq
+import math
 
-from MetaAugment.main import *
-
-# import MetaAugment.child_networks as child_networks
-# from MetaAugment.main import *
+import math
+import torch
 
+from enum import Enum
+from torch import Tensor
+from typing import List, Tuple, Optional, Dict
 
-np.random.seed(0)
-random.seed(0)
+from torchvision.transforms import functional as F, InterpolationMode
 
+# import MetaAugment.child_networks as child_networks
+# from main import *
+# from autoaugment_learners.autoaugment import *
+
+
+# np.random.seed(0)
+# random.seed(0)
+
+
+augmentation_space = [
+            # (function_name, do_we_need_to_specify_magnitude)
+            ("ShearX", True),
+            ("ShearY", True),
+            ("TranslateX", True),
+            ("TranslateY", True),
+            ("Rotate", True),
+            ("Brightness", True),
+            ("Color", True),
+            ("Contrast", True),
+            ("Sharpness", True),
+            ("Posterize", True),
+            ("Solarize", True),
+            ("AutoContrast", False),
+            ("Equalize", False),
+            ("Invert", False),
+        ]
 
 class Learner(nn.Module):
-    def __init__(self, num_transforms = 3):
+    def __init__(self, fun_num=14, p_bins=11, m_bins=10, sub_num_pol=5):
+        self.fun_num = fun_num
+        self.p_bins = p_bins 
+        self.m_bins = m_bins 
+        self.sub_num_pol = sub_num_pol
+
         super().__init__()
         self.conv1 = nn.Conv2d(1, 6, 5)
         self.relu1 = nn.ReLU()
@@ -37,10 +73,9 @@ class Learner(nn.Module):
         self.relu3 = nn.ReLU()
         self.fc2 = nn.Linear(120, 84)
         self.relu4 = nn.ReLU()
-        self.fc3 = nn.Linear(84, num_transforms + 21)
-        # self.sig = nn.Sigmoid()
-# Currently using discrete outputs for the probabilities 
+        self.fc3 = nn.Linear(84, self.sub_num_pol * 2 * (self.fun_num + self.p_bins + self.m_bins))
 
+# Currently using discrete outputs for the probabilities 
 
     def forward(self, x):
         y = self.conv1(x)
@@ -58,13 +93,6 @@ class Learner(nn.Module):
 
         return y
 
-    def get_idx(self, x):
-        y = self.forward(x)
-        idx_ret = torch.argmax(y[:, 0:3].mean(dim = 0))
-        p_ret = 0.1 * torch.argmax(y[:, 3:].mean(dim = 0))
-        return (idx_ret, p_ret)
-
-        # return (torch.argmax(y[0:3]), y[torch.argmax(y[3:])])
 
 class LeNet(nn.Module):
     def __init__(self):
@@ -100,53 +128,36 @@ class LeNet(nn.Module):
 
 
 # code from https://github.com/ChawDoe/LeNet5-MNIST-PyTorch/blob/master/train.py
-def train_model(transform_idx, p, child_network):
-    """
-    Takes in the specific transformation index and probability 
-    """
+# def train_model(full_policy, child_network):
+#     """
+#     Takes in the specific transformation index and probability 
+#     """
 
-    if transform_idx == 0:
-        transform_train = torchvision.transforms.Compose(
-           [
-            torchvision.transforms.RandomVerticalFlip(p),
-            torchvision.transforms.ToTensor(),
-            ]
-               )
-    elif transform_idx == 1:
-        transform_train = torchvision.transforms.Compose(
-           [
-            torchvision.transforms.RandomHorizontalFlip(p),
-            torchvision.transforms.ToTensor(),
-            ]
-               )
-    else:
-        transform_train = torchvision.transforms.Compose(
-           [
-            torchvision.transforms.RandomGrayscale(p),
-            torchvision.transforms.ToTensor(),
-            ]
-               )
+#     # transformation = generate_policy(5, ps, mags)
 
-    batch_size = 32
-    n_samples = 0.005
+#     train_transform = transforms.Compose([
+#                                             full_policy,
+#                                             transforms.ToTensor()
+#                                         ])
 
-    train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, transform=transform_train)
-    test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False, transform=torchvision.transforms.ToTensor())
+#     batch_size = 32
+#     n_samples = 0.005
 
-    train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, 0.01)
+#     train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, transform=train_transform)
+#     test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False, transform=torchvision.transforms.ToTensor())
 
+#     train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, 0.01)
 
-    # child_network = child_networks.lenet()
 
-    sgd = optim.SGD(child_network.parameters(), lr=1e-1)
-    cost = nn.CrossEntropyLoss()
-    epoch = 20
+#     sgd = optim.SGD(child_network.parameters(), lr=1e-1)
+#     cost = nn.CrossEntropyLoss()
+#     epoch = 20
 
 
-    best_acc = train_child_network(child_network, train_loader, test_loader,
-                                     sgd, cost, max_epochs=100, print_every_epoch=False)
+#     best_acc = train_child_network(child_network, train_loader, test_loader,
+#                                      sgd, cost, max_epochs=100, print_every_epoch=False)
 
-    return best_acc
+#     return best_acc
 
 
 
@@ -171,74 +182,760 @@ train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=600
 
 
 
+
 class Evolutionary_learner():
 
-    def __init__(self, network, num_solutions = 30, num_generations = 10, num_parents_mating = 15, train_loader = None, sec_model = None):
-        self.meta_rl_agent = network
+    def __init__(self, network, num_solutions = 10, num_generations = 5, num_parents_mating = 5, train_loader = None, child_network = None, p_bins = 11, mag_bins = 10, sub_num_pol=5, fun_num = 14, augmentation_space = None, train_dataset = None, test_dataset = None):
+        self.auto_aug_agent = Learner(fun_num=fun_num, p_bins=p_bins, m_bins=mag_bins, sub_num_pol=sub_num_pol)
         self.torch_ga = torchga.TorchGA(model=network, num_solutions=num_solutions)
         self.num_generations = num_generations
         self.num_parents_mating = num_parents_mating
         self.initial_population = self.torch_ga.population_weights
         self.train_loader = train_loader
-        self.backup_model = sec_model
+        self.child_network = child_network
+        self.p_bins = p_bins 
+        self.sub_num_pol = sub_num_pol
+        self.mag_bins = mag_bins
+        self.fun_num = fun_num
+        self.augmentation_space = augmentation_space
+        self.train_dataset = train_dataset
+        self.test_dataset = test_dataset
 
         assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!'
 
         self.set_up_instance()
     
 
+    def generate_policy(self, sp_num, ps, mags):
+        """
+        
+        """
+        policies = []
+        for subpol in range(sp_num):
+            sub = []
+            for idx in range(2):
+                transformation = augmentation_space[(2*subpol) + idx]
+                p = ps[(2*subpol) + idx]
+                mag = mags[(2*subpol) + idx]
+                sub.append((transformation, p, mag))
+            policies.append(tuple(sub))
+        
+        return policies
+
+
+    def get_full_policy(self, x):
+        """
+        Generates the full policy (5 x 2 subpolicies)
+        """
+        section = self.auto_aug_agent.fun_num + self.auto_aug_agent.p_bins + self.auto_aug_agent.m_bins
+        y = self.auto_aug_agent.forward(x)
+        full_policy = []
+        for pol in range(self.sub_num_pol):
+            int_pol = []
+            for _ in range(2):
+                idx_ret = torch.argmax(y[:, (pol * section):(pol*section) + self.fun_num].mean(dim = 0))
+
+                trans, need_mag = self.augmentation_space[idx_ret]
+
+                p_ret = (1/(self.p_bins-1)) * torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0))
+                mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0)) if need_mag else None
+                int_pol.append((trans, p_ret, mag))
+
+            full_policy.append(tuple(int_pol))
+
+        return full_policy
+# 
+    
+    def get_policy_cov(self, x, alpha = 0.5):
+        """
+        Need p_bins = 1, num_sub_pol = 1, mag_bins = 1
+        """
+        section = self.auto_aug_agent.fun_num + self.auto_aug_agent.p_bins + self.auto_aug_agent.m_bins
+
+        y = self.auto_aug_agent.forward(x) # 1000 x 32
+
+        y_1 = torch.softmax(y[:,:self.auto_aug_agent.fun_num], dim = 1) # 1000 x 14
+        y[:,:self.auto_aug_agent.fun_num] = y_1
+        y_2 = torch.softmax(y[:,section:section+self.auto_aug_agent.fun_num], dim = 1)
+        y[:,section:section+self.auto_aug_agent.fun_num] = y_2
+        concat = torch.cat((y_1, y_2), dim = 1)
+
+        cov_mat = torch.cov(concat.T)#[:self.auto_aug_agent.fun_num, self.auto_aug_agent.fun_num:]
+        cov_mat = cov_mat[:self.auto_aug_agent.fun_num, self.auto_aug_agent.fun_num:]
+        shape_store = cov_mat.shape
+
+        counter, prob1, prob2, mag1, mag2 = (0, 0, 0, 0, 0)
+
+
+        prob_mat = torch.zeros(shape_store)
+        for idx in range(y.shape[0]):
+            prob_mat[torch.argmax(y_1[idx])][torch.argmax(y_2[idx])] += 1
+        prob_mat = prob_mat / torch.sum(prob_mat)
+
+        cov_mat = (alpha * cov_mat) + ((1 - alpha)*prob_mat)
+
+        cov_mat = torch.reshape(cov_mat, (1, -1)).squeeze()
+        max_idx = torch.argmax(cov_mat)
+        val = (max_idx//shape_store[0])
+        max_idx = (val, max_idx - (val * shape_store[0]))
+
+
+        if not self.augmentation_space[max_idx[0]][1]:
+            mag1 = None
+        if not self.augmentation_space[max_idx[1]][1]:
+            mag2 = None
+   
+        for idx in range(y.shape[0]):
+            if (torch.argmax(y_1[idx]) == max_idx[0]) and (torch.argmax(y_2[idx]) == max_idx[1]):
+                prob1 += torch.sigmoid(y[idx, self.auto_aug_agent.fun_num]).item()
+                prob2 += torch.sigmoid(y[idx, section+self.auto_aug_agent.fun_num]).item()
+                if mag1 is not None:
+                    mag1 += min(max(0, (y[idx, self.auto_aug_agent.fun_num+1]).item()), 8)
+                if mag2 is not None:
+                    mag2 += min(max(0, y[idx, section+self.auto_aug_agent.fun_num+1].item()), 8)
+                counter += 1
+
+        prob1 = prob1/counter if counter != 0 else 0
+        prob2 = prob2/counter if counter != 0 else 0
+        if mag1 is not None:
+            mag1 = mag1/counter 
+        if mag2 is not None:
+            mag2 = mag2/counter    
+
+        
+        return [(self.augmentation_space[max_idx[0]][0], prob1, mag1), (self.augmentation_space[max_idx[1]][0], prob2, mag2)]
+
+
+        
+
+
+
     def run_instance(self, return_weights = False):
+        """
+        Runs the GA instance and returns the model weights as a dictionary
+        """
         self.ga_instance.run()
         solution, solution_fitness, solution_idx = self.ga_instance.best_solution()
         if return_weights:
-            return torchga.model_weights_as_dict(model=self.meta_rl_agent, weights_vector=solution)
+            return torchga.model_weights_as_dict(model=self.auto_aug_agent, weights_vector=solution)
         else:
             return solution, solution_fitness, solution_idx
 
+
     def new_model(self):
-        copy_model = copy.deepcopy(self.backup_model)
+        """
+        Simple function to create a copy of the secondary model (used for classification)
+        """
+        copy_model = copy.deepcopy(self.child_network)
         return copy_model
 
 
     def set_up_instance(self):
+
         def fitness_func(solution, sol_idx):
             """
             Defines fitness function (accuracy of the model)
             """
-            model_weights_dict = torchga.model_weights_as_dict(model=self.meta_rl_agent,
+            print("FITNESS HERE")
+
+            model_weights_dict = torchga.model_weights_as_dict(model=self.auto_aug_agent,
                                                             weights_vector=solution)
-            self.meta_rl_agent.load_state_dict(model_weights_dict)
+
+            self.auto_aug_agent.load_state_dict(model_weights_dict)
+
             for idx, (test_x, label_x) in enumerate(train_loader):
-                trans_idx, p = self.meta_rl_agent.get_idx(test_x)
-            cop_mod = self.new_model()
-            fit_val = train_model(trans_idx, p, cop_mod)
-            cop_mod = 0
+                full_policy = self.get_policy_cov(test_x)
+            print("FULL POLICY: ", full_policy)
+
+
+            fit_val = (test_autoaugment_policy(full_policy, self.train_dataset, self.test_dataset)[0]) #+ test_autoaugment_policy(full_policy, self.train_dataset, self.test_dataset)[0]) / 2
+
+            print("DONE FITNESS")
+
             return fit_val
 
         def on_generation(ga_instance):
             """
             Just prints stuff while running
             """
-            print("Generation = {generation}".format(generation=self.ga_instance.generations_completed))
-            print("Fitness    = {fitness}".format(fitness=self.ga_instance.best_solution()[1]))
+            print("Generation = {generation}".format(generation=ga_instance.generations_completed))
+            print("Fitness    = {fitness}".format(fitness=ga_instance.best_solution()[1]))
             return
 
 
         self.ga_instance = pygad.GA(num_generations=self.num_generations, 
             num_parents_mating=self.num_parents_mating, 
             initial_population=self.initial_population,
+            mutation_percent_genes = 0.1,
             fitness_func=fitness_func,
             on_generation = on_generation)
 
 
-meta_rl_agent = Learner()
-ev_learner = Evolutionary_learner(meta_rl_agent, train_loader=train_loader, sec_model=LeNet())
+
+
+
+
+
+
+
+
+
+# HEREHEREHERE0
+
+def create_toy(train_dataset, test_dataset, batch_size, n_samples, seed=100):
+    # shuffle and take first n_samples %age of training dataset
+    shuffle_order_train = np.random.RandomState(seed=seed).permutation(len(train_dataset))
+    shuffled_train_dataset = torch.utils.data.Subset(train_dataset, shuffle_order_train)
+    
+    indices_train = torch.arange(int(n_samples*len(train_dataset)))
+    reduced_train_dataset = torch.utils.data.Subset(shuffled_train_dataset, indices_train)
+    
+    # shuffle and take first n_samples %age of test dataset
+    shuffle_order_test = np.random.RandomState(seed=seed).permutation(len(test_dataset))
+    shuffled_test_dataset = torch.utils.data.Subset(test_dataset, shuffle_order_test)
+
+    big = 4 # how much bigger is the test set
+
+    indices_test = torch.arange(int(n_samples*len(test_dataset)*big))
+    reduced_test_dataset = torch.utils.data.Subset(shuffled_test_dataset, indices_test)
+
+    # push into DataLoader
+    train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size)
+    test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size)
+
+    return train_loader, test_loader
+
+
+def train_child_network(child_network, train_loader, test_loader, sgd,
+                         cost, max_epochs=2000, early_stop_num = 5, logging=False,
+                         print_every_epoch=True):
+    if torch.cuda.is_available():
+        device = torch.device('cuda')
+    else:
+        device = torch.device('cpu')
+    child_network = child_network.to(device=device)
+    
+    best_acc=0
+    early_stop_cnt = 0
+    
+    # logging accuracy for plotting
+    acc_log = [] 
+
+    # train child_network and check validation accuracy each epoch
+    for _epoch in range(max_epochs):
+
+        # train child_network
+        child_network.train()
+        for idx, (train_x, train_label) in enumerate(train_loader):
+            # onto device
+            train_x = train_x.to(device=device, dtype=train_x.dtype)
+            train_label = train_label.to(device=device, dtype=train_label.dtype)
+
+            # label_np = np.zeros((train_label.shape[0], 10))
+
+            sgd.zero_grad()
+            predict_y = child_network(train_x.float())
+            loss = cost(predict_y, train_label.long())
+            loss.backward()
+            sgd.step()
+
+        # check validation accuracy on validation set
+        correct = 0
+        _sum = 0
+        child_network.eval()
+        with torch.no_grad():
+            for idx, (test_x, test_label) in enumerate(test_loader):
+                # onto device
+                test_x = test_x.to(device=device, dtype=test_x.dtype)
+                test_label = test_label.to(device=device, dtype=test_label.dtype)
+
+                predict_y = child_network(test_x.float()).detach()
+                predict_ys = torch.argmax(predict_y, axis=-1)
+
+                # label_np = test_label.numpy()
+
+                _ = predict_ys == test_label
+                correct += torch.sum(_, axis=-1)
+                # correct += torch.sum(_.numpy(), axis=-1)
+                _sum += _.shape[0]
+        
+        # update best validation accuracy if it was higher, otherwise increase early stop count
+        acc = correct / _sum
+
+        if acc > best_acc :
+            best_acc = acc
+            early_stop_cnt = 0
+        else:
+            early_stop_cnt += 1
+
+        # exit if validation gets worse over 10 runs
+        if early_stop_cnt >= early_stop_num:
+            print('main.train_child_network best accuracy: ', best_acc)
+            break
+        
+        # if print_every_epoch:
+            # print('main.train_child_network best accuracy: ', best_acc)
+        acc_log.append(acc)
+
+    if logging:
+        return best_acc.item(), acc_log
+    return best_acc.item()
+
+def test_autoaugment_policy(subpolicies, train_dataset, test_dataset):
+
+    aa_transform = AutoAugment()
+    aa_transform.subpolicies = subpolicies
+
+    train_transform = transforms.Compose([
+                                            aa_transform,
+                                            transforms.ToTensor()
+                                        ])
+
+    train_dataset.transform = train_transform
+
+    # create toy dataset from above uploaded data
+    train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size=32, n_samples=0.1)
+
+    child_network = LeNet()
+    sgd = optim.SGD(child_network.parameters(), lr=1e-1)
+    cost = nn.CrossEntropyLoss()
+
+    best_acc, acc_log = train_child_network(child_network, train_loader, test_loader,
+                                                sgd, cost, max_epochs=100, logging=True)
+
+    return best_acc, acc_log
+
+
+
+__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"]
+
+
+def _apply_op(img: Tensor, op_name: str, magnitude: float,
+              interpolation: InterpolationMode, fill: Optional[List[float]]):
+    if op_name == "ShearX":
+        img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(magnitude), 0.0],
+                       interpolation=interpolation, fill=fill)
+    elif op_name == "ShearY":
+        img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)],
+                       interpolation=interpolation, fill=fill)
+    elif op_name == "TranslateX":
+        img = F.affine(img, angle=0.0, translate=[int(magnitude), 0], scale=1.0,
+                       interpolation=interpolation, shear=[0.0, 0.0], fill=fill)
+    elif op_name == "TranslateY":
+        img = F.affine(img, angle=0.0, translate=[0, int(magnitude)], scale=1.0,
+                       interpolation=interpolation, shear=[0.0, 0.0], fill=fill)
+    elif op_name == "Rotate":
+        img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill)
+    elif op_name == "Brightness":
+        img = F.adjust_brightness(img, 1.0 + magnitude)
+    elif op_name == "Color":
+        img = F.adjust_saturation(img, 1.0 + magnitude)
+    elif op_name == "Contrast":
+        img = F.adjust_contrast(img, 1.0 + magnitude)
+    elif op_name == "Sharpness":
+        img = F.adjust_sharpness(img, 1.0 + magnitude)
+    elif op_name == "Posterize":
+        img = F.posterize(img, int(magnitude))
+    elif op_name == "Solarize":
+        img = F.solarize(img, magnitude)
+    elif op_name == "AutoContrast":
+        img = F.autocontrast(img)
+    elif op_name == "Equalize":
+        img = F.equalize(img)
+    elif op_name == "Invert":
+        img = F.invert(img)
+    elif op_name == "Identity":
+        pass
+    else:
+        raise ValueError("The provided operator {} is not recognized.".format(op_name))
+    return img
+
+
+class AutoAugmentPolicy(Enum):
+    """AutoAugment policies learned on different datasets.
+    Available policies are IMAGENET, CIFAR10 and SVHN.
+    """
+    IMAGENET = "imagenet"
+    CIFAR10 = "cifar10"
+    SVHN = "svhn"
+
+
+# FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class
+class AutoAugment(torch.nn.Module):
+    r"""AutoAugment data augmentation method based on
+    `"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
+    If the image is torch Tensor, it should be of type torch.uint8, and it is expected
+    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
+    If img is PIL Image, it is expected to be in mode "L" or "RGB".
+
+    Args:
+        policy (AutoAugmentPolicy): Desired policy enum defined by
+            :class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
+        interpolation (InterpolationMode): Desired interpolation enum defined by
+            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
+            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
+        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
+            image. If given a number, the value is used for all bands respectively.
+    """
+
+    def __init__(
+        self,
+        policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
+        interpolation: InterpolationMode = InterpolationMode.NEAREST,
+        fill: Optional[List[float]] = None
+    ) -> None:
+        super().__init__()
+        self.policy = policy
+        self.interpolation = interpolation
+        self.fill = fill
+        self.subpolicies = self._get_subpolicies(policy)
+
+    def _get_subpolicies(
+        self,
+        policy: AutoAugmentPolicy
+    ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
+        if policy == AutoAugmentPolicy.IMAGENET:
+            return [
+                (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
+                (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
+                (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
+                (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
+                (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
+                (("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
+                (("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
+                (("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
+                (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
+                (("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
+                (("Rotate", 0.8, 8), ("Color", 0.4, 0)),
+                (("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
+                (("Equalize", 0.0, None), ("Equalize", 0.8, None)),
+                (("Invert", 0.6, None), ("Equalize", 1.0, None)),
+                (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
+                (("Rotate", 0.8, 8), ("Color", 1.0, 2)),
+                (("Color", 0.8, 8), ("Solarize", 0.8, 7)),
+                (("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
+                (("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
+                (("Color", 0.4, 0), ("Equalize", 0.6, None)),
+                (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
+                (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
+                (("Invert", 0.6, None), ("Equalize", 1.0, None)),
+                (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
+                (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
+            ]
+        elif policy == AutoAugmentPolicy.CIFAR10:
+            return [
+                (("Invert", 0.1, None), ("Contrast", 0.2, 6)),
+                (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
+                (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
+                (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
+                (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
+                (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
+                (("Color", 0.4, 3), ("Brightness", 0.6, 7)),
+                (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
+                (("Equalize", 0.6, None), ("Equalize", 0.5, None)),
+                (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
+                (("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
+                (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
+                (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
+                (("Brightness", 0.9, 6), ("Color", 0.2, 8)),
+                (("Solarize", 0.5, 2), ("Invert", 0.0, None)),
+                (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
+                (("Equalize", 0.2, None), ("Equalize", 0.6, None)),
+                (("Color", 0.9, 9), ("Equalize", 0.6, None)),
+                (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
+                (("Brightness", 0.1, 3), ("Color", 0.7, 0)),
+                (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
+                (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
+                (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
+                (("Equalize", 0.8, None), ("Invert", 0.1, None)),
+                (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
+            ]
+        elif policy == AutoAugmentPolicy.SVHN:
+            return [
+                (("ShearX", 0.9, 4), ("Invert", 0.2, None)),
+                (("ShearY", 0.9, 8), ("Invert", 0.7, None)),
+                (("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
+                (("Invert", 0.9, None), ("Equalize", 0.6, None)),
+                (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
+                (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
+                (("ShearY", 0.9, 8), ("Invert", 0.4, None)),
+                (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
+                (("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
+                (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
+                (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
+                (("ShearY", 0.8, 8), ("Invert", 0.7, None)),
+                (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
+                (("Invert", 0.9, None), ("Equalize", 0.6, None)),
+                (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
+                (("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
+                (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
+                (("Invert", 0.6, None), ("Rotate", 0.8, 4)),
+                (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
+                (("ShearX", 0.1, 6), ("Invert", 0.6, None)),
+                (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
+                (("ShearY", 0.8, 4), ("Invert", 0.8, None)),
+                (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
+                (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
+                (("ShearX", 0.7, 2), ("Invert", 0.1, None)),
+            ]
+        else:
+            raise ValueError("The provided policy {} is not recognized.".format(policy))
+
+    def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
+        return {
+            # op_name: (magnitudes, signed)
+            "ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
+            "ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
+            "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
+            "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
+            "Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
+            "Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
+            "Color": (torch.linspace(0.0, 0.9, num_bins), True),
+            "Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
+            "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
+            "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
+            "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
+            "AutoContrast": (torch.tensor(0.0), False),
+            "Equalize": (torch.tensor(0.0), False),
+            "Invert": (torch.tensor(0.0), False),
+        }
+
+    @staticmethod
+    def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]:
+        """Get parameters for autoaugment transformation
+
+        Returns:
+            params required by the autoaugment transformation
+        """
+        policy_id = int(torch.randint(transform_num, (1,)).item())
+        probs = torch.rand((2,))
+        signs = torch.randint(2, (2,))
+
+        return policy_id, probs, signs
+
+    def forward(self, img: Tensor, dis_mag = True) -> Tensor:
+        """
+            img (PIL Image or Tensor): Image to be transformed.
+
+        Returns:
+            PIL Image or Tensor: AutoAugmented image.
+        """
+        fill = self.fill
+        if isinstance(img, Tensor):
+            if isinstance(fill, (int, float)):
+                fill = [float(fill)] * F.get_image_num_channels(img)
+            elif fill is not None:
+                fill = [float(f) for f in fill]
+
+        transform_id, probs, signs = self.get_params(len(self.subpolicies))
+        # print("transform_id, probs, signs : ", transform_id, probs, signs )
+
+        # for i, (op_name, p, magnitude_id) in enumerate(self.subpolicies[transform_id]):
+        # for i, (op_name, p, magnitude_id) in enumerate(self.subpolicies):
+        #     print("op_name, p, magnitude_id: ", op_name, p, magnitude_id)
+        #     if probs[i] <= p:
+        #         op_meta = self._augmentation_space(10, F.get_image_size(img))
+        #         magnitudes, signed = op_meta[op_name]
+        #         magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0
+        #         if signed and signs[i] == 0:
+        #             magnitude *= -1.0
+        #         img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
+
+        for i, (op_name, p, magnitude) in enumerate(self.subpolicies):
+            img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
+
+
+        return img
+
+    def __repr__(self) -> str:
+        return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill)
+
+
+class RandAugment(torch.nn.Module):
+    r"""RandAugment data augmentation method based on
+    `"RandAugment: Practical automated data augmentation with a reduced search space"
+    <https://arxiv.org/abs/1909.13719>`_.
+    If the image is torch Tensor, it should be of type torch.uint8, and it is expected
+    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
+    If img is PIL Image, it is expected to be in mode "L" or "RGB".
+
+    Args:
+        num_ops (int): Number of augmentation transformations to apply sequentially.
+        magnitude (int): Magnitude for all the transformations.
+        num_magnitude_bins (int): The number of different magnitude values.
+        interpolation (InterpolationMode): Desired interpolation enum defined by
+            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
+            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
+        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
+            image. If given a number, the value is used for all bands respectively.
+        """
+
+    def __init__(self, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 31,
+                 interpolation: InterpolationMode = InterpolationMode.NEAREST,
+                 fill: Optional[List[float]] = None) -> None:
+        super().__init__()
+        self.num_ops = num_ops
+        self.magnitude = magnitude
+        self.num_magnitude_bins = num_magnitude_bins
+        self.interpolation = interpolation
+        self.fill = fill
+
+    def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
+        return {
+            # op_name: (magnitudes, signed)
+            "Identity": (torch.tensor(0.0), False),
+            "ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
+            "ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
+            "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
+            "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
+            "Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
+            "Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
+            "Color": (torch.linspace(0.0, 0.9, num_bins), True),
+            "Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
+            "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
+            "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
+            "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
+            "AutoContrast": (torch.tensor(0.0), False),
+            "Equalize": (torch.tensor(0.0), False),
+        }
+
+    def forward(self, img: Tensor) -> Tensor:
+        """
+            img (PIL Image or Tensor): Image to be transformed.
+
+        Returns:
+            PIL Image or Tensor: Transformed image.
+        """
+        fill = self.fill
+        if isinstance(img, Tensor):
+            if isinstance(fill, (int, float)):
+                fill = [float(fill)] * F.get_image_num_channels(img)
+            elif fill is not None:
+                fill = [float(f) for f in fill]
+
+        for _ in range(self.num_ops):
+            op_meta = self._augmentation_space(self.num_magnitude_bins, F.get_image_size(img))
+            op_index = int(torch.randint(len(op_meta), (1,)).item())
+            op_name = list(op_meta.keys())[op_index]
+            magnitudes, signed = op_meta[op_name]
+            magnitude = float(magnitudes[self.magnitude].item()) if magnitudes.ndim > 0 else 0.0
+            if signed and torch.randint(2, (1,)):
+                magnitude *= -1.0
+            img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
+
+        return img
+
+    def __repr__(self) -> str:
+        s = self.__class__.__name__ + '('
+        s += 'num_ops={num_ops}'
+        s += ', magnitude={magnitude}'
+        s += ', num_magnitude_bins={num_magnitude_bins}'
+        s += ', interpolation={interpolation}'
+        s += ', fill={fill}'
+        s += ')'
+        return s.format(**self.__dict__)
+
+
+class TrivialAugmentWide(torch.nn.Module):
+    r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in
+    `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`.
+    If the image is torch Tensor, it should be of type torch.uint8, and it is expected
+    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
+    If img is PIL Image, it is expected to be in mode "L" or "RGB".
+
+    Args:
+        num_magnitude_bins (int): The number of different magnitude values.
+        interpolation (InterpolationMode): Desired interpolation enum defined by
+            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
+            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
+        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
+            image. If given a number, the value is used for all bands respectively.
+        """
+
+    def __init__(self, num_magnitude_bins: int = 31, interpolation: InterpolationMode = InterpolationMode.NEAREST,
+                 fill: Optional[List[float]] = None) -> None:
+        super().__init__()
+        self.num_magnitude_bins = num_magnitude_bins
+        self.interpolation = interpolation
+        self.fill = fill
+
+    def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:
+        return {
+            # op_name: (magnitudes, signed)
+            "Identity": (torch.tensor(0.0), False),
+            "ShearX": (torch.linspace(0.0, 0.99, num_bins), True),
+            "ShearY": (torch.linspace(0.0, 0.99, num_bins), True),
+            "TranslateX": (torch.linspace(0.0, 32.0, num_bins), True),
+            "TranslateY": (torch.linspace(0.0, 32.0, num_bins), True),
+            "Rotate": (torch.linspace(0.0, 135.0, num_bins), True),
+            "Brightness": (torch.linspace(0.0, 0.99, num_bins), True),
+            "Color": (torch.linspace(0.0, 0.99, num_bins), True),
+            "Contrast": (torch.linspace(0.0, 0.99, num_bins), True),
+            "Sharpness": (torch.linspace(0.0, 0.99, num_bins), True),
+            "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False),
+            "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
+            "AutoContrast": (torch.tensor(0.0), False),
+            "Equalize": (torch.tensor(0.0), False),
+        }
+
+    def forward(self, img: Tensor) -> Tensor:
+        """
+            img (PIL Image or Tensor): Image to be transformed.
+
+        Returns:
+            PIL Image or Tensor: Transformed image.
+        """
+        fill = self.fill
+        if isinstance(img, Tensor):
+            if isinstance(fill, (int, float)):
+                fill = [float(fill)] * F.get_image_num_channels(img)
+            elif fill is not None:
+                fill = [float(f) for f in fill]
+
+        op_meta = self._augmentation_space(self.num_magnitude_bins)
+        op_index = int(torch.randint(len(op_meta), (1,)).item())
+        op_name = list(op_meta.keys())[op_index]
+        magnitudes, signed = op_meta[op_name]
+        magnitude = float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item()) \
+            if magnitudes.ndim > 0 else 0.0
+        if signed and torch.randint(2, (1,)):
+            magnitude *= -1.0
+
+        return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
+
+    def __repr__(self) -> str:
+        s = self.__class__.__name__ + '('
+        s += 'num_magnitude_bins={num_magnitude_bins}'
+        s += ', interpolation={interpolation}'
+        s += ', fill={fill}'
+        s += ')'
+        return s.format(**self.__dict__)
+
+# HEREHEREHEREHERE1
+
+
+
+
+
+
+
+
+train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, 
+                            transform=None)
+test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False,
+                            transform=torchvision.transforms.ToTensor())
+
+
+auto_aug_agent = Learner()
+ev_learner = Evolutionary_learner(auto_aug_agent, train_loader=train_loader, child_network=LeNet(), augmentation_space=augmentation_space, p_bins=1, mag_bins=1, sub_num_pol=1, train_dataset=train_dataset, test_dataset=test_dataset)
 ev_learner.run_instance()
 
 
 solution, solution_fitness, solution_idx = ev_learner.ga_instance.best_solution()
-print("Fitness value of the best solution = {solution_fitness}".format(solution_fitness=solution_fitness))
-print("Index of the best solution : {solution_idx}".format(solution_idx=solution_idx))
+
+print(f"Best solution : {solution}")
+print(f"Fitness value of the best solution = {solution_fitness}")
+print(f"Index of the best solution : {solution_idx}")
 # Fetch the parameters of the best solution.
-best_solution_weights = torchga.model_weights_as_dict(model=ev_learner.meta_rl_agent,
+best_solution_weights = torchga.model_weights_as_dict(model=ev_learner.auto_aug_agent,
                                                       weights_vector=solution)
\ No newline at end of file
diff --git a/MetaAugment/GA_results.png b/MetaAugment/GA_results.png
new file mode 100644
index 0000000000000000000000000000000000000000..62449415b64500804927328ca677c4c023085436
Binary files /dev/null and b/MetaAugment/GA_results.png differ
diff --git a/MetaAugment/MetaAugment/test/MNIST/raw/t10k-labels-idx1-ubyte b/MetaAugment/MetaAugment/test/MNIST/raw/t10k-labels-idx1-ubyte
new file mode 100644
index 0000000000000000000000000000000000000000..d1c3a970612bbd2df47a3c0697f82bd394abc450
Binary files /dev/null and b/MetaAugment/MetaAugment/test/MNIST/raw/t10k-labels-idx1-ubyte differ
diff --git a/MetaAugment/MetaAugment/test/MNIST/raw/t10k-labels-idx1-ubyte.gz b/MetaAugment/MetaAugment/test/MNIST/raw/t10k-labels-idx1-ubyte.gz
new file mode 100644
index 0000000000000000000000000000000000000000..a7e141541c1d08d3f2ed01eae03e644f9e2fd0c5
Binary files /dev/null and b/MetaAugment/MetaAugment/test/MNIST/raw/t10k-labels-idx1-ubyte.gz differ
diff --git a/MetaAugment/MetaAugment/test/MNIST/raw/train-labels-idx1-ubyte b/MetaAugment/MetaAugment/test/MNIST/raw/train-labels-idx1-ubyte
new file mode 100644
index 0000000000000000000000000000000000000000..d6b4c5db3b52063d543fb397aede09aba0dc5234
Binary files /dev/null and b/MetaAugment/MetaAugment/test/MNIST/raw/train-labels-idx1-ubyte differ
diff --git a/MetaAugment/MetaAugment/test/MNIST/raw/train-labels-idx1-ubyte.gz b/MetaAugment/MetaAugment/test/MNIST/raw/train-labels-idx1-ubyte.gz
new file mode 100644
index 0000000000000000000000000000000000000000..707a576bb523304d5b674de436c0779d77b7d480
Binary files /dev/null and b/MetaAugment/MetaAugment/test/MNIST/raw/train-labels-idx1-ubyte.gz differ
diff --git a/MetaAugment/MetaAugment/train/MNIST/raw/t10k-labels-idx1-ubyte b/MetaAugment/MetaAugment/train/MNIST/raw/t10k-labels-idx1-ubyte
new file mode 100644
index 0000000000000000000000000000000000000000..d1c3a970612bbd2df47a3c0697f82bd394abc450
Binary files /dev/null and b/MetaAugment/MetaAugment/train/MNIST/raw/t10k-labels-idx1-ubyte differ
diff --git a/MetaAugment/MetaAugment/train/MNIST/raw/t10k-labels-idx1-ubyte.gz b/MetaAugment/MetaAugment/train/MNIST/raw/t10k-labels-idx1-ubyte.gz
new file mode 100644
index 0000000000000000000000000000000000000000..a7e141541c1d08d3f2ed01eae03e644f9e2fd0c5
Binary files /dev/null and b/MetaAugment/MetaAugment/train/MNIST/raw/t10k-labels-idx1-ubyte.gz differ
diff --git a/MetaAugment/MetaAugment/train/MNIST/raw/train-labels-idx1-ubyte b/MetaAugment/MetaAugment/train/MNIST/raw/train-labels-idx1-ubyte
new file mode 100644
index 0000000000000000000000000000000000000000..d6b4c5db3b52063d543fb397aede09aba0dc5234
Binary files /dev/null and b/MetaAugment/MetaAugment/train/MNIST/raw/train-labels-idx1-ubyte differ
diff --git a/MetaAugment/MetaAugment/train/MNIST/raw/train-labels-idx1-ubyte.gz b/MetaAugment/MetaAugment/train/MNIST/raw/train-labels-idx1-ubyte.gz
new file mode 100644
index 0000000000000000000000000000000000000000..707a576bb523304d5b674de436c0779d77b7d480
Binary files /dev/null and b/MetaAugment/MetaAugment/train/MNIST/raw/train-labels-idx1-ubyte.gz differ
diff --git a/MetaAugment/__pycache__/main.cpython-38.pyc b/MetaAugment/__pycache__/main.cpython-38.pyc
deleted file mode 100644
index 5dcce355ebb82a6ff165d8bf473700ee3f54eae9..0000000000000000000000000000000000000000
Binary files a/MetaAugment/__pycache__/main.cpython-38.pyc and /dev/null differ
diff --git a/MetaAugment/autoaugment_learners/__init__.py b/MetaAugment/autoaugment_learners/__init__.py
index 1c7b4de8c7c374fe4eee3502f7086dceb22a7d9b..149af5b8165316ae1e00e6fc1c476338324b73cc 100644
--- a/MetaAugment/autoaugment_learners/__init__.py
+++ b/MetaAugment/autoaugment_learners/__init__.py
@@ -1 +1,3 @@
-from .randomsearch_learner import *
\ No newline at end of file
+from .aa_learner import *
+from .randomsearch_learner import *
+from .gru_learner import *
diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py
index 0215eea6a7ebb714133e9d1888fd9045947aa3e1..1de37eef1fe0fddcff4bc7eb5d4b5d7d6eaa3f4b 100644
--- a/MetaAugment/autoaugment_learners/aa_learner.py
+++ b/MetaAugment/autoaugment_learners/aa_learner.py
@@ -1,20 +1,19 @@
-# The parent class for all other autoaugment learners``
+# The parent class for all other autoaugment learners
 
 import torch
-import numpy as np
-from MetaAugment.main import *
-import MetaAugment.child_networks as cn
-import torchvision.transforms as transforms
-from MetaAugment.autoaugment_learners.autoaugment import *
+import torch.nn as nn
+import torch.optim as optim
+from MetaAugment.main import train_child_network, create_toy
+from MetaAugment.autoaugment_learners.autoaugment import AutoAugment
 
-import torchvision.transforms.autoaugment as torchaa
-from torchvision.transforms import functional as F, InterpolationMode
+import torchvision.transforms as transforms
 
 from pprint import pprint
+import matplotlib.pyplot as plt
+
 
 # We will use this augmentation_space temporarily. Later on we will need to 
 # make sure we are able to add other image functions if the users want.
-num_bins = 10
 augmentation_space = [
             # (function_name, do_we_need_to_specify_magnitude)
             ("ShearX", True),
@@ -34,11 +33,9 @@ augmentation_space = [
         ]
 
 
-# TODO: Right now the aa_learner is identical to randomsearch_learner. Change
-# this so that it can act as a superclass to all other augment learners
 class aa_learner:
     def __init__(self, sp_num=5, fun_num=14, p_bins=11, m_bins=10, discrete_p_m=False):
-        '''
+        """
         Args:
             spdim (int): number of subpolicies per policy
             fun_num (int): number of image functions in our search space
@@ -49,12 +46,15 @@ class aa_learner:
                                     magnitude as discrete variables as the out put of the 
                                     controller (A controller can be a neural network, genetic
                                     algorithm, etc.)
-        '''
+
+        """
         self.sp_num = sp_num
         self.fun_num = fun_num
         self.p_bins = p_bins
         self.m_bins = m_bins
 
+        self.op_tensor_length = fun_num+p_bins+m_bins if discrete_p_m else fun_num+2
+
         # should we repre
         self.discrete_p_m = discrete_p_m
 
@@ -62,8 +62,8 @@ class aa_learner:
         self.history = []
 
 
-    def translate_operation_tensor(self, operation_tensor):
-        '''
+    def translate_operation_tensor(self, operation_tensor, return_log_prob=False, argmax=False):
+        """
         takes in a tensor representing an operation and returns an actual operation which
         is in the form of:
             ("Invert", 0.8, None)
@@ -72,69 +72,167 @@ class aa_learner:
 
         Args:
             operation_tensor (tensor): 
-                                - If discrete_p_m is True, we expect to take in a tensor with
+                                We expect this tensor to already have been softmaxed.
+                                Furthermore,
+                                - If self.discrete_p_m is True, we expect to take in a tensor with
                                 dimension (self.fun_num + self.p_bins + self.m_bins)
-                                - If discrete_p_m is False, we expect to take in a tensor with
+                                - If self.discrete_p_m is False, we expect to take in a tensor with
                                 dimension (self.fun_num + 1 + 1)
-            continuous_p_m (boolean): whether the operation_tensor has continuous representations
-                                    of probability and magnitude
-        '''
+
+            return_log_prob (boolesn): 
+                                When this is on, we return which indices (of fun, prob, mag) were
+                                chosen (either randomly or deterministically, depending on argmax).
+                                This is used, for example, in the gru_learner to calculate the
+                                probability of the actions were chosen, which is then logged, then
+                                differentiated.
+
+            argmax (boolean): 
+                            Whether we are taking the argmax of the softmaxed tensors. 
+                            If this is False, we treat the softmaxed outputs as multinomial pdf's.
+
+        Returns:
+            operation (list of tuples):
+                                An operation in the format that can be directly put into an
+                                AutoAugment object.
+            log_prob (float):
+                            Used in reinforcement learning updates, such as proximal policy update
+                            in the gru_learner.
+                            Can only be used when self.discrete_p_m.
+                            We add the logged values of the indices of the image_function,
+                            probability, and magnitude chosen.
+                            This corresponds to multiplying the non-logged values, then logging
+                            it.                  
+        """
+
+        if (not self.discrete_p_m) and return_log_prob:
+            raise ValueError("You are not supposed to use return_log_prob=True when the agent's \
+                            self.discrete_p_m is False!")
+
+        # make sure shape is correct
+        assert operation_tensor.shape==(self.op_tensor_length, ), operation_tensor.shape
+
         # if probability and magnitude are represented as discrete variables
         if self.discrete_p_m:
-            fun_t = operation_tensor[ : self.fun_num]
-            prob_t = operation_tensor[self.fun_num : self.fun_num+self.p_bins]
-            mag_t = operation_tensor[-self.m_bins : ]
+            fun_t, prob_t, mag_t = operation_tensor.split([self.fun_num, self.p_bins, self.m_bins])
+
+            # make sure they are of right size
+            assert fun_t.shape==(self.fun_num,), f'{fun_t.shape} != {self.fun_num}'
+            assert prob_t.shape==(self.p_bins,), f'{prob_t.shape} != {self.p_bins}'
+            assert mag_t.shape==(self.m_bins,), f'{mag_t.shape} != {self.m_bins}'
 
-            fun = torch.argmax(fun_t)
-            prob = torch.argmax(prob_t) # 0 <= p <= 10
-            mag = torch.argmax(mag_t) # 0 <= m <= 9
 
-            function = augmentation_space[fun][0]
-            prob = prob/10
+            if argmax==True:
+                fun_idx = torch.argmax(fun_t).item()
+                prob_idx = torch.argmax(prob_t).item() # 0 <= p <= 10
+                mag = torch.argmax(mag_t).item() # 0 <= m <= 9
+            elif argmax==False:
+                # we need these to add up to 1 to be valid pdf's of multinomials
+                assert torch.sum(fun_t).isclose(torch.ones(1)), torch.sum(fun_t)
+                assert torch.sum(prob_t).isclose(torch.ones(1)), torch.sum(prob_t)
+                assert torch.sum(mag_t).isclose(torch.ones(1)), torch.sum(mag_t)
+
+                fun_idx = torch.multinomial(fun_t, 1).item() # 0 <= fun <= self.fun_num-1
+                prob_idx = torch.multinomial(prob_t, 1).item() # 0 <= p <= 10
+                mag = torch.multinomial(mag_t, 1).item() # 0 <= m <= 9
+
+            function = augmentation_space[fun_idx][0]
+            prob = prob_idx/10
+
+            indices = (fun_idx, prob_idx, mag)
+
+            # log probability is the sum of the log of the softmax values of the indices 
+            # (of fun_t, prob_t, mag_t) that we have chosen
+            log_prob = torch.log(fun_t[fun_idx]) + torch.log(prob_t[prob_idx]) + torch.log(mag_t[mag])
 
 
         # if probability and magnitude are represented as continuous variables
         else:
-            fun_t = operation_tensor[:self.fun_num]
-            p = operation_tensor[-2].item() # 0 < p < 1
-            m = operation_tensor[-1].item() # 0 < m < 9
-
-            fun = torch.argmax(fun_t)
+            fun_t, prob, mag = operation_tensor.split([self.fun_num, 1, 1])
+            prob = prob.item()
+            # 0 =< prob =< 1
+            mag = mag.item()
+            # 0 =< mag =< 9
 
-            function = augmentation_space[fun][0]
-            prob = round(p, 1) # round to nearest first decimal digit
-            mag = round(m) # round to nearest integer
+            # make sure the shape is correct
+            assert fun_t.shape==(self.fun_num,), f'{fun_t.shape} != {self.fun_num}'
+            
+            if argmax==True:
+                fun_idx = torch.argmax(fun_t)
+            elif argmax==False:
+                assert torch.sum(fun_t).isclose(torch.ones(1))
+                fun_idx = torch.multinomial(fun_t, 1).item()
+            prob = round(prob, 1) # round to nearest first decimal digit
+            mag = round(mag) # round to nearest integer
+            
+        function = augmentation_space[fun_idx][0]
 
+        assert 0 <= prob <= 1
+        assert 0 <= mag <= self.m_bins-1
+        
         # if the image function does not require a magnitude, we set the magnitude to None
-        if augmentation_space[fun][0] == True: # if the image function has a magnitude
-            return (function, prob, mag)
+        if augmentation_space[fun_idx][1] == True: # if the image function has a magnitude
+            operation = (function, prob, mag)
         else:
-            return (function, prob, None)
-
+            operation =  (function, prob, None)
+        
+        if return_log_prob:
+            return operation, log_prob
+        else:
+            return operation
+        
 
     def generate_new_policy(self):
-        '''
-        Generate a new random policy in the form of
-            [
-            (("Invert", 0.8, None), ("Contrast", 0.2, 6)),
-            (("Rotate", 0.7, 2), ("Invert", 0.8, None)),
-            (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
-            (("ShearY", 0.5, 8), ("Invert", 0.7, None)),
-            ]
-        '''
+        """
+        Generate a new policy which can be fed into an AutoAugment object 
+        by calling:
+            AutoAugment.subpolicies = policy
+        
+        Args:
+            none
+        
+        Returns:
+            new_policy (list[tuple]):
+                        A new policy generated by the controller. It
+                        has the form of:
+                            [
+                            (("Invert", 0.8, None), ("Contrast", 0.2, 6)),
+                            (("Rotate", 0.7, 2), ("Invert", 0.8, None)),
+                            (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
+                            (("ShearY", 0.5, 8), ("Invert", 0.7, None)),
+                            ]
+                        This object can be fed into an AutoAUgment object
+                        by calling: AutoAugment.subpolicies = policy
+        """
+
         raise NotImplementedError('generate_new_policy not implemented in aa_learner')
 
 
     def learn(self, train_dataset, test_dataset, child_network_architecture, toy_flag):
-        '''
-        Does the loop which is seen in Figure 1 in the AutoAugment paper.
-        In other words, repeat:
+        """
+        Runs the main loop (of finding a good policy for the given child network,
+        training dataset, and test(validation) dataset)
+
+        Does the loop which is seen in Figure 1 in the AutoAugment paper
+        which is:
             1. <generate a random policy>
             2. <see how good that policy is>
             3. <save how good the policy is in a list/dictionary>
-        until a certain condition (either specified by the user or pre-specified) is met
-        '''
+        
+        Args:
+            train_dataset (torchvision.dataset.vision.VisionDataset)
+            test_dataset (torchvision.dataset.vision.VisionDataset)
+            child_network_architecture (type): NOTE THAT THIS VARIABLE IS NOT
+                                    A nn.module object. Therefore, this needs
+                                    to be, say, `models.LeNet` instead of 
+                                    `models.LeNet()`.
+            toy_flag (boolean): whether we want to obtain a toy version of 
+                            train_dataset and test_dataset and use those.
+
+        Returns:
+            none
+        """
 
+        # This is dummy code
         # test out 15 random policies
         for _ in range(15):
             policy = self.generate_new_policy()
@@ -147,12 +245,26 @@ class aa_learner:
             self.history.append((policy, reward))
     
 
-    def test_autoaugment_policy(self, policy, child_network, train_dataset, test_dataset, toy_flag):
-        '''
+    def test_autoaugment_policy(self, policy, child_network, train_dataset, test_dataset, 
+                                toy_flag, logging=False):
+        """
         Given a policy (using AutoAugment paper terminology), we train a child network
         using the policy and return the accuracy (how good the policy is for the dataset and 
         child network).
-        '''
+
+        Args: 
+            policy (list[tuple]): A list of tuples representing a policy.
+            child_network (nn.module)
+            train_dataset (torchvision.dataset.vision.VisionDataset)
+            test_dataset (torchvision.dataset.vision.VisionDataset)
+            toy_flag (boolean): Whether we want to obtain a toy version of 
+                            train_dataset and test_dataset and use those.
+            logging (boolean): Whether we want to save logs
+        
+        Returns:
+            accuracy (float): best accuracy reached in any
+        """
+
         # We need to define an object aa_transform which takes in the image and 
         # transforms it with the policy (specified in its .policies attribute)
         # in its forward pass
@@ -170,16 +282,55 @@ class aa_learner:
         train_loader, test_loader = create_toy(train_dataset,
                                                 test_dataset,
                                                 batch_size=32,
-                                                n_samples=0.01,
+                                                n_samples=0.5,
                                                 seed=100)
-
+        
         # train the child network with the dataloaders equipped with our specific policy
         accuracy = train_child_network(child_network, 
                                     train_loader, 
                                     test_loader, 
-                                    sgd = optim.SGD(child_network.parameters(), lr=1e-1),
+                                    sgd = optim.SGD(child_network.parameters(), lr=3e-1),
+                                    # sgd = optim.Adadelta(child_network.parameters(), lr=1e-2),
                                     cost = nn.CrossEntropyLoss(),
-                                    max_epochs = 100, 
+                                    max_epochs = 3000000, 
                                     early_stop_num = 15, 
-                                    logging = False)
-        return accuracy
\ No newline at end of file
+                                    logging = logging,
+                                    print_every_epoch=True)
+        
+        # if logging is true, 'accuracy' is actually a tuple: (accuracy, accuracy_log)
+        return accuracy
+    
+
+    def demo_plot(self, train_dataset, test_dataset, child_network_architecture, toy_flag, n=5):
+        """
+        I made this to plot a couple of accuracy graphs to help manually tune my gradient 
+        optimizer hyperparameters.
+
+        Saves a plot of `n` training accuracy graphs overlapped.
+        """
+        
+        acc_lists = []
+
+        # This is dummy code
+        # test out `n` random policies
+        for _ in range(n):
+            policy = self.generate_new_policy()
+
+            pprint(policy)
+            child_network = child_network_architecture()
+            reward, acc_list = self.test_autoaugment_policy(policy, child_network, train_dataset,
+                                                test_dataset, toy_flag, logging=True)
+
+            self.history.append((policy, reward))
+            acc_lists.append(acc_list)
+
+        for acc_list in acc_lists:
+            plt.plot(acc_list)
+        plt.title('I ran 5 random policies to see if there is any sign of \
+                    catastrophic failure during training. If there are \
+                    any lines which reach significantly lower (>10%) \
+                    accuracies, you might want to tune the hyperparameters')
+        plt.xlabel('epoch')
+        plt.ylabel('accuracy')
+        plt.show()
+        plt.savefig('training_graphs_without_policies')
\ No newline at end of file
diff --git a/MetaAugment/autoaugment_learners/ac_learner.py b/MetaAugment/autoaugment_learners/ac_learner.py
new file mode 100644
index 0000000000000000000000000000000000000000..30c0b6cd9e404f4699f409e5e9c07838ce6bfbb2
--- /dev/null
+++ b/MetaAugment/autoaugment_learners/ac_learner.py
@@ -0,0 +1,357 @@
+# %%
+import numpy as np
+import matplotlib.pyplot as plt 
+from itertools import count
+
+import torch
+import torch.optim as optim
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision
+from torch.distributions import Categorical
+from torch.utils.data import TensorDataset, DataLoader
+
+
+from collections import namedtuple, deque
+import math
+import random
+
+from MetaAugment.main import *
+
+
+batch_size = 128
+
+test_dataset = torchvision.datasets.MNIST(root='test_dataset/', train=False, download=True, transform=torchvision.transforms.ToTensor())
+train_dataset = torchvision.datasets.MNIST(root='test_dataset/', train=True, download=True, transform=torchvision.transforms.ToTensor())
+test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
+train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
+print('test_loader', len(test_loader))
+print('train_loader',len(train_loader))
+
+def create_toy(train_dataset, test_dataset, batch_size, n_samples):
+    # shuffle and take first n_samples %age of training dataset
+    shuffled_train_dataset = torch.utils.data.Subset(train_dataset, torch.randperm(len(train_dataset)).tolist())
+    indices_train = torch.arange(int(n_samples*len(train_dataset)))
+    reduced_train_dataset = torch.utils.data.Subset(shuffled_train_dataset, indices_train)
+    # shuffle and take first n_samples %age of test dataset
+    shuffled_test_dataset = torch.utils.data.Subset(test_dataset, torch.randperm(len(test_dataset)).tolist())
+    indices_test = torch.arange(int(n_samples*len(test_dataset)))
+    reduced_test_dataset = torch.utils.data.Subset(shuffled_test_dataset, indices_test)
+
+    # push into DataLoader
+    train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size)
+    test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size)
+
+    return train_loader, test_loader
+
+# train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, 10)
+
+
+class LeNet(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 6, 5)
+        self.relu1 = nn.ReLU()
+        self.pool1 = nn.MaxPool2d(2)
+        self.conv2 = nn.Conv2d(6, 16, 5)
+        self.relu2 = nn.ReLU()
+        self.pool2 = nn.MaxPool2d(2)
+        self.fc1 = nn.Linear(256, 120)
+        self.relu3 = nn.ReLU()
+        self.fc2 = nn.Linear(120, 84)
+        self.relu4 = nn.ReLU()
+        self.fc3 = nn.Linear(84, 10)
+        self.relu5 = nn.ReLU()
+
+    def forward(self, x):
+        y = self.conv1(x)
+        y = self.relu1(y)
+        y = self.pool1(y)
+        y = self.conv2(y)
+        y = self.relu2(y)
+        y = self.pool2(y)
+        y = y.view(y.shape[0], -1)
+        y = self.fc1(y)
+        y = self.relu3(y)
+        y = self.fc2(y)
+        y = self.relu4(y)
+        y = self.fc3(y)
+        y = self.relu5(y)
+        return y
+
+# %% [markdown]
+# ## collect reward
+
+# %%
+
+def collect_reward(train_loader, test_loader, max_epochs=100, early_stop_num=10):
+    child_network = LeNet() 
+    sgd = optim.SGD(child_network.parameters(), lr=1e-1)
+    cost = nn.CrossEntropyLoss()
+    best_acc=0
+    early_stop_cnt = 0
+    
+    # train child_network and check validation accuracy each epoch
+    print('max_epochs', max_epochs)
+    for _epoch in range(max_epochs):
+        print('_epoch', _epoch)
+        # train child_network
+        child_network.train()
+        for t, (train_x, train_label) in enumerate(train_loader):
+            label_np = np.zeros((train_label.shape[0], 10))
+            sgd.zero_grad()
+            predict_y = child_network(train_x.float())
+            loss = cost(predict_y, train_label.long())
+            loss.backward()
+            sgd.step()
+
+        # check validation accuracy on validation set
+        correct = 0
+        _sum = 0
+        child_network.eval()
+        for idx, (test_x, test_label) in enumerate(test_loader):
+            predict_y = child_network(test_x.float()).detach()
+            predict_ys = np.argmax(predict_y, axis=-1)
+            label_np = test_label.numpy()
+            _ = predict_ys == test_label
+            correct += np.sum(_.numpy(), axis=-1)
+            _sum += _.shape[0]
+        
+        # update best validation accuracy if it was higher, otherwise increase early stop count
+        acc = correct / _sum
+
+        if acc > best_acc :
+            best_acc = acc
+            early_stop_cnt = 0
+        else:
+            early_stop_cnt += 1
+
+        # exit if validation gets worse over 10 runs
+        if early_stop_cnt >= early_stop_num:
+            break
+
+        # if _epoch%30 == 0:
+        #     print('child_network accuracy: ', best_acc)
+        
+    return best_acc
+
+
+# %%
+for t, (train_x, train_label) in enumerate(test_loader):
+    print(train_x.shape)
+    print(train_label)
+    break
+len(test_loader)
+
+# %%
+collect_reward(train_loader, test_loader)
+
+
+# %% [markdown]
+# ## Policy network
+
+# %%
+class Policy(nn.Module):
+    """
+    implements both actor and critic in one model
+    """
+    def __init__(self):
+        super(Policy, self).__init__()
+        self.conv1 = nn.Conv2d(1, 6, 5 , stride=2)
+        self.conv2 = nn.Conv2d(6, 12, 5, stride=2)
+        self.maxpool = nn.MaxPool2d(4)
+
+        # actor's layer
+        self.action_head = nn.Linear(12, 2)
+
+        # critic's layer
+        self.value_head = nn.Linear(12, 1)
+
+        # action & reward buffer
+        self.saved_actions = []
+        self.rewards = []
+
+    def forward(self, x):
+        """
+        forward of both actor and critic
+        """
+        x = F.relu(self.conv1(x))
+        x = F.relu(self.conv2(x))
+        x = self.maxpool(x)
+        x = x.view(x.size(0), -1)
+        # print('x', x.shape)
+
+        # actor: choses action to take from state s_t 
+        # by returning probability of each action
+        # print('self.action_head(x)', self.action_head(x).shape)
+        action_prob = F.softmax(self.action_head(x), dim=-1)
+        # print('action_prob', action_prob.shape)
+
+        # critic: evaluates being in the state s_t
+        state_values = self.value_head(x)
+
+        # return values for both actor and critic as a tuple of 2 values:
+        # 1. a list with the probability of each action over the action space
+        # 2. the value from state s_t 
+        return action_prob, state_values
+
+
+# %%
+test_dataset = torchvision.datasets.MNIST(root='test_dataset/', train=False, download=True, transform=torchvision.transforms.ToTensor())
+train_dataset = torchvision.datasets.MNIST(root='test_dataset/', train=True, download=True, transform=torchvision.transforms.ToTensor())
+test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
+train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
+
+policy_model = Policy()
+# for t, (x, y) in enumerate(train_loader):
+#     # print(x.shape)
+#     policy_model(x)
+
+# %% [markdown]
+# ## select action
+
+# %%
+SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])
+
+def select_action(train_loader, policy_model):
+    probs_list = []
+    value_list = []
+    for t, (x, y) in enumerate(train_loader):
+        probs_i, state_value_i = policy_model(x)
+        probs_list += [probs_i]
+        value_list += [state_value_i]
+
+    probs = torch.mean(torch.cat(probs_list), dim=0)
+    state_value = torch.mean(torch.cat(value_list))
+    # print('probs_i', probs_i)
+    # print('probs', probs)
+    # create a categorical distribution over the list of probabilities of actions
+    m = Categorical(probs)
+    # print('m', m)
+    # and sample an action using the distribution
+    action = m.sample()
+    # print('action', action)
+
+    # save to action buffer
+    policy_model.saved_actions.append(SavedAction(m.log_prob(action), state_value))
+
+    # the action to take (left or right)
+    return action.item()
+
+
+# %%
+torch.tensor([1, 2, 3])
+
+# %% [markdown]
+# ## take action
+
+# %%
+def take_action(action_idx):
+    # Define actions (data augmentation policy) --- can be improved
+    action_list = [
+    torchvision.transforms.Compose([torchvision.transforms.RandomVerticalFlip(),
+        torchvision.transforms.ToTensor()]),
+    torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(),
+        torchvision.transforms.ToTensor()]),
+    torchvision.transforms.Compose([torchvision.transforms.RandomGrayscale(),
+        torchvision.transforms.ToTensor()]),
+    torchvision.transforms.Compose([torchvision.transforms.RandomAffine(30),
+        torchvision.transforms.ToTensor()])]
+
+    # transform   
+    transform = action_list[action_idx]
+    test_dataset = torchvision.datasets.MNIST(root='test_dataset/', train=False, download=True, transform=transform)
+    train_dataset = torchvision.datasets.MNIST(root='test_dataset/', train=True, download=True, transform=transform)
+    train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, n_samples=0.0002)
+    return train_loader, test_loader
+
+
+# %% [markdown]
+# ## finish episode
+
+# %%
+policy_model = Policy()
+optimizer = optim.Adam(policy_model.parameters(), lr=3e-2)
+eps = np.finfo(np.float32).eps.item()
+gamma = 0.9
+def finish_episode():
+    """
+    Training code. Calculates actor and critic loss and performs backprop.
+    """
+    R = 0
+    saved_actions = policy_model.saved_actions
+    policy_losses = [] # list to save actor (policy) loss
+    value_losses = [] # list to save critic (value) loss
+    returns = [] # list to save the true values
+
+    # calculate the true value using rewards returned from the environment
+    for r in policy_model.rewards[::-1]:
+        # calculate the discounted value
+        R = r + gamma * R
+        returns.insert(0, R)
+
+    returns = torch.tensor(returns)
+    returns = (returns - returns.mean()) / (returns.std() + eps)
+
+    for (log_prob, value), R in zip(saved_actions, returns):
+        advantage = R - value.item()
+
+        # calculate actor (policy) loss 
+        policy_losses.append(-log_prob * advantage)
+
+        # calculate critic (value) loss using L1 smooth loss
+        value_losses.append(F.smooth_l1_loss(value, torch.tensor([R])))
+
+    # reset gradients
+    optimizer.zero_grad()
+
+    # sum up all the values of policy_losses and value_losses
+    loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum()
+
+    # perform backprop
+    loss.backward()
+    optimizer.step()
+
+    # reset rewards and action buffer
+    del policy_model.rewards[:]
+    del policy_model.saved_actions[:]
+
+# %% [markdown]
+# ## run
+
+# %%
+
+running_reward = 10
+episodes_num = 100
+policy_model = Policy()
+for i_episode in range(episodes_num) :
+    # initiate a new state
+    train_dataset = torchvision.datasets.MNIST(root='test_dataset/', train=True, download=True, transform=torchvision.transforms.ToTensor())
+    # train_dataset = torchvision.datasets.MNIST(root='test_dataset/', train=True, download=True, transform=torchvision.transforms.ToTensor())
+    train_loader_state = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
+
+    # select action from policy
+    action_idx = select_action(train_loader, policy_model)
+    print('>>> action_idx', action_idx)
+
+    # take the action -> apply data augmentation
+    train_loader, test_loader = take_action(action_idx)
+    reward = collect_reward(train_loader, test_loader)
+    print('>>> reward', reward)
+
+    # if args.render:
+    #     env.render()
+
+    policy_model.rewards.append(reward)
+
+    # perform backprop
+    finish_episode()
+
+    # # log result
+    if i_episode % 10 == 0:
+        print('Episode {}\tLast reward (val accuracy): {:.2f}'.format(i_episode, reward))
+
+# %%
+
+
+
diff --git a/MetaAugment/autoaugment_learners/actor_critic.ipynb b/MetaAugment/autoaugment_learners/actor_critic.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..26ab1bb761f4b66fc2d8e141e774d5952639929f
--- /dev/null
+++ b/MetaAugment/autoaugment_learners/actor_critic.ipynb
@@ -0,0 +1,678 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import numpy as np\n",
+    "import matplotlib.pyplot as plt \n",
+    "from itertools import count\n",
+    "\n",
+    "import torch\n",
+    "import torch.optim as optim\n",
+    "import torch.nn as nn\n",
+    "import torch.nn.functional as F\n",
+    "import torchvision\n",
+    "from torch.distributions import Categorical\n",
+    "from torch.utils.data import TensorDataset, DataLoader\n",
+    "\n",
+    "\n",
+    "from collections import namedtuple, deque\n",
+    "import math\n",
+    "import random\n",
+    "\n",
+    "from MetaAugment.main import *"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "'/Users/miawang/Library/CloudStorage/OneDrive-ImperialCollegeLondon/MSc AI/SE group project/actor critic RL'"
+      ]
+     },
+     "execution_count": 3,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "pwd"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 26,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "test_loader 79\n",
+      "train_loader 469\n"
+     ]
+    }
+   ],
+   "source": [
+    "## download data\n",
+    "batch_size = 128\n",
+    "\n",
+    "test_dataset = torchvision.datasets.MNIST(root='test_dataset/', train=False, download=True, transform=torchvision.transforms.ToTensor())\n",
+    "train_dataset = torchvision.datasets.MNIST(root='test_dataset/', train=True, download=True, transform=torchvision.transforms.ToTensor())\n",
+    "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)\n",
+    "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
+    "print('test_loader', len(test_loader))\n",
+    "print('train_loader',len(train_loader))\n",
+    "\n",
+    "def create_toy(train_dataset, test_dataset, batch_size, n_samples):\n",
+    "    # shuffle and take first n_samples %age of training dataset\n",
+    "    shuffled_train_dataset = torch.utils.data.Subset(train_dataset, torch.randperm(len(train_dataset)).tolist())\n",
+    "    indices_train = torch.arange(int(n_samples*len(train_dataset)))\n",
+    "    reduced_train_dataset = torch.utils.data.Subset(shuffled_train_dataset, indices_train)\n",
+    "    # shuffle and take first n_samples %age of test dataset\n",
+    "    shuffled_test_dataset = torch.utils.data.Subset(test_dataset, torch.randperm(len(test_dataset)).tolist())\n",
+    "    indices_test = torch.arange(int(n_samples*len(test_dataset)))\n",
+    "    reduced_test_dataset = torch.utils.data.Subset(shuffled_test_dataset, indices_test)\n",
+    "\n",
+    "    # push into DataLoader\n",
+    "    train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size)\n",
+    "    test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size)\n",
+    "\n",
+    "    return train_loader, test_loader\n",
+    "\n",
+    "# train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, 10)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### CNN"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "class LeNet(nn.Module):\n",
+    "    def __init__(self):\n",
+    "        super().__init__()\n",
+    "        self.conv1 = nn.Conv2d(1, 6, 5)\n",
+    "        self.relu1 = nn.ReLU()\n",
+    "        self.pool1 = nn.MaxPool2d(2)\n",
+    "        self.conv2 = nn.Conv2d(6, 16, 5)\n",
+    "        self.relu2 = nn.ReLU()\n",
+    "        self.pool2 = nn.MaxPool2d(2)\n",
+    "        self.fc1 = nn.Linear(256, 120)\n",
+    "        self.relu3 = nn.ReLU()\n",
+    "        self.fc2 = nn.Linear(120, 84)\n",
+    "        self.relu4 = nn.ReLU()\n",
+    "        self.fc3 = nn.Linear(84, 10)\n",
+    "        self.relu5 = nn.ReLU()\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        y = self.conv1(x)\n",
+    "        y = self.relu1(y)\n",
+    "        y = self.pool1(y)\n",
+    "        y = self.conv2(y)\n",
+    "        y = self.relu2(y)\n",
+    "        y = self.pool2(y)\n",
+    "        y = y.view(y.shape[0], -1)\n",
+    "        y = self.fc1(y)\n",
+    "        y = self.relu3(y)\n",
+    "        y = self.fc2(y)\n",
+    "        y = self.relu4(y)\n",
+    "        y = self.fc3(y)\n",
+    "        y = self.relu5(y)\n",
+    "        return y"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## collect reward"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 27,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "def collect_reward(train_loader, test_loader, max_epochs=100, early_stop_num=10):\n",
+    "    child_network = LeNet() \n",
+    "    sgd = optim.SGD(child_network.parameters(), lr=1e-1)\n",
+    "    cost = nn.CrossEntropyLoss()\n",
+    "    best_acc=0\n",
+    "    early_stop_cnt = 0\n",
+    "    \n",
+    "    # train child_network and check validation accuracy each epoch\n",
+    "    print('max_epochs', max_epochs)\n",
+    "    for _epoch in range(max_epochs):\n",
+    "        print('_epoch', _epoch)\n",
+    "        # train child_network\n",
+    "        child_network.train()\n",
+    "        for t, (train_x, train_label) in enumerate(train_loader):\n",
+    "            label_np = np.zeros((train_label.shape[0], 10))\n",
+    "            sgd.zero_grad()\n",
+    "            predict_y = child_network(train_x.float())\n",
+    "            loss = cost(predict_y, train_label.long())\n",
+    "            loss.backward()\n",
+    "            sgd.step()\n",
+    "\n",
+    "        # check validation accuracy on validation set\n",
+    "        correct = 0\n",
+    "        _sum = 0\n",
+    "        child_network.eval()\n",
+    "        for idx, (test_x, test_label) in enumerate(test_loader):\n",
+    "            predict_y = child_network(test_x.float()).detach()\n",
+    "            predict_ys = np.argmax(predict_y, axis=-1)\n",
+    "            label_np = test_label.numpy()\n",
+    "            _ = predict_ys == test_label\n",
+    "            correct += np.sum(_.numpy(), axis=-1)\n",
+    "            _sum += _.shape[0]\n",
+    "        \n",
+    "        # update best validation accuracy if it was higher, otherwise increase early stop count\n",
+    "        acc = correct / _sum\n",
+    "\n",
+    "        if acc > best_acc :\n",
+    "            best_acc = acc\n",
+    "            early_stop_cnt = 0\n",
+    "        else:\n",
+    "            early_stop_cnt += 1\n",
+    "\n",
+    "        # exit if validation gets worse over 10 runs\n",
+    "        if early_stop_cnt >= early_stop_num:\n",
+    "            break\n",
+    "\n",
+    "        # if _epoch%30 == 0:\n",
+    "        #     print('child_network accuracy: ', best_acc)\n",
+    "        \n",
+    "    return best_acc\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 28,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "torch.Size([128, 1, 28, 28])\n",
+      "tensor([3, 2, 7, 6, 9, 4, 1, 1, 8, 8, 8, 7, 4, 8, 0, 2, 2, 7, 7, 9, 3, 9, 7, 0,\n",
+      "        1, 3, 6, 5, 9, 8, 0, 4, 4, 1, 0, 0, 3, 8, 1, 2, 5, 5, 2, 0, 9, 5, 7, 4,\n",
+      "        0, 5, 5, 2, 0, 6, 7, 2, 5, 5, 0, 1, 3, 2, 0, 1, 3, 4, 8, 1, 0, 9, 5, 7,\n",
+      "        8, 8, 8, 8, 3, 6, 8, 0, 9, 0, 6, 2, 0, 5, 3, 3, 0, 8, 7, 4, 1, 0, 6, 3,\n",
+      "        9, 5, 1, 7, 3, 0, 7, 0, 2, 2, 4, 7, 8, 6, 1, 6, 2, 7, 3, 9, 9, 9, 6, 3,\n",
+      "        9, 8, 6, 6, 3, 5, 7, 8])\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "79"
+      ]
+     },
+     "execution_count": 28,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "for t, (train_x, train_label) in enumerate(test_loader):\n",
+    "    print(train_x.shape)\n",
+    "    print(train_label)\n",
+    "    break\n",
+    "len(test_loader)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 29,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "max_epochs 100\n",
+      "_epoch 0\n"
+     ]
+    },
+    {
+     "ename": "KeyboardInterrupt",
+     "evalue": "",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-29-646b050414f6>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mcollect_reward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+      "\u001b[0;32m<ipython-input-27-e6965b79120c>\u001b[0m in \u001b[0;36mcollect_reward\u001b[0;34m(train_loader, test_loader, max_epochs, early_stop_num)\u001b[0m\n\u001b[1;32m     12\u001b[0m         \u001b[0;31m# train child_network\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m         \u001b[0mchild_network\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m         \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mtrain_x\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_label\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     15\u001b[0m             \u001b[0mlabel_np\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_label\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     16\u001b[0m             \u001b[0msgd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    519\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_sampler_iter\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    520\u001b[0m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_reset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 521\u001b[0;31m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    522\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_num_yielded\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    523\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_kind\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0m_DatasetKind\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mIterable\u001b[0m \u001b[0;32mand\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m_next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    559\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_next_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    560\u001b[0m         \u001b[0mindex\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 561\u001b[0;31m         \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dataset_fetcher\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    562\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    563\u001b[0m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py\u001b[0m in \u001b[0;36mfetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m     47\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     48\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     50\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     51\u001b[0m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m     47\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mfetch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     48\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_collation\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     50\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     51\u001b[0m             \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpossibly_batched_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torchvision/datasets/mnist.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, index)\u001b[0m\n\u001b[1;32m    132\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    133\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 134\u001b[0;31m             \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    135\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    136\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtarget_transform\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torchvision/transforms/transforms.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, pic)\u001b[0m\n\u001b[1;32m     96\u001b[0m             \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mConverted\u001b[0m \u001b[0mimage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     97\u001b[0m         \"\"\"\n\u001b[0;32m---> 98\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_tensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpic\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     99\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    100\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__repr__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torchvision/transforms/functional.py\u001b[0m in \u001b[0;36mto_tensor\u001b[0;34m(pic)\u001b[0m\n\u001b[1;32m    146\u001b[0m     \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpic\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpic\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpic\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetbands\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    147\u001b[0m     \u001b[0;31m# put it from HWC to CHW format\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 148\u001b[0;31m     \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpermute\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontiguous\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    149\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mByteTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    150\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mimg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdefault_float_dtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdiv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m255\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
+     ]
+    }
+   ],
+   "source": [
+    "collect_reward(train_loader, test_loader)\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Policy network"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class Policy(nn.Module):\n",
+    "    \"\"\"\n",
+    "    implements both actor and critic in one model\n",
+    "    \"\"\"\n",
+    "    def __init__(self):\n",
+    "        super(Policy, self).__init__()\n",
+    "        self.conv1 = nn.Conv2d(1, 6, 5 , stride=2)\n",
+    "        self.conv2 = nn.Conv2d(6, 12, 5, stride=2)\n",
+    "        self.maxpool = nn.MaxPool2d(4)\n",
+    "\n",
+    "        # actor's layer\n",
+    "        self.action_head = nn.Linear(12, 2)\n",
+    "\n",
+    "        # critic's layer\n",
+    "        self.value_head = nn.Linear(12, 1)\n",
+    "\n",
+    "        # action & reward buffer\n",
+    "        self.saved_actions = []\n",
+    "        self.rewards = []\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        \"\"\"\n",
+    "        forward of both actor and critic\n",
+    "        \"\"\"\n",
+    "        x = F.relu(self.conv1(x))\n",
+    "        x = F.relu(self.conv2(x))\n",
+    "        x = self.maxpool(x)\n",
+    "        x = x.view(x.size(0), -1)\n",
+    "        # print('x', x.shape)\n",
+    "\n",
+    "        # actor: choses action to take from state s_t \n",
+    "        # by returning probability of each action\n",
+    "        # print('self.action_head(x)', self.action_head(x).shape)\n",
+    "        action_prob = F.softmax(self.action_head(x), dim=-1)\n",
+    "        # print('action_prob', action_prob.shape)\n",
+    "\n",
+    "        # critic: evaluates being in the state s_t\n",
+    "        state_values = self.value_head(x)\n",
+    "\n",
+    "        # return values for both actor and critic as a tuple of 2 values:\n",
+    "        # 1. a list with the probability of each action over the action space\n",
+    "        # 2. the value from state s_t \n",
+    "        return action_prob, state_values\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "test_dataset = torchvision.datasets.MNIST(root='test_dataset/', train=False, download=True, transform=torchvision.transforms.ToTensor())\n",
+    "train_dataset = torchvision.datasets.MNIST(root='test_dataset/', train=True, download=True, transform=torchvision.transforms.ToTensor())\n",
+    "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)\n",
+    "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
+    "\n",
+    "policy_model = Policy()\n",
+    "# for t, (x, y) in enumerate(train_loader):\n",
+    "#     # print(x.shape)\n",
+    "#     policy_model(x)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## select action"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])\n",
+    "\n",
+    "def select_action(train_loader, policy_model):\n",
+    "    probs_list = []\n",
+    "    value_list = []\n",
+    "    for t, (x, y) in enumerate(train_loader):\n",
+    "        probs_i, state_value_i = policy_model(x)\n",
+    "        probs_list += [probs_i]\n",
+    "        value_list += [state_value_i]\n",
+    "\n",
+    "    probs = torch.mean(torch.cat(probs_list), dim=0)\n",
+    "    state_value = torch.mean(torch.cat(value_list))\n",
+    "    # print('probs_i', probs_i)\n",
+    "    # print('probs', probs)\n",
+    "    # create a categorical distribution over the list of probabilities of actions\n",
+    "    m = Categorical(probs)\n",
+    "    # print('m', m)\n",
+    "    # and sample an action using the distribution\n",
+    "    action = m.sample()\n",
+    "    # print('action', action)\n",
+    "\n",
+    "    # save to action buffer\n",
+    "    policy_model.saved_actions.append(SavedAction(m.log_prob(action), state_value))\n",
+    "\n",
+    "    # the action to take (left or right)\n",
+    "    return action.item()\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([1, 2, 3])"
+      ]
+     },
+     "execution_count": 8,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "torch.tensor([1, 2, 3])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## take action"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def take_action(action_idx):\n",
+    "    # Define actions (data augmentation policy) --- can be improved\n",
+    "    action_list = [\n",
+    "    torchvision.transforms.Compose([torchvision.transforms.RandomVerticalFlip(),\n",
+    "        torchvision.transforms.ToTensor()]),\n",
+    "    torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(),\n",
+    "        torchvision.transforms.ToTensor()]),\n",
+    "    torchvision.transforms.Compose([torchvision.transforms.RandomGrayscale(),\n",
+    "        torchvision.transforms.ToTensor()]),\n",
+    "    torchvision.transforms.Compose([torchvision.transforms.RandomAffine(30),\n",
+    "        torchvision.transforms.ToTensor()])]\n",
+    "\n",
+    "    # transform   \n",
+    "    transform = action_list[action_idx]\n",
+    "    test_dataset = torchvision.datasets.MNIST(root='test_dataset/', train=False, download=True, transform=transform)\n",
+    "    train_dataset = torchvision.datasets.MNIST(root='test_dataset/', train=True, download=True, transform=transform)\n",
+    "    train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, n_samples=0.0002)\n",
+    "    return train_loader, test_loader\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## finish episode"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "policy_model = Policy()\n",
+    "optimizer = optim.Adam(policy_model.parameters(), lr=3e-2)\n",
+    "eps = np.finfo(np.float32).eps.item()\n",
+    "gamma = 0.9\n",
+    "def finish_episode():\n",
+    "    \"\"\"\n",
+    "    Training code. Calculates actor and critic loss and performs backprop.\n",
+    "    \"\"\"\n",
+    "    R = 0\n",
+    "    saved_actions = policy_model.saved_actions\n",
+    "    policy_losses = [] # list to save actor (policy) loss\n",
+    "    value_losses = [] # list to save critic (value) loss\n",
+    "    returns = [] # list to save the true values\n",
+    "\n",
+    "    # calculate the true value using rewards returned from the environment\n",
+    "    for r in policy_model.rewards[::-1]:\n",
+    "        # calculate the discounted value\n",
+    "        R = r + gamma * R\n",
+    "        returns.insert(0, R)\n",
+    "\n",
+    "    returns = torch.tensor(returns)\n",
+    "    returns = (returns - returns.mean()) / (returns.std() + eps)\n",
+    "\n",
+    "    for (log_prob, value), R in zip(saved_actions, returns):\n",
+    "        advantage = R - value.item()\n",
+    "\n",
+    "        # calculate actor (policy) loss \n",
+    "        policy_losses.append(-log_prob * advantage)\n",
+    "\n",
+    "        # calculate critic (value) loss using L1 smooth loss\n",
+    "        value_losses.append(F.smooth_l1_loss(value, torch.tensor([R])))\n",
+    "\n",
+    "    # reset gradients\n",
+    "    optimizer.zero_grad()\n",
+    "\n",
+    "    # sum up all the values of policy_losses and value_losses\n",
+    "    loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum()\n",
+    "\n",
+    "    # perform backprop\n",
+    "    loss.backward()\n",
+    "    optimizer.step()\n",
+    "\n",
+    "    # reset rewards and action buffer\n",
+    "    del policy_model.rewards[:]\n",
+    "    del policy_model.saved_actions[:]"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## run"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      ">>> action_idx 1\n",
+      "child_network accuracy:  0\n",
+      ">>> reward 0\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "<ipython-input-10-e36f3237ebab>:31: UserWarning: Using a target size (torch.Size([1])) that is different to the input size (torch.Size([])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
+      "  value_losses.append(F.smooth_l1_loss(value, torch.tensor([R])))\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Episode 0\tLast reward (val accuracy): 0.00\n",
+      ">>> action_idx 1\n",
+      "child_network accuracy:  0.5\n",
+      ">>> reward 0.5\n",
+      ">>> action_idx 1\n",
+      "child_network accuracy:  0.5\n",
+      ">>> reward 0.5\n",
+      ">>> action_idx 0\n",
+      "child_network accuracy:  0\n",
+      ">>> reward 0\n",
+      ">>> action_idx 1\n",
+      "child_network accuracy:  0\n",
+      ">>> reward 0\n",
+      ">>> action_idx 1\n",
+      "child_network accuracy:  0\n",
+      ">>> reward 0\n",
+      ">>> action_idx 0\n",
+      "child_network accuracy:  0.5\n",
+      ">>> reward 0.5\n",
+      ">>> action_idx 1\n",
+      "child_network accuracy:  0\n",
+      ">>> reward 0.5\n",
+      ">>> action_idx 1\n",
+      "child_network accuracy:  0\n",
+      ">>> reward 0\n",
+      ">>> action_idx 1\n",
+      "child_network accuracy:  0\n",
+      ">>> reward 0\n",
+      ">>> action_idx 0\n",
+      "child_network accuracy:  0\n",
+      ">>> reward 0\n",
+      "Episode 10\tLast reward (val accuracy): 0.00\n",
+      ">>> action_idx 1\n",
+      "child_network accuracy:  0\n",
+      ">>> reward 0\n",
+      ">>> action_idx 0\n",
+      "child_network accuracy:  0\n",
+      ">>> reward 0\n",
+      ">>> action_idx 0\n",
+      "child_network accuracy:  0\n",
+      ">>> reward 0\n",
+      ">>> action_idx 1\n",
+      "child_network accuracy:  0\n",
+      ">>> reward 0\n",
+      ">>> action_idx 1\n",
+      "child_network accuracy:  0\n",
+      ">>> reward 0\n",
+      ">>> action_idx 1\n",
+      "child_network accuracy:  0\n",
+      ">>> reward 0.5\n",
+      ">>> action_idx 1\n"
+     ]
+    },
+    {
+     "ename": "KeyboardInterrupt",
+     "evalue": "",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-11-7ab134985f61>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     14\u001b[0m     \u001b[0;31m# take the action -> apply data augmentation\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m     \u001b[0mtrain_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_loader\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtake_action\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maction_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     16\u001b[0m     \u001b[0mreward\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcollect_reward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     17\u001b[0m     \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'>>> reward'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreward\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m<ipython-input-9-4b54ec3ed4a2>\u001b[0m in \u001b[0;36mtake_action\u001b[0;34m(action_idx)\u001b[0m\n\u001b[1;32m     14\u001b[0m     \u001b[0mtransform\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0maction_list\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0maction_idx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     15\u001b[0m     \u001b[0mtest_dataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorchvision\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdatasets\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mMNIST\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mroot\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'test_dataset/'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdownload\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtransform\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m     \u001b[0mtrain_dataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorchvision\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdatasets\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mMNIST\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mroot\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'test_dataset/'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdownload\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtransform\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     17\u001b[0m     \u001b[0mtrain_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_loader\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcreate_toy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_dataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_dataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_samples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.0002\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     18\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mtrain_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_loader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torchvision/datasets/mnist.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, root, train, transform, target_transform, download)\u001b[0m\n\u001b[1;32m     91\u001b[0m                                ' You can use download=True to download it')\n\u001b[1;32m     92\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 93\u001b[0;31m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtargets\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_load_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     95\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_check_legacy_exist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torchvision/datasets/mnist.py\u001b[0m in \u001b[0;36m_load_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    110\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_load_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    111\u001b[0m         \u001b[0mimage_file\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34mf\"{'train' if self.train else 't10k'}-images-idx3-ubyte\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 112\u001b[0;31m         \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mread_image_file\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mraw_folder\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mimage_file\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    113\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    114\u001b[0m         \u001b[0mlabel_file\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34mf\"{'train' if self.train else 't10k'}-labels-idx1-ubyte\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torchvision/datasets/mnist.py\u001b[0m in \u001b[0;36mread_image_file\u001b[0;34m(path)\u001b[0m\n\u001b[1;32m    507\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    508\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mread_image_file\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 509\u001b[0;31m     \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mread_sn3_pascalvincent_tensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstrict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    510\u001b[0m     \u001b[0;32massert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muint8\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    511\u001b[0m     \u001b[0;32massert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndimension\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torchvision/datasets/mnist.py\u001b[0m in \u001b[0;36mread_sn3_pascalvincent_tensor\u001b[0;34m(path, strict)\u001b[0m\n\u001b[1;32m    485\u001b[0m     \u001b[0;31m# read\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    486\u001b[0m     \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"rb\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 487\u001b[0;31m         \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    488\u001b[0m     \u001b[0;31m# parse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    489\u001b[0m     \u001b[0mmagic\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_int\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
+     ]
+    }
+   ],
+   "source": [
+    "\n",
+    "running_reward = 10\n",
+    "episodes_num = 100\n",
+    "policy_model = Policy()\n",
+    "for i_episode in range(episodes_num) :\n",
+    "    # initiate a new state\n",
+    "    train_dataset = torchvision.datasets.MNIST(root='test_dataset/', train=True, download=True, transform=torchvision.transforms.ToTensor())\n",
+    "    # train_dataset = torchvision.datasets.MNIST(root='test_dataset/', train=True, download=True, transform=torchvision.transforms.ToTensor())\n",
+    "    train_loader_state = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
+    "\n",
+    "    # select action from policy\n",
+    "    action_idx = select_action(train_loader, policy_model)\n",
+    "    print('>>> action_idx', action_idx)\n",
+    "\n",
+    "    # take the action -> apply data augmentation\n",
+    "    train_loader, test_loader = take_action(action_idx)\n",
+    "    reward = collect_reward(train_loader, test_loader)\n",
+    "    print('>>> reward', reward)\n",
+    "\n",
+    "    # if args.render:\n",
+    "    #     env.render()\n",
+    "\n",
+    "    policy_model.rewards.append(reward)\n",
+    "\n",
+    "    # perform backprop\n",
+    "    finish_episode()\n",
+    "\n",
+    "    # # log result\n",
+    "    if i_episode % 10 == 0:\n",
+    "        print('Episode {}\\tLast reward (val accuracy): {:.2f}'.format(i_episode, reward))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "interpreter": {
+   "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
+  },
+  "kernelspec": {
+   "display_name": "Python 3.8.9 64-bit",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.8"
+  },
+  "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/MetaAugment/autoaugment_learners/baseline.py b/MetaAugment/autoaugment_learners/baseline.py
new file mode 100644
index 0000000000000000000000000000000000000000..e33d4e1e33887bb81c6e7634697cc1a0e4840987
--- /dev/null
+++ b/MetaAugment/autoaugment_learners/baseline.py
@@ -0,0 +1,30 @@
+import MetaAugment.child_networks as cn
+from pprint import pprint
+import torchvision.datasets as datasets
+import torchvision
+from MetaAugment.autoaugment_learners.aa_learner import aa_learner
+import pickle
+
+train_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/train',
+                                train=True, download=True, transform=None)
+test_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/test', 
+                        train=False, download=True, transform=torchvision.transforms.ToTensor())
+child_network = cn.bad_lenet
+
+aalearner = aa_learner(discrete_p_m=True)
+
+# this policy is same as identity function, because probabaility and magnitude are both zero
+null_policy = [(("Contrast", 0.0, 0.0), ("Contrast", 0.0, 0.0))]
+
+
+with open('bad_lenet_baseline.txt', 'w') as file:
+    file.write('')
+
+for _ in range(100):
+    acc = aalearner.test_autoaugment_policy(null_policy, child_network(), train_dataset, test_dataset, 
+                                toy_flag=True, logging=False)
+    with open('bad_lenet_baseline.txt', 'a') as file:
+        file.write(str(acc))
+        file.write('\n')
+
+pprint(aalearner.history)
\ No newline at end of file
diff --git a/MetaAugment/autoaugment_learners/gru_learner.py b/MetaAugment/autoaugment_learners/gru_learner.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0d8dcb983cb38b2dcf946aef8dab46556ca6bb1
--- /dev/null
+++ b/MetaAugment/autoaugment_learners/gru_learner.py
@@ -0,0 +1,206 @@
+import torch
+
+import MetaAugment.child_networks as cn
+from MetaAugment.autoaugment_learners.aa_learner import aa_learner
+from MetaAugment.controller_networks.rnn_controller import RNNModel
+
+from pprint import pprint
+import pickle
+
+
+
+# We will use this augmentation_space temporarily. Later on we will need to 
+# make sure we are able to add other image functions if the users want.
+augmentation_space = [
+            # (function_name, do_we_need_to_specify_magnitude)
+            ("ShearX", True),
+            ("ShearY", True),
+            ("TranslateX", True),
+            ("TranslateY", True),
+            ("Rotate", True),
+            ("Brightness", True),
+            ("Color", True),
+            ("Contrast", True),
+            ("Sharpness", True),
+            ("Posterize", True),
+            ("Solarize", True),
+            ("AutoContrast", False),
+            ("Equalize", False),
+            ("Invert", False),
+        ]
+
+
+class gru_learner(aa_learner):
+    """
+    An AutoAugment learner with a GRU controller 
+
+    The original AutoAugment paper(http://arxiv.org/abs/1805.09501) 
+    uses a LSTM controller updated via Proximal Policy Optimization.
+    (See Section 3 of AutoAugment paper)
+
+    The GRU has been shown to be as powerful of a sequential neural
+    network as the LSTM whilst training and testing much faster
+    (https://arxiv.org/abs/1412.3555), which is why we substituted
+    the LSTM for the GRU.
+    """
+
+    def __init__(self, sp_num=5, fun_num=14, p_bins=11, m_bins=10, discrete_p_m=True, alpha=0.2):
+        """
+        Args:
+            alpha (float): Exploration parameter. It is multiplied to 
+                    operation tensors before they're softmaxed. 
+                    The lower this value, the more smoothed the output
+                    of the softmaxed will be, hence more exploration.
+        """
+        
+        super().__init__(sp_num, fun_num, p_bins, m_bins, discrete_p_m=True)
+        self.alpha = alpha
+
+        self.rnn_output_size = fun_num+p_bins+m_bins
+        self.controller = RNNModel(mode='GRU', output_size=self.rnn_output_size, 
+                                    num_layers=2, bias=True)
+        self.softmax = torch.nn.Softmax(dim=0)
+
+
+    def generate_new_policy(self):
+        """
+        The GRU controller pops out a new policy.
+
+        At each time step, the GRU outputs a 
+        (fun_num + p_bins + m_bins, ) dimensional tensor which 
+        contains information regarding which 'image function' to use,
+        which value of 'probability(prob)' and 'magnitude(mag)' to use.
+
+        We run the GRU for 10 timesteps to obtain 10 of such tensors.
+
+        We then softmax the parts of the tensor which represents the
+        choice of function, prob, and mag seperately, so that the
+        resulting tensor's values sums up to 3.
+
+        Then we input each tensor into self.translate_operation_tensor
+        with parameter (return_log_prob=True), which outputs a tuple
+        in the form of ('img_function_name', prob, mag) and a float
+        representing the log probability that we chose the chosen 
+        func, prob and mag. 
+
+        We add up the log probabilities of each operation.
+
+        We turn the operations into a list of 5 tuples such as:
+            [
+            (("Invert", 0.8, None), ("Contrast", 0.2, 6)),
+            (("Rotate", 0.7, 2), ("Invert", 0.8, None)),
+            (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
+            (("ShearY", 0.5, 8), ("Invert", 0.7, None)),
+            ]
+        This list can then be input into an AutoAugment object
+        as is done in self.learn()
+        
+        We return the list and the sum of the log probs
+        """
+
+        log_prob = 0
+
+        # we need a random input to put in
+        random_input = torch.zeros(self.rnn_output_size, requires_grad=False)
+
+        # 2*self.sp_num because we need 2 operations for every subpolicy
+        vectors = self.controller(input=random_input, time_steps=2*self.sp_num)
+
+        # softmax the funcion vector, probability vector, and magnitude vector
+        # of each timestep
+        softmaxed_vectors = []
+        for vector in vectors:
+            fun_t, prob_t, mag_t = vector.split([self.fun_num, self.p_bins, self.m_bins])
+            fun_t = self.softmax(fun_t * self.alpha)
+            prob_t = self.softmax(prob_t * self.alpha)
+            mag_t = self.softmax(mag_t * self.alpha)
+            softmaxed_vector = torch.cat((fun_t, prob_t, mag_t))
+            softmaxed_vectors.append(softmaxed_vector)
+            
+        new_policy = []
+
+        for subpolicy_idx in range(self.sp_num):
+            # the vector corresponding to the first operation of this subpolicy
+            op1 = softmaxed_vectors[2*subpolicy_idx]
+            # the vector corresponding to the second operation of this subpolicy
+            op2 = softmaxed_vectors[2*subpolicy_idx+1]
+
+            # translate both vectors
+            op1, log_prob1 = self.translate_operation_tensor(op1, return_log_prob=True)
+            op2, log_prob2 = self.translate_operation_tensor(op2, return_log_prob=True)
+            
+            new_policy.append((op1,op2))
+            log_prob += (log_prob1+log_prob2)
+        
+        return new_policy, log_prob
+
+
+    def learn(self, train_dataset, test_dataset, child_network_architecture, toy_flag, m=8):
+        # optimizer for training the GRU controller
+        cont_optim = torch.optim.SGD(self.controller.parameters(), lr=1e-2)
+
+        m = 8 # minibatch size
+        b = 0.88 # b is the running exponential mean of the rewards, used for training stability
+               # (see section 3.2 of https://arxiv.org/abs/1611.01578)
+
+        for _ in range(1000):
+            cont_optim.zero_grad()
+
+            # obj(objective) is $ \sum_{k=1}^m (reward_k-b) \sum_{t=1}^T log(P(a_t|a_{(t-1):1};\theta_c))$,
+            # which is used in PPO
+            obj = 0
+
+            # sum up the rewards within a minibatch in order to update the running mean, 'b'
+            mb_rewards_sum = 0
+
+            for k in range(m):
+                # log_prob is $\sum_{t=1}^T log(P(a_t|a_{(t-1):1};\theta_c))$, used in PPO
+                policy, log_prob = self.generate_new_policy()
+
+                pprint(policy)
+                child_network = child_network_architecture()
+                reward = self.test_autoaugment_policy(policy, child_network, train_dataset,
+                                                    test_dataset, toy_flag)
+                mb_rewards_sum += reward
+
+                # log
+                self.history.append((policy, reward))
+
+                # gradient accumulation
+                obj += (reward-b)*log_prob
+            
+            # update running mean of rewards
+            b = 0.7*b + 0.3*(mb_rewards_sum/m)
+
+            (-obj).backward() # We put a minus because we want to maximize the objective, not 
+                              # minimize it.
+            cont_optim.step()
+
+            # save the history every 1 epochs as a pickle
+            with open('gru_logs.pkl', 'wb') as file:
+                pickle.dump(self.history, file)
+            with open('gru_learner.pkl', 'wb') as file:
+                pickle.dump(self, file)
+             
+
+
+
+if __name__=='__main__':
+
+    # We can initialize the train_dataset with its transform as None.
+    # Later on, we will change this object's transform attribute to the policy
+    # that we want to test
+    import torchvision.datasets as datasets
+    import torchvision
+    torch.manual_seed(0)
+
+    train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=True, 
+                                transform=None)
+    test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=True,
+                                transform=torchvision.transforms.ToTensor())
+    child_network = cn.lenet
+
+
+    learner = gru_learner(discrete_p_m=False)
+    learner.learn(train_dataset, test_dataset, child_network, toy_flag=True)
+    pprint(learner.history)
diff --git a/MetaAugment/autoaugment_learners/randomsearch_learner.py b/MetaAugment/autoaugment_learners/randomsearch_learner.py
index 980757da9a9510d7ae6d45022101143b1f4d8599..48f1b6f439f589d3fc8843dee2b303302a773662 100644
--- a/MetaAugment/autoaugment_learners/randomsearch_learner.py
+++ b/MetaAugment/autoaugment_learners/randomsearch_learner.py
@@ -1,21 +1,17 @@
 import torch
 import numpy as np
-import torchvision.transforms as transforms
-import torchvision.transforms.autoaugment as torchaa
-from torchvision.transforms import functional as F, InterpolationMode
 
-from MetaAugment.main import *
 import MetaAugment.child_networks as cn
-from MetaAugment.autoaugment_learners.autoaugment import *
-from MetaAugment.autoaugment_learners.aa_learner import *
+from MetaAugment.autoaugment_learners.aa_learner import aa_learner
 
 from pprint import pprint
+import matplotlib.pyplot as plt
+import pickle
 
 
 
 # We will use this augmentation_space temporarily. Later on we will need to 
 # make sure we are able to add other image functions if the users want.
-num_bins = 10
 augmentation_space = [
             # (function_name, do_we_need_to_specify_magnitude)
             ("ShearX", True),
@@ -36,55 +32,52 @@ augmentation_space = [
 
 class randomsearch_learner(aa_learner):
     def __init__(self, sp_num=5, fun_num=14, p_bins=11, m_bins=10, discrete_p_m=False):
-        '''
-        Args:
-            spdim: number of subpolicies per policy
-            fun_num: number of image functions in our search space
-            p_bins: number of bins we divide the interval [0,1] for probabilities
-            m_bins: number of bins we divide the magnitude space
-        '''
         super().__init__(sp_num, fun_num, p_bins, m_bins, discrete_p_m)
-
-        # TODO: We should probably use a different way to store results than self.history
-        self.history = []
+        
 
     def generate_new_discrete_operation(self):
-        '''
+        """
         generate a new random operation in the form of a tensor of dimension:
             (fun_num + 11 + 10)
 
+        Used only when self.discrete_p_m=True
+
         The first fun_num dimensions is a 1-hot encoding to specify which function to use.
         The next 11 dimensions specify which 'probability' to choose.
             (0.0, 0.1, ..., 1.0)
         The next 10 dimensions specify which 'magnitude' to choose.
             (0, 1, ..., 9)
-        '''
+        """
+
         random_fun = np.random.randint(0, self.fun_num)
         random_prob = np.random.randint(0, self.p_bins)
         random_mag = np.random.randint(0, self.m_bins)
         
         fun_t= torch.zeros(self.fun_num)
-        fun_t[random_fun] = 1
+        fun_t[random_fun] = 1.0
         prob_t = torch.zeros(self.p_bins)
-        prob_t[random_prob] = 1
+        prob_t[random_prob] = 1.0
         mag_t = torch.zeros(self.m_bins)
-        mag_t[random_mag] = 1
+        mag_t[random_mag] = 1.0
 
         return torch.cat([fun_t, prob_t, mag_t])
 
 
     def generate_new_continuous_operation(self):
-        '''
+        """
         Returns operation_tensor, which is a tensor representation of a random operation with
         dimension:
             (fun_num + 1 + 1)
 
+        Used only when self.discrete_p_m=False.
+
         The first fun_num dimensions is a 1-hot encoding to specify which function to use.
         The next 1 dimensions specify which 'probability' to choose.
             0 < x < 1
         The next 1 dimensions specify which 'magnitude' to choose.
             0 < x < 9
-        '''
+        """
+
         fun_p_m = torch.zeros(self.fun_num + 2)
         
         # pick a random image function
@@ -92,21 +85,17 @@ class randomsearch_learner(aa_learner):
         fun_p_m[random_fun] = 1
 
         fun_p_m[-2] = np.random.uniform() # 0<prob<1
-        fun_p_m[-1] = np.random.uniform() * (self.m_bins-1) # 0<mag<9
+        fun_p_m[-1] = np.random.uniform() * (self.m_bins-0.0000001) - 0.4999999 # -0.5<mag<9.5
         
         return fun_p_m
 
 
     def generate_new_policy(self):
-        '''
-        Generate a new random policy in the form of
-            [
-            (("Invert", 0.8, None), ("Contrast", 0.2, 6)),
-            (("Rotate", 0.7, 2), ("Invert", 0.8, None)),
-            (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
-            (("ShearY", 0.5, 8), ("Invert", 0.7, None)),
-            ]
-        '''
+        """
+        Generates a new policy, with the elements chosen at random
+        (unifom random distribution).
+        """
+
         new_policy = []
         
         for _ in range(self.sp_num): # generate sp_num subpolicies for each policy
@@ -129,15 +118,8 @@ class randomsearch_learner(aa_learner):
 
 
     def learn(self, train_dataset, test_dataset, child_network_architecture, toy_flag):
-        '''
-        Does the loop which is seen in Figure 1 in the AutoAugment paper.
-        In other words, repeat:
-            1. <generate a random policy>
-            2. <see how good that policy is>
-            3. <save how good the policy is in a list/dictionary>
-        '''
         # test out 15 random policies
-        for _ in range(15):
+        for _ in range(1500):
             policy = self.generate_new_policy()
 
             pprint(policy)
@@ -147,19 +129,27 @@ class randomsearch_learner(aa_learner):
 
             self.history.append((policy, reward))
 
+            # save the history every 10 epochs as a pickle
+            if _%10==1:
+                with open('randomsearch_logs.pkl', 'wb') as file:
+                    pickle.dump(self.history, file)
+    
+
 
-if __name__=='__main__':
 
+if __name__=='__main__':
     # We can initialize the train_dataset with its transform as None.
     # Later on, we will change this object's transform attribute to the policy
     # that we want to test
-    train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, 
-                                transform=None)
-    test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False,
-                                transform=torchvision.transforms.ToTensor())
-    child_network = cn.lenet
-
+    import torchvision.datasets as datasets
+    import torchvision
     
-    rs_learner = randomsearch_learner(discrete_p_m=False)
+    train_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/train',
+                                    train=True, download=True, transform=None)
+    test_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/test', 
+                            train=False, download=True, transform=torchvision.transforms.ToTensor())
+    child_network = cn.bad_lenet
+
+    rs_learner = randomsearch_learner(discrete_p_m=True)
     rs_learner.learn(train_dataset, test_dataset, child_network, toy_flag=True)
     pprint(rs_learner.history)
\ No newline at end of file
diff --git a/MetaAugment/child_networks/bad_lenet.py b/MetaAugment/child_networks/bad_lenet.py
index 296192cb7746ae3de1aaf0e9954ad9c60dcaae78..c85d432f6834df29d7a695b3540fcd8475309691 100644
--- a/MetaAugment/child_networks/bad_lenet.py
+++ b/MetaAugment/child_networks/bad_lenet.py
@@ -1,21 +1,56 @@
 import torch.nn as nn
 
 
+# class Bad_LeNet(nn.Module):
+#     # 1. I reduced the channel sizes of the convolutional layers
+#     # 2. I reduced the number of fully ocnnected layers from 3 to 2
+#     # 
+#     # no. of weights: 25*2 + 25*2*4 + 16*4*10 = 250+640 = 890
+#     def __init__(self):
+#         super().__init__()
+#         self.conv1 = nn.Conv2d(1, 2, 5)
+#         self.relu1 = nn.ReLU()
+#         self.pool1 = nn.MaxPool2d(2)
+#         self.conv2 = nn.Conv2d(2, 4, 5)
+#         self.relu2 = nn.ReLU()
+#         self.pool2 = nn.MaxPool2d(2)
+#         self.fc1 = nn.Linear(16*4,  10)
+#         self.relu3 = nn.ReLU()
+
+
+#     def forward(self, x):
+#         y = self.conv1(x)
+#         y = self.relu1(y)
+#         y = self.pool1(y)
+#         y = self.conv2(y)
+#         y = self.relu2(y)
+#         y = self.pool2(y)
+#         y = y.view(y.shape[0], -1)
+#         y = self.fc1(y)
+#         y = self.relu3(y)
+#         return y
+
 class Bad_LeNet(nn.Module):
+    # 1. I reduced the channel sizes of the convolutional layers
+    # 2. I reduced the number of fully connected layers from 3 to 2
+    # 
+    # no. of weights: 25*2 + 25*2*3 + 4*3*10 = 50+150+120 = 320
     def __init__(self):
         super().__init__()
-        self.conv1 = nn.Conv2d(1, 6, 5)
+        self.conv1 = nn.Conv2d(1, 2, 5)
         self.relu1 = nn.ReLU()
         self.pool1 = nn.MaxPool2d(2)
-        self.conv2 = nn.Conv2d(6, 16, 5)
+        self.conv2 = nn.Conv2d(2, 3, 5)
         self.relu2 = nn.ReLU()
-        self.pool2 = nn.MaxPool2d(2)
-        self.fc1 = nn.Linear(256, 120)
+        self.pool2 = nn.MaxPool2d(4)
+        self.fc1 = nn.Linear(4*3,  10)
         self.relu3 = nn.ReLU()
-        self.fc2 = nn.Linear(120, 84)
-        self.relu4 = nn.ReLU()
-        self.fc3 = nn.Linear(84, 10)
-        self.relu5 = nn.ReLU()
+        
+        # self.fc2 = nn.Linear(20, 14)
+        # self.relu4 = nn.ReLU()
+        # self.fc3 = nn.Linear(14, 10)
+        # self.relu5 = nn.ReLU()
+
 
     def forward(self, x):
         y = self.conv1(x)
@@ -27,10 +62,10 @@ class Bad_LeNet(nn.Module):
         y = y.view(y.shape[0], -1)
         y = self.fc1(y)
         y = self.relu3(y)
-        y = self.fc2(y)
-        y = self.relu4(y)
-        y = self.fc3(y)
-        y = self.relu5(y)
+        # y = self.fc2(y)
+        # y = self.relu4(y)
+        # y = self.fc3(y)
+        # y = self.relu5(y)
         return y
 
 
diff --git a/MetaAugment/child_networks/lenet.py b/MetaAugment/child_networks/lenet.py
index 5546bfa76f3529f074f024dc1a8b81307d27eec0..e4c1cb6efcf397186400336004fad4ba831bcbc1 100644
--- a/MetaAugment/child_networks/lenet.py
+++ b/MetaAugment/child_networks/lenet.py
@@ -2,6 +2,7 @@ import torch.nn as nn
 
 
 class LeNet(nn.Module):
+    # no. of params: 25*6 + 25*6*16 + 256*120 + 120*84 + 84*10 = > 30,000
     def __init__(self):
         super().__init__()
         self.conv1 = nn.Conv2d(1, 6, 5)
diff --git a/MetaAugment/controller_networks/rnn_controller.py b/MetaAugment/controller_networks/rnn_controller.py
new file mode 100644
index 0000000000000000000000000000000000000000..12680eae88cbda7f93949f30ffd619ec65f46069
--- /dev/null
+++ b/MetaAugment/controller_networks/rnn_controller.py
@@ -0,0 +1,230 @@
+import torch
+import torch.nn as nn
+import math
+
+class LSTMCell(nn.Module):
+    def __init__(self, input_size, hidden_size, bias=True):
+        super(LSTMCell, self).__init__()
+        self.input_size = input_size
+        self.hidden_size = hidden_size
+        self.bias = bias
+
+        self.x2h = nn.Linear(input_size, 4 * hidden_size, bias=bias)
+        self.h2h = nn.Linear(hidden_size, 4 * hidden_size, bias=bias)
+        
+        self.reset_parameters()
+
+    def reset_parameters(self):
+        std = 1.0 / math.sqrt(self.hidden_size)
+        for w in self.parameters():
+            w.data.uniform_(-std, std)
+
+    def forward(self, input, hx=None):
+        if hx is None:
+            hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
+            hx = (hx, hx)
+            
+        # We used hx to pack both the hidden and cell states
+        hx, cx = hx
+        
+        hi = self.x2h(input) + self.h2h(hx)
+        i, f, o, g = torch.chunk(hi, 4, dim=-1)
+        i = torch.sigmoid(i)
+        f = torch.sigmoid(f)
+        o = torch.sigmoid(o)
+        g = torch.tanh(g)
+        cy = f * cx + i * g  
+        hy = o * torch.tanh(cy)
+
+        return (hy, cy)
+
+    
+class GRUCell(nn.Module):
+    def __init__(self, input_size, hidden_size, bias=True):
+        super(GRUCell, self).__init__()
+        self.input_size = input_size
+        self.hidden_size = hidden_size
+        self.bias = bias
+
+        self.x2h = nn.Linear(input_size, 2 * hidden_size, bias=bias)
+        self.h2h = nn.Linear(hidden_size, 2 * hidden_size, bias=bias)
+
+        self.x2r = nn.Linear(input_size, hidden_size, bias=bias)
+        self.h2r = nn.Linear(hidden_size, hidden_size, bias=bias)
+        self.reset_parameters()
+        
+
+    def reset_parameters(self):
+        std = 1.0 / math.sqrt(self.hidden_size)
+        for w in self.parameters():
+            w.data.uniform_(-std, std)
+
+    def forward(self, input, hx=None):
+        if hx is None:
+            hx = input.new_zeros(self.hidden_size, requires_grad=False)
+
+        z, r = torch.chunk(self.x2h(input) + self.h2h(hx), 2, -1)
+        z = torch.sigmoid(z)
+        r = torch.sigmoid(r)
+        g = torch.tanh(self.h2r(hx)*r + self.x2r(input))
+        hy = z * hx + (1 - z) * g
+        
+        return hy
+
+
+class RNNModel(nn.Module):
+    def __init__(self, mode, output_size, num_layers, bias):
+        super(RNNModel, self).__init__()
+        self.mode = mode
+        self.input_size = output_size
+        self.hidden_size = output_size
+        self.num_layers = num_layers
+        self.bias = bias
+        self.output_size = output_size
+        
+        self.rnn_cell_list = nn.ModuleList()
+        
+        if mode == 'LSTM':
+    
+            self.rnn_cell_list.append(LSTMCell(self.input_size,
+                                              self.hidden_size,
+                                              self.bias))
+            for l in range(1, self.num_layers):
+                self.rnn_cell_list.append(LSTMCell(self.hidden_size,
+                                                  self.hidden_size,
+                                                  self.bias))
+            
+
+        elif mode == 'GRU':
+    
+            self.rnn_cell_list.append(GRUCell(self.input_size,
+                                              self.hidden_size,
+                                              self.bias))
+            for l in range(1, self.num_layers):
+                self.rnn_cell_list.append(GRUCell(self.hidden_size,
+                                                  self.hidden_size,
+                                                  self.bias))
+
+        else:
+            raise ValueError("Invalid RNN mode selected.")
+
+
+        self.att_fc = nn.Linear(self.hidden_size, 1)
+        self.fc = nn.Linear(self.hidden_size, self.output_size)
+
+        
+    def forward(self, input, time_steps=10, hx=None):
+        # The 'input' is the input x into the first timestep
+        # I think this should be a random vector
+        assert input.shape == (self.output_size, )
+
+        outs = []
+        h0 = [None] * self.num_layers if hx is None else list(hx)
+    
+
+        X = [None] * time_steps
+        X[0] = input # first input is 'input'
+        for layer_idx, layer_cell in enumerate(self.rnn_cell_list):
+            hx = h0[layer_idx]
+            for i in range(time_steps):
+                hx = layer_cell(X[i], hx)
+                
+                # we feed in this timestep's output into the next timestep's input
+                # except if we are at the last timestep
+                if i != time_steps-1:
+                    X[i+1] = hx if self.mode == 'GRU' else hx[0]
+                
+        outs = X
+    
+
+        # out = outs[-1].squeeze()
+
+        # out = self.fc(out)
+        
+        # return out
+
+        return outs
+    
+
+class BidirRecurrentModel(nn.Module):
+    def __init__(self, mode, input_size, hidden_size, num_layers, bias, output_size):
+        super(BidirRecurrentModel, self).__init__()
+        self.mode = mode
+        self.input_size = input_size
+        self.hidden_size = hidden_size
+        self.num_layers = num_layers
+        self.bias = bias
+        self.output_size = output_size
+        
+        self.rnn_cell_list = nn.ModuleList()
+        self.rnn_cell_list_rev = nn.ModuleList()
+        
+        if mode == 'LSTM':
+            self.rnn_cell_list.append(LSTMCell(self.input_size,
+                                               self.hidden_size,
+                                               self.bias))
+            for l in range(1, self.num_layers):
+                self.rnn_cell_list.append(LSTMCell(self.hidden_size,
+                                                  self.hidden_size,
+                                                  self.bias))
+                                                  
+            self.rnn_cell_list_rev.append(LSTMCell(self.input_size,
+                                                   self.hidden_size,
+                                                   self.bias))
+            for l in range(1, self.num_layers):
+                self.rnn_cell_list_rev.append(LSTMCell(self.hidden_size,
+                                                       self.hidden_size,
+                                                       self.bias))
+                                                  
+        elif mode == 'GRU':
+            self.rnn_cell_list.append(GRUCell(self.input_size,
+                                              self.hidden_size,
+                                              self.bias))
+            for l in range(1, self.num_layers):
+                self.rnn_cell_list.append(GRUCell(self.hidden_size,
+                                                  self.hidden_size,
+                                                  self.bias))
+            
+            self.rnn_cell_list_rev.append(GRUCell(self.input_size,
+                                                  self.hidden_size,
+                                                  self.bias))
+            for l in range(1, self.num_layers):
+                self.rnn_cell_list_rev.append(GRUCell(self.hidden_size,
+                                                      self.hidden_size,
+                                                      self.bias))
+ 
+        else:
+            raise ValueError("Invalid RNN mode selected.")
+
+
+        self.fc = nn.Linear(2 * self.hidden_size, self.output_size)
+        
+        
+    def forward(self, input, hx=None):
+        assert NotImplementedError('right now this forward function is written for classification. \
+                                You should modify it for our purpose, like the RNNModel was.')
+        outs = []
+        outs_rev = []
+        
+        X = list(input.permute(1, 0, 2))
+        X_rev = list(input.permute(1, 0, 2))
+        X_rev.reverse()
+        hi = [None] * self.num_layers if hx is None else list(hx)
+        hi_rev = [None] * self.num_layers if hx is None else list(hx)
+        for j in range(self.num_layers):
+            hx = hi[j]
+            hx_rev = hi_rev[j]
+            for i in range(input.shape[1]):
+                hx = self.rnn_cell_list[j](X[i], hx)
+                X[i] = hx if self.mode != 'LSTM' else hx[0]
+                hx_rev = self.rnn_cell_list_rev[j](X_rev[i], hx_rev)
+                X_rev[i] = hx_rev if self.mode != 'LSTM' else hx_rev[0]
+        outs = X 
+        outs_rev = X_rev
+
+        out = outs[-1].squeeze()
+        out_rev = outs_rev[0].squeeze()
+        out = torch.cat((out, out_rev), 1)
+
+        out = self.fc(out)
+        return out
\ No newline at end of file
diff --git a/MetaAugment/genetic_learner_results.py b/MetaAugment/genetic_learner_results.py
new file mode 100644
index 0000000000000000000000000000000000000000..35d9de8df2e17748b34e6879d4a3ae75dca9d9fb
--- /dev/null
+++ b/MetaAugment/genetic_learner_results.py
@@ -0,0 +1,109 @@
+import matplotlib.pyplot as plt
+import numpy as np
+
+
+# Fixed seed (same as benchmark)
+
+# Looking at last generation can make out general trends of which transformations lead to the largest accuracies
+
+
+gen_1_acc = [0.1998, 0.1405, 0.1678, 0.9690, 0.9672, 0.9540, 0.9047, 0.9730, 0.2060, 0.9260, 0.8035, 0.9715, 0.9737, 0.14, 0.9645]
+
+gen_2_acc = [0.9218, 0.9753, 0.9758, 0.1088, 0.9710, 0.1655, 0.9735, 0.9655, 0.9740, 0.9377]
+
+gen_3_acc = [0.1445, 0.9740, 0.9643, 0.9750, 0.9492, 0.9693, 0.1262, 0.9660, 0.9760, 0.9697]
+
+gen_4_acc = [0.9697, 0.1238, 0.9613, 0.9737, 0.9603, 0.8620, 0.9712, 0.9617, 0.9737, 0.1855]
+
+gen_5_acc = [0.6445, 0.9705, 0.9668, 0.9765, 0.1142, 0.9780, 0.9700, 0.2120, 0.9555, 0.9732]
+
+gen_6_acc = [0.9710, 0.9665, 0.2077, 0.9535, 0.9765, 0.9712, 0.9697, 0.2145, 0.9523, 0.9718, 0.9718, 0.9718, 0.2180, 0.9622, 0.9785]
+
+gen_acc = [gen_1_acc, gen_2_acc, gen_3_acc, gen_4_acc, gen_5_acc, gen_6_acc]
+
+gen_acc_means = []
+gen_acc_stds = []
+
+for val in gen_acc:
+    gen_acc_means.append(np.mean(val))
+    gen_acc_stds.append(np.std(val))
+
+
+
+# Vary seed
+
+gen_1_vary = [0.1998, 0.9707, 0.9715, 0.9657, 0.8347, 0.9655, 0.1870, 0.0983, 0.3750, 0.9765, 0.9712, 0.9705, 0.9635, 0.9718, 0.1170]
+
+gen_2_vary = [0.9758, 0.9607, 0.9597, 0.9753, 0.1165, 0.1503, 0.9747, 0.1725, 0.9645, 0.2290]
+
+gen_3_vary = [0.1357, 0.9725, 0.1708, 0.9607, 0.2132, 0.9730, 0.9743, 0.9690, 0.0850, 0.9755]
+
+gen_4_vary = [0.9722, 0.9760, 0.9697, 0.1155, 0.9715, 0.9688, 0.1785, 0.9745, 0.2362, 0.9765]
+
+gen_5_vary = [0.9705, 0.2280, 0.9745, 0.1875, 0.9735, 0.9735, 0.9720, 0.9678, 0.9770, 0.1155]
+
+gen_6_vary = [0.9685, 0.9730, 0.9735, 0.9760, 0.1495, 0.9707, 0.9700, 0.9747, 0.9750, 0.1155, 0.9732, 0.9745, 0.9758, 0.9768, 0.1155]
+
+gen_vary = [gen_1_vary, gen_2_vary, gen_3_vary, gen_4_vary, gen_5_vary, gen_6_vary]
+
+gen_vary_means = []
+gen_vary_stds = []
+
+for val in gen_vary:
+    gen_vary_means.append(np.mean(val))
+    gen_vary_stds.append(np.std(val))
+
+
+
+
+
+# Multiple runs 
+
+gen_1_mult = [0.1762, 0.9575, 0.1200, 0.9660, 0.9650, 0.9570, 0.9745, 0.9700, 0.15, 0.23, 0.16, 0.186, 0.9640, 0.9650]
+
+gen_2_mult = [0.17, 0.1515, 0.1700, 0.9625, 0.9630, 0.9732, 0.9680, 0.9633, 0.9530, 0.9640]
+
+gen_3_mult = [0.9750, 0.9720, 0.9655, 0.9530, 0.9623, 0.9730, 0.9748, 0.9625, 0.9716, 0.9672]
+
+gen_4_mult = [0.9724, 0.9755, 0.9657, 0.9718, 0.9690, 0.9735, 0.9715, 0.9300, 0.9725, 0.9695]
+
+gen_5_mult = [0.9560, 0.9750, 0.8750, 0.9717, 0.9731, 0.9741, 0.9747, 0.9726, 0.9729, 0.9727]
+
+gen_6_mult = [0.9730, 0.9740, 0.9715, 0.9755, 0.9761, 0.9700, 0.9755, 0.9750, 0.9726, 0.9748, 0.9705, 0.9745, 0.9752, 0.9740, 0.9744]
+
+
+
+gen_mult = [gen_1_mult, gen_2_mult, gen_3_mult,  gen_4_mult, gen_5_mult, gen_6_mult]
+
+gen_mult_means = []
+gen_mult_stds = []
+
+for val in gen_mult:
+    gen_mult_means.append(np.mean(val))
+    gen_mult_stds.append(np.std(val))
+
+num_gen = [i for i in range(len(gen_mult))]
+
+
+# Baseline
+baseline = [0.7990 for i in range(len(gen_mult))]
+
+
+
+# plt.errorbar(num_gen, gen_acc_means, yerr=gen_acc_stds, linestyle = 'dotted', label = 'Fixed seed GA')
+# plt.errorbar(num_gen, gen_vary_means, linestyle = 'dotted', yerr=gen_vary_stds, label = 'Varying seed GA')
+# plt.errorbar(num_gen, gen_mult_means, linestyle = 'dotted', yerr=gen_mult_stds, label = 'Varying seed GA 2')
+
+plt.plot(num_gen, gen_acc_means, linestyle = 'dotted', label = 'Fixed seed GA')
+plt.plot(num_gen, gen_vary_means, linestyle = 'dotted',  label = 'Varying seed GA')
+plt.plot(num_gen, gen_mult_means, linestyle = 'dotted', label = 'Varying seed GA 2')
+
+plt.plot(num_gen, baseline, label = 'Fixed seed baseline')
+
+
+plt.xlabel('Generation', fontsize = 16)
+plt.ylabel('Validation Accuracy', fontsize = 16)
+
+plt.legend()
+
+plt.savefig('GA_results.png')
\ No newline at end of file
diff --git a/MetaAugment/main.py b/MetaAugment/main.py
index 0fd76bcf189a297f1d8decd88f39851f4ce3433c..0c5cdeae9020cda7d9923d345e87f3ea93ac2595 100644
--- a/MetaAugment/main.py
+++ b/MetaAugment/main.py
@@ -1,12 +1,9 @@
 import numpy as np
 import torch
-torch.manual_seed(0)
 import torch.nn as nn
-import torch.nn.functional as F
 import torch.optim as optim
 import torchvision
 import torchvision.datasets as datasets
-import torchvision.transforms.autoaugment as autoaugment
 #import MetaAugment.AutoAugmentDemo.ops as ops # 
 
 # code from https://github.com/ChawDoe/LeNet5-MNIST-PyTorch/blob/master/train.py
@@ -25,7 +22,7 @@ def create_toy(train_dataset, test_dataset, batch_size, n_samples, seed=100):
     shuffle_order_test = np.random.RandomState(seed=seed).permutation(len(test_dataset))
     shuffled_test_dataset = torch.utils.data.Subset(test_dataset, shuffle_order_test)
 
-    big = 4 # how much bigger is the test set
+    big = 1 # how much bigger is the test set
 
     indices_test = torch.arange(int(n_samples*len(test_dataset)*big))
     reduced_test_dataset = torch.utils.data.Subset(shuffled_test_dataset, indices_test)
@@ -109,6 +106,8 @@ def train_child_network(child_network, train_loader, test_loader, sgd,
 
     if logging:
         return best_acc.item(), acc_log
+    
+    print('main.train_child_network best accuracy: ', best_acc)
     return best_acc.item()
 
 if __name__=='__main__':
diff --git a/Procfile b/Procfile
deleted file mode 100644
index 05b126a4f3af13308397226e5c9cd881e2913083..0000000000000000000000000000000000000000
--- a/Procfile
+++ /dev/null
@@ -1 +0,0 @@
-web: flask run --host=0.0.0.0 --port=$PORT
diff --git a/auto_augmentation/.DS_Store b/auto_augmentation/.DS_Store
deleted file mode 100644
index d425cd4dcb0261152a0f748a0c6c05c35c554588..0000000000000000000000000000000000000000
Binary files a/auto_augmentation/.DS_Store and /dev/null differ
diff --git a/auto_augmentation/__init__.py b/auto_augmentation/__init__.py
deleted file mode 100644
index da6ac6a46769c882614ff0940d1535877c4728d7..0000000000000000000000000000000000000000
--- a/auto_augmentation/__init__.py
+++ /dev/null
@@ -1,32 +0,0 @@
-import os
-
-from flask import Flask, render_template, request, flash
-
-from auto_augmentation import home, progress,result
-
-def create_app(test_config=None):
-    # create and configure the app
-    app = Flask(__name__, instance_relative_config=True)
-    app.config.from_mapping(
-        SECRET_KEY='dev',
-    )
-
-    if test_config is None:
-        # load the instance config, if it exists, when not testing
-        app.config.from_pyfile('config.py', silent=True)
-    else:
-        # load the test config if passed in
-        app.config.from_mapping(test_config)
-
-    # ensure the instance folder exists
-    os.makedirs(app.instance_path, exist_ok=True)
-
-    from auto_augmentation import download_file
-
-    app.register_blueprint(home.bp)
-    app.register_blueprint(progress.bp)
-    app.register_blueprint(result.bp)
-    app.register_blueprint(download_file.bp)
-    
-
-    return app
diff --git a/auto_augmentation/download_file.py b/auto_augmentation/download_file.py
deleted file mode 100644
index 35b9f5a9fdb8b37351cffdc8ccf8d7f51cb131ba..0000000000000000000000000000000000000000
--- a/auto_augmentation/download_file.py
+++ /dev/null
@@ -1,12 +0,0 @@
-from flask import Blueprint, request, render_template, flash, send_file
-
-bp = Blueprint("download_file", __name__)
-
-@bp.route("/download_file", methods=["GET"])
-@bp.route("/download", methods=["GET", "POST"])
-def download():    
-    # Setup for the 'return send_file()' function call at the end of this function
-    path = 'templates/CNN.zip' # e.g. 'templates/download.markdown'
-
-    return send_file(path,
-                    as_attachment=True)
diff --git a/auto_augmentation/home.py b/auto_augmentation/home.py
deleted file mode 100644
index 7b14acb5ecee2aea47d184117701959714d10894..0000000000000000000000000000000000000000
--- a/auto_augmentation/home.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from flask import Blueprint, render_template
-
-bp = Blueprint("home", __name__)
-
-@bp.route("/")
-def index():
-    return render_template("home.html")
diff --git a/auto_augmentation/progress.py b/auto_augmentation/progress.py
deleted file mode 100644
index b95acdc19ded31e44c9a66dc922bc4eade7ecac6..0000000000000000000000000000000000000000
--- a/auto_augmentation/progress.py
+++ /dev/null
@@ -1,9 +0,0 @@
-from flask import Blueprint, request, render_template, flash, send_file
-import subprocess
-
-bp = Blueprint("progress", __name__)
-
-@bp.route("/user_input", methods=["GET", "POST"])
-def response():
-    
-    return render_template("progress.html")
\ No newline at end of file
diff --git a/auto_augmentation/result.py b/auto_augmentation/result.py
deleted file mode 100644
index 965af5a298f5c22b51b2562a25a25769a0a2d96c..0000000000000000000000000000000000000000
--- a/auto_augmentation/result.py
+++ /dev/null
@@ -1,9 +0,0 @@
-from flask import Blueprint, request, render_template, flash, send_file
-import subprocess
-
-bp = Blueprint("result", __name__)
-
-@bp.route("/show_result", methods=["GET", "POST"])
-def response():
-    
-    return render_template("result.html")
\ No newline at end of file
diff --git a/auto_augmentation/static/.DS_Store b/auto_augmentation/static/.DS_Store
deleted file mode 100644
index cbf9ce2f5606f2ec8e9da4a923b1306d7d64d602..0000000000000000000000000000000000000000
Binary files a/auto_augmentation/static/.DS_Store and /dev/null differ
diff --git a/auto_augmentation/static/image/data_augment_cat.jpeg b/auto_augmentation/static/image/data_augment_cat.jpeg
deleted file mode 100644
index 900b738c8375325e26d21c43eb3a5ccf3b3c2827..0000000000000000000000000000000000000000
Binary files a/auto_augmentation/static/image/data_augment_cat.jpeg and /dev/null differ
diff --git a/auto_augmentation/static/image/training_plot.png b/auto_augmentation/static/image/training_plot.png
deleted file mode 100644
index 1128a55103c47d7ec36b43bb83a1698ae01c44b3..0000000000000000000000000000000000000000
Binary files a/auto_augmentation/static/image/training_plot.png and /dev/null differ
diff --git a/auto_augmentation/templates/basic.html b/auto_augmentation/templates/basic.html
deleted file mode 100644
index db609ea5f49ca101d3394b2e1219c1488160908b..0000000000000000000000000000000000000000
--- a/auto_augmentation/templates/basic.html
+++ /dev/null
@@ -1,12 +0,0 @@
-<!doctype html>
-<html>
-  <head>
-    {% block head %}
-    <title>{% block title %}{% endblock %} - Meta Reinforcement Learning for Data Augmentation</title>
-    {% endblock %}
-  </head>
-  <body>
-    {% block body %}{% endblock %}
-  </body>
-</html>
-
diff --git a/auto_augmentation/templates/home.html b/auto_augmentation/templates/home.html
deleted file mode 100644
index a1e7d3d03070c13fc055656ee64443424fb85609..0000000000000000000000000000000000000000
--- a/auto_augmentation/templates/home.html
+++ /dev/null
@@ -1,42 +0,0 @@
-{% extends "basic.html" %}
-{% block title%}Home{% endblock %}
-{% block body %}
-<h1>Meta Reinforcement Learning for Data Augmentation</h1>
-
-
-<form action="/user_input">
-  <!-- upload dataset -->
-  <label for="dataset">Please upload your dataset here:</label>
-  <input type="file" name="dataset" class="upload"><br><br>
-
-  <!-- radio button -->
-  What task is your dataset used for?<br>
-  <input type="radio" id="outputtype1"
-    name="output" value="binary_cls">
-  <label for="outputtype1">Binary Classification</label><br>
-
-  <input type="radio" id="outputtype2"
-    name="output" value="multi_cls">
-  <label for="outputtype2">Multi-classification</label><br>
-
-  <input type="radio" id="outputtype3"
-  name="output" value="regression">
-  <label for="outputtype3">Linear Regression</label><br><br>
-
-  <label for="data_aug_method">Which data augmentation method you would like exclude?</label>
-    <select id="data_aug_method" name="data_aug_method">
-        <option value="Translate">Translate</option>
-        <option value="Rotate">Rotate</option>
-        <option value="AutoContrast">AutoContrast</option>
-        <option value="Equalize">Equalize</option>
-        <option value="Solarize">Solarize</option>
-        <option value="Posterize">Posterize</option>
-        <option value="Contrast">Contrast</option>
-        <option value="Brightness">Brightness</option>
-
-    </select><br><br>
-
-  <input type="submit">
-</form>
-  
-{% endblock %}
diff --git a/auto_augmentation/templates/progress.html b/auto_augmentation/templates/progress.html
deleted file mode 100644
index ea4c33d713224bd889f147ddcb5f8d2e6fb0b6f7..0000000000000000000000000000000000000000
--- a/auto_augmentation/templates/progress.html
+++ /dev/null
@@ -1,14 +0,0 @@
-{% extends "basic.html" %}
-{% block title%}Progress{% endblock %}
-{% block body %}
-Training the model...
-
-<div>
-    <img src="{{url_for('static', filename='image/training_plot.png')}}" class="img-thumbnail" />
-    <form action="/show_result">
-        <input type="submit" value='Show Result'>
-    </form>
-
-</div>
- 
-{% endblock %}
\ No newline at end of file
diff --git a/auto_augmentation/templates/result.html b/auto_augmentation/templates/result.html
deleted file mode 100644
index 6e127fdb1ed52f2abefc166f23dc7b8d25062f48..0000000000000000000000000000000000000000
--- a/auto_augmentation/templates/result.html
+++ /dev/null
@@ -1,18 +0,0 @@
-{% extends "basic.html" %}
-{% block title %}Result{% endblock %}
-{% block body %}
-
-<div>  
-  <b>Accuracy before data augmentation is: 64.6%</b><br>
-  <b>Accuracy after data augmentation is: 79.3%</b>
-
-</div>
-
-<div>
-  <form action="/download">
-  <input type="submit" value='Download CNN'>
-  </form>
-</div>
-
-{% endblock %}
-
diff --git a/bad_lenet_baseline.txt b/bad_lenet_baseline.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0ee426570305c6b7ab6eee3db0d80cd7b3d7d604
--- /dev/null
+++ b/bad_lenet_baseline.txt
@@ -0,0 +1,2 @@
+0.4399999976158142
+0.550000011920929
diff --git a/check_pickles.py b/check_pickles.py
new file mode 100644
index 0000000000000000000000000000000000000000..d29e77309e2c36f8ab2b59af45420c23542767ce
--- /dev/null
+++ b/check_pickles.py
@@ -0,0 +1,12 @@
+import pickle
+from pprint import pprint
+
+with open('randomsearch_logs.pkl', 'rb') as file:
+    list = pickle.load(file)
+
+print(len(list))
+
+with open('gru_logs.pkl','rb') as file:
+    list = pickle.load(file)
+
+print(len(list))
diff --git a/conftest.py b/conftest.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/docs/Makefile b/docs/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..d0c3cbf1020d5c292abdedf27627c6abe25e2293
--- /dev/null
+++ b/docs/Makefile
@@ -0,0 +1,20 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS    ?=
+SPHINXBUILD   ?= sphinx-build
+SOURCEDIR     = source
+BUILDDIR      = build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+	@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+	@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/docs/make.bat b/docs/make.bat
new file mode 100644
index 0000000000000000000000000000000000000000..dc1312ab09ca6fb0267dee6b28a38e69c253631a
--- /dev/null
+++ b/docs/make.bat
@@ -0,0 +1,35 @@
+@ECHO OFF
+
+pushd %~dp0
+
+REM Command file for Sphinx documentation
+
+if "%SPHINXBUILD%" == "" (
+	set SPHINXBUILD=sphinx-build
+)
+set SOURCEDIR=source
+set BUILDDIR=build
+
+%SPHINXBUILD% >NUL 2>NUL
+if errorlevel 9009 (
+	echo.
+	echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
+	echo.installed, then set the SPHINXBUILD environment variable to point
+	echo.to the full path of the 'sphinx-build' executable. Alternatively you
+	echo.may add the Sphinx directory to PATH.
+	echo.
+	echo.If you don't have Sphinx installed, grab it from
+	echo.https://www.sphinx-doc.org/
+	exit /b 1
+)
+
+if "%1" == "" goto help
+
+%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+goto end
+
+:help
+%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+
+:end
+popd
diff --git a/docs/source/conf.py b/docs/source/conf.py
new file mode 100644
index 0000000000000000000000000000000000000000..d49d8b583463c4a3e00080bdffaab6809a2b3277
--- /dev/null
+++ b/docs/source/conf.py
@@ -0,0 +1,69 @@
+# Configuration file for the Sphinx documentation builder.
+#
+# This file only contains a selection of the most common options. For a full
+# list see the documentation:
+# https://www.sphinx-doc.org/en/master/usage/configuration.html
+
+# -- Path setup --------------------------------------------------------------
+
+# If extensions (or modules to document with autodoc) are in another directory,
+# add these directories to sys.path here. If the directory is relative to the
+# documentation root, use os.path.abspath to make it absolute, like shown here.
+#
+import os
+import sys
+
+# this tells sphinx that our MetaAugment folder is two folder levels
+# outside the /docs folder
+sys.path.insert(0, os.path.abspath('../..'))
+
+
+# -- Project information -----------------------------------------------------
+
+project = 'metarl'
+copyright = '2022, metarl_team'
+author = 'metarl_team'
+
+# The full version, including alpha/beta/rc tags
+release = '0.0'
+
+
+# -- General configuration ---------------------------------------------------
+
+# Add any Sphinx extension module names here, as strings. They can be
+# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
+# ones.
+extensions = [
+    'sphinx.ext.autodoc',
+    'sphinx.ext.autosummary',
+    'sphinx.ext.coverage', 
+    'sphinx.ext.napoleon',
+    'sphinx.ext.viewcode',
+]
+
+# turn on sphinx.ext.autosummary
+autosummary_generate = False
+
+# turn on sphinx.ext.coverage
+coverage_show_missing_items = True
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ['_templates']
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+# This pattern also affects html_static_path and html_extra_path.
+exclude_patterns = []
+
+
+# -- Options for HTML output -------------------------------------------------
+
+# The theme to use for HTML and HTML Help pages.  See the documentation for
+# a list of builtin themes.
+#
+html_theme = 'alabaster'
+
+# Add any paths that contain custom static files (such as style sheets) here,
+# relative to this directory. They are copied after the builtin static files,
+# so a file named "default.css" will overwrite the builtin "default.css".
+html_static_path = ['_static']
\ No newline at end of file
diff --git a/docs/source/index.rst b/docs/source/index.rst
new file mode 100644
index 0000000000000000000000000000000000000000..be94200eadb75e6cbb8e5547227d8f76c2878e52
--- /dev/null
+++ b/docs/source/index.rst
@@ -0,0 +1,33 @@
+Welcome to metarl's documentation!
+==================================
+
+.. toctree::
+   :maxdepth: 3
+   :caption: Contents:
+
+   usage/installation
+
+.. autoclass:: MetaAugment.autoaugment_learners.aa_learner.aa_learner
+   
+.. autoclass:: MetaAugment.autoaugment_learners.randomsearch_learner.randomsearch_learner
+.. autoclass:: MetaAugment.autoaugment_learners.gru_learner.gru_learner
+
+.. automodule:: MetaAugment.controller_networks
+   :members:
+
+.. automodule:: MetaAugment.child_networks
+   :members:
+
+
+   
+
+
+
+
+
+Indices and tables
+==================
+
+* :ref:`genindex`
+* :ref:`modindex`
+* :ref:`search`
diff --git a/docs/source/usage/installation.rst b/docs/source/usage/installation.rst
new file mode 100644
index 0000000000000000000000000000000000000000..706976672b7d14a3caa233b88ef90a24caefb86e
--- /dev/null
+++ b/docs/source/usage/installation.rst
@@ -0,0 +1 @@
+explain how to install MetaAugment here
\ No newline at end of file
diff --git a/gru_learner.pkl b/gru_learner.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..4eaf3a8d472dbd5ab986535cedfa35586761f1bc
Binary files /dev/null and b/gru_learner.pkl differ
diff --git a/gru_logs.pkl b/gru_logs.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..0f05ed6ae9e8455b59460fd109da56f2ea422f43
Binary files /dev/null and b/gru_logs.pkl differ
diff --git a/heroku.yml b/heroku.yml
deleted file mode 100644
index 8a254e2e411577dcc5abec859535d28502c75e38..0000000000000000000000000000000000000000
--- a/heroku.yml
+++ /dev/null
@@ -1,3 +0,0 @@
-build:
-  docker:
-    web: Dockerfile  # path to your Dockerfile 
diff --git a/old_templates/home.html b/old_templates/home.html
deleted file mode 100644
index 25108159cdf9dc16ccaa77ade8fcb76e3dddf73c..0000000000000000000000000000000000000000
--- a/old_templates/home.html
+++ /dev/null
@@ -1,62 +0,0 @@
-{% extends "layout.html" %}
-
-{% block content %}
-
-    <!-- Starts image section -->
-    <div class="row justify-content-md-center mb-4">
-        <h2 class='text-primary'>Data Augmentation with Meta Reinforcement Learning</h2>
-    </div>
-    
-    <div>
-      <img src="{{url_for('autoaugmentation/static', filename='images/data augment cat.jpeg')}}" class="img-thumbnail" />
-      <div class="caption">
-        <p><strong>Data Augmentation</strong></p>
-      </div>
-    </div>
-    <!-- Ends image section -->
-    
-    <!-- Starts upload section -->
-
-    <section>
-
-      <div class="container-fluid details">
-        <form action="/predict" method="post" enctype="multipart/form-data" onsubmit="showloading()">
-
-            <input type="file" name="image" class="upload"><br><br>
-            
-            <label for="user_classify">Which classfiication can best describe the uploading picutre:</label>
-            <select id="user_classify" name="user_classify">
-                <option value="Not sure">Not sure</option>
-                <option value="Diseased Leaf">Diseased Leaf</option>
-                <option value="Diseased Plant">Diseased Plant</option>
-                <option value="Healthy Leaf">Healthy Leaf</option>
-                <option value="Healthy Plant">Healthy Plant</option>
-            </select><br><br>
-
-            <label class="camera distance" for="Myheight"> Input yout height(m):</label>
-            <input type="text" id="Myheight" name="Myheight"><br>
-
-            <label class="camera distance" for="camera_dist_left"> Camera predicted distance from the left pillar(m):</label>
-            <input type="text" id="camera_dist_left" name="camera_dist_left"><br>
-
-            <label class="camera distance" for="camera_dist_right"> Camera predicted distance from the right pillar(m):</label>
-            <input type="text" id="camera_dist_right" name="camera_dist_right"><br><br>
-
-            <h5>More information about this cotton plant (optional)</h5>
-            <label class='measurements' for='height'>Cotton plant height: </label>
-            <input type='text' id='height' name = 'height'><br>
-            <label class='measurements' for='width'>Cotton plant width:</label>
-            <input type='text' id='width' name = 'width'><br>
-            <label class='notes' for='added_notes'>Additional notes: </label>
-            <input type='text' id='added_notes' name = 'added_notes'><br><br>
-
-            <input type="submit" value="Submit and Predict!"> <br><br><br>
-        </form>
-      </div>
-    </section>
-    
-    <!-- Ends upload section -->
-
-
-
-{% endblock %}
\ No newline at end of file
diff --git a/old_templates/layout.html b/old_templates/layout.html
deleted file mode 100644
index 501b9e0767a95365af0317748cff5cad6e094b94..0000000000000000000000000000000000000000
--- a/old_templates/layout.html
+++ /dev/null
@@ -1,33 +0,0 @@
-<!doctype html>
-<html>
-  
-  <head>
-   
-    <meta charset="utf-8">
-    <meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
-    
-    <link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/css/bootstrap.min.css">
-    <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script>
-    <script src="https://maxcdn.bootstrapcdn.com/bootstrap/3.4.1/js/bootstrap.min.js"></script>
-    
-    
-    <title> Cotton Leaf Disease Detection </title>
-    
-  </head>
-  
-  <body>
-      <div class="container pt-3">
-      
-        <div id="content">{% block content %}{% endblock %}</div>
-      
-        <div id="footer">
-          {% block footer %}
-          <div class="row">
-
-          </div>
-          {% endblock %}
-        </div>
-      </div>
-    
-  </body>
-</html>
\ No newline at end of file
diff --git a/old_templates/result.html b/old_templates/result.html
deleted file mode 100644
index 4eca0b216866b09a2b48b0a1a69ee675874040e7..0000000000000000000000000000000000000000
--- a/old_templates/result.html
+++ /dev/null
@@ -1,24 +0,0 @@
-{% extends "basic.html" %}
-{% block title %}Result{% endblock %}
-{% block body %}
-Response to query:<br>
-
-<div>  
-  <form action="/choose_file">
-    <input type="radio" id="downloadTypeChoice1"
-      name="Filetype" value="html">
-    <label for="downloadChoice1">HTML</label><br>
-  
-    <input type="radio" id="downloadTypeChoice2"
-      name="Filetype" value="markdown">
-    <label for="downloadTypeChoice2">Markdown</label><br>
-  
-    <input type="radio" id="downloadTypeChoice3"
-    name="Filetype" value="pdf">
-    <label for="downloadTypeChoice3">PDF</label><br>
-  
-    <input type="submit">
-  </form>
-</div>
-
-{% endblock %}
diff --git a/plot_pickles.py b/plot_pickles.py
new file mode 100644
index 0000000000000000000000000000000000000000..90462cfa84744dab03058add1d47d270f17bcd56
--- /dev/null
+++ b/plot_pickles.py
@@ -0,0 +1,47 @@
+import pickle
+from pprint import pprint
+import matplotlib.pyplot as plt
+from torch import gru
+
+def get_maxacc(log):
+    output = []
+    maxacc = 0
+    for policy, acc in log:
+        maxacc = max(maxacc, acc)
+        output.append(maxacc)
+    return output
+
+with open('randomsearch_logs.pkl', 'rb') as file:
+    rs_list = pickle.load(file)
+
+with open('gru_logs.pkl', 'rb') as file:
+    gru_list = pickle.load(file)
+
+
+plt.plot(get_maxacc(rs_list), label='randomsearcher')
+plt.plot(get_maxacc(gru_list), label='gru learner')
+plt.title('Comparing two agents')
+plt.ylabel('best accuracy to date')
+plt.xlabel('number of policies tested')
+plt.legend()
+plt.show()
+
+plt.plot([acc for pol,acc in rs_list], label='randomsearcher')
+plt.plot([acc for pol,acc in gru_list], label='gru learner')
+plt.title('Comparing two agents')
+plt.ylabel('best accuracy to date')
+plt.xlabel('number of policies tested')
+plt.legend()
+plt.show()
+
+
+def get_best5(log):
+    l = sorted(log, reverse=True, key=lambda x:x[1])
+    return (l[:5])
+
+def get_worst5(log):
+    l = sorted(log, key=lambda x:x[1])
+    return (l[:5])
+
+pprint(get_best5(rs_list))
+pprint(get_best5(gru_list))
\ No newline at end of file
diff --git a/progress.html b/progress.html
deleted file mode 100644
index 9bad71aadb77c26e3887dfddaa00e38e25d0336a..0000000000000000000000000000000000000000
--- a/progress.html
+++ /dev/null
@@ -1,8 +0,0 @@
-{% extends "structure.html" %}
-{% block title%}Home{% endblock %}
-{% block body %}
-<h1>Loading</h1>
-      <progress value = "65" max = "100"/>
-
-{% endblock %}
-
diff --git a/randomsearch_logs.pkl b/randomsearch_logs.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..5673b3dea80e0e57e55629ef67f264c27888ff00
Binary files /dev/null and b/randomsearch_logs.pkl differ
diff --git a/requirements.txt b/requirements.txt
deleted file mode 100644
index 76b3ed54867825d4aadb3d219c376ee10576d1f3..0000000000000000000000000000000000000000
--- a/requirements.txt
+++ /dev/null
@@ -1,18 +0,0 @@
-attrs
-click==8.0.3
-Flask==2.0.2
-iniconfig==1.1.1
-itsdangerous==2.0.1
-Jinja2==3.0.3
-MarkupSafe==2.0.1
-packaging==21.3
-pandoc==2.0.1
-pdflatex
-pluggy==1.0.0
-py==1.11.0
-pyparsing==3.0.6
-pytest
-python-dotenv==0.19.2
-toml==0.10.2
-Werkzeug==2.0.2
-weasyprint==51
\ No newline at end of file
diff --git a/setup.py b/setup.py
deleted file mode 100644
index 606849326a4002007fd42060b51e69a19c18675c..0000000000000000000000000000000000000000
--- a/setup.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from setuptools import setup
-
-setup()
diff --git a/stdout.txt b/stdout.txt
deleted file mode 100755
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000