diff --git a/MetaAugment/UCB1_JC_py.py b/MetaAugment/UCB1_JC_py.py
index 1986368aff7f5d42e966f61e0cf17424d0f2fb7e..66bacb32e8896f486422bd4661c4c1f5699589ce 100644
--- a/MetaAugment/UCB1_JC_py.py
+++ b/MetaAugment/UCB1_JC_py.py
@@ -5,6 +5,7 @@
 
 
 import numpy as np
+from sklearn.covariance import log_likelihood
 import torch
 torch.manual_seed(0)
 import torch.nn as nn
@@ -13,6 +14,7 @@ import torch.optim as optim
 import torch.utils.data as data_utils
 import torchvision
 import torchvision.datasets as datasets
+import pickle
 
 from matplotlib import pyplot as plt
 from numpy import save, load
@@ -198,7 +200,7 @@ def sample_sub_policy(policies, policy, num_sub_policies):
 
 
 """Sample policy, open and apply above transformations"""
-def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet):
+def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet, ds_name=None):
 
     # get number of policies and sub-policies
     num_policies = len(policies)
@@ -226,6 +228,7 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl
         # create transformations using above info
         transform = torchvision.transforms.Compose(
             [torchvision.transforms.RandomAffine(degrees=(degrees,degrees), shear=(shear,shear), scale=(scale,scale)),
+            torchvision.transforms.CenterCrop(28),
             torchvision.transforms.ToTensor()])
 
         # open data and apply these transformations
@@ -244,14 +247,24 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl
         elif ds == "CIFAR100":
             train_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=True, transform=transform)
             test_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/fashionmnist/test', train=False, download=True, transform=transform)
+        elif ds == 'Other':
+            dataset = datasets.ImageFolder('./MetaAugment/datasets/'+ ds_name, transform=transform)
+            len_train = int(0.8*len(dataset))
+            train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len_train, len(dataset)-len_train])
+
+        print('train_dataset', len(train_dataset), 'test_dataset', len(test_dataset))
+
 
         # check sizes of images
         img_height = len(train_dataset[0][0][0])
         img_width = len(train_dataset[0][0][0][0])
         img_channels = len(train_dataset[0][0])
 
+
         # check output labels
-        if ds == "CIFAR10" or ds == "CIFAR100":
+        if ds == 'Other':
+            num_labels = len(dataset.class_to_idx)
+        elif ds == "CIFAR10" or ds == "CIFAR100":
             num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1)
         else:
             num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1).item()
@@ -264,8 +277,11 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl
             model = LeNet(img_height, img_width, num_labels, img_channels)
         elif IsLeNet == "EasyNet":
             model = EasyNet(img_height, img_width, num_labels, img_channels)
-        else:
+        elif IsLeNet == 'SimpleNet':
             model = SimpleNet(img_height, img_width, num_labels, img_channels)
+        else:
+            model = pickle.load(open(f'datasets/childnetwork', "rb"))
+
         sgd = optim.SGD(model.parameters(), lr=1e-1)
         cost = nn.CrossEntropyLoss()
 
@@ -319,7 +335,7 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl
         best_q_value = max(q_values)
         best_q_values.append(best_q_value)
 
-        if (policy+1) % 10 == 0:
+        if (policy+1) % 1 == 0:
             print("Iteration: {},\tQ-Values: {}, Best Policy: {}".format(policy+1, list(np.around(np.array(q_values),2)), max(list(np.around(np.array(q_values),2)))))
 
         # update counts