diff --git a/MetaAugment/UCB1_JC_py.py b/MetaAugment/UCB1_JC_py.py
index f49889521253e052334500180e8f46b0b48f9d40..f829ab709f6446742a8a2771332714b9bf578f48 100644
--- a/MetaAugment/UCB1_JC_py.py
+++ b/MetaAugment/UCB1_JC_py.py
@@ -20,8 +20,8 @@ from matplotlib import pyplot as plt
 from numpy import save, load
 from tqdm import trange
 
-from MetaAugment.child_networks import *
-from MetaAugment.main import create_toy, train_child_network
+from .child_networks import *
+from .main import create_toy, train_child_network
 
 
 # In[6]:
@@ -102,7 +102,7 @@ def sample_sub_policy(policies, policy, num_sub_policies):
 
 
 """Sample policy, open and apply above transformations"""
-def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation, iterations, IsLeNet, ds_name=None):
+def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation, iterations, IsLeNet):
 
     # get number of policies and sub-policies
     num_policies = len(policies)
@@ -130,40 +130,32 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl
         # create transformations using above info
         transform = torchvision.transforms.Compose(
             [torchvision.transforms.RandomAffine(degrees=(degrees,degrees), shear=(shear,shear), scale=(scale,scale)),
-            torchvision.transforms.CenterCrop(28), # <--- need to remove after finishing testing
             torchvision.transforms.ToTensor()])
 
         # open data and apply these transformations
         if ds == "MNIST":
-            train_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/train', train=True, download=True, transform=transform)
-            test_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/test', train=False, download=True, transform=transform)
+            train_dataset = datasets.MNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)
+            test_dataset = datasets.MNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)
         elif ds == "KMNIST":
-            train_dataset = datasets.KMNIST(root='./MetaAugment/datasets/kmnist/train', train=True, download=True, transform=transform)
-            test_dataset = datasets.KMNIST(root='./MetaAugment/datasets/kmnist/test', train=False, download=True, transform=transform)
+            train_dataset = datasets.KMNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)
+            test_dataset = datasets.KMNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)
         elif ds == "FashionMNIST":
-            train_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=True, transform=transform)
-            test_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/test', train=False, download=True, transform=transform)
+            train_dataset = datasets.FashionMNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)
+            test_dataset = datasets.FashionMNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)
         elif ds == "CIFAR10":
-            train_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/train', train=True, download=True, transform=transform)
-            test_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/test', train=False, download=True, transform=transform)
+            train_dataset = datasets.CIFAR10(root='./MetaAugment/train', train=True, download=True, transform=transform)
+            test_dataset = datasets.CIFAR10(root='./MetaAugment/test', train=False, download=True, transform=transform)
         elif ds == "CIFAR100":
-            train_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/train', train=True, download=True, transform=transform)
-            test_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/test', train=False, download=True, transform=transform)
-        elif ds == 'Other':
-            dataset = datasets.ImageFolder('./MetaAugment/datasets/upload_dataset/'+ ds_name, transform=transform)
-            len_train = int(0.8*len(dataset))
-            train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len_train, len(dataset)-len_train])
+            train_dataset = datasets.CIFAR100(root='./MetaAugment/train', train=True, download=True, transform=transform)
+            test_dataset = datasets.CIFAR100(root='./MetaAugment/test', train=False, download=True, transform=transform)
 
         # check sizes of images
         img_height = len(train_dataset[0][0][0])
         img_width = len(train_dataset[0][0][0][0])
         img_channels = len(train_dataset[0][0])
 
-
         # check output labels
-        if ds == 'Other':
-            num_labels = len(dataset.class_to_idx)
-        elif ds == "CIFAR10" or ds == "CIFAR100":
+        if ds == "CIFAR10" or ds == "CIFAR100":
             num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1)
         else:
             num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1).item()
@@ -172,22 +164,70 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl
         train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size)
 
         # create model
-	device = 'cuda' if torch.cuda.is_available() else 'cpu'
+        device = 'cuda' if torch.cuda.is_available() else 'cpu'
         if IsLeNet == "LeNet":
             model = LeNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)
         elif IsLeNet == "EasyNet":
             model = EasyNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)
-        elif IsLeNet == 'SimpleNet':
-            model = SimpleNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)
         else:
-            model = pickle.load(open(f'datasets/childnetwork', "rb"))
-
+            model = SimpleNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)
         sgd = optim.SGD(model.parameters(), lr=1e-1)
         cost = nn.CrossEntropyLoss()
 
-        best_acc = train_child_network(model, train_loader, test_loader, sgd,
-                         cost, max_epochs, early_stop_num, early_stop_flag,
-			 average_validation, logging=False, print_every_epoch=False)
+        # set variables for best validation accuracy and early stop count
+        best_acc = 0
+        early_stop_cnt = 0
+        total_val = 0
+
+        # train model and check validation accuracy each epoch
+        for _epoch in range(max_epochs):
+
+            # train model
+            model.train()
+            for idx, (train_x, train_label) in enumerate(train_loader):
+                train_x, train_label = train_x.to(device), train_label.to(device) # new code
+                label_np = np.zeros((train_label.shape[0], num_labels))
+                sgd.zero_grad()
+                predict_y = model(train_x.float())
+                loss = cost(predict_y, train_label.long())
+                loss.backward()
+                sgd.step()
+
+            # check validation accuracy on validation set
+            correct = 0
+            _sum = 0
+            model.eval()
+            for idx, (test_x, test_label) in enumerate(test_loader):
+                test_x, test_label = test_x.to(device), test_label.to(device) # new code
+                predict_y = model(test_x.float()).detach()
+                #predict_ys = np.argmax(predict_y, axis=-1)
+                predict_ys = torch.argmax(predict_y, axis=-1) # changed np to torch
+                #label_np = test_label.numpy()
+                _ = predict_ys == test_label
+                #correct += np.sum(_.numpy(), axis=-1)
+                correct += np.sum(_.cpu().numpy(), axis=-1) # added .cpu()
+                _sum += _.shape[0]
+            
+            acc = correct / _sum
+
+            if average_validation[0] <= _epoch <= average_validation[1]:
+                total_val += acc
+
+            # update best validation accuracy if it was higher, otherwise increase early stop count
+            if acc > best_acc :
+                best_acc = acc
+                early_stop_cnt = 0
+            else:
+                early_stop_cnt += 1
+
+            # exit if validation gets worse over 10 runs and using early stopping
+            if early_stop_cnt >= early_stop_num and early_stop_flag:
+                break
+
+            # exit if using fixed epoch length
+            if _epoch >= average_validation[1] and not early_stop_flag:
+                best_acc = total_val / (average_validation[1] - average_validation[0] + 1)
+                break
 
         # update q_values
         if policy < num_policies:
@@ -198,7 +238,7 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl
         best_q_value = max(q_values)
         best_q_values.append(best_q_value)
 
-        if (policy+1) % 5 == 0:
+        if (policy+1) % 10 == 0:
             print("Iteration: {},\tQ-Values: {}, Best Policy: {}".format(policy+1, list(np.around(np.array(q_values),2)), max(list(np.around(np.array(q_values),2)))))
 
         # update counts
@@ -210,7 +250,6 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl
             for i in range(num_policies):
                 q_plus_cnt[i] = q_values[i] + np.sqrt(2*np.log(total_count)/cnts[i])
 
-        # yield q_values, best_q_values
     return q_values, best_q_values
 
 
diff --git a/MetaAugment/__init__.py b/MetaAugment/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391