diff --git a/MetaAugment/UCB1_JC_py.py b/MetaAugment/UCB1_JC_py.py
index f829ab709f6446742a8a2771332714b9bf578f48..959d55230a8ad0bfa7c4973c8b0857d93c67213f 100644
--- a/MetaAugment/UCB1_JC_py.py
+++ b/MetaAugment/UCB1_JC_py.py
@@ -102,7 +102,7 @@ def sample_sub_policy(policies, policy, num_sub_policies):
 
 
 """Sample policy, open and apply above transformations"""
-def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation, iterations, IsLeNet):
+def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation, iterations, IsLeNet, ds_name=None):
 
     # get number of policies and sub-policies
     num_policies = len(policies)
@@ -130,32 +130,40 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl
         # create transformations using above info
         transform = torchvision.transforms.Compose(
             [torchvision.transforms.RandomAffine(degrees=(degrees,degrees), shear=(shear,shear), scale=(scale,scale)),
+            torchvision.transforms.CenterCrop(28), # <--- need to remove after finishing testing
             torchvision.transforms.ToTensor()])
 
         # open data and apply these transformations
         if ds == "MNIST":
-            train_dataset = datasets.MNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)
-            test_dataset = datasets.MNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)
+            train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=True, transform=transform)
+            test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=True, transform=transform)
         elif ds == "KMNIST":
-            train_dataset = datasets.KMNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)
-            test_dataset = datasets.KMNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)
+            train_dataset = datasets.KMNIST(root='./datasets/kmnist/train', train=True, download=True, transform=transform)
+            test_dataset = datasets.KMNIST(root='./datasets/kmnist/test', train=False, download=True, transform=transform)
         elif ds == "FashionMNIST":
-            train_dataset = datasets.FashionMNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)
-            test_dataset = datasets.FashionMNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)
+            train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train', train=True, download=True, transform=transform)
+            test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test', train=False, download=True, transform=transform)
         elif ds == "CIFAR10":
-            train_dataset = datasets.CIFAR10(root='./MetaAugment/train', train=True, download=True, transform=transform)
-            test_dataset = datasets.CIFAR10(root='./MetaAugment/test', train=False, download=True, transform=transform)
+            train_dataset = datasets.CIFAR10(root='./datasets/cifar10/train', train=True, download=True, transform=transform)
+            test_dataset = datasets.CIFAR10(root='./datasets/cifar10/test', train=False, download=True, transform=transform)
         elif ds == "CIFAR100":
-            train_dataset = datasets.CIFAR100(root='./MetaAugment/train', train=True, download=True, transform=transform)
-            test_dataset = datasets.CIFAR100(root='./MetaAugment/test', train=False, download=True, transform=transform)
+            train_dataset = datasets.CIFAR100(root='./datasets/cifar100/train', train=True, download=True, transform=transform)
+            test_dataset = datasets.CIFAR100(root='./datasets/cifar100/test', train=False, download=True, transform=transform)
+        elif ds == 'Other':
+            dataset = datasets.ImageFolder('./datasets/upload_dataset/'+ ds_name, transform=transform)
+            len_train = int(0.8*len(dataset))
+            train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len_train, len(dataset)-len_train])
 
         # check sizes of images
         img_height = len(train_dataset[0][0][0])
         img_width = len(train_dataset[0][0][0][0])
         img_channels = len(train_dataset[0][0])
 
+
         # check output labels
-        if ds == "CIFAR10" or ds == "CIFAR100":
+        if ds == 'Other':
+            num_labels = len(dataset.class_to_idx)
+        elif ds == "CIFAR10" or ds == "CIFAR100":
             num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1)
         else:
             num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1).item()
@@ -164,70 +172,26 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl
         train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size)
 
         # create model
-        device = 'cuda' if torch.cuda.is_available() else 'cpu'
+        if torch.cuda.is_available():
+            device='cuda'
+        else:
+            device='cpu'
+        
         if IsLeNet == "LeNet":
             model = LeNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)
         elif IsLeNet == "EasyNet":
             model = EasyNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)
-        else:
+        elif IsLeNet == 'SimpleNet':
             model = SimpleNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)
-        sgd = optim.SGD(model.parameters(), lr=1e-1)
+        else:
+            model = pickle.load(open(f'datasets/childnetwork', "rb"))
+
+        sgd = optim.SGD(model.parameters(), lr=learning_rate)
         cost = nn.CrossEntropyLoss()
 
-        # set variables for best validation accuracy and early stop count
-        best_acc = 0
-        early_stop_cnt = 0
-        total_val = 0
-
-        # train model and check validation accuracy each epoch
-        for _epoch in range(max_epochs):
-
-            # train model
-            model.train()
-            for idx, (train_x, train_label) in enumerate(train_loader):
-                train_x, train_label = train_x.to(device), train_label.to(device) # new code
-                label_np = np.zeros((train_label.shape[0], num_labels))
-                sgd.zero_grad()
-                predict_y = model(train_x.float())
-                loss = cost(predict_y, train_label.long())
-                loss.backward()
-                sgd.step()
-
-            # check validation accuracy on validation set
-            correct = 0
-            _sum = 0
-            model.eval()
-            for idx, (test_x, test_label) in enumerate(test_loader):
-                test_x, test_label = test_x.to(device), test_label.to(device) # new code
-                predict_y = model(test_x.float()).detach()
-                #predict_ys = np.argmax(predict_y, axis=-1)
-                predict_ys = torch.argmax(predict_y, axis=-1) # changed np to torch
-                #label_np = test_label.numpy()
-                _ = predict_ys == test_label
-                #correct += np.sum(_.numpy(), axis=-1)
-                correct += np.sum(_.cpu().numpy(), axis=-1) # added .cpu()
-                _sum += _.shape[0]
-            
-            acc = correct / _sum
-
-            if average_validation[0] <= _epoch <= average_validation[1]:
-                total_val += acc
-
-            # update best validation accuracy if it was higher, otherwise increase early stop count
-            if acc > best_acc :
-                best_acc = acc
-                early_stop_cnt = 0
-            else:
-                early_stop_cnt += 1
-
-            # exit if validation gets worse over 10 runs and using early stopping
-            if early_stop_cnt >= early_stop_num and early_stop_flag:
-                break
-
-            # exit if using fixed epoch length
-            if _epoch >= average_validation[1] and not early_stop_flag:
-                best_acc = total_val / (average_validation[1] - average_validation[0] + 1)
-                break
+        best_acc = train_child_network(model, train_loader, test_loader, sgd,
+                         cost, max_epochs, early_stop_num, early_stop_flag,
+			 average_validation, logging=False, print_every_epoch=False)
 
         # update q_values
         if policy < num_policies:
@@ -238,7 +202,7 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl
         best_q_value = max(q_values)
         best_q_values.append(best_q_value)
 
-        if (policy+1) % 10 == 0:
+        if (policy+1) % 5 == 0:
             print("Iteration: {},\tQ-Values: {}, Best Policy: {}".format(policy+1, list(np.around(np.array(q_values),2)), max(list(np.around(np.array(q_values),2)))))
 
         # update counts
@@ -250,6 +214,7 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl
             for i in range(num_policies):
                 q_plus_cnt[i] = q_values[i] + np.sqrt(2*np.log(total_count)/cnts[i])
 
+        # yield q_values, best_q_values
     return q_values, best_q_values