diff --git a/.DS_Store b/.DS_Store
index 87b56ad1c0caa0cd8b0aa4497cbd4d095b75bc27..720cf3ab50cbd4bb4f33acbbc3cb3516e7778732 100644
Binary files a/.DS_Store and b/.DS_Store differ
diff --git a/MetaAugment/Baseline_JC.ipynb b/MetaAugment/Baseline_JC.ipynb
index d0ab8ea0710b9cf9a0cdb2629e80a8036b014d47..d979dc8a67b4c0232a21967b43e340f90b08a844 100644
--- a/MetaAugment/Baseline_JC.ipynb
+++ b/MetaAugment/Baseline_JC.ipynb
@@ -171,6 +171,183 @@
     },
     {
       "cell_type": "code",
+      "execution_count": 5,
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "KVhYheLfBP33",
+        "outputId": "8009d87f-7e39-40e3-c6ef-8f3a12f9433f"
+      },
+      "outputs": [
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
+            "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MetaAugment/train/MNIST/raw/train-images-idx3-ubyte.gz\n"
+          ]
+        },
+        {
+          "name": "stderr",
+          "output_type": "stream",
+          "text": [
+            "9913344it [00:04, 2462502.04it/s]                             \n"
+          ]
+        },
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "Extracting ./MetaAugment/train/MNIST/raw/train-images-idx3-ubyte.gz to ./MetaAugment/train/MNIST/raw\n",
+            "\n",
+            "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
+            "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MetaAugment/train/MNIST/raw/train-labels-idx1-ubyte.gz\n"
+          ]
+        },
+        {
+          "name": "stderr",
+          "output_type": "stream",
+          "text": [
+            "29696it [00:00, 3785722.37it/s]          \n"
+          ]
+        },
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "Extracting ./MetaAugment/train/MNIST/raw/train-labels-idx1-ubyte.gz to ./MetaAugment/train/MNIST/raw\n",
+            "\n",
+            "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
+            "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MetaAugment/train/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
+          ]
+        },
+        {
+          "name": "stderr",
+          "output_type": "stream",
+          "text": [
+            "1649664it [00:00, 3348476.95it/s]                             \n"
+          ]
+        },
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "Extracting ./MetaAugment/train/MNIST/raw/t10k-images-idx3-ubyte.gz to ./MetaAugment/train/MNIST/raw\n",
+            "\n",
+            "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
+            "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MetaAugment/train/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
+          ]
+        },
+        {
+          "name": "stderr",
+          "output_type": "stream",
+          "text": [
+            "5120it [00:00, 2935726.11it/s]          \n"
+          ]
+        },
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "Extracting ./MetaAugment/train/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MetaAugment/train/MNIST/raw\n",
+            "\n",
+            "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
+            "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MetaAugment/test/MNIST/raw/train-images-idx3-ubyte.gz\n"
+          ]
+        },
+        {
+          "name": "stderr",
+          "output_type": "stream",
+          "text": [
+            "9913344it [00:04, 2338660.11it/s]                             \n"
+          ]
+        },
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "Extracting ./MetaAugment/test/MNIST/raw/train-images-idx3-ubyte.gz to ./MetaAugment/test/MNIST/raw\n",
+            "\n",
+            "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
+            "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MetaAugment/test/MNIST/raw/train-labels-idx1-ubyte.gz\n"
+          ]
+        },
+        {
+          "name": "stderr",
+          "output_type": "stream",
+          "text": [
+            "29696it [00:00, 33554432.00it/s]         "
+          ]
+        },
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "Extracting ./MetaAugment/test/MNIST/raw/train-labels-idx1-ubyte.gz to ./MetaAugment/test/MNIST/raw\n",
+            "\n",
+            "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
+            "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MetaAugment/test/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
+          ]
+        },
+        {
+          "name": "stderr",
+          "output_type": "stream",
+          "text": [
+            "\n",
+            "1649664it [00:00, 2786152.46it/s]                             \n"
+          ]
+        },
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "Extracting ./MetaAugment/test/MNIST/raw/t10k-images-idx3-ubyte.gz to ./MetaAugment/test/MNIST/raw\n",
+            "\n",
+            "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
+            "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MetaAugment/test/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
+          ]
+        },
+        {
+          "name": "stderr",
+          "output_type": "stream",
+          "text": [
+            "5120it [00:00, 4789214.20it/s]          \n"
+          ]
+        },
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "Extracting ./MetaAugment/test/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MetaAugment/test/MNIST/raw\n",
+            "\n",
+            "0\tBest accuracy: 18.00%\n",
+            "10\tBest accuracy: 75.50%\n",
+            "20\tBest accuracy: 78.00%\n",
+            "30\tBest accuracy: 95.00%\n",
+            "40\tBest accuracy: 95.50%\n",
+            "50\tBest accuracy: 94.00%\n",
+            "60\tBest accuracy: 85.00%\n",
+            "70\tBest accuracy: 85.50%\n",
+            "80\tBest accuracy: 62.50%\n",
+            "90\tBest accuracy: 76.00%\n",
+            "Average best accuracy: 79.86%\n",
+            "\n",
+            "0\tAverage accuracy: 93.50%\n",
+            "10\tAverage accuracy: 93.45%\n",
+            "20\tAverage accuracy: 46.95%\n",
+            "30\tAverage accuracy: 71.41%\n",
+            "40\tAverage accuracy: 73.68%\n",
+            "50\tAverage accuracy: 64.50%\n",
+            "60\tAverage accuracy: 72.50%\n",
+            "70\tAverage accuracy: 94.36%\n",
+            "80\tAverage accuracy: 84.77%\n",
+            "90\tAverage accuracy: 92.14%\n",
+            "Average average accuracy: 80.92%\n",
+            "\n"
+          ]
+        }
+      ],
       "source": [
         "batch_size = 32               # size of batch the inner NN is trained with\n",
         "toy_size = 0.02               # total propeortion of training and test set we use\n",
@@ -198,47 +375,14 @@
         "    if baselines % 10 == 0:\n",
         "        print(\"{}\\tAverage accuracy: {:.2f}%\".format(baselines, best_acc*100))\n",
         "print(\"Average average accuracy: {:.2f}%\\n\".format(np.mean(best_accuracies)*100))"
-      ],
-      "metadata": {
-        "colab": {
-          "base_uri": "https://localhost:8080/"
-        },
-        "id": "KVhYheLfBP33",
-        "outputId": "8009d87f-7e39-40e3-c6ef-8f3a12f9433f"
-      },
-      "execution_count": 5,
-      "outputs": [
-        {
-          "output_type": "stream",
-          "name": "stdout",
-          "text": [
-            "0\tBest accuracy: 49.00%\n",
-            "10\tBest accuracy: 86.50%\n",
-            "20\tBest accuracy: 95.00%\n",
-            "30\tBest accuracy: 54.00%\n",
-            "40\tBest accuracy: 94.00%\n",
-            "50\tBest accuracy: 93.50%\n",
-            "60\tBest accuracy: 66.50%\n",
-            "70\tBest accuracy: 94.50%\n",
-            "80\tBest accuracy: 74.50%\n",
-            "90\tBest accuracy: 74.00%\n",
-            "Average best accuracy: 79.58%\n",
-            "\n",
-            "0\tAverage accuracy: 68.95%\n",
-            "10\tAverage accuracy: 69.95%\n",
-            "20\tAverage accuracy: 85.00%\n",
-            "30\tAverage accuracy: 93.32%\n",
-            "40\tAverage accuracy: 68.00%\n",
-            "50\tAverage accuracy: 85.36%\n",
-            "60\tAverage accuracy: 92.36%\n",
-            "70\tAverage accuracy: 56.95%\n",
-            "80\tAverage accuracy: 93.59%\n",
-            "90\tAverage accuracy: 64.91%\n",
-            "Average average accuracy: 78.90%\n",
-            "\n"
-          ]
-        }
       ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {},
+      "outputs": [],
+      "source": []
     }
   ],
   "metadata": {
@@ -262,9 +406,9 @@
       "name": "python",
       "nbconvert_exporter": "python",
       "pygments_lexer": "ipython3",
-      "version": "3.7.7"
+      "version": "3.9.7"
     }
   },
   "nbformat": 4,
   "nbformat_minor": 0
-}
\ No newline at end of file
+}
diff --git a/MetaAugment/CP2_Max.py b/MetaAugment/CP2_Max.py
index c1e91a97eed2634c29e78325556bf20e52e05ca9..792e81e1f85932408755840fbcbc09612137d39e 100644
--- a/MetaAugment/CP2_Max.py
+++ b/MetaAugment/CP2_Max.py
@@ -1,3 +1,4 @@
+from cgi import test
 import numpy as np
 import torch
 torch.manual_seed(0)
@@ -16,15 +17,24 @@ import copy
 from torchvision.transforms import functional as F, InterpolationMode
 from typing import List, Tuple, Optional, Dict
 import heapq
+import math
 
+import math
+import torch
+
+from enum import Enum
+from torch import Tensor
+from typing import List, Tuple, Optional, Dict
 
+from torchvision.transforms import functional as F, InterpolationMode
 
-# from MetaAugment.main import *
 # import MetaAugment.child_networks as child_networks
+# from main import *
+# from autoaugment_learners.autoaugment import *
 
 
-np.random.seed(0)
-random.seed(0)
+# np.random.seed(0)
+# random.seed(0)
 
 
 augmentation_space = [
@@ -172,9 +182,10 @@ train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=600
 
 
 
+
 class Evolutionary_learner():
 
-    def __init__(self, network, num_solutions = 30, num_generations = 10, num_parents_mating = 15, train_loader = None, child_network = None, p_bins = 11, mag_bins = 10, sub_num_pol=5, fun_num = 14, augmentation_space = None, train_dataset = None, test_dataset = None):
+    def __init__(self, network, num_solutions = 10, num_generations = 5, num_parents_mating = 5, train_loader = None, child_network = None, p_bins = 11, mag_bins = 10, sub_num_pol=5, fun_num = 14, augmentation_space = None, train_dataset = None, test_dataset = None):
         self.auto_aug_agent = Learner(fun_num=fun_num, p_bins=p_bins, m_bins=mag_bins, sub_num_pol=sub_num_pol)
         self.torch_ga = torchga.TorchGA(model=network, num_solutions=num_solutions)
         self.num_generations = num_generations
@@ -211,30 +222,6 @@ class Evolutionary_learner():
         
         return policies
 
-# Every image has specific operation. Policy for every image (2 (trans., prob., mag) output)
-
-
-# RNN -> change the end -/- leave for now, ask Javier
-
-
-# Use mini-batch with current output, get mode transformation -> mean probability and magnitude
-#   Pass through each image in mini-batch to get one/two (transformation, prob., mag.) tuples
-#   Average softmax probability (get softmax of the outputs, then average them to get the probability)
-
-
-# For every batch, store all outputs. Pick top operations
-# Every image -> output 2 operation tuples e.g. 14 trans + 1 prob + 1 mag. 32 output total. 
-#   14 neuron output is then prob. of transformations (softmax + average across dim = 0)
-#   1000x28 
-#   Problem 1: have 28, if we pick argmax top 2
-
-    # For each image have 28 dim output. Calculate covariance of 1000x28 using np.cov(28_dim_vector.T)
-    # Give 28x28 covariance matrix. Pick top k pairs (corresponds to largest covariance pairs)
-    #   Once we have pairs, go back to 1000x32 output. Find cases where the largest cov. pairs are used and use those probs and mags
-
-
-# Covariance matrix -> prob. of occurance (might be bad pairs)
-# Pair criteria -> highest softmax prob and probaility of occurence
 
     def get_full_policy(self, x):
         """
@@ -257,9 +244,9 @@ class Evolutionary_learner():
             full_policy.append(tuple(int_pol))
 
         return full_policy
-
+# 
     
-    def get_policy_cov(self, x):
+    def get_policy_cov(self, x, alpha = 0.5):
         """
         Need p_bins = 1, num_sub_pol = 1, mag_bins = 1
         """
@@ -268,48 +255,55 @@ class Evolutionary_learner():
         y = self.auto_aug_agent.forward(x) # 1000 x 32
 
         y_1 = torch.softmax(y[:,:self.auto_aug_agent.fun_num], dim = 1) # 1000 x 14
+        y[:,:self.auto_aug_agent.fun_num] = y_1
         y_2 = torch.softmax(y[:,section:section+self.auto_aug_agent.fun_num], dim = 1)
+        y[:,section:section+self.auto_aug_agent.fun_num] = y_2
         concat = torch.cat((y_1, y_2), dim = 1)
 
         cov_mat = torch.cov(concat.T)#[:self.auto_aug_agent.fun_num, self.auto_aug_agent.fun_num:]
         cov_mat = cov_mat[:self.auto_aug_agent.fun_num, self.auto_aug_agent.fun_num:]
         shape_store = cov_mat.shape
 
+        counter, prob1, prob2, mag1, mag2 = (0, 0, 0, 0, 0)
+
+
+        prob_mat = torch.zeros(shape_store)
+        for idx in range(y.shape[0]):
+            prob_mat[torch.argmax(y_1[idx])][torch.argmax(y_2[idx])] += 1
+        prob_mat = prob_mat / torch.sum(prob_mat)
+
+        cov_mat = (alpha * cov_mat) + ((1 - alpha)*prob_mat)
+
         cov_mat = torch.reshape(cov_mat, (1, -1)).squeeze()
         max_idx = torch.argmax(cov_mat)
         val = (max_idx//shape_store[0])
         max_idx = (val, max_idx - (val * shape_store[0]))
 
-        counter, prob1, prob2, mag1, mag2 = (0, 0, 0, 0, 0)
 
-        if self.augmentation_space[max_idx[0]]:
+        if not self.augmentation_space[max_idx[0]][1]:
             mag1 = None
-        if self.augmentation_space[max_idx[1]]:
+        if not self.augmentation_space[max_idx[1]][1]:
             mag2 = None
-
+   
         for idx in range(y.shape[0]):
-            # print("torch.argmax(y_1[idx]): ", torch.argmax(y_1[idx]))
-            # print("torch.argmax(y_2[idx]): ", torch.argmax(y_2[idx]))
-            # print("max idx0: ", max_idx[0])
-            # print("max idx1: ", max_idx[1])
-
             if (torch.argmax(y_1[idx]) == max_idx[0]) and (torch.argmax(y_2[idx]) == max_idx[1]):
-                prob1 += y[idx, self.auto_aug_agent.fun_num+1]
-                prob2 += y[idx, section+self.auto_aug_agent.fun_num+1]
+                prob1 += torch.sigmoid(y[idx, self.auto_aug_agent.fun_num]).item()
+                prob2 += torch.sigmoid(y[idx, section+self.auto_aug_agent.fun_num]).item()
                 if mag1 is not None:
-                    mag1 += y[idx, self.auto_aug_agent.fun_num+2]
+                    mag1 += min(max(0, (y[idx, self.auto_aug_agent.fun_num+1]).item()), 8)
                 if mag2 is not None:
-                    mag2 += y[idx, section+self.auto_aug_agent.fun_num+2]
+                    mag2 += min(max(0, y[idx, section+self.auto_aug_agent.fun_num+1].item()), 8)
                 counter += 1
-        
+
         prob1 = prob1/counter if counter != 0 else 0
         prob2 = prob2/counter if counter != 0 else 0
         if mag1 is not None:
             mag1 = mag1/counter 
         if mag2 is not None:
-            mag2 = mag2/counter            
+            mag2 = mag2/counter    
+
         
-        return [(self.augmentation_space[max_idx[0]], prob1, mag1), (self.augmentation_space[max_idx[1]], prob2, mag2)]
+        return [(self.augmentation_space[max_idx[0]][0], prob1, mag1), (self.augmentation_space[max_idx[1]][0], prob2, mag2)]
 
 
         
@@ -342,6 +336,7 @@ class Evolutionary_learner():
             """
             Defines fitness function (accuracy of the model)
             """
+            print("FITNESS HERE")
 
             model_weights_dict = torchga.model_weights_as_dict(model=self.auto_aug_agent,
                                                             weights_vector=solution)
@@ -349,14 +344,13 @@ class Evolutionary_learner():
             self.auto_aug_agent.load_state_dict(model_weights_dict)
 
             for idx, (test_x, label_x) in enumerate(train_loader):
-                # full_policy = self.get_full_policy(test_x)
                 full_policy = self.get_policy_cov(test_x)
+            print("FULL POLICY: ", full_policy)
 
-            print("full_policy: ", full_policy)
-            cop_mod = self.new_model()
 
-            fit_val = test_autoaugment_policy(full_policy, self.train_dataset, self.test_dataset)[0]
-            cop_mod = 0
+            fit_val = (test_autoaugment_policy(full_policy, self.train_dataset, self.test_dataset)[0]) #+ test_autoaugment_policy(full_policy, self.train_dataset, self.test_dataset)[0]) / 2
+
+            print("DONE FITNESS")
 
             return fit_val
 
@@ -372,6 +366,7 @@ class Evolutionary_learner():
         self.ga_instance = pygad.GA(num_generations=self.num_generations, 
             num_parents_mating=self.num_parents_mating, 
             initial_population=self.initial_population,
+            mutation_percent_genes = 0.1,
             fitness_func=fitness_func,
             on_generation = on_generation)
 
@@ -381,14 +376,566 @@ class Evolutionary_learner():
 
 
 
+
+
+
+
+# HEREHEREHERE0
+
+def create_toy(train_dataset, test_dataset, batch_size, n_samples, seed=100):
+    # shuffle and take first n_samples %age of training dataset
+    shuffle_order_train = np.random.RandomState(seed=seed).permutation(len(train_dataset))
+    shuffled_train_dataset = torch.utils.data.Subset(train_dataset, shuffle_order_train)
+    
+    indices_train = torch.arange(int(n_samples*len(train_dataset)))
+    reduced_train_dataset = torch.utils.data.Subset(shuffled_train_dataset, indices_train)
+    
+    # shuffle and take first n_samples %age of test dataset
+    shuffle_order_test = np.random.RandomState(seed=seed).permutation(len(test_dataset))
+    shuffled_test_dataset = torch.utils.data.Subset(test_dataset, shuffle_order_test)
+
+    big = 4 # how much bigger is the test set
+
+    indices_test = torch.arange(int(n_samples*len(test_dataset)*big))
+    reduced_test_dataset = torch.utils.data.Subset(shuffled_test_dataset, indices_test)
+
+    # push into DataLoader
+    train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size)
+    test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size)
+
+    return train_loader, test_loader
+
+
+def train_child_network(child_network, train_loader, test_loader, sgd,
+                         cost, max_epochs=2000, early_stop_num = 5, logging=False,
+                         print_every_epoch=True):
+    if torch.cuda.is_available():
+        device = torch.device('cuda')
+    else:
+        device = torch.device('cpu')
+    child_network = child_network.to(device=device)
+    
+    best_acc=0
+    early_stop_cnt = 0
+    
+    # logging accuracy for plotting
+    acc_log = [] 
+
+    # train child_network and check validation accuracy each epoch
+    for _epoch in range(max_epochs):
+
+        # train child_network
+        child_network.train()
+        for idx, (train_x, train_label) in enumerate(train_loader):
+            # onto device
+            train_x = train_x.to(device=device, dtype=train_x.dtype)
+            train_label = train_label.to(device=device, dtype=train_label.dtype)
+
+            # label_np = np.zeros((train_label.shape[0], 10))
+
+            sgd.zero_grad()
+            predict_y = child_network(train_x.float())
+            loss = cost(predict_y, train_label.long())
+            loss.backward()
+            sgd.step()
+
+        # check validation accuracy on validation set
+        correct = 0
+        _sum = 0
+        child_network.eval()
+        with torch.no_grad():
+            for idx, (test_x, test_label) in enumerate(test_loader):
+                # onto device
+                test_x = test_x.to(device=device, dtype=test_x.dtype)
+                test_label = test_label.to(device=device, dtype=test_label.dtype)
+
+                predict_y = child_network(test_x.float()).detach()
+                predict_ys = torch.argmax(predict_y, axis=-1)
+
+                # label_np = test_label.numpy()
+
+                _ = predict_ys == test_label
+                correct += torch.sum(_, axis=-1)
+                # correct += torch.sum(_.numpy(), axis=-1)
+                _sum += _.shape[0]
+        
+        # update best validation accuracy if it was higher, otherwise increase early stop count
+        acc = correct / _sum
+
+        if acc > best_acc :
+            best_acc = acc
+            early_stop_cnt = 0
+        else:
+            early_stop_cnt += 1
+
+        # exit if validation gets worse over 10 runs
+        if early_stop_cnt >= early_stop_num:
+            print('main.train_child_network best accuracy: ', best_acc)
+            break
+        
+        # if print_every_epoch:
+            # print('main.train_child_network best accuracy: ', best_acc)
+        acc_log.append(acc)
+
+    if logging:
+        return best_acc.item(), acc_log
+    return best_acc.item()
+
+def test_autoaugment_policy(subpolicies, train_dataset, test_dataset):
+
+    aa_transform = AutoAugment()
+    aa_transform.subpolicies = subpolicies
+
+    train_transform = transforms.Compose([
+                                            aa_transform,
+                                            transforms.ToTensor()
+                                        ])
+
+    train_dataset.transform = train_transform
+
+    # create toy dataset from above uploaded data
+    train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size=32, n_samples=0.1)
+
+    child_network = LeNet()
+    sgd = optim.SGD(child_network.parameters(), lr=1e-1)
+    cost = nn.CrossEntropyLoss()
+
+    best_acc, acc_log = train_child_network(child_network, train_loader, test_loader,
+                                                sgd, cost, max_epochs=100, logging=True)
+
+    return best_acc, acc_log
+
+
+
+__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"]
+
+
+def _apply_op(img: Tensor, op_name: str, magnitude: float,
+              interpolation: InterpolationMode, fill: Optional[List[float]]):
+    if op_name == "ShearX":
+        img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(magnitude), 0.0],
+                       interpolation=interpolation, fill=fill)
+    elif op_name == "ShearY":
+        img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)],
+                       interpolation=interpolation, fill=fill)
+    elif op_name == "TranslateX":
+        img = F.affine(img, angle=0.0, translate=[int(magnitude), 0], scale=1.0,
+                       interpolation=interpolation, shear=[0.0, 0.0], fill=fill)
+    elif op_name == "TranslateY":
+        img = F.affine(img, angle=0.0, translate=[0, int(magnitude)], scale=1.0,
+                       interpolation=interpolation, shear=[0.0, 0.0], fill=fill)
+    elif op_name == "Rotate":
+        img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill)
+    elif op_name == "Brightness":
+        img = F.adjust_brightness(img, 1.0 + magnitude)
+    elif op_name == "Color":
+        img = F.adjust_saturation(img, 1.0 + magnitude)
+    elif op_name == "Contrast":
+        img = F.adjust_contrast(img, 1.0 + magnitude)
+    elif op_name == "Sharpness":
+        img = F.adjust_sharpness(img, 1.0 + magnitude)
+    elif op_name == "Posterize":
+        img = F.posterize(img, int(magnitude))
+    elif op_name == "Solarize":
+        img = F.solarize(img, magnitude)
+    elif op_name == "AutoContrast":
+        img = F.autocontrast(img)
+    elif op_name == "Equalize":
+        img = F.equalize(img)
+    elif op_name == "Invert":
+        img = F.invert(img)
+    elif op_name == "Identity":
+        pass
+    else:
+        raise ValueError("The provided operator {} is not recognized.".format(op_name))
+    return img
+
+
+class AutoAugmentPolicy(Enum):
+    """AutoAugment policies learned on different datasets.
+    Available policies are IMAGENET, CIFAR10 and SVHN.
+    """
+    IMAGENET = "imagenet"
+    CIFAR10 = "cifar10"
+    SVHN = "svhn"
+
+
+# FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class
+class AutoAugment(torch.nn.Module):
+    r"""AutoAugment data augmentation method based on
+    `"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
+    If the image is torch Tensor, it should be of type torch.uint8, and it is expected
+    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
+    If img is PIL Image, it is expected to be in mode "L" or "RGB".
+
+    Args:
+        policy (AutoAugmentPolicy): Desired policy enum defined by
+            :class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
+        interpolation (InterpolationMode): Desired interpolation enum defined by
+            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
+            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
+        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
+            image. If given a number, the value is used for all bands respectively.
+    """
+
+    def __init__(
+        self,
+        policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
+        interpolation: InterpolationMode = InterpolationMode.NEAREST,
+        fill: Optional[List[float]] = None
+    ) -> None:
+        super().__init__()
+        self.policy = policy
+        self.interpolation = interpolation
+        self.fill = fill
+        self.subpolicies = self._get_subpolicies(policy)
+
+    def _get_subpolicies(
+        self,
+        policy: AutoAugmentPolicy
+    ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
+        if policy == AutoAugmentPolicy.IMAGENET:
+            return [
+                (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
+                (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
+                (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
+                (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
+                (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
+                (("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
+                (("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
+                (("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
+                (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
+                (("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
+                (("Rotate", 0.8, 8), ("Color", 0.4, 0)),
+                (("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
+                (("Equalize", 0.0, None), ("Equalize", 0.8, None)),
+                (("Invert", 0.6, None), ("Equalize", 1.0, None)),
+                (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
+                (("Rotate", 0.8, 8), ("Color", 1.0, 2)),
+                (("Color", 0.8, 8), ("Solarize", 0.8, 7)),
+                (("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
+                (("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
+                (("Color", 0.4, 0), ("Equalize", 0.6, None)),
+                (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
+                (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
+                (("Invert", 0.6, None), ("Equalize", 1.0, None)),
+                (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
+                (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
+            ]
+        elif policy == AutoAugmentPolicy.CIFAR10:
+            return [
+                (("Invert", 0.1, None), ("Contrast", 0.2, 6)),
+                (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
+                (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
+                (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
+                (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
+                (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
+                (("Color", 0.4, 3), ("Brightness", 0.6, 7)),
+                (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
+                (("Equalize", 0.6, None), ("Equalize", 0.5, None)),
+                (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
+                (("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
+                (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
+                (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
+                (("Brightness", 0.9, 6), ("Color", 0.2, 8)),
+                (("Solarize", 0.5, 2), ("Invert", 0.0, None)),
+                (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
+                (("Equalize", 0.2, None), ("Equalize", 0.6, None)),
+                (("Color", 0.9, 9), ("Equalize", 0.6, None)),
+                (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
+                (("Brightness", 0.1, 3), ("Color", 0.7, 0)),
+                (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
+                (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
+                (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
+                (("Equalize", 0.8, None), ("Invert", 0.1, None)),
+                (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
+            ]
+        elif policy == AutoAugmentPolicy.SVHN:
+            return [
+                (("ShearX", 0.9, 4), ("Invert", 0.2, None)),
+                (("ShearY", 0.9, 8), ("Invert", 0.7, None)),
+                (("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
+                (("Invert", 0.9, None), ("Equalize", 0.6, None)),
+                (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
+                (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
+                (("ShearY", 0.9, 8), ("Invert", 0.4, None)),
+                (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
+                (("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
+                (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
+                (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
+                (("ShearY", 0.8, 8), ("Invert", 0.7, None)),
+                (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
+                (("Invert", 0.9, None), ("Equalize", 0.6, None)),
+                (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
+                (("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
+                (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
+                (("Invert", 0.6, None), ("Rotate", 0.8, 4)),
+                (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
+                (("ShearX", 0.1, 6), ("Invert", 0.6, None)),
+                (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
+                (("ShearY", 0.8, 4), ("Invert", 0.8, None)),
+                (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
+                (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
+                (("ShearX", 0.7, 2), ("Invert", 0.1, None)),
+            ]
+        else:
+            raise ValueError("The provided policy {} is not recognized.".format(policy))
+
+    def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
+        return {
+            # op_name: (magnitudes, signed)
+            "ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
+            "ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
+            "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
+            "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
+            "Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
+            "Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
+            "Color": (torch.linspace(0.0, 0.9, num_bins), True),
+            "Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
+            "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
+            "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
+            "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
+            "AutoContrast": (torch.tensor(0.0), False),
+            "Equalize": (torch.tensor(0.0), False),
+            "Invert": (torch.tensor(0.0), False),
+        }
+
+    @staticmethod
+    def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]:
+        """Get parameters for autoaugment transformation
+
+        Returns:
+            params required by the autoaugment transformation
+        """
+        policy_id = int(torch.randint(transform_num, (1,)).item())
+        probs = torch.rand((2,))
+        signs = torch.randint(2, (2,))
+
+        return policy_id, probs, signs
+
+    def forward(self, img: Tensor, dis_mag = True) -> Tensor:
+        """
+            img (PIL Image or Tensor): Image to be transformed.
+
+        Returns:
+            PIL Image or Tensor: AutoAugmented image.
+        """
+        fill = self.fill
+        if isinstance(img, Tensor):
+            if isinstance(fill, (int, float)):
+                fill = [float(fill)] * F.get_image_num_channels(img)
+            elif fill is not None:
+                fill = [float(f) for f in fill]
+
+        transform_id, probs, signs = self.get_params(len(self.subpolicies))
+        # print("transform_id, probs, signs : ", transform_id, probs, signs )
+
+        # for i, (op_name, p, magnitude_id) in enumerate(self.subpolicies[transform_id]):
+        # for i, (op_name, p, magnitude_id) in enumerate(self.subpolicies):
+        #     print("op_name, p, magnitude_id: ", op_name, p, magnitude_id)
+        #     if probs[i] <= p:
+        #         op_meta = self._augmentation_space(10, F.get_image_size(img))
+        #         magnitudes, signed = op_meta[op_name]
+        #         magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0
+        #         if signed and signs[i] == 0:
+        #             magnitude *= -1.0
+        #         img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
+
+        for i, (op_name, p, magnitude) in enumerate(self.subpolicies):
+            img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
+
+
+        return img
+
+    def __repr__(self) -> str:
+        return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill)
+
+
+class RandAugment(torch.nn.Module):
+    r"""RandAugment data augmentation method based on
+    `"RandAugment: Practical automated data augmentation with a reduced search space"
+    <https://arxiv.org/abs/1909.13719>`_.
+    If the image is torch Tensor, it should be of type torch.uint8, and it is expected
+    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
+    If img is PIL Image, it is expected to be in mode "L" or "RGB".
+
+    Args:
+        num_ops (int): Number of augmentation transformations to apply sequentially.
+        magnitude (int): Magnitude for all the transformations.
+        num_magnitude_bins (int): The number of different magnitude values.
+        interpolation (InterpolationMode): Desired interpolation enum defined by
+            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
+            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
+        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
+            image. If given a number, the value is used for all bands respectively.
+        """
+
+    def __init__(self, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 31,
+                 interpolation: InterpolationMode = InterpolationMode.NEAREST,
+                 fill: Optional[List[float]] = None) -> None:
+        super().__init__()
+        self.num_ops = num_ops
+        self.magnitude = magnitude
+        self.num_magnitude_bins = num_magnitude_bins
+        self.interpolation = interpolation
+        self.fill = fill
+
+    def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
+        return {
+            # op_name: (magnitudes, signed)
+            "Identity": (torch.tensor(0.0), False),
+            "ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
+            "ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
+            "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
+            "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
+            "Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
+            "Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
+            "Color": (torch.linspace(0.0, 0.9, num_bins), True),
+            "Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
+            "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
+            "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
+            "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
+            "AutoContrast": (torch.tensor(0.0), False),
+            "Equalize": (torch.tensor(0.0), False),
+        }
+
+    def forward(self, img: Tensor) -> Tensor:
+        """
+            img (PIL Image or Tensor): Image to be transformed.
+
+        Returns:
+            PIL Image or Tensor: Transformed image.
+        """
+        fill = self.fill
+        if isinstance(img, Tensor):
+            if isinstance(fill, (int, float)):
+                fill = [float(fill)] * F.get_image_num_channels(img)
+            elif fill is not None:
+                fill = [float(f) for f in fill]
+
+        for _ in range(self.num_ops):
+            op_meta = self._augmentation_space(self.num_magnitude_bins, F.get_image_size(img))
+            op_index = int(torch.randint(len(op_meta), (1,)).item())
+            op_name = list(op_meta.keys())[op_index]
+            magnitudes, signed = op_meta[op_name]
+            magnitude = float(magnitudes[self.magnitude].item()) if magnitudes.ndim > 0 else 0.0
+            if signed and torch.randint(2, (1,)):
+                magnitude *= -1.0
+            img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
+
+        return img
+
+    def __repr__(self) -> str:
+        s = self.__class__.__name__ + '('
+        s += 'num_ops={num_ops}'
+        s += ', magnitude={magnitude}'
+        s += ', num_magnitude_bins={num_magnitude_bins}'
+        s += ', interpolation={interpolation}'
+        s += ', fill={fill}'
+        s += ')'
+        return s.format(**self.__dict__)
+
+
+class TrivialAugmentWide(torch.nn.Module):
+    r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in
+    `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`.
+    If the image is torch Tensor, it should be of type torch.uint8, and it is expected
+    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
+    If img is PIL Image, it is expected to be in mode "L" or "RGB".
+
+    Args:
+        num_magnitude_bins (int): The number of different magnitude values.
+        interpolation (InterpolationMode): Desired interpolation enum defined by
+            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
+            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
+        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
+            image. If given a number, the value is used for all bands respectively.
+        """
+
+    def __init__(self, num_magnitude_bins: int = 31, interpolation: InterpolationMode = InterpolationMode.NEAREST,
+                 fill: Optional[List[float]] = None) -> None:
+        super().__init__()
+        self.num_magnitude_bins = num_magnitude_bins
+        self.interpolation = interpolation
+        self.fill = fill
+
+    def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:
+        return {
+            # op_name: (magnitudes, signed)
+            "Identity": (torch.tensor(0.0), False),
+            "ShearX": (torch.linspace(0.0, 0.99, num_bins), True),
+            "ShearY": (torch.linspace(0.0, 0.99, num_bins), True),
+            "TranslateX": (torch.linspace(0.0, 32.0, num_bins), True),
+            "TranslateY": (torch.linspace(0.0, 32.0, num_bins), True),
+            "Rotate": (torch.linspace(0.0, 135.0, num_bins), True),
+            "Brightness": (torch.linspace(0.0, 0.99, num_bins), True),
+            "Color": (torch.linspace(0.0, 0.99, num_bins), True),
+            "Contrast": (torch.linspace(0.0, 0.99, num_bins), True),
+            "Sharpness": (torch.linspace(0.0, 0.99, num_bins), True),
+            "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False),
+            "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
+            "AutoContrast": (torch.tensor(0.0), False),
+            "Equalize": (torch.tensor(0.0), False),
+        }
+
+    def forward(self, img: Tensor) -> Tensor:
+        """
+            img (PIL Image or Tensor): Image to be transformed.
+
+        Returns:
+            PIL Image or Tensor: Transformed image.
+        """
+        fill = self.fill
+        if isinstance(img, Tensor):
+            if isinstance(fill, (int, float)):
+                fill = [float(fill)] * F.get_image_num_channels(img)
+            elif fill is not None:
+                fill = [float(f) for f in fill]
+
+        op_meta = self._augmentation_space(self.num_magnitude_bins)
+        op_index = int(torch.randint(len(op_meta), (1,)).item())
+        op_name = list(op_meta.keys())[op_index]
+        magnitudes, signed = op_meta[op_name]
+        magnitude = float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item()) \
+            if magnitudes.ndim > 0 else 0.0
+        if signed and torch.randint(2, (1,)):
+            magnitude *= -1.0
+
+        return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
+
+    def __repr__(self) -> str:
+        s = self.__class__.__name__ + '('
+        s += 'num_magnitude_bins={num_magnitude_bins}'
+        s += ', interpolation={interpolation}'
+        s += ', fill={fill}'
+        s += ')'
+        return s.format(**self.__dict__)
+
+# HEREHEREHEREHERE1
+
+
+
+
+
+
+
+
+train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, 
+                            transform=None)
+test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False,
+                            transform=torchvision.transforms.ToTensor())
+
+
 auto_aug_agent = Learner()
-ev_learner = Evolutionary_learner(auto_aug_agent, train_loader=train_loader, child_network=LeNet(), augmentation_space=augmentation_space, p_bins=1, mag_bins=1, sub_num_pol=1)
+ev_learner = Evolutionary_learner(auto_aug_agent, train_loader=train_loader, child_network=LeNet(), augmentation_space=augmentation_space, p_bins=1, mag_bins=1, sub_num_pol=1, train_dataset=train_dataset, test_dataset=test_dataset)
 ev_learner.run_instance()
 
 
 solution, solution_fitness, solution_idx = ev_learner.ga_instance.best_solution()
-print("Fitness value of the best solution = {solution_fitness}".format(solution_fitness=solution_fitness))
-print("Index of the best solution : {solution_idx}".format(solution_idx=solution_idx))
+
+print(f"Best solution : {solution}")
+print(f"Fitness value of the best solution = {solution_fitness}")
+print(f"Index of the best solution : {solution_idx}")
 # Fetch the parameters of the best solution.
 best_solution_weights = torchga.model_weights_as_dict(model=ev_learner.auto_aug_agent,
                                                       weights_vector=solution)
\ No newline at end of file
diff --git a/MetaAugment/GA_results.png b/MetaAugment/GA_results.png
new file mode 100644
index 0000000000000000000000000000000000000000..62449415b64500804927328ca677c4c023085436
Binary files /dev/null and b/MetaAugment/GA_results.png differ
diff --git a/MetaAugment/MetaAugment/test/MNIST/raw/t10k-labels-idx1-ubyte b/MetaAugment/MetaAugment/test/MNIST/raw/t10k-labels-idx1-ubyte
new file mode 100644
index 0000000000000000000000000000000000000000..d1c3a970612bbd2df47a3c0697f82bd394abc450
Binary files /dev/null and b/MetaAugment/MetaAugment/test/MNIST/raw/t10k-labels-idx1-ubyte differ
diff --git a/MetaAugment/MetaAugment/test/MNIST/raw/t10k-labels-idx1-ubyte.gz b/MetaAugment/MetaAugment/test/MNIST/raw/t10k-labels-idx1-ubyte.gz
new file mode 100644
index 0000000000000000000000000000000000000000..a7e141541c1d08d3f2ed01eae03e644f9e2fd0c5
Binary files /dev/null and b/MetaAugment/MetaAugment/test/MNIST/raw/t10k-labels-idx1-ubyte.gz differ
diff --git a/MetaAugment/MetaAugment/test/MNIST/raw/train-labels-idx1-ubyte b/MetaAugment/MetaAugment/test/MNIST/raw/train-labels-idx1-ubyte
new file mode 100644
index 0000000000000000000000000000000000000000..d6b4c5db3b52063d543fb397aede09aba0dc5234
Binary files /dev/null and b/MetaAugment/MetaAugment/test/MNIST/raw/train-labels-idx1-ubyte differ
diff --git a/MetaAugment/MetaAugment/test/MNIST/raw/train-labels-idx1-ubyte.gz b/MetaAugment/MetaAugment/test/MNIST/raw/train-labels-idx1-ubyte.gz
new file mode 100644
index 0000000000000000000000000000000000000000..707a576bb523304d5b674de436c0779d77b7d480
Binary files /dev/null and b/MetaAugment/MetaAugment/test/MNIST/raw/train-labels-idx1-ubyte.gz differ
diff --git a/MetaAugment/MetaAugment/train/MNIST/raw/t10k-labels-idx1-ubyte b/MetaAugment/MetaAugment/train/MNIST/raw/t10k-labels-idx1-ubyte
new file mode 100644
index 0000000000000000000000000000000000000000..d1c3a970612bbd2df47a3c0697f82bd394abc450
Binary files /dev/null and b/MetaAugment/MetaAugment/train/MNIST/raw/t10k-labels-idx1-ubyte differ
diff --git a/MetaAugment/MetaAugment/train/MNIST/raw/t10k-labels-idx1-ubyte.gz b/MetaAugment/MetaAugment/train/MNIST/raw/t10k-labels-idx1-ubyte.gz
new file mode 100644
index 0000000000000000000000000000000000000000..a7e141541c1d08d3f2ed01eae03e644f9e2fd0c5
Binary files /dev/null and b/MetaAugment/MetaAugment/train/MNIST/raw/t10k-labels-idx1-ubyte.gz differ
diff --git a/MetaAugment/MetaAugment/train/MNIST/raw/train-labels-idx1-ubyte b/MetaAugment/MetaAugment/train/MNIST/raw/train-labels-idx1-ubyte
new file mode 100644
index 0000000000000000000000000000000000000000..d6b4c5db3b52063d543fb397aede09aba0dc5234
Binary files /dev/null and b/MetaAugment/MetaAugment/train/MNIST/raw/train-labels-idx1-ubyte differ
diff --git a/MetaAugment/MetaAugment/train/MNIST/raw/train-labels-idx1-ubyte.gz b/MetaAugment/MetaAugment/train/MNIST/raw/train-labels-idx1-ubyte.gz
new file mode 100644
index 0000000000000000000000000000000000000000..707a576bb523304d5b674de436c0779d77b7d480
Binary files /dev/null and b/MetaAugment/MetaAugment/train/MNIST/raw/train-labels-idx1-ubyte.gz differ
diff --git a/MetaAugment/genetic_learner_results.py b/MetaAugment/genetic_learner_results.py
new file mode 100644
index 0000000000000000000000000000000000000000..35d9de8df2e17748b34e6879d4a3ae75dca9d9fb
--- /dev/null
+++ b/MetaAugment/genetic_learner_results.py
@@ -0,0 +1,109 @@
+import matplotlib.pyplot as plt
+import numpy as np
+
+
+# Fixed seed (same as benchmark)
+
+# Looking at last generation can make out general trends of which transformations lead to the largest accuracies
+
+
+gen_1_acc = [0.1998, 0.1405, 0.1678, 0.9690, 0.9672, 0.9540, 0.9047, 0.9730, 0.2060, 0.9260, 0.8035, 0.9715, 0.9737, 0.14, 0.9645]
+
+gen_2_acc = [0.9218, 0.9753, 0.9758, 0.1088, 0.9710, 0.1655, 0.9735, 0.9655, 0.9740, 0.9377]
+
+gen_3_acc = [0.1445, 0.9740, 0.9643, 0.9750, 0.9492, 0.9693, 0.1262, 0.9660, 0.9760, 0.9697]
+
+gen_4_acc = [0.9697, 0.1238, 0.9613, 0.9737, 0.9603, 0.8620, 0.9712, 0.9617, 0.9737, 0.1855]
+
+gen_5_acc = [0.6445, 0.9705, 0.9668, 0.9765, 0.1142, 0.9780, 0.9700, 0.2120, 0.9555, 0.9732]
+
+gen_6_acc = [0.9710, 0.9665, 0.2077, 0.9535, 0.9765, 0.9712, 0.9697, 0.2145, 0.9523, 0.9718, 0.9718, 0.9718, 0.2180, 0.9622, 0.9785]
+
+gen_acc = [gen_1_acc, gen_2_acc, gen_3_acc, gen_4_acc, gen_5_acc, gen_6_acc]
+
+gen_acc_means = []
+gen_acc_stds = []
+
+for val in gen_acc:
+    gen_acc_means.append(np.mean(val))
+    gen_acc_stds.append(np.std(val))
+
+
+
+# Vary seed
+
+gen_1_vary = [0.1998, 0.9707, 0.9715, 0.9657, 0.8347, 0.9655, 0.1870, 0.0983, 0.3750, 0.9765, 0.9712, 0.9705, 0.9635, 0.9718, 0.1170]
+
+gen_2_vary = [0.9758, 0.9607, 0.9597, 0.9753, 0.1165, 0.1503, 0.9747, 0.1725, 0.9645, 0.2290]
+
+gen_3_vary = [0.1357, 0.9725, 0.1708, 0.9607, 0.2132, 0.9730, 0.9743, 0.9690, 0.0850, 0.9755]
+
+gen_4_vary = [0.9722, 0.9760, 0.9697, 0.1155, 0.9715, 0.9688, 0.1785, 0.9745, 0.2362, 0.9765]
+
+gen_5_vary = [0.9705, 0.2280, 0.9745, 0.1875, 0.9735, 0.9735, 0.9720, 0.9678, 0.9770, 0.1155]
+
+gen_6_vary = [0.9685, 0.9730, 0.9735, 0.9760, 0.1495, 0.9707, 0.9700, 0.9747, 0.9750, 0.1155, 0.9732, 0.9745, 0.9758, 0.9768, 0.1155]
+
+gen_vary = [gen_1_vary, gen_2_vary, gen_3_vary, gen_4_vary, gen_5_vary, gen_6_vary]
+
+gen_vary_means = []
+gen_vary_stds = []
+
+for val in gen_vary:
+    gen_vary_means.append(np.mean(val))
+    gen_vary_stds.append(np.std(val))
+
+
+
+
+
+# Multiple runs 
+
+gen_1_mult = [0.1762, 0.9575, 0.1200, 0.9660, 0.9650, 0.9570, 0.9745, 0.9700, 0.15, 0.23, 0.16, 0.186, 0.9640, 0.9650]
+
+gen_2_mult = [0.17, 0.1515, 0.1700, 0.9625, 0.9630, 0.9732, 0.9680, 0.9633, 0.9530, 0.9640]
+
+gen_3_mult = [0.9750, 0.9720, 0.9655, 0.9530, 0.9623, 0.9730, 0.9748, 0.9625, 0.9716, 0.9672]
+
+gen_4_mult = [0.9724, 0.9755, 0.9657, 0.9718, 0.9690, 0.9735, 0.9715, 0.9300, 0.9725, 0.9695]
+
+gen_5_mult = [0.9560, 0.9750, 0.8750, 0.9717, 0.9731, 0.9741, 0.9747, 0.9726, 0.9729, 0.9727]
+
+gen_6_mult = [0.9730, 0.9740, 0.9715, 0.9755, 0.9761, 0.9700, 0.9755, 0.9750, 0.9726, 0.9748, 0.9705, 0.9745, 0.9752, 0.9740, 0.9744]
+
+
+
+gen_mult = [gen_1_mult, gen_2_mult, gen_3_mult,  gen_4_mult, gen_5_mult, gen_6_mult]
+
+gen_mult_means = []
+gen_mult_stds = []
+
+for val in gen_mult:
+    gen_mult_means.append(np.mean(val))
+    gen_mult_stds.append(np.std(val))
+
+num_gen = [i for i in range(len(gen_mult))]
+
+
+# Baseline
+baseline = [0.7990 for i in range(len(gen_mult))]
+
+
+
+# plt.errorbar(num_gen, gen_acc_means, yerr=gen_acc_stds, linestyle = 'dotted', label = 'Fixed seed GA')
+# plt.errorbar(num_gen, gen_vary_means, linestyle = 'dotted', yerr=gen_vary_stds, label = 'Varying seed GA')
+# plt.errorbar(num_gen, gen_mult_means, linestyle = 'dotted', yerr=gen_mult_stds, label = 'Varying seed GA 2')
+
+plt.plot(num_gen, gen_acc_means, linestyle = 'dotted', label = 'Fixed seed GA')
+plt.plot(num_gen, gen_vary_means, linestyle = 'dotted',  label = 'Varying seed GA')
+plt.plot(num_gen, gen_mult_means, linestyle = 'dotted', label = 'Varying seed GA 2')
+
+plt.plot(num_gen, baseline, label = 'Fixed seed baseline')
+
+
+plt.xlabel('Generation', fontsize = 16)
+plt.ylabel('Validation Accuracy', fontsize = 16)
+
+plt.legend()
+
+plt.savefig('GA_results.png')
\ No newline at end of file