diff --git a/backend/MetaAugment/Evo_learner.py b/backend/MetaAugment/Evo_learner.py
index 2f97fbd0f0d4a6c70896bbbce9383e11b2cd5ad0..71170f437b7674a15e8941d0a15c6990595ab55d 100644
--- a/backend/MetaAugment/Evo_learner.py
+++ b/backend/MetaAugment/Evo_learner.py
@@ -25,6 +25,7 @@ from torch import Tensor
 
 
 
+
 class Learner(nn.Module):
     def __init__(self, fun_num=14, p_bins=11, m_bins=10, sub_num_pol=5):
         self.fun_num = fun_num
@@ -46,6 +47,7 @@ class Learner(nn.Module):
         self.fc3 = nn.Linear(84, self.sub_num_pol * 2 * (self.fun_num + self.p_bins + self.m_bins))
         
     def forward(self, x):
+        x = x[:, 0:1, :, :]
         y = self.conv1(x)
         y = self.relu1(y)
         y = self.pool1(y)
@@ -61,12 +63,29 @@ class Learner(nn.Module):
 
         return y
 
+class LeNet(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.fc1 = nn.Linear(784, 2048)
+        self.relu1 = nn.ReLU()
+        self.fc2 = nn.Linear(2048, 10)
+        self.relu2 = nn.ReLU()
+
+    def forward(self, x):
+        x = x.reshape((-1, 784))
+        y = self.fc1(x)
+        y = self.relu1(y)
+        y = self.fc2(y)
+        y = self.relu2(y)
+        return y
+
 
 class Evolutionary_learner():
 
     def __init__(self, network, num_solutions = 10, num_generations = 5, num_parents_mating = 5, batch_size=32, child_network = None, p_bins = 11, mag_bins = 10, sub_num_pol=5, fun_num = 14, exclude_method=[], augmentation_space = None, ds=None, ds_name=None):
-        self.auto_aug_agent = Learner(fun_num=fun_num, p_bins=p_bins, m_bins=mag_bins, sub_num_pol=sub_num_pol)
-        self.torch_ga = torchga.TorchGA(model=network, num_solutions=num_solutions)
+        self.auto_aug_agent = network
+        self.torch_ga = torchga.TorchGA(model=self.auto_aug_agent, num_solutions=num_solutions)
+
         self.num_generations = num_generations
         self.num_parents_mating = num_parents_mating
         self.initial_population = self.torch_ga.population_weights
@@ -295,8 +314,8 @@ class Evolutionary_learner():
             for idx, (test_x, label_x) in enumerate(self.train_loader):
                 full_policy = self.get_policy_cov(test_x)
 
-            fit_val = ((test_autoaugment_policy(full_policy, self.train_dataset, self.test_dataset)[0])/
-                        + test_autoaugment_policy(full_policy, self.train_dataset, self.test_dataset)[0]) / 2
+            fit_val = ((test_autoaugment_policy(full_policy, self.train_dataset, self.test_dataset, self.child_network)[0])/
+                        + test_autoaugment_policy(full_policy, self.train_dataset, self.test_dataset, self.child_network)[0]) / 2
 
             return fit_val
 
@@ -382,8 +401,12 @@ def train_child_network(child_network, train_loader, test_loader, sgd,
         correct = 0
         _sum = 0
         child_network.eval()
+        print("here0")
         with torch.no_grad():
+            print("here1")
+            print("len test_loader: ", len(test_loader))
             for idx, (test_x, test_label) in enumerate(test_loader):
+                print("here2")
                 test_x = test_x.to(device=device, dtype=test_x.dtype)
                 test_label = test_label.to(device=device, dtype=test_label.dtype)
 
@@ -394,6 +417,7 @@ def train_child_network(child_network, train_loader, test_loader, sgd,
                 correct += torch.sum(_, axis=-1)
 
                 _sum += _.shape[0]
+                print("SUM: ", _sum)
         
         acc = correct / _sum
 
@@ -836,3 +860,4 @@ class TrivialAugmentWide(torch.nn.Module):
 
 
 
+
diff --git a/backend/auto_augmentation/progress.py b/backend/auto_augmentation/progress.py
index c33d45ca275dc305bbeca58624819d92401608d3..a83b2bb77020f71226d363e8a4a3f3176a5c660b 100644
--- a/backend/auto_augmentation/progress.py
+++ b/backend/auto_augmentation/progress.py
@@ -81,7 +81,9 @@ def response():
             policies = UCB1_JC.generate_policies(num_policies, num_sub_policies)
             q_values, best_q_values = UCB1_JC.run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet, ds_name)
         elif auto_aug_leanrer == 'Evolutionary Learner':
-            learner = Evo.Evolutionary_learner(network=Evo.Learner(), fun_num=num_funcs, p_bins=1, mag_bins=1, sub_num_pol=1, ds = ds, ds_name=ds_name, exclude_method=exclude_method)
+            network = Evo.Learner(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1)
+            child_network = Evo.LeNet()
+            learner = Evo.Evolutionary_learner(network=network, fun_num=num_funcs, p_bins=1, mag_bins=1, sub_num_pol=1, ds = ds, ds_name=ds_name, exclude_method=exclude_method, child_network=child_network)
             learner.run_instance()
         elif auto_aug_leanrer == 'Random Searcher':
             pass 
@@ -131,4 +133,4 @@ def response():
 #             return redirect(url_for('uploaded_file', filename=filename))
 #     return '''
     
-#     '''
\ No newline at end of file
+#     '''