diff --git a/MetaAugment/CP2_Max.py b/MetaAugment/CP2_Max.py
index 005676f234792310fa0a11a6a170cd1822fb6e2a..e928b7de2bd152aaa05b56c8fbb8321a8681edaa 100644
--- a/MetaAugment/CP2_Max.py
+++ b/MetaAugment/CP2_Max.py
@@ -12,9 +12,10 @@ import random
 import pygad
 import pygad.torchga as torchga
 import random
+import copy
 
-import MetaAugment.child_networks as child_networks
-from MetaAugment.main import *
+# import MetaAugment.child_networks as child_networks
+# from MetaAugment.main import *
 
 
 np.random.seed(0)
@@ -22,7 +23,7 @@ random.seed(0)
 
 
 class Learner(nn.Module):
-    def __init__(self):
+    def __init__(self, num_transforms = 3):
         super().__init__()
         self.conv1 = nn.Conv2d(1, 6, 5)
         self.relu1 = nn.ReLU()
@@ -51,16 +52,51 @@ class Learner(nn.Module):
         y = self.relu4(y)
         y = self.fc3(y)
 
-        # y = self.sig(y)
-        # print("y[3:, :] shape: ", y[:, 3:].shape)
+        return y
+
+    def get_idx(self, x):
+        y = self.forward(x)
         idx_ret = torch.argmax(y[:, 0:3].mean(dim = 0))
         p_ret = 0.1 * torch.argmax(y[:, 3:].mean(dim = 0))
         return (idx_ret, p_ret)
+
         # return (torch.argmax(y[0:3]), y[torch.argmax(y[3:])])
 
+class LeNet(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(1, 6, 5)
+        self.relu1 = nn.ReLU()
+        self.pool1 = nn.MaxPool2d(2)
+        self.conv2 = nn.Conv2d(6, 16, 5)
+        self.relu2 = nn.ReLU()
+        self.pool2 = nn.MaxPool2d(2)
+        self.fc1 = nn.Linear(256, 120)
+        self.relu3 = nn.ReLU()
+        self.fc2 = nn.Linear(120, 84)
+        self.relu4 = nn.ReLU()
+        self.fc3 = nn.Linear(84, 10)
+        self.relu5 = nn.ReLU()
+
+    def forward(self, x):
+        y = self.conv1(x)
+        y = self.relu1(y)
+        y = self.pool1(y)
+        y = self.conv2(y)
+        y = self.relu2(y)
+        y = self.pool2(y)
+        y = y.view(y.shape[0], -1)
+        y = self.fc1(y)
+        y = self.relu3(y)
+        y = self.fc2(y)
+        y = self.relu4(y)
+        y = self.fc3(y)
+        return y
+
+
 
 # code from https://github.com/ChawDoe/LeNet5-MNIST-PyTorch/blob/master/train.py
-def train_model(transform_idx, p):
+def train_model(transform_idx, p, child_network):
     """
     Takes in the specific transformation index and probability 
     """
@@ -93,16 +129,11 @@ def train_model(transform_idx, p):
     train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, transform=transform_train)
     test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False, transform=torchvision.transforms.ToTensor())
 
-    # create toy dataset from above uploaded data
     train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, 0.01)
 
-    # train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size)
-    # test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size)
 
-    # print("Size of training dataset:\t", len(reduced_train_dataset))
-    # print("Size of testing dataset:\t", len(reduced_test_dataset), "\n")
+    # child_network = child_networks.lenet()
 
-    child_network = child_networks.lenet()
     sgd = optim.SGD(child_network.parameters(), lr=1e-1)
     cost = nn.CrossEntropyLoss()
     epoch = 20
@@ -114,29 +145,84 @@ def train_model(transform_idx, p):
 
 
 
-def fitness_func(solution, sol_idx):
-    """
-    Defines fitness function (accuracy of the model)
-    """
-    global train_loader, meta_rl_agent
-    model_weights_dict = torchga.model_weights_as_dict(model=meta_rl_agent,
-                                                       weights_vector=solution)
-    # Use the current solution as the model parameters.
-    meta_rl_agent.load_state_dict(model_weights_dict)
-    for idx, (test_x, label_x) in enumerate(train_loader):
-        trans_idx, p = meta_rl_agent(test_x)
 
-    fit_val = train_model(trans_idx, p)
 
-    return fit_val
 
+def train_child_network(child_network, train_loader, test_loader, sgd,
+                         cost, max_epochs=100, early_stop_num = 10, logging=False):
+    best_acc=0
+    early_stop_cnt = 0
+    
+    # logging accuracy for plotting
+    acc_log = [] 
+
+    # train child_network and check validation accuracy each epoch
+    for _epoch in range(max_epochs):
+
+        # train child_network
+        child_network.train()
+        for idx, (train_x, train_label) in enumerate(train_loader):
+            label_np = np.zeros((train_label.shape[0], 10))
+            sgd.zero_grad()
+            predict_y = child_network(train_x.float())
+            loss = cost(predict_y, train_label.long())
+            loss.backward()
+            sgd.step()
+
+        # check validation accuracy on validation set
+        correct = 0
+        _sum = 0
+        child_network.eval()
+        with torch.no_grad():
+            for idx, (test_x, test_label) in enumerate(test_loader):
+                predict_y = child_network(test_x.float()).detach()
+                predict_ys = np.argmax(predict_y, axis=-1)
+                label_np = test_label.numpy()
+                _ = predict_ys == test_label
+                correct += np.sum(_.numpy(), axis=-1)
+                _sum += _.shape[0]
+        
+        # update best validation accuracy if it was higher, otherwise increase early stop count
+        acc = correct / _sum
+
+
+        if acc > best_acc :
+            best_acc = acc
+            early_stop_cnt = 0
+        else:
+            early_stop_cnt += 1
+
+        # exit if validation gets worse over 10 runs
+        if early_stop_cnt >= early_stop_num:
+            break
+        
+        # print('main.train_child_network best accuracy: ', best_acc)
+        acc_log.append(acc)
+
+    if logging:
+        return best_acc, acc_log
+    return best_acc
 
-def callback_generation(ga_instance):
-    """
-    Just prints stuff while running
-    """
-    print("Generation = {generation}".format(generation=ga_instance.generations_completed))
-    print("Fitness    = {fitness}".format(fitness=ga_instance.best_solution()[1]))
+def create_toy(train_dataset, test_dataset, batch_size, n_samples, seed=100):
+    # shuffle and take first n_samples %age of training dataset
+    shuffle_order_train = np.random.RandomState(seed=seed).permutation(len(train_dataset))
+    shuffled_train_dataset = torch.utils.data.Subset(train_dataset, shuffle_order_train)
+    
+    indices_train = torch.arange(int(n_samples*len(train_dataset)))
+    reduced_train_dataset = torch.utils.data.Subset(shuffled_train_dataset, indices_train)
+    
+    # shuffle and take first n_samples %age of test dataset
+    shuffle_order_test = np.random.RandomState(seed=seed).permutation(len(test_dataset))
+    shuffled_test_dataset = torch.utils.data.Subset(test_dataset, shuffle_order_test)
+
+    indices_test = torch.arange(int(n_samples*len(test_dataset)))
+    reduced_test_dataset = torch.utils.data.Subset(shuffled_test_dataset, indices_test)
+
+    # push into DataLoader
+    train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size)
+    test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size)
+
+    return train_loader, test_loader
 
 
 # ORGANISING DATA
@@ -158,32 +244,74 @@ train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=600
 
 
 
-# GENERATING THE GA INSTANCE
+class Evolutionary_learner():
 
-meta_rl_agent = Learner()
-torch_ga = torchga.TorchGA(model=meta_rl_agent,
-                           num_solutions=20)
+    def __init__(self, network, num_solutions = 30, num_generations = 10, num_parents_mating = 15, train_loader = None, sec_model = None):
+        self.meta_rl_agent = network
+        self.torch_ga = torchga.TorchGA(model=network, num_solutions=num_solutions)
+        self.num_generations = num_generations
+        self.num_parents_mating = num_parents_mating
+        self.initial_population = self.torch_ga.population_weights
+        self.train_loader = train_loader
+        self.backup_model = sec_model
 
-# HYPERPARAMETER FOR THE GA 
+        assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!'
 
-num_generations = 100 # Number of generations.
-num_parents_mating = 20 # Number of solutions to be selected as parents in the mating pool.
-initial_population = torch_ga.population_weights
+        self.set_up_instance()
+    
 
-ga_instance = pygad.GA(num_generations=num_generations, 
-                       num_parents_mating=num_parents_mating, 
-                       initial_population=initial_population,
-                       fitness_func=fitness_func,
-                       on_generation=callback_generation)
-ga_instance.run()
+    def run_instance(self, return_weights = False):
+        self.ga_instance.run()
+        solution, solution_fitness, solution_idx = self.ga_instance.best_solution()
+        if return_weights:
+            return torchga.model_weights_as_dict(model=self.meta_rl_agent, weights_vector=solution)
+        else:
+            return solution, solution_fitness, solution_idx
 
-solution, solution_fitness, solution_idx = ga_instance.best_solution()
-print("Fitness value of the best solution = {solution_fitness}".format(solution_fitness=solution_fitness))
-print("Index of the best solution : {solution_idx}".format(solution_idx=solution_idx))
-# Fetch the parameters of the best solution.
-best_solution_weights = torchga.model_weights_as_dict(model=meta_rl_agent,
-                                                      weights_vector=solution)
+    def new_model(self):
+        copy_model = copy.deepcopy(self.backup_model)
+        return copy_model
+
+
+    def set_up_instance(self):
+        def fitness_func(solution, sol_idx):
+            """
+            Defines fitness function (accuracy of the model)
+            """
+            model_weights_dict = torchga.model_weights_as_dict(model=self.meta_rl_agent,
+                                                            weights_vector=solution)
+            self.meta_rl_agent.load_state_dict(model_weights_dict)
+            for idx, (test_x, label_x) in enumerate(train_loader):
+                trans_idx, p = self.meta_rl_agent.get_idx(test_x)
+            cop_mod = self.new_model()
+            fit_val = train_model(trans_idx, p, cop_mod)
+            cop_mod = 0
+            return fit_val
+
+        def on_generation(ga_instance):
+            """
+            Just prints stuff while running
+            """
+            print("Generation = {generation}".format(generation=self.ga_instance.generations_completed))
+            print("Fitness    = {fitness}".format(fitness=self.ga_instance.best_solution()[1]))
+            return
 
 
+        self.ga_instance = pygad.GA(num_generations=self.num_generations, 
+            num_parents_mating=self.num_parents_mating, 
+            initial_population=self.initial_population,
+            fitness_func=fitness_func,
+            on_generation = on_generation)
 
 
+meta_rl_agent = Learner()
+ev_learner = Evolutionary_learner(meta_rl_agent, train_loader=train_loader, sec_model=LeNet())
+ev_learner.run_instance()
+
+
+solution, solution_fitness, solution_idx = ev_learner.ga_instance.best_solution()
+print("Fitness value of the best solution = {solution_fitness}".format(solution_fitness=solution_fitness))
+print("Index of the best solution : {solution_idx}".format(solution_idx=solution_idx))
+# Fetch the parameters of the best solution.
+best_solution_weights = torchga.model_weights_as_dict(model=ev_learner.meta_rl_agent,
+                                                      weights_vector=solution)
\ No newline at end of file
diff --git a/MetaAugment/METALEANER.py b/MetaAugment/METALEANER.py
new file mode 100644
index 0000000000000000000000000000000000000000..c94246d6898ccf2d316c1dae7644513bf113149e
--- /dev/null
+++ b/MetaAugment/METALEANER.py
@@ -0,0 +1,7 @@
+
+
+
+# Neural network 
+# Input the dataset (same batch size, have to check if the input sizes are correc i.e. 28x28)
+# Output the hyperprameters --> weights of network, kernel size, number of layers, number of kernels
+# 
\ No newline at end of file