diff --git a/MetaAugment/UCB1_JC_py.py b/MetaAugment/UCB1_JC_py.py
index 9c16382a6e6474179465520516d0dd1a7fdab010..48a8573dc4431ab07ecb318aa945a10e1ef2d38d 100644
--- a/MetaAugment/UCB1_JC_py.py
+++ b/MetaAugment/UCB1_JC_py.py
@@ -21,7 +21,7 @@ from numpy import save, load
 from tqdm import trange
 
 from MetaAugment.child_networks import *
-from MetaAugment.main import create_toy
+from MetaAugment.main import create_toy, train_child_network
 
 
 # In[6]:
@@ -184,46 +184,9 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl
         sgd = optim.SGD(model.parameters(), lr=1e-1)
         cost = nn.CrossEntropyLoss()
 
-        # set variables for best validation accuracy and early stop count
-        best_acc = 0
-        early_stop_cnt = 0
-
-        # train model and check validation accuracy each epoch
-        for _epoch in range(max_epochs):
-
-            # train model
-            model.train()
-            for idx, (train_x, train_label) in enumerate(train_loader):
-                label_np = np.zeros((train_label.shape[0], num_labels))
-                sgd.zero_grad()
-                predict_y = model(train_x.float())
-                loss = cost(predict_y, train_label.long())
-                loss.backward()
-                sgd.step()
-
-            # check validation accuracy on validation set
-            correct = 0
-            _sum = 0
-            model.eval()
-            for idx, (test_x, test_label) in enumerate(test_loader):
-                predict_y = model(test_x.float()).detach()
-                predict_ys = np.argmax(predict_y, axis=-1)
-                label_np = test_label.numpy()
-                _ = predict_ys == test_label
-                correct += np.sum(_.numpy(), axis=-1)
-                _sum += _.shape[0]
-            
-            # update best validation accuracy if it was higher, otherwise increase early stop count
-            acc = correct / _sum
-            if acc > best_acc :
-                best_acc = acc
-                early_stop_cnt = 0
-            else:
-                early_stop_cnt += 1
-
-            # exit if validation gets worse over 10 runs
-            if early_stop_cnt >= early_stop_num:
-                break
+        best_acc = train_child_network(model, train_loader, test_loader, sgd,
+                         cost, max_epochs, early_stop_num, logging=False,
+                         print_every_epoch=False)
 
         # update q_values
         if policy < num_policies:
@@ -253,25 +216,25 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl
 # # In[9]:
 
 
-# batch_size = 32       # size of batch the inner NN is trained with
-# learning_rate = 1e-1  # fix learning rate
-# ds = "MNIST"          # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)
-# toy_size = 0.02       # total propeortion of training and test set we use
-# max_epochs = 100      # max number of epochs that is run if early stopping is not hit
-# early_stop_num = 10   # max number of worse validation scores before early stopping is triggered
-# num_policies = 5      # fix number of policies
-# num_sub_policies = 5  # fix number of sub-policies in a policy
-# iterations = 100      # total iterations, should be more than the number of policies
-# IsLeNet = "SimpleNet" # using LeNet or EasyNet or SimpleNet
+batch_size = 32       # size of batch the inner NN is trained with
+learning_rate = 1e-1  # fix learning rate
+ds = "MNIST"          # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)
+toy_size = 0.02       # total propeortion of training and test set we use
+max_epochs = 100      # max number of epochs that is run if early stopping is not hit
+early_stop_num = 10   # max number of worse validation scores before early stopping is triggered
+num_policies = 5      # fix number of policies
+num_sub_policies = 5  # fix number of sub-policies in a policy
+iterations = 100      # total iterations, should be more than the number of policies
+IsLeNet = "SimpleNet" # using LeNet or EasyNet or SimpleNet
 
-# # generate random policies at start
-# policies = generate_policies(num_policies, num_sub_policies)
+# generate random policies at start
+policies = generate_policies(num_policies, num_sub_policies)
 
-# q_values, best_q_values = run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet)
+q_values, best_q_values = run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet)
 
-# plt.plot(best_q_values)
+plt.plot(best_q_values)
 
-# best_q_values = np.array(best_q_values)
-# save('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), best_q_values)
-# #best_q_values = load('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), allow_pickle=True)
+best_q_values = np.array(best_q_values)
+save('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), best_q_values)
+#best_q_values = load('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), allow_pickle=True)
 
diff --git a/MetaAugment/child_networks/__init__.py b/MetaAugment/child_networks/__init__.py
index 88d93647b235ce75d9af9f7c77441134d3035f1e..767a4e2b646b5d7aab76087e15a308ed4ed36956 100644
--- a/MetaAugment/child_networks/__init__.py
+++ b/MetaAugment/child_networks/__init__.py
@@ -1,2 +1,67 @@
 from .lenet import *
-from .bad_lenet import *
\ No newline at end of file
+from .bad_lenet import *
+
+class LeNet(nn.Module):
+    def __init__(self, img_height=28, img_width=28, num_labels=10, img_channels=1):
+        super().__init__()
+        self.conv1 = nn.Conv2d(img_channels, 6, 5)
+        self.relu1 = nn.ReLU()
+        self.pool1 = nn.MaxPool2d(2)
+        self.conv2 = nn.Conv2d(6, 16, 5)
+        self.relu2 = nn.ReLU()
+        self.pool2 = nn.MaxPool2d(2)
+        self.fc1 = nn.Linear(int((((img_height-4)/2-4)/2)*(((img_width-4)/2-4)/2)*16), 120)
+        self.relu3 = nn.ReLU()
+        self.fc2 = nn.Linear(120, 84)
+        self.relu4 = nn.ReLU()
+        self.fc3 = nn.Linear(84, num_labels)
+        self.relu5 = nn.ReLU()
+
+    def forward(self, x):
+        y = self.conv1(x)
+        y = self.relu1(y)
+        y = self.pool1(y)
+        y = self.conv2(y)
+        y = self.relu2(y)
+        y = self.pool2(y)
+        y = y.view(y.shape[0], -1)
+        y = self.fc1(y)
+        y = self.relu3(y)
+        y = self.fc2(y)
+        y = self.relu4(y)
+        y = self.fc3(y)
+        y = self.relu5(y)
+        return y
+
+
+"""Define internal NN module that trains on the dataset"""
+class EasyNet(nn.Module):
+    def __init__(self, img_height=28, img_width=28, num_labels=10, img_channels=1):
+        super().__init__()
+        self.fc1 = nn.Linear(img_height*img_width*img_channels, 2048)
+        self.relu1 = nn.ReLU()
+        self.fc2 = nn.Linear(2048, num_labels)
+        self.relu2 = nn.ReLU()
+
+    def forward(self, x):
+        y = x.view(x.shape[0], -1)
+        y = self.fc1(y)
+        y = self.relu1(y)
+        y = self.fc2(y)
+        y = self.relu2(y)
+        return y
+
+
+
+"""Define internal NN module that trains on the dataset"""
+class SimpleNet(nn.Module):
+    def __init__(self, img_height=28, img_width=28, num_labels=10, img_channels=1):
+        super().__init__()
+        self.fc1 = nn.Linear(img_height*img_width*img_channels, num_labels)
+        self.relu1 = nn.ReLU()
+
+    def forward(self, x):
+        y = x.view(x.shape[0], -1)
+        y = self.fc1(y)
+        y = self.relu1(y)
+        return y
diff --git a/MetaAugment/main.py b/MetaAugment/main.py
index 9a43c222a794021c002d046adb16cf8687162261..a7842f792680564d90fcb593fe2b4c7c4ca0e0ef 100644
--- a/MetaAugment/main.py
+++ b/MetaAugment/main.py
@@ -106,9 +106,8 @@ def train_child_network(child_network, train_loader, test_loader, sgd,
 
     if logging:
         return best_acc.item(), acc_log
-    
-    print('main.train_child_network best accuracy: ', best_acc)
-    return best_acc.item()
+    else:
+        return best_acc.item()
 
 if __name__=='__main__':
     import MetaAugment.child_networks as cn