From c68dc525cd9d94ee6003b41018d6d50d1d420636 Mon Sep 17 00:00:00 2001
From: Max Ramsay King <maxramsayking@gmail.com>
Date: Tue, 19 Apr 2022 03:03:29 -0700
Subject: [PATCH] moved evo learner to aa_learner

---
 MetaAugment/CP2_Max.py                        | 947 ------------------
 .../autoaugment_learners/evo_learner.py       | 283 ++++++
 .../controller_networks/evo_controller.py     |  39 +
 .../controller_networks/rnn_controller.py     |   7 -
 backend/.DS_Store                             | Bin 0 -> 6148 bytes
 5 files changed, 322 insertions(+), 954 deletions(-)
 delete mode 100644 MetaAugment/CP2_Max.py
 create mode 100644 MetaAugment/autoaugment_learners/evo_learner.py
 create mode 100644 MetaAugment/controller_networks/evo_controller.py
 create mode 100644 backend/.DS_Store

diff --git a/MetaAugment/CP2_Max.py b/MetaAugment/CP2_Max.py
deleted file mode 100644
index a13fd207..00000000
--- a/MetaAugment/CP2_Max.py
+++ /dev/null
@@ -1,947 +0,0 @@
-from cgi import test
-import numpy as np
-import torch
-torch.manual_seed(0)
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.optim as optim
-import torchvision
-import torchvision.datasets as datasets
-from torchvision import transforms
-import torchvision.transforms.autoaugment as autoaugment
-import random
-import pygad
-import pygad.torchga as torchga
-import random
-import copy
-from torchvision.transforms import functional as F, InterpolationMode
-from typing import List, Tuple, Optional, Dict
-import heapq
-import math
-
-import math
-import torch
-
-from enum import Enum
-from torch import Tensor
-from typing import List, Tuple, Optional, Dict
-
-from torchvision.transforms import functional as F, InterpolationMode
-
-# import MetaAugment.child_networks as child_networks
-# from main import *
-# from autoaugment_learners.autoaugment import *
-
-
-# np.random.seed(0)
-# random.seed(0)
-
-
-augmentation_space = [
-            # (function_name, do_we_need_to_specify_magnitude)
-            ("ShearX", True),
-            ("ShearY", True),
-            ("TranslateX", True),
-            ("TranslateY", True),
-            ("Rotate", True),
-            ("Brightness", True),
-            ("Color", True),
-            ("Contrast", True),
-            ("Sharpness", True),
-            ("Posterize", True),
-            ("Solarize", True),
-            ("AutoContrast", False),
-            ("Equalize", False),
-            ("Invert", False),
-        ]
-
-class Learner(nn.Module):
-    def __init__(self, fun_num=14, p_bins=11, m_bins=10, sub_num_pol=5):
-        self.fun_num = fun_num
-        self.p_bins = p_bins 
-        self.m_bins = m_bins 
-        self.sub_num_pol = sub_num_pol
-
-        super().__init__()
-        self.conv1 = nn.Conv2d(1, 6, 5)
-        self.relu1 = nn.ReLU()
-        self.pool1 = nn.MaxPool2d(2)
-        self.conv2 = nn.Conv2d(6, 16, 5)
-        self.relu2 = nn.ReLU()
-        self.pool2 = nn.MaxPool2d(2)
-        self.fc1 = nn.Linear(256, 120)
-        self.relu3 = nn.ReLU()
-        self.fc2 = nn.Linear(120, 84)
-        self.relu4 = nn.ReLU()
-        self.fc3 = nn.Linear(84, self.sub_num_pol * 2 * (self.fun_num + self.p_bins + self.m_bins))
-
-# Currently using discrete outputs for the probabilities 
-
-    def forward(self, x):
-        y = self.conv1(x)
-        y = self.relu1(y)
-        y = self.pool1(y)
-        y = self.conv2(y)
-        y = self.relu2(y)
-        y = self.pool2(y)
-        y = y.view(y.shape[0], -1)
-        y = self.fc1(y)
-        y = self.relu3(y)
-        y = self.fc2(y)
-        y = self.relu4(y)
-        y = self.fc3(y)
-
-        return y
-
-
-# class LeNet(nn.Module):
-#     def __init__(self):
-#         super().__init__()
-#         self.conv1 = nn.Conv2d(1, 6, 5)
-#         self.relu1 = nn.ReLU()
-#         self.pool1 = nn.MaxPool2d(2)
-#         self.conv2 = nn.Conv2d(6, 16, 5)
-#         self.relu2 = nn.ReLU()
-#         self.pool2 = nn.MaxPool2d(2)
-#         self.fc1 = nn.Linear(256, 120)
-#         self.relu3 = nn.ReLU()
-#         self.fc2 = nn.Linear(120, 84)
-#         self.relu4 = nn.ReLU()
-#         self.fc3 = nn.Linear(84, 10)
-#         self.relu5 = nn.ReLU()
-
-#     def forward(self, x):
-#         y = self.conv1(x)
-#         y = self.relu1(y)
-#         y = self.pool1(y)
-#         y = self.conv2(y)
-#         y = self.relu2(y)
-#         y = self.pool2(y)
-#         y = y.view(y.shape[0], -1)
-#         y = self.fc1(y)
-#         y = self.relu3(y)
-#         y = self.fc2(y)
-#         y = self.relu4(y)
-#         y = self.fc3(y)
-#         return y
-
-
-class LeNet(nn.Module):
-    def __init__(self):
-        super().__init__()
-        self.fc1 = nn.Linear(784, 2048)
-        self.relu1 = nn.ReLU()
-        self.fc2 = nn.Linear(2048, 10)
-        self.relu2 = nn.ReLU()
-
-    def forward(self, x):
-        x = x.reshape((-1, 784))
-        y = self.fc1(x)
-        y = self.relu1(y)
-        y = self.fc2(y)
-        y = self.relu2(y)
-        return y
-
-
-
-# ORGANISING DATA
-
-# transforms = ['RandomResizedCrop', 'RandomHorizontalFlip', 'RandomVerticalCrop', 'RandomRotation']
-train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=True, transform=torchvision.transforms.ToTensor())
-test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=True, transform=torchvision.transforms.ToTensor())
-n_samples = 0.02
-# shuffle and take first n_samples  %age of training dataset
-shuffled_train_dataset = torch.utils.data.Subset(train_dataset, torch.randperm(len(train_dataset)).tolist())
-indices_train = torch.arange(int(n_samples*len(train_dataset)))
-reduced_train_dataset = torch.utils.data.Subset(shuffled_train_dataset, indices_train)
-# shuffle and take first n_samples %age of test dataset
-shuffled_test_dataset = torch.utils.data.Subset(test_dataset, torch.randperm(len(test_dataset)).tolist())
-indices_test = torch.arange(int(n_samples*len(test_dataset)))
-reduced_test_dataset = torch.utils.data.Subset(shuffled_test_dataset, indices_test)
-
-train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=60000)
-
-
-
-
-
-class Evolutionary_learner():
-
-    def __init__(self, network, num_solutions = 10, num_generations = 5, num_parents_mating = 5, train_loader = None, child_network = None, p_bins = 11, mag_bins = 10, sub_num_pol=5, fun_num = 14, augmentation_space = None, train_dataset = None, test_dataset = None):
-        self.auto_aug_agent = Learner(fun_num=fun_num, p_bins=p_bins, m_bins=mag_bins, sub_num_pol=sub_num_pol)
-        self.torch_ga = torchga.TorchGA(model=network, num_solutions=num_solutions)
-        self.num_generations = num_generations
-        self.num_parents_mating = num_parents_mating
-        self.initial_population = self.torch_ga.population_weights
-        self.train_loader = train_loader
-        self.child_network = child_network
-        self.p_bins = p_bins 
-        self.sub_num_pol = sub_num_pol
-        self.mag_bins = mag_bins
-        self.fun_num = fun_num
-        self.augmentation_space = augmentation_space
-        self.train_dataset = train_dataset
-        self.test_dataset = test_dataset
-
-        assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!'
-
-        self.set_up_instance()
-
-
-    def get_full_policy(self, x):
-        """
-        Generates the full policy (self.num_sub_pol subpolicies). Network architecture requires
-        output size 5 * 2 * (self.fun_num + self.p_bins + self.mag_bins)
-
-        Parameters 
-        -----------
-        x -> PyTorch tensor
-            Input data for network 
-
-        Returns
-        ----------
-        full_policy -> [((String, float, float), (String, float, float)), ...)
-            Full policy consisting of tuples of subpolicies. Each subpolicy consisting of
-            two transformations, with a probability and magnitude float for each
-        """
-        section = self.auto_aug_agent.fun_num + self.auto_aug_agent.p_bins + self.auto_aug_agent.m_bins
-        y = self.auto_aug_agent.forward(x)
-        full_policy = []
-        for pol in range(self.sub_num_pol):
-            int_pol = []
-            for _ in range(2):
-                idx_ret = torch.argmax(y[:, (pol * section):(pol*section) + self.fun_num].mean(dim = 0))
-
-                trans, need_mag = self.augmentation_space[idx_ret]
-
-                p_ret = (1/(self.p_bins-1)) * torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0))
-                mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0)) if need_mag else None
-                int_pol.append((trans, p_ret, mag))
-
-            full_policy.append(tuple(int_pol))
-
-        return full_policy
-
-    
-    def get_policy_cov(self, x, alpha = 0.5):
-        """
-        Selects policy using population and covariance matrices. For this method 
-        we require p_bins = 1, num_sub_pol = 1, mag_bins = 1. 
-
-        Parameters
-        ------------
-        x -> PyTorch Tensor
-            Input data for the AutoAugment network 
-
-        alpha -> Float
-            Proportion for covariance and population matrices 
-
-        Returns
-        -----------
-        Subpolicy -> [(String, float, float), (String, float, float)]
-            Subpolicy consisting of two tuples of policies, each with a string associated 
-            to a transformation, a float for a probability, and a float for a magnittude
-        """
-        section = self.auto_aug_agent.fun_num + self.auto_aug_agent.p_bins + self.auto_aug_agent.m_bins
-
-        y = self.auto_aug_agent.forward(x) # 1000 x 32
-
-        y_1 = torch.softmax(y[:,:self.auto_aug_agent.fun_num], dim = 1) # 1000 x 14
-        y[:,:self.auto_aug_agent.fun_num] = y_1
-        y_2 = torch.softmax(y[:,section:section+self.auto_aug_agent.fun_num], dim = 1)
-        y[:,section:section+self.auto_aug_agent.fun_num] = y_2
-        concat = torch.cat((y_1, y_2), dim = 1)
-
-        cov_mat = torch.cov(concat.T)#[:self.auto_aug_agent.fun_num, self.auto_aug_agent.fun_num:]
-        cov_mat = cov_mat[:self.auto_aug_agent.fun_num, self.auto_aug_agent.fun_num:]
-        shape_store = cov_mat.shape
-
-        counter, prob1, prob2, mag1, mag2 = (0, 0, 0, 0, 0)
-
-
-        prob_mat = torch.zeros(shape_store)
-        for idx in range(y.shape[0]):
-            prob_mat[torch.argmax(y_1[idx])][torch.argmax(y_2[idx])] += 1
-        prob_mat = prob_mat / torch.sum(prob_mat)
-
-        cov_mat = (alpha * cov_mat) + ((1 - alpha)*prob_mat)
-
-        cov_mat = torch.reshape(cov_mat, (1, -1)).squeeze()
-        max_idx = torch.argmax(cov_mat)
-        val = (max_idx//shape_store[0])
-        max_idx = (val, max_idx - (val * shape_store[0]))
-
-
-        if not self.augmentation_space[max_idx[0]][1]:
-            mag1 = None
-        if not self.augmentation_space[max_idx[1]][1]:
-            mag2 = None
-    
-        for idx in range(y.shape[0]):
-            if (torch.argmax(y_1[idx]) == max_idx[0]) and (torch.argmax(y_2[idx]) == max_idx[1]):
-                prob1 += torch.sigmoid(y[idx, self.auto_aug_agent.fun_num]).item()
-                prob2 += torch.sigmoid(y[idx, section+self.auto_aug_agent.fun_num]).item()
-                if mag1 is not None:
-                    mag1 += min(max(0, (y[idx, self.auto_aug_agent.fun_num+1]).item()), 8)
-                if mag2 is not None:
-                    mag2 += min(max(0, y[idx, section+self.auto_aug_agent.fun_num+1].item()), 8)
-                counter += 1
-
-        prob1 = prob1/counter if counter != 0 else 0
-        prob2 = prob2/counter if counter != 0 else 0
-        if mag1 is not None:
-            mag1 = mag1/counter 
-        if mag2 is not None:
-            mag2 = mag2/counter    
-
-        
-        return [(self.augmentation_space[max_idx[0]][0], prob1, mag1), (self.augmentation_space[max_idx[1]][0], prob2, mag2)]
-
-
-        
-
-
-
-    def run_instance(self, return_weights = False):
-        """
-        Runs the GA instance and returns the model weights as a dictionary
-
-        Parameters
-        ------------
-        return_weights -> Bool
-            Determines if the weight of the GA network should be returned 
-        
-        Returns
-        ------------
-        If return_weights:
-            Network weights -> Dictionary
-        
-        Else:
-            Solution -> Best GA instance solution
-
-            Solution fitness -> Float
-
-            Solution_idx -> Int
-        """
-        self.ga_instance.run()
-        solution, solution_fitness, solution_idx = self.ga_instance.best_solution()
-        if return_weights:
-            return torchga.model_weights_as_dict(model=self.auto_aug_agent, weights_vector=solution)
-        else:
-            return solution, solution_fitness, solution_idx
-
-
-    def new_model(self):
-        """
-        Simple function to create a copy of the secondary model (used for classification)
-        """
-        copy_model = copy.deepcopy(self.child_network)
-        return copy_model
-
-
-    def set_up_instance(self):
-        """
-        Initialises GA instance, as well as fitness and on_generation functions
-        
-        """
-
-        def fitness_func(solution, sol_idx):
-            """
-            Defines the fitness function for the parent selection
-
-            Parameters
-            --------------
-            solution -> GA solution instance (parsed automatically)
-
-            sol_idx -> GA solution index (parsed automatically)
-
-            Returns 
-            --------------
-            fit_val -> float            
-            """
-
-            model_weights_dict = torchga.model_weights_as_dict(model=self.auto_aug_agent,
-                                                            weights_vector=solution)
-
-            self.auto_aug_agent.load_state_dict(model_weights_dict)
-
-            for idx, (test_x, label_x) in enumerate(train_loader):
-                full_policy = self.get_policy_cov(test_x)
-
-            fit_val = ((test_autoaugment_policy(full_policy, self.train_dataset, self.test_dataset)[0])/
-                        + test_autoaugment_policy(full_policy, self.train_dataset, self.test_dataset)[0]) / 2
-
-            return fit_val
-
-        def on_generation(ga_instance):
-            """
-            Prints information of generational fitness
-
-            Parameters 
-            -------------
-            ga_instance -> GA instance
-
-            Returns
-            -------------
-            None
-            """
-            print("Generation = {generation}".format(generation=ga_instance.generations_completed))
-            print("Fitness    = {fitness}".format(fitness=ga_instance.best_solution()[1]))
-            return
-
-
-        self.ga_instance = pygad.GA(num_generations=self.num_generations, 
-            num_parents_mating=self.num_parents_mating, 
-            initial_population=self.initial_population,
-            mutation_percent_genes = 0.1,
-            fitness_func=fitness_func,
-            on_generation = on_generation)
-
-
-
-
-# HEREHEREHERE0
-
-def create_toy(train_dataset, test_dataset, batch_size, n_samples, seed=100):
-    # shuffle and take first n_samples %age of training dataset
-    shuffle_order_train = np.random.RandomState(seed=seed).permutation(len(train_dataset))
-    shuffled_train_dataset = torch.utils.data.Subset(train_dataset, shuffle_order_train)
-    
-    indices_train = torch.arange(int(n_samples*len(train_dataset)))
-    reduced_train_dataset = torch.utils.data.Subset(shuffled_train_dataset, indices_train)
-    
-    # shuffle and take first n_samples %age of test dataset
-    shuffle_order_test = np.random.RandomState(seed=seed).permutation(len(test_dataset))
-    shuffled_test_dataset = torch.utils.data.Subset(test_dataset, shuffle_order_test)
-
-    big = 4 # how much bigger is the test set
-
-    indices_test = torch.arange(int(n_samples*len(test_dataset)*big))
-    reduced_test_dataset = torch.utils.data.Subset(shuffled_test_dataset, indices_test)
-
-    # push into DataLoader
-    train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size)
-    test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size)
-
-    return train_loader, test_loader
-
-
-def train_child_network(child_network, train_loader, test_loader, sgd,
-                            cost, max_epochs=2000, early_stop_num = 5, logging=False,
-                            print_every_epoch=True):
-    if torch.cuda.is_available():
-        device = torch.device('cuda')
-    else:
-        device = torch.device('cpu')
-    child_network = child_network.to(device=device)
-    
-    best_acc=0
-    early_stop_cnt = 0
-    
-    # logging accuracy for plotting
-    acc_log = [] 
-
-    # train child_network and check validation accuracy each epoch
-    for _epoch in range(max_epochs):
-
-        # train child_network
-        child_network.train()
-        for idx, (train_x, train_label) in enumerate(train_loader):
-            # onto device
-            train_x = train_x.to(device=device, dtype=train_x.dtype)
-            train_label = train_label.to(device=device, dtype=train_label.dtype)
-
-            # label_np = np.zeros((train_label.shape[0], 10))
-
-            sgd.zero_grad()
-            predict_y = child_network(train_x.float())
-            loss = cost(predict_y, train_label.long())
-            loss.backward()
-            sgd.step()
-
-        # check validation accuracy on validation set
-        correct = 0
-        _sum = 0
-        child_network.eval()
-        with torch.no_grad():
-            for idx, (test_x, test_label) in enumerate(test_loader):
-                # onto device
-                test_x = test_x.to(device=device, dtype=test_x.dtype)
-                test_label = test_label.to(device=device, dtype=test_label.dtype)
-
-                predict_y = child_network(test_x.float()).detach()
-                predict_ys = torch.argmax(predict_y, axis=-1)
-    
-                _ = predict_ys == test_label
-                correct += torch.sum(_, axis=-1)
-
-                _sum += _.shape[0]
-        
-        # update best validation accuracy if it was higher, otherwise increase early stop count
-        acc = correct / _sum
-
-        if acc > best_acc :
-            best_acc = acc
-            early_stop_cnt = 0
-        else:
-            early_stop_cnt += 1
-
-        # exit if validation gets worse over 10 runs
-        if early_stop_cnt >= early_stop_num:
-            print('main.train_child_network best accuracy: ', best_acc)
-            break
-        
-        # if print_every_epoch:
-            # print('main.train_child_network best accuracy: ', best_acc)
-        acc_log.append(acc)
-
-    if logging:
-        return best_acc.item(), acc_log
-    return best_acc.item()
-
-def test_autoaugment_policy(subpolicies, train_dataset, test_dataset):
-
-    aa_transform = AutoAugment()
-    aa_transform.subpolicies = subpolicies
-
-    train_transform = transforms.Compose([
-                                            aa_transform,
-                                            transforms.ToTensor()
-                                        ])
-
-    train_dataset.transform = train_transform
-
-    # create toy dataset from above uploaded data
-    train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size=32, n_samples=0.1)
-
-    child_network = LeNet()
-    sgd = optim.SGD(child_network.parameters(), lr=1e-1)
-    cost = nn.CrossEntropyLoss()
-
-    best_acc, acc_log = train_child_network(child_network, train_loader, test_loader,
-                                                sgd, cost, max_epochs=100, logging=True)
-
-    return best_acc, acc_log
-
-
-
-__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"]
-
-
-def _apply_op(img: Tensor, op_name: str, magnitude: float,
-                interpolation: InterpolationMode, fill: Optional[List[float]]):
-    if op_name == "ShearX":
-        img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(magnitude), 0.0],
-                        interpolation=interpolation, fill=fill)
-    elif op_name == "ShearY":
-        img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)],
-                        interpolation=interpolation, fill=fill)
-    elif op_name == "TranslateX":
-        img = F.affine(img, angle=0.0, translate=[int(magnitude), 0], scale=1.0,
-                        interpolation=interpolation, shear=[0.0, 0.0], fill=fill)
-    elif op_name == "TranslateY":
-        img = F.affine(img, angle=0.0, translate=[0, int(magnitude)], scale=1.0,
-                        interpolation=interpolation, shear=[0.0, 0.0], fill=fill)
-    elif op_name == "Rotate":
-        img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill)
-    elif op_name == "Brightness":
-        img = F.adjust_brightness(img, 1.0 + magnitude)
-    elif op_name == "Color":
-        img = F.adjust_saturation(img, 1.0 + magnitude)
-    elif op_name == "Contrast":
-        img = F.adjust_contrast(img, 1.0 + magnitude)
-    elif op_name == "Sharpness":
-        img = F.adjust_sharpness(img, 1.0 + magnitude)
-    elif op_name == "Posterize":
-        img = F.posterize(img, int(magnitude))
-    elif op_name == "Solarize":
-        img = F.solarize(img, magnitude)
-    elif op_name == "AutoContrast":
-        img = F.autocontrast(img)
-    elif op_name == "Equalize":
-        img = F.equalize(img)
-    elif op_name == "Invert":
-        img = F.invert(img)
-    elif op_name == "Identity":
-        pass
-    else:
-        raise ValueError("The provided operator {} is not recognized.".format(op_name))
-    return img
-
-
-class AutoAugmentPolicy(Enum):
-    """AutoAugment policies learned on different datasets.
-    Available policies are IMAGENET, CIFAR10 and SVHN.
-    """
-    IMAGENET = "imagenet"
-    CIFAR10 = "cifar10"
-    SVHN = "svhn"
-
-
-# FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class
-class AutoAugment(torch.nn.Module):
-    r"""AutoAugment data augmentation method based on
-    `"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
-    If the image is torch Tensor, it should be of type torch.uint8, and it is expected
-    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
-    If img is PIL Image, it is expected to be in mode "L" or "RGB".
-
-    Args:
-        policy (AutoAugmentPolicy): Desired policy enum defined by
-            :class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
-        interpolation (InterpolationMode): Desired interpolation enum defined by
-            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
-            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
-        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
-            image. If given a number, the value is used for all bands respectively.
-    """
-
-    def __init__(
-        self,
-        policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
-        interpolation: InterpolationMode = InterpolationMode.NEAREST,
-        fill: Optional[List[float]] = None
-    ) -> None:
-        super().__init__()
-        self.policy = policy
-        self.interpolation = interpolation
-        self.fill = fill
-        self.subpolicies = self._get_subpolicies(policy)
-
-    def _get_subpolicies(
-        self,
-        policy: AutoAugmentPolicy
-    ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
-        if policy == AutoAugmentPolicy.IMAGENET:
-            return [
-                (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
-                (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
-                (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
-                (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
-                (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
-                (("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
-                (("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
-                (("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
-                (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
-                (("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
-                (("Rotate", 0.8, 8), ("Color", 0.4, 0)),
-                (("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
-                (("Equalize", 0.0, None), ("Equalize", 0.8, None)),
-                (("Invert", 0.6, None), ("Equalize", 1.0, None)),
-                (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
-                (("Rotate", 0.8, 8), ("Color", 1.0, 2)),
-                (("Color", 0.8, 8), ("Solarize", 0.8, 7)),
-                (("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
-                (("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
-                (("Color", 0.4, 0), ("Equalize", 0.6, None)),
-                (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
-                (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
-                (("Invert", 0.6, None), ("Equalize", 1.0, None)),
-                (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
-                (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
-            ]
-        elif policy == AutoAugmentPolicy.CIFAR10:
-            return [
-                (("Invert", 0.1, None), ("Contrast", 0.2, 6)),
-                (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
-                (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
-                (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
-                (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
-                (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
-                (("Color", 0.4, 3), ("Brightness", 0.6, 7)),
-                (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
-                (("Equalize", 0.6, None), ("Equalize", 0.5, None)),
-                (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
-                (("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
-                (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
-                (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
-                (("Brightness", 0.9, 6), ("Color", 0.2, 8)),
-                (("Solarize", 0.5, 2), ("Invert", 0.0, None)),
-                (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
-                (("Equalize", 0.2, None), ("Equalize", 0.6, None)),
-                (("Color", 0.9, 9), ("Equalize", 0.6, None)),
-                (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
-                (("Brightness", 0.1, 3), ("Color", 0.7, 0)),
-                (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
-                (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
-                (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
-                (("Equalize", 0.8, None), ("Invert", 0.1, None)),
-                (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
-            ]
-        elif policy == AutoAugmentPolicy.SVHN:
-            return [
-                (("ShearX", 0.9, 4), ("Invert", 0.2, None)),
-                (("ShearY", 0.9, 8), ("Invert", 0.7, None)),
-                (("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
-                (("Invert", 0.9, None), ("Equalize", 0.6, None)),
-                (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
-                (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
-                (("ShearY", 0.9, 8), ("Invert", 0.4, None)),
-                (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
-                (("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
-                (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
-                (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
-                (("ShearY", 0.8, 8), ("Invert", 0.7, None)),
-                (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
-                (("Invert", 0.9, None), ("Equalize", 0.6, None)),
-                (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
-                (("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
-                (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
-                (("Invert", 0.6, None), ("Rotate", 0.8, 4)),
-                (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
-                (("ShearX", 0.1, 6), ("Invert", 0.6, None)),
-                (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
-                (("ShearY", 0.8, 4), ("Invert", 0.8, None)),
-                (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
-                (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
-                (("ShearX", 0.7, 2), ("Invert", 0.1, None)),
-            ]
-        else:
-            raise ValueError("The provided policy {} is not recognized.".format(policy))
-
-    def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
-        return {
-            # op_name: (magnitudes, signed)
-            "ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
-            "ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
-            "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
-            "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
-            "Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
-            "Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
-            "Color": (torch.linspace(0.0, 0.9, num_bins), True),
-            "Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
-            "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
-            "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
-            "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
-            "AutoContrast": (torch.tensor(0.0), False),
-            "Equalize": (torch.tensor(0.0), False),
-            "Invert": (torch.tensor(0.0), False),
-        }
-
-    @staticmethod
-    def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]:
-        """Get parameters for autoaugment transformation
-
-        Returns:
-            params required by the autoaugment transformation
-        """
-        policy_id = int(torch.randint(transform_num, (1,)).item())
-        probs = torch.rand((2,))
-        signs = torch.randint(2, (2,))
-
-        return policy_id, probs, signs
-
-    def forward(self, img: Tensor, dis_mag = True) -> Tensor:
-        """
-            img (PIL Image or Tensor): Image to be transformed.
-
-        Returns:
-            PIL Image or Tensor: AutoAugmented image.
-        """
-        fill = self.fill
-        if isinstance(img, Tensor):
-            if isinstance(fill, (int, float)):
-                fill = [float(fill)] * F.get_image_num_channels(img)
-            elif fill is not None:
-                fill = [float(f) for f in fill]
-
-        transform_id, probs, signs = self.get_params(len(self.subpolicies))
-
-        for i, (op_name, p, magnitude) in enumerate(self.subpolicies):
-            img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
-
-
-        return img
-
-    def __repr__(self) -> str:
-        return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill)
-
-
-class RandAugment(torch.nn.Module):
-    r"""RandAugment data augmentation method based on
-    `"RandAugment: Practical automated data augmentation with a reduced search space"
-    <https://arxiv.org/abs/1909.13719>`_.
-    If the image is torch Tensor, it should be of type torch.uint8, and it is expected
-    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
-    If img is PIL Image, it is expected to be in mode "L" or "RGB".
-
-    Args:
-        num_ops (int): Number of augmentation transformations to apply sequentially.
-        magnitude (int): Magnitude for all the transformations.
-        num_magnitude_bins (int): The number of different magnitude values.
-        interpolation (InterpolationMode): Desired interpolation enum defined by
-            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
-            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
-        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
-            image. If given a number, the value is used for all bands respectively.
-        """
-
-    def __init__(self, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 31,
-                    interpolation: InterpolationMode = InterpolationMode.NEAREST,
-                    fill: Optional[List[float]] = None) -> None:
-        super().__init__()
-        self.num_ops = num_ops
-        self.magnitude = magnitude
-        self.num_magnitude_bins = num_magnitude_bins
-        self.interpolation = interpolation
-        self.fill = fill
-
-    def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
-        return {
-            # op_name: (magnitudes, signed)
-            "Identity": (torch.tensor(0.0), False),
-            "ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
-            "ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
-            "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
-            "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
-            "Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
-            "Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
-            "Color": (torch.linspace(0.0, 0.9, num_bins), True),
-            "Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
-            "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
-            "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
-            "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
-            "AutoContrast": (torch.tensor(0.0), False),
-            "Equalize": (torch.tensor(0.0), False),
-        }
-
-    def forward(self, img: Tensor) -> Tensor:
-        """
-            img (PIL Image or Tensor): Image to be transformed.
-
-        Returns:
-            PIL Image or Tensor: Transformed image.
-        """
-        fill = self.fill
-        if isinstance(img, Tensor):
-            if isinstance(fill, (int, float)):
-                fill = [float(fill)] * F.get_image_num_channels(img)
-            elif fill is not None:
-                fill = [float(f) for f in fill]
-
-        for _ in range(self.num_ops):
-            op_meta = self._augmentation_space(self.num_magnitude_bins, F.get_image_size(img))
-            op_index = int(torch.randint(len(op_meta), (1,)).item())
-            op_name = list(op_meta.keys())[op_index]
-            magnitudes, signed = op_meta[op_name]
-            magnitude = float(magnitudes[self.magnitude].item()) if magnitudes.ndim > 0 else 0.0
-            if signed and torch.randint(2, (1,)):
-                magnitude *= -1.0
-            img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
-
-        return img
-
-    def __repr__(self) -> str:
-        s = self.__class__.__name__ + '('
-        s += 'num_ops={num_ops}'
-        s += ', magnitude={magnitude}'
-        s += ', num_magnitude_bins={num_magnitude_bins}'
-        s += ', interpolation={interpolation}'
-        s += ', fill={fill}'
-        s += ')'
-        return s.format(**self.__dict__)
-
-
-class TrivialAugmentWide(torch.nn.Module):
-    r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in
-    `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`.
-    If the image is torch Tensor, it should be of type torch.uint8, and it is expected
-    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
-    If img is PIL Image, it is expected to be in mode "L" or "RGB".
-
-    Args:
-        num_magnitude_bins (int): The number of different magnitude values.
-        interpolation (InterpolationMode): Desired interpolation enum defined by
-            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
-            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
-        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
-            image. If given a number, the value is used for all bands respectively.
-        """
-
-    def __init__(self, num_magnitude_bins: int = 31, interpolation: InterpolationMode = InterpolationMode.NEAREST,
-                    fill: Optional[List[float]] = None) -> None:
-        super().__init__()
-        self.num_magnitude_bins = num_magnitude_bins
-        self.interpolation = interpolation
-        self.fill = fill
-
-    def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:
-        return {
-            # op_name: (magnitudes, signed)
-            "Identity": (torch.tensor(0.0), False),
-            "ShearX": (torch.linspace(0.0, 0.99, num_bins), True),
-            "ShearY": (torch.linspace(0.0, 0.99, num_bins), True),
-            "TranslateX": (torch.linspace(0.0, 32.0, num_bins), True),
-            "TranslateY": (torch.linspace(0.0, 32.0, num_bins), True),
-            "Rotate": (torch.linspace(0.0, 135.0, num_bins), True),
-            "Brightness": (torch.linspace(0.0, 0.99, num_bins), True),
-            "Color": (torch.linspace(0.0, 0.99, num_bins), True),
-            "Contrast": (torch.linspace(0.0, 0.99, num_bins), True),
-            "Sharpness": (torch.linspace(0.0, 0.99, num_bins), True),
-            "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False),
-            "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
-            "AutoContrast": (torch.tensor(0.0), False),
-            "Equalize": (torch.tensor(0.0), False),
-        }
-
-    def forward(self, img: Tensor) -> Tensor:
-        """
-            img (PIL Image or Tensor): Image to be transformed.
-
-        Returns:
-            PIL Image or Tensor: Transformed image.
-        """
-        fill = self.fill
-        if isinstance(img, Tensor):
-            if isinstance(fill, (int, float)):
-                fill = [float(fill)] * F.get_image_num_channels(img)
-            elif fill is not None:
-                fill = [float(f) for f in fill]
-
-        op_meta = self._augmentation_space(self.num_magnitude_bins)
-        op_index = int(torch.randint(len(op_meta), (1,)).item())
-        op_name = list(op_meta.keys())[op_index]
-        magnitudes, signed = op_meta[op_name]
-        magnitude = float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item()) \
-            if magnitudes.ndim > 0 else 0.0
-        if signed and torch.randint(2, (1,)):
-            magnitude *= -1.0
-
-        return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
-
-    def __repr__(self) -> str:
-        s = self.__class__.__name__ + '('
-        s += 'num_magnitude_bins={num_magnitude_bins}'
-        s += ', interpolation={interpolation}'
-        s += ', fill={fill}'
-        s += ')'
-        return s.format(**self.__dict__)
-
-# HEREHEREHEREHERE1
-
-
-
-
-
-
-
-
-train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, 
-                            transform=None)
-test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False,
-                            transform=torchvision.transforms.ToTensor())
-
-
-auto_aug_agent = Learner()
-ev_learner = Evolutionary_learner(auto_aug_agent, train_loader=train_loader, child_network=LeNet(), augmentation_space=augmentation_space, p_bins=1, mag_bins=1, sub_num_pol=1, train_dataset=train_dataset, test_dataset=test_dataset)
-ev_learner.run_instance()
-
-
-solution, solution_fitness, solution_idx = ev_learner.ga_instance.best_solution()
-
-print(f"Best solution : {solution}")
-print(f"Fitness value of the best solution = {solution_fitness}")
-print(f"Index of the best solution : {solution_idx}")
-# Fetch the parameters of the best solution.
-best_solution_weights = torchga.model_weights_as_dict(model=ev_learner.auto_aug_agent,
-                                                        weights_vector=solution)
diff --git a/MetaAugment/autoaugment_learners/evo_learner.py b/MetaAugment/autoaugment_learners/evo_learner.py
new file mode 100644
index 00000000..9c249c1b
--- /dev/null
+++ b/MetaAugment/autoaugment_learners/evo_learner.py
@@ -0,0 +1,283 @@
+from cgi import test
+import torch
+torch.manual_seed(0)
+import torch.nn as nn
+import pygad
+import pygad.torchga as torchga
+import copy
+import torch
+from MetaAugment.controller_networks.evo_controller import Evo_learner
+
+from MetaAugment.autoaugment_learners.aa_learner import aa_learner, augmentation_space
+import MetaAugment.child_networks as cn
+
+
+class Evolutionary_learner():
+
+    def __init__(self, 
+                sp_num=1,
+                num_solutions = 10, 
+                num_generations = 5, 
+                num_parents_mating = 5,
+                learning_rate = 1e-1, 
+                max_epochs=float('inf'),
+                early_stop_num=20,
+                train_loader = None, 
+                child_network = None, 
+                p_bins = 1, 
+                m_bins = 1, 
+                discrete_p_m=False,
+                batch_size=8,
+                toy_flag=False,
+                toy_size=0.1,
+                sub_num_pol=5, 
+                fun_num = 14,
+                exclude_method=[],
+                ):
+
+        super().__init__(sp_num, 
+            fun_num, 
+            p_bins, 
+            m_bins, 
+            discrete_p_m=discrete_p_m, 
+            batch_size=batch_size, 
+            toy_flag=toy_flag, 
+            toy_size=toy_size, 
+            learning_rate=learning_rate,
+            max_epochs=max_epochs,
+            early_stop_num=early_stop_num,)
+
+
+        self.auto_aug_agent = Evo_learner(fun_num=fun_num, p_bins=p_bins, m_bins=m_bins, sub_num_pol=sub_num_pol)
+        self.torch_ga = torchga.TorchGA(model=self.auto_aug_agent, num_solutions=num_solutions)
+        self.num_generations = num_generations
+        self.num_parents_mating = num_parents_mating
+        self.initial_population = self.torch_ga.population_weights
+        self.train_loader = train_loader
+        self.child_network = child_network
+        self.p_bins = p_bins 
+        self.sub_num_pol = sub_num_pol
+        self.m_bins = m_bins
+        self.fun_num = fun_num
+        self.augmentation_space = [x for x in augmentation_space if x[0] not in exclude_method]
+
+
+
+        assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!'
+
+        self.set_up_instance()
+
+
+    def get_full_policy(self, x):
+        """
+        Generates the full policy (self.num_sub_pol subpolicies). Network architecture requires
+        output size 5 * 2 * (self.fun_num + self.p_bins + self.m_bins)
+
+        Parameters 
+        -----------
+        x -> PyTorch tensor
+            Input data for network 
+
+        Returns
+        ----------
+        full_policy -> [((String, float, float), (String, float, float)), ...)
+            Full policy consisting of tuples of subpolicies. Each subpolicy consisting of
+            two transformations, with a probability and magnitude float for each
+        """
+        section = self.auto_aug_agent.fun_num + self.auto_aug_agent.p_bins + self.auto_aug_agent.m_bins
+        y = self.auto_aug_agent.forward(x)
+        full_policy = []
+        for pol in range(self.sub_num_pol):
+            int_pol = []
+            for _ in range(2):
+                idx_ret = torch.argmax(y[:, (pol * section):(pol*section) + self.fun_num].mean(dim = 0))
+
+                trans, need_mag = self.augmentation_space[idx_ret]
+
+                p_ret = (1/(self.p_bins-1)) * torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0))
+                mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0)) if need_mag else None
+                int_pol.append((trans, p_ret, mag))
+
+            full_policy.append(tuple(int_pol))
+
+        return full_policy
+
+    
+    def get_single_policy_cov(self, x, alpha = 0.5):
+        """
+        Selects policy using population and covariance matrices. For this method 
+        we require p_bins = 1, num_sub_pol = 1, m_bins = 1. 
+
+        Parameters
+        ------------
+        x -> PyTorch Tensor
+            Input data for the AutoAugment network 
+
+        alpha -> Float
+            Proportion for covariance and population matrices 
+
+        Returns
+        -----------
+        Subpolicy -> [(String, float, float), (String, float, float)]
+            Subpolicy consisting of two tuples of policies, each with a string associated 
+            to a transformation, a float for a probability, and a float for a magnittude
+        """
+        section = self.auto_aug_agent.fun_num + self.auto_aug_agent.p_bins + self.auto_aug_agent.m_bins
+
+        y = self.auto_aug_agent.forward(x) # 1000 x 32
+
+        y_1 = torch.softmax(y[:,:self.auto_aug_agent.fun_num], dim = 1) # 1000 x 14
+        y[:,:self.auto_aug_agent.fun_num] = y_1
+        y_2 = torch.softmax(y[:,section:section+self.auto_aug_agent.fun_num], dim = 1)
+        y[:,section:section+self.auto_aug_agent.fun_num] = y_2
+        concat = torch.cat((y_1, y_2), dim = 1)
+
+        cov_mat = torch.cov(concat.T)#[:self.auto_aug_agent.fun_num, self.auto_aug_agent.fun_num:]
+        cov_mat = cov_mat[:self.auto_aug_agent.fun_num, self.auto_aug_agent.fun_num:]
+        shape_store = cov_mat.shape
+
+        counter, prob1, prob2, mag1, mag2 = (0, 0, 0, 0, 0)
+
+
+        prob_mat = torch.zeros(shape_store)
+        for idx in range(y.shape[0]):
+            prob_mat[torch.argmax(y_1[idx])][torch.argmax(y_2[idx])] += 1
+        prob_mat = prob_mat / torch.sum(prob_mat)
+
+        cov_mat = (alpha * cov_mat) + ((1 - alpha)*prob_mat)
+
+        cov_mat = torch.reshape(cov_mat, (1, -1)).squeeze()
+        max_idx = torch.argmax(cov_mat)
+        val = (max_idx//shape_store[0])
+        max_idx = (val, max_idx - (val * shape_store[0]))
+
+
+        if not self.augmentation_space[max_idx[0]][1]:
+            mag1 = None
+        if not self.augmentation_space[max_idx[1]][1]:
+            mag2 = None
+    
+        for idx in range(y.shape[0]):
+            if (torch.argmax(y_1[idx]) == max_idx[0]) and (torch.argmax(y_2[idx]) == max_idx[1]):
+                prob1 += torch.sigmoid(y[idx, self.auto_aug_agent.fun_num]).item()
+                prob2 += torch.sigmoid(y[idx, section+self.auto_aug_agent.fun_num]).item()
+                if mag1 is not None:
+                    mag1 += min(max(0, (y[idx, self.auto_aug_agent.fun_num+1]).item()), 8)
+                if mag2 is not None:
+                    mag2 += min(max(0, y[idx, section+self.auto_aug_agent.fun_num+1].item()), 8)
+                counter += 1
+
+        prob1 = prob1/counter if counter != 0 else 0
+        prob2 = prob2/counter if counter != 0 else 0
+        if mag1 is not None:
+            mag1 = mag1/counter 
+        if mag2 is not None:
+            mag2 = mag2/counter    
+
+        
+        return [(self.augmentation_space[max_idx[0]][0], prob1, mag1), (self.augmentation_space[max_idx[1]][0], prob2, mag2)]
+
+
+    def run_instance(self, return_weights = False):
+        """
+        Runs the GA instance and returns the model weights as a dictionary
+
+        Parameters
+        ------------
+        return_weights -> Bool
+            Determines if the weight of the GA network should be returned 
+        
+        Returns
+        ------------
+        If return_weights:
+            Network weights -> Dictionary
+        
+        Else:
+            Solution -> Best GA instance solution
+
+            Solution fitness -> Float
+
+            Solution_idx -> Int
+        """
+        self.ga_instance.run()
+        solution, solution_fitness, solution_idx = self.ga_instance.best_solution()
+        if return_weights:
+            return torchga.model_weights_as_dict(model=self.auto_aug_agent, weights_vector=solution)
+        else:
+            return solution, solution_fitness, solution_idx
+
+
+    def new_model(self):
+        """
+        Simple function to create a copy of the secondary model (used for classification)
+        """
+        copy_model = copy.deepcopy(self.child_network)
+        return copy_model
+
+
+    def set_up_instance(self, train_dataset, test_dataset):
+        """
+        Initialises GA instance, as well as fitness and on_generation functions
+        
+        """
+
+        def fitness_func(solution, sol_idx):
+            """
+            Defines the fitness function for the parent selection
+
+            Parameters
+            --------------
+            solution -> GA solution instance (parsed automatically)
+
+            sol_idx -> GA solution index (parsed automatically)
+
+            Returns 
+            --------------
+            fit_val -> float            
+            """
+
+            model_weights_dict = torchga.model_weights_as_dict(model=self.auto_aug_agent,
+                                                            weights_vector=solution)
+
+            self.auto_aug_agent.load_state_dict(model_weights_dict)
+            self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size)
+
+            for idx, (test_x, label_x) in enumerate(self.train_loader):
+                if self.sp_num == 1:
+                    full_policy = self.get_single_policy_cov(test_x)
+                else:                    
+                    full_policy = self.get_full_policy(test_x)
+
+
+            fit_val = ((self.test_autoaugment_policy(full_policy, train_dataset, test_dataset)[0])/
+                        + self.test_autoaugment_policy(full_policy, train_dataset, test_dataset)[0]) / 2
+
+            return fit_val
+
+        def on_generation(ga_instance):
+            """
+            Prints information of generational fitness
+
+            Parameters 
+            -------------
+            ga_instance -> GA instance
+
+            Returns
+            -------------
+            None
+            """
+            print("Generation = {generation}".format(generation=ga_instance.generations_completed))
+            print("Fitness    = {fitness}".format(fitness=ga_instance.best_solution()[1]))
+            return
+
+
+        self.ga_instance = pygad.GA(num_generations=self.num_generations, 
+            num_parents_mating=self.num_parents_mating, 
+            initial_population=self.initial_population,
+            mutation_percent_genes = 0.1,
+            fitness_func=fitness_func,
+            on_generation = on_generation)
+
+
+
+
diff --git a/MetaAugment/controller_networks/evo_controller.py b/MetaAugment/controller_networks/evo_controller.py
new file mode 100644
index 00000000..aa7a1da2
--- /dev/null
+++ b/MetaAugment/controller_networks/evo_controller.py
@@ -0,0 +1,39 @@
+import torch
+import torch.nn as nn
+import math
+
+class Evo_learner(nn.Module):
+    def __init__(self, fun_num=14, p_bins=11, m_bins=10, sub_num_pol=5):
+        self.fun_num = fun_num
+        self.p_bins = p_bins 
+        self.m_bins = m_bins 
+        self.sub_num_pol = sub_num_pol
+
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 6, 5)
+        self.relu1 = nn.ReLU()
+        self.pool1 = nn.MaxPool2d(2)
+        self.conv2 = nn.Conv2d(6, 16, 5)
+        self.relu2 = nn.ReLU()
+        self.pool2 = nn.MaxPool2d(2)
+        self.fc1 = nn.Linear(256, 120)
+        self.relu3 = nn.ReLU()
+        self.fc2 = nn.Linear(120, 84)
+        self.relu4 = nn.ReLU()
+        self.fc3 = nn.Linear(84, self.sub_num_pol * 2 * (self.fun_num + self.p_bins + self.m_bins))
+
+    def forward(self, x):
+        y = self.conv1(x)
+        y = self.relu1(y)
+        y = self.pool1(y)
+        y = self.conv2(y)
+        y = self.relu2(y)
+        y = self.pool2(y)
+        y = y.view(y.shape[0], -1)
+        y = self.fc1(y)
+        y = self.relu3(y)
+        y = self.fc2(y)
+        y = self.relu4(y)
+        y = self.fc3(y)
+
+        return y
\ No newline at end of file
diff --git a/MetaAugment/controller_networks/rnn_controller.py b/MetaAugment/controller_networks/rnn_controller.py
index 12680eae..e01c9afd 100644
--- a/MetaAugment/controller_networks/rnn_controller.py
+++ b/MetaAugment/controller_networks/rnn_controller.py
@@ -135,13 +135,6 @@ class RNNModel(nn.Module):
                     X[i+1] = hx if self.mode == 'GRU' else hx[0]
                 
         outs = X
-    
-
-        # out = outs[-1].squeeze()
-
-        # out = self.fc(out)
-        
-        # return out
 
         return outs
     
diff --git a/backend/.DS_Store b/backend/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..570c5e5d328e0ea9d571920546e2eeb302b4777e
GIT binary patch
literal 6148
zcmeHK%}T>S5Z<-XrW7Fu6^{#E3&zqO#7l_v1&ruHr6we3FlI}Wwue&4QD4YM@p+ut
z-H64UMeGdhe)GGV{UH0p7~}pdK4HvZjM>l-IVuf;?#9rTNk-&2Mzkm<Q6?j>-%aeV
z1Ae>3<}6`3n}7fQXp*Hx>AD|$t7UEPwC%RtweN!`xd_UzoE3gJy+!Lv$}Fn&FuG2Y
z#n{<9mw6ecc`{W6X`DdH-A$Uua^cHa9v7-M&;h$+cgD`aa_RQOU@%<u#PYNc`q1sK
zRvml)==kh%{G7ey@>SExfo&zb25Wc&<!d$bV47z#e*|BhU*{1L1H=F^Kn!d$1Ljn)
zyPHe`t(_Pk1{xW_{XsxO^bA%S)z$$WUY{{;A)<heZwW+U&@)(Rga`=NrGUDWn<oa>
z<=_`4&ofwQ)a8t;nPD6=bNzVXYIg7omCm@Qk$Pf)7+7bZrA-IV{|oqK8Xx)VC1eo;
z#K1pefLo(r<in!O+4^I7c-9JN571CBuR;X`^tDR>7`Trds-TVw)FIC^SZTyj(67n?
P=^~&Ap^g~%1qMC=%|%K9

literal 0
HcmV?d00001

-- 
GitLab