diff --git a/MetaAugment/autoaugment_learners/evo_learner.py b/MetaAugment/autoaugment_learners/evo_learner.py
index 1ff576b28a4d39f367550bb2fa15168e0a9b99c8..18ecf751e614585c7db86902eb3cce927dd696f5 100644
--- a/MetaAugment/autoaugment_learners/evo_learner.py
+++ b/MetaAugment/autoaugment_learners/evo_learner.py
@@ -7,12 +7,11 @@ import pygad.torchga as torchga
 import copy
 import torch
 from MetaAugment.controller_networks.evo_controller import evo_controller
-
-from MetaAugment.autoaugment_learners.aa_learner import aa_learner, augmentation_space
 import MetaAugment.child_networks as cn
+from .aa_learner import aa_learner, augmentation_space
 
 
-class evo_learner():
+class evo_learner(aa_learner):
 
     def __init__(self, 
                 sp_num=1,
diff --git a/MetaAugment/autoaugment_learners/ucb_learner.py b/MetaAugment/autoaugment_learners/ucb_learner.py
index 41b8977156e9148965b0ffa6c00fe4d0a4a2595d..8862e14bce93a177a8d875cbb68203faaa0be1ff 100644
--- a/MetaAugment/autoaugment_learners/ucb_learner.py
+++ b/MetaAugment/autoaugment_learners/ucb_learner.py
@@ -5,220 +5,150 @@
 
 
 import numpy as np
-from sklearn.covariance import log_likelihood
 import torch
-torch.manual_seed(0)
 import torch.nn as nn
-import torch.nn.functional as F
 import torch.optim as optim
-import torch.utils.data as data_utils
 import torchvision
-import torchvision.datasets as datasets
-import pickle
 
-from matplotlib import pyplot as plt
-from numpy import save, load
 from tqdm import trange
 
 from ..child_networks import *
-from ..main import create_toy, train_child_network
+from ..main import train_child_network
+from .randomsearch_learner import randomsearch_learner
+from .aa_learner import augmentation_space
+
+
+class ucb_learner(randomsearch_learner):
+    """
+    Tests randomly sampled policies from the search space specified by the AutoAugment
+    paper. Acts as a baseline for other aa_learner's.
+    """
+    def __init__(self,
+                # parameters that define the search space
+                sp_num=5,
+                fun_num=14,
+                p_bins=11,
+                m_bins=10,
+                discrete_p_m=True,
+                # hyperparameters for when training the child_network
+                batch_size=8,
+                toy_flag=False,
+                toy_size=0.1,
+                learning_rate=1e-1,
+                max_epochs=float('inf'),
+                early_stop_num=30,
+                # ucb_learner specific hyperparameter
+                num_policies=100
+                ):
+        
+        super().__init__(sp_num, 
+                fun_num, 
+                p_bins, 
+                m_bins, 
+                discrete_p_m=discrete_p_m,
+                batch_size=batch_size,
+                toy_flag=toy_flag,
+                toy_size=toy_size,
+                learning_rate=learning_rate,
+                max_epochs=max_epochs,
+                early_stop_num=early_stop_num,)
+        
+        self.num_policies = num_policies
 
+        # When this learner is initialized we generate `num_policies` number
+        # of random policies. 
+        # generate_new_policy is inherited from the randomsearch_learner class
+        self.policies = [self.generate_new_policy() for _ in self.num_policies]
 
-# In[6]:
+        # attributes used in the UCB1 algorithm
+        self.q_values = [0]*self.num_policies
+        self.cnts = [0]*self.num_policies
+        self.q_plus_cnt = [0]*self.num_policies
+        self.total_count = 0
 
+    def learn(self, 
+            train_dataset, 
+            test_dataset, 
+            child_network_architecture, 
+            iterations=15):
 
-"""Randomly generate 10 policies"""
-"""Each policy has 5 sub-policies"""
-"""For each sub-policy, pick 2 transformations, 2 probabilities and 2 magnitudes"""
+        #Initialize vector weights, counts and regret
 
-def generate_policies(num_policies, num_sub_policies):
-    
-    policies = np.zeros([num_policies,num_sub_policies,6])
 
-    # Policies array will be 10x5x6
-    for policy in range(num_policies):
-        for sub_policy in range(num_sub_policies):
-            # pick two sub_policy transformations (0=rotate, 1=shear, 2=scale)
-            policies[policy, sub_policy, 0] = np.random.randint(0,3)
-            policies[policy, sub_policy, 1] = np.random.randint(0,3)
-            while policies[policy, sub_policy, 0] == policies[policy, sub_policy, 1]:
-                policies[policy, sub_policy, 1] = np.random.randint(0,3)
-
-            # pick probabilities
-            policies[policy, sub_policy, 2] = np.random.randint(0,11) / 10
-            policies[policy, sub_policy, 3] = np.random.randint(0,11) / 10
-
-            # pick magnitudes
-            for transformation in range(2):
-                if policies[policy, sub_policy, transformation] <= 1:
-                    policies[policy, sub_policy, transformation + 4] = np.random.randint(-4,5)*5
-                elif policies[policy, sub_policy, transformation] == 2:
-                    policies[policy, sub_policy, transformation + 4] = np.random.randint(5,15)/10
-
-    return policies
-
-
-# In[7]:
-
-
-"""Pick policy and sub-policy"""
-"""Each row of data should have a different sub-policy but for now, this will do"""
-
-def sample_sub_policy(policies, policy, num_sub_policies):
-    sub_policy = np.random.randint(0,num_sub_policies)
-
-    degrees = 0
-    shear = 0
-    scale = 1
-
-    # check for rotations
-    if policies[policy, sub_policy][0] == 0:
-        if np.random.uniform() < policies[policy, sub_policy][2]:
-            degrees = policies[policy, sub_policy][4]
-    elif policies[policy, sub_policy][1] == 0:
-        if np.random.uniform() < policies[policy, sub_policy][3]:
-            degrees = policies[policy, sub_policy][5]
-
-    # check for shears
-    if policies[policy, sub_policy][0] == 1:
-        if np.random.uniform() < policies[policy, sub_policy][2]:
-            shear = policies[policy, sub_policy][4]
-    elif policies[policy, sub_policy][1] == 1:
-        if np.random.uniform() < policies[policy, sub_policy][3]:
-            shear = policies[policy, sub_policy][5]
-
-    # check for scales
-    if policies[policy, sub_policy][0] == 2:
-        if np.random.uniform() < policies[policy, sub_policy][2]:
-            scale = policies[policy, sub_policy][4]
-    elif policies[policy, sub_policy][1] == 2:
-        if np.random.uniform() < policies[policy, sub_policy][3]:
-            scale = policies[policy, sub_policy][5]
-
-    return degrees, shear, scale
-
-
-# In[8]:
-
-
-"""Sample policy, open and apply above transformations"""
-def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation, iterations, IsLeNet, ds_name=None):
-
-    # get number of policies and sub-policies
-    num_policies = len(policies)
-    num_sub_policies = len(policies[0])
-
-    #Initialize vector weights, counts and regret
-    q_values = [0]*num_policies
-    cnts = [0]*num_policies
-    q_plus_cnt = [0]*num_policies
-    total_count = 0
-
-    best_q_values = []
-
-    for policy in trange(iterations):
-
-        # get the action to try (either initially in order or using best q_plus_cnt value)
-        if policy >= num_policies:
-            this_policy = np.argmax(q_plus_cnt)
-        else:
-            this_policy = policy
-
-        # get info of transformation for this sub-policy
-        degrees, shear, scale = sample_sub_policy(policies, this_policy, num_sub_policies)
-
-        # create transformations using above info
-        transform = torchvision.transforms.Compose(
-            [torchvision.transforms.RandomAffine(degrees=(degrees,degrees), shear=(shear,shear), scale=(scale,scale)),
-            torchvision.transforms.CenterCrop(28), # <--- need to remove after finishing testing
-            torchvision.transforms.ToTensor()])
-
-        # open data and apply these transformations
-        if ds == "MNIST":
-            train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=True, transform=transform)
-            test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=True, transform=transform)
-        elif ds == "KMNIST":
-            train_dataset = datasets.KMNIST(root='./datasets/kmnist/train', train=True, download=True, transform=transform)
-            test_dataset = datasets.KMNIST(root='./datasets/kmnist/test', train=False, download=True, transform=transform)
-        elif ds == "FashionMNIST":
-            train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train', train=True, download=True, transform=transform)
-            test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test', train=False, download=True, transform=transform)
-        elif ds == "CIFAR10":
-            train_dataset = datasets.CIFAR10(root='./datasets/cifar10/train', train=True, download=True, transform=transform)
-            test_dataset = datasets.CIFAR10(root='./datasets/cifar10/test', train=False, download=True, transform=transform)
-        elif ds == "CIFAR100":
-            train_dataset = datasets.CIFAR100(root='./datasets/cifar100/train', train=True, download=True, transform=transform)
-            test_dataset = datasets.CIFAR100(root='./datasets/cifar100/test', train=False, download=True, transform=transform)
-        elif ds == 'Other':
-            dataset = datasets.ImageFolder('./datasets/upload_dataset/'+ ds_name, transform=transform)
-            len_train = int(0.8*len(dataset))
-            train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len_train, len(dataset)-len_train])
-
-        # check sizes of images
-        img_height = len(train_dataset[0][0][0])
-        img_width = len(train_dataset[0][0][0][0])
-        img_channels = len(train_dataset[0][0])
-
-
-        # check output labels
-        if ds == 'Other':
-            num_labels = len(dataset.class_to_idx)
-        elif ds == "CIFAR10" or ds == "CIFAR100":
-            num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1)
-        else:
-            num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1).item()
-
-        # create toy dataset from above uploaded data
-        train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size)
-
-        # create model
-        if torch.cuda.is_available():
-            device='cuda'
-        else:
-            device='cpu'
-        
-        if IsLeNet == "LeNet":
-            model = LeNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)
-        elif IsLeNet == "EasyNet":
-            model = EasyNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)
-        elif IsLeNet == 'SimpleNet':
-            model = SimpleNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)
-        else:
-            model = pickle.load(open(f'datasets/childnetwork', "rb"))
+        best_q_values = []
 
-        sgd = optim.SGD(model.parameters(), lr=learning_rate)
-        cost = nn.CrossEntropyLoss()
+        for this_iter in trange(iterations):
 
-        best_acc = train_child_network(model, train_loader, test_loader, sgd,
-                         cost, max_epochs, early_stop_num, early_stop_flag,
-			 average_validation, logging=False, print_every_epoch=False)
+            # get the action to try (either initially in order or using best q_plus_cnt value)
+            if this_iter >= self.num_policies:
+                this_policy = self.policies[np.argmax(self.q_plus_cnt)]
+            else:
+                this_policy = this_iter
 
-        # update q_values
-        if policy < num_policies:
-            q_values[this_policy] += best_acc
-        else:
-            q_values[this_policy] = (q_values[this_policy]*cnts[this_policy] + best_acc) / (cnts[this_policy] + 1)
 
-        best_q_value = max(q_values)
-        best_q_values.append(best_q_value)
+            best_acc = self.test_autoaugment_policy(
+                                this_policy,
+                                child_network_architecture,
+                                train_dataset,
+                                test_dataset,
+                                logging=False
+                                )
 
-        if (policy+1) % 5 == 0:
-            print("Iteration: {},\tQ-Values: {}, Best Policy: {}".format(policy+1, list(np.around(np.array(q_values),2)), max(list(np.around(np.array(q_values),2)))))
+            # update q_values
+            if this_iter < self.num_policies:
+                self.q_values[this_policy] += best_acc
+            else:
+                self.q_values[this_policy] = (self.q_values[this_policy]*self.cnts[this_policy] + best_acc) / (self.cnts[this_policy] + 1)
 
-        # update counts
-        cnts[this_policy] += 1
-        total_count += 1
+            best_q_value = max(self.q_values)
+            best_q_values.append(best_q_value)
 
-        # update q_plus_cnt values every turn after the initial sweep through
-        if policy >= num_policies - 1:
-            for i in range(num_policies):
-                q_plus_cnt[i] = q_values[i] + np.sqrt(2*np.log(total_count)/cnts[i])
+            if (this_iter+1) % 5 == 0:
+                print("Iteration: {},\tQ-Values: {}, Best this_iter: {}".format(
+                                this_iter+1, 
+                                list(np.around(np.array(self.q_values),2)), 
+                                max(list(np.around(np.array(self.q_values),2)))
+                                )
+                    )
 
-        # yield q_values, best_q_values
-    return q_values, best_q_values
+            # update counts
+            self.cnts[this_policy] += 1
+            self.total_count += 1
+
+            # update q_plus_cnt values every turn after the initial sweep through
+            if this_iter >= self.num_policies - 1:
+                for i in range(self.num_policies):
+                    self.q_plus_cnt[i] = self.q_values[i] + np.sqrt(2*np.log(self.total_count)/self.cnts[i])
+
+            # yield q_values, best_q_values
+        return self.q_values, best_q_values
+
+
+       
+
+    
+def run_UCB1(
+            policies, 
+            batch_size, 
+            learning_rate, 
+            ds, 
+            toy_size, 
+            max_epochs, 
+            early_stop_num, 
+            early_stop_flag, 
+            average_validation, 
+            iterations, 
+            IsLeNet
+        ):
+    pass
+
+def generate_policies(
+            num_policies, 
+            self.sp_num
+        ):
+    pass
 
 
-# # In[9]:
 
 if __name__=='__main__':
     batch_size = 32       # size of batch the inner NN is trained with
@@ -230,18 +160,6 @@ if __name__=='__main__':
     early_stop_flag = True        # implement early stopping or not
     average_validation = [15,25]  # if not implementing early stopping, what epochs are we averaging over
     num_policies = 5      # fix number of policies
-    num_sub_policies = 5  # fix number of sub-policies in a policy
+    sp_num = 5  # fix number of sub-policies in a policy
     iterations = 100      # total iterations, should be more than the number of policies
-    IsLeNet = "SimpleNet" # using LeNet or EasyNet or SimpleNet
-
-    # generate random policies at start
-    policies = generate_policies(num_policies, num_sub_policies)
-
-    q_values, best_q_values = run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation, iterations, IsLeNet)
-
-    plt.plot(best_q_values)
-
-    best_q_values = np.array(best_q_values)
-    save('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), best_q_values)
-    #best_q_values = load('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), allow_pickle=True)
-
+    IsLeNet = "SimpleNet" # using LeNet or EasyNet or SimpleNet
\ No newline at end of file
diff --git a/backend_react/react_app.py b/backend_react/react_app.py
index 7f4b78d8ff7af38f3e126ceeb9a82602632ed5e7..9fa264ae2003c422b7e928596e1fcd4efe302a22 100644
--- a/backend_react/react_app.py
+++ b/backend_react/react_app.py
@@ -1,39 +1,24 @@
 from dataclasses import dataclass
 from flask import Flask, request, current_app, render_template
 # from flask_cors import CORS
-import subprocess
 import os
 import zipfile
 
-import numpy as np
 import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.optim as optim
-import torch.utils.data as data_utils
-import torchvision
-import torchvision.datasets as datasets
-
-from matplotlib import pyplot as plt
+
 from numpy import save, load
-from tqdm import trange
 torch.manual_seed(0)
 
 import os
 import sys
 sys.path.insert(0, os.path.abspath('..'))
+import wapp_util
 
-# # import agents and its functions
-from MetaAugment.autoaugment_learners import ucb_learner as UCB1_JC
-from MetaAugment.autoaugment_learners import evo_learner
-import MetaAugment.controller_networks as cn
-import MetaAugment.autoaugment_learners as aal
-print('@@@ import successful')
 
 # import agents and its functions
-# from ..MetaAugment import UCB1_JC_py as UCB1_JC
-# from ..MetaAugment import Evo_learner as Evo
-# print('@@@ import successful')
+from MetaAugment import UCB1_JC_py as UCB1_JC
+from MetaAugment import Evo_learner as Evo
+print('@@@ import successful')
 
 app = Flask(__name__)
 
@@ -141,27 +126,27 @@ def confirm():
 @app.route('/training', methods=['POST', 'GET'])
 def training():
 
-    # aa learner
-    auto_aug_learner = current_app.config.get('AAL')
-
-    # search space & problem setting
-    ds = current_app.config.get('ds')
-    ds_name = current_app.config.get('DSN')
-    exclude_method = current_app.config.get('exc_meth')
-    num_funcs = current_app.config.get('NUMFUN')
-    num_policies = current_app.config.get('NP')
-    num_sub_policies = current_app.config.get('NSP')
-    toy_size = current_app.config.get('TS')
+    # # aa learner
+    # auto_aug_learner = current_app.config.get('AAL')
+
+    # # search space & problem setting
+    # ds = current_app.config.get('ds')
+    # ds_name = current_app.config.get('DSN')
+    # exclude_method = current_app.config.get('exc_meth')
+    # num_funcs = current_app.config.get('NUMFUN')
+    # num_policies = current_app.config.get('NP')
+    # num_sub_policies = current_app.config.get('NSP')
+    # toy_size = current_app.config.get('TS')
     
-    # child network
-    IsLeNet = current_app.config.get('ISLENET')
+    # # child network
+    # IsLeNet = current_app.config.get('ISLENET')
 
-    # child network training hyperparameters
-    batch_size = current_app.config.get('BS')
-    early_stop_num = current_app.config.get('ESN')
-    iterations = current_app.config.get('IT')
-    learning_rate = current_app.config.get('LR')
-    max_epochs = current_app.config.get('ME')
+    # # child network training hyperparameters
+    # batch_size = current_app.config.get('BS')
+    # early_stop_num = current_app.config.get('ESN')
+    # iterations = current_app.config.get('IT')
+    # learning_rate = current_app.config.get('LR')
+    # max_epochs = current_app.config.get('ME')
 
     # default values 
     max_epochs = 10      # max number of epochs that is run if early stopping is not hit
@@ -170,46 +155,9 @@ def training():
     num_sub_policies = 5  # fix number of sub-policies in a policy
     data = current_app.config.get('data')
 
+    return {'status': 'training done!'}
 
-    if data.auto_aug_learner == 'UCB':
-        policies = UCB1_JC.generate_policies(num_policies, num_sub_policies)
-        q_values, best_q_values = UCB1_JC.run_UCB1(
-                                                policies,
-                                                data.batch_size, 
-                                                data.learning_rate, 
-                                                data.ds, 
-                                                data.toy_size, 
-                                                max_epochs, 
-                                                early_stop_num, 
-                                                data.iterations, 
-                                                data.IsLeNet, 
-                                                data.ds_name
-                                                )     
-        best_q_values = np.array(best_q_values)
-
-    elif data.auto_aug_learner == 'Evolutionary Learner':
-
-        network = cn.evo_controller.evo_controller(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1)
-        child_network = aal.evo.LeNet()
-        learner = aal.evo.evo_learner(
-                                    network=network, 
-                                    fun_num=num_funcs, 
-                                    p_bins=1, 
-                                    mag_bins=1, 
-                                    sub_num_pol=1, 
-                                    ds = ds, 
-                                    ds_name=ds_name, 
-                                    exclude_method=exclude_method, 
-                                    child_network=child_network
-                                    )
-
-        learner.run_instance()
-    elif data.auto_aug_learner == 'Random Searcher':
-        pass 
-    elif data.auto_aug_learner == 'Genetic Learner':
-        pass
 
-    return {'status': 'training done!'}
 
 
 
diff --git a/flask_mvp/app.py b/flask_mvp/app.py
index 5e39517f6ae17dc93910e02f960e7aac0074dd7d..8f71616620872dc04fd66be39753bb45a74ae2e4 100644
--- a/flask_mvp/app.py
+++ b/flask_mvp/app.py
@@ -3,7 +3,8 @@
 #     app.run(host='0.0.0.0',port=port)
 
 from numpy import broadcast
-from auto_augmentation import home, progress,result, training
+from auto_augmentation import home, progress,result
+from flask_mvp.auto_augmentation import training
 from flask_socketio import SocketIO,  send
 
 from flask import Flask, flash, request, redirect, url_for
diff --git a/flask_mvp/auto_augmentation/__init__.py b/flask_mvp/auto_augmentation/__init__.py
index 0899be3d1b979ffc3f6e5a123cdb848470b29feb..72634111728ad96c69734e682ef37bae7c112a75 100644
--- a/flask_mvp/auto_augmentation/__init__.py
+++ b/flask_mvp/auto_augmentation/__init__.py
@@ -2,7 +2,8 @@ import os
 
 from flask import Flask, render_template, request, flash
 
-from auto_augmentation import home, progress,result, training
+from auto_augmentation import home, progress,result
+from flask_mvp.auto_augmentation import training
 
 def create_app(test_config=None):
     # create and configure the app
diff --git a/flask_mvp/auto_augmentation/progress.py b/flask_mvp/auto_augmentation/progress.py
index 4c3e96b28ca42e47eada9913c5008bafb90f5ddb..9d82a63424f112db3be30ffaea5fa2239faa3417 100644
--- a/flask_mvp/auto_augmentation/progress.py
+++ b/flask_mvp/auto_augmentation/progress.py
@@ -1,32 +1,12 @@
 from flask import Blueprint, request, render_template, flash, send_file, current_app, g, session
-import subprocess
 import os
 import zipfile
 
-import numpy as np
 import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.optim as optim
-import torch.utils.data as data_utils
-import torchvision
-import torchvision.datasets as datasets
-
-from matplotlib import pyplot as plt
-from numpy import save, load
-from tqdm import trange
 torch.manual_seed(0)
-# import agents and its functions
 
-from MetaAugment.autoaugment_learners import ucb_learner
-# hi
-from MetaAugment import Evo_learner as Evo
-
-import MetaAugment.autoaugment_learners as aal
-from MetaAugment.main import create_toy
-import MetaAugment.child_networks as cn
-import pickle
 
+import wapp_util
 
 bp = Blueprint("progress", __name__)
 
@@ -92,100 +72,20 @@ def response():
         
 
 
-        if auto_aug_learner == 'UCB':
-            policies = ucb_learner.generate_policies(num_policies, num_sub_policies)
-            q_values, best_q_values = ucb_learner.run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet, ds_name)
-        elif auto_aug_learner == 'Evolutionary Learner':
-            learner = Evo.Evolutionary_learner(fun_num=num_funcs, p_bins=1, mag_bins=1, sub_num_pol=1, ds_name=ds_name, exclude_method=exclude_method)
-            learner.run_instance()
-        elif auto_aug_learner == 'Random Searcher':
-            # As opposed to when ucb==True, `ds` and `IsLenet` are processed outside of the agent
-            # This system makes more sense for the user who is not using the webapp and is instead
-            # using the library within their code
-            download = True
-            if ds == "MNIST":
-                train_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/train', train=True, download=download)
-                test_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/test', train=False,
-                                                download=download, transform=torchvision.transforms.ToTensor())
-            elif ds == "KMNIST":
-                train_dataset = datasets.KMNIST(root='./MetaAugment/datasets/kmnist/train', train=True, download=download)
-                test_dataset = datasets.KMNIST(root='./MetaAugment/datasets/kmnist/test', train=False,
-                                                download=download, transform=torchvision.transforms.ToTensor())
-            elif ds == "FashionMNIST":
-                train_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=download)
-                test_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/test', train=False,
-                                                download=download, transform=torchvision.transforms.ToTensor())
-            elif ds == "CIFAR10":
-                train_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/train', train=True, download=download)
-                test_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/test', train=False,
-                                                download=download, transform=torchvision.transforms.ToTensor())
-            elif ds == "CIFAR100":
-                train_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/train', train=True, download=download)
-                test_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/test', train=False,
-                                                download=download, transform=torchvision.transforms.ToTensor())
-            elif ds == 'Other':
-                dataset = datasets.ImageFolder('./MetaAugment/datasets/'+ ds_name)
-                len_train = int(0.8*len(dataset))
-                train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len_train, len(dataset)-len_train])
-
-            # check sizes of images
-            img_height = len(train_dataset[0][0][0])
-            img_width = len(train_dataset[0][0][0][0])
-            img_channels = len(train_dataset[0][0])
-            # check output labels
-            if ds == 'Other':
-                num_labels = len(dataset.class_to_idx)
-            elif ds == "CIFAR10" or ds == "CIFAR100":
-                num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1)
-            else:
-                num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1).item()
-            # create toy dataset from above uploaded data
-            train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size)
-            # create model
-            if IsLeNet == "LeNet":
-                model = cn.LeNet(img_height, img_width, num_labels, img_channels)
-            elif IsLeNet == "EasyNet":
-                model = cn.EasyNet(img_height, img_width, num_labels, img_channels)
-            elif IsLeNet == 'SimpleNet':
-                model = cn.SimpleNet(img_height, img_width, num_labels, img_channels)
-            else:
-                model = pickle.load(open(f'datasets/childnetwork', "rb"))
-
-            # use an aa_learner. in this case, a rs learner
-            agent = aal.randomsearch_learner(batch_size=batch_size,
-                                            toy_flag=True,
-                                            learning_rate=learning_rate,
-                                            toy_size=toy_size,
-                                            max_epochs=max_epochs,
-                                            early_stop_num=early_stop_num,
-                                            )
-            agent.learn(train_dataset,
-                        test_dataset,
-                        child_network_architecture=model,
-                        iterations=iterations)
-        elif auto_aug_learner == 'Genetic Learner':
-            pass
-
-        plt.figure()
-        plt.plot(q_values)
-
-
-        # if auto_aug_learner == 'UCB':
-        #     policies = ucb_learner.generate_policies(num_policies, num_sub_policies)
-        #     q_values, best_q_values = ucb_learner.run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet, ds_name)     
-        #     # plt.figure()
-        #     # plt.plot(q_values)
-        #     best_q_values = np.array(best_q_values)
-
-        # elif auto_aug_learner == 'Evolutionary Learner':
-        #     network = Evo.Learner(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1)
-        #     child_network = Evo.LeNet()
-        #     learner = Evo.Evolutionary_learner(network=network, fun_num=num_funcs, p_bins=1, mag_bins=1, sub_num_pol=1, ds = ds, ds_name=ds_name, exclude_method=exclude_method, child_network=child_network)
-        #     learner.run_instance()
-        # elif auto_aug_learner == 'Random Searcher':
-        #     pass 
-        # elif auto_aug_learner == 'Genetic Learner':
-        #     pass
+        learner = wapp_util.parse_users_learner_spec(auto_aug_learner, 
+                                                    ds, 
+                                                    exclude_method, 
+                                                    num_funcs, 
+                                                    num_policies, 
+                                                    num_sub_policies, 
+                                                    toy_size, 
+                                                    IsLeNet, 
+                                                    batch_size, 
+                                                    early_stop_num, 
+                                                    iterations, 
+                                                    learning_rate, 
+                                                    max_epochs, 
+                                                    ds_name)
 
 
     current_app.config['AAL'] = auto_aug_learner
diff --git a/flask_mvp/auto_augmentation/training.py b/flask_mvp/auto_augmentation/training.py
index 5e695b58a2994efb1bdc89bb363b3eddf643d9dc..e20b5867a3ebc6c8ff48ccf4e18411c7cbe3d08f 100644
--- a/flask_mvp/auto_augmentation/training.py
+++ b/flask_mvp/auto_augmentation/training.py
@@ -1,28 +1,11 @@
 from flask import Blueprint, request, render_template, flash, send_file, current_app
-import subprocess
 import os
-import zipfile
 
-import numpy as np
 import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.optim as optim
-import torch.utils.data as data_utils
-import torchvision
-import torchvision.datasets as datasets
-
-from matplotlib import pyplot as plt
-from numpy import save, load
-from tqdm import trange
 torch.manual_seed(0)
-# import agents and its functions
-
-import MetaAugment.autoaugment_learners as aal
-import MetaAugment.controller_networks as cont_n
-import MetaAugment.child_networks as cn
 
 
+import wapp_util
 
 bp = Blueprint("training", __name__)
 
@@ -56,41 +39,22 @@ def response():
     max_epochs = current_app.config.get('ME')
 
 
-    if auto_aug_learner == 'UCB':
-        policies = aal.ucb_learner.generate_policies(num_policies, num_sub_policies)
-        q_values, best_q_values = aal.ucb_learner.run_UCB1(
-                                                policies, 
-                                                batch_size, 
-                                                learning_rate, 
-                                                ds, 
-                                                toy_size, 
-                                                max_epochs, 
-                                                early_stop_num, 
-                                                iterations, 
-                                                IsLeNet, 
-                                                ds_name
-                                                )     
-        best_q_values = np.array(best_q_values)
-
-    elif auto_aug_learner == 'Evolutionary Learner':
-        network = cont_n.evo_controller(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1)
-        child_network = cn.LeNet()
-        learner = aal.evo_learner(
-                                network=network, 
-                                fun_num=num_funcs, 
-                                p_bins=1, 
-                                mag_bins=1, 
-                                sub_num_pol=1, 
-                                ds = ds, 
-                                ds_name=ds_name, 
-                                exclude_method=exclude_method, 
-                                child_network=child_network
-                                )
-        learner.run_instance()
-    elif auto_aug_learner == 'Random Searcher':
-        pass 
-    elif auto_aug_learner == 'Genetic Learner':
-        pass
+    wapp_util.parse_users_learner_spec(
+            auto_aug_learner, 
+            ds, 
+            ds_name, 
+            exclude_method, 
+            num_funcs, 
+            num_policies, 
+            num_sub_policies, 
+            toy_size, 
+            IsLeNet, 
+            batch_size, 
+            early_stop_num, 
+            iterations, 
+            learning_rate, 
+            max_epochs
+            )
 
     return render_template("progress.html", auto_aug_learner=auto_aug_learner)
 
diff --git a/temp_util/parse_ds_cn_arch.py b/temp_util/parse_ds_cn_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..082711acb24e952a3b79a033bfbd58bd73bcb741
--- /dev/null
+++ b/temp_util/parse_ds_cn_arch.py
@@ -0,0 +1,59 @@
+from ..child_networks import *
+from ..main import create_toy, train_child_network
+import torch
+import torchvision.datasets as datasets
+import pickle
+
+def parse_ds_cn_arch(self, ds, ds_name, IsLeNet, transform):
+    # open data and apply these transformations
+    if ds == "MNIST":
+        train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=True, transform=transform)
+        test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=True, transform=transform)
+    elif ds == "KMNIST":
+        train_dataset = datasets.KMNIST(root='./datasets/kmnist/train', train=True, download=True, transform=transform)
+        test_dataset = datasets.KMNIST(root='./datasets/kmnist/test', train=False, download=True, transform=transform)
+    elif ds == "FashionMNIST":
+        train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train', train=True, download=True, transform=transform)
+        test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test', train=False, download=True, transform=transform)
+    elif ds == "CIFAR10":
+        train_dataset = datasets.CIFAR10(root='./datasets/cifar10/train', train=True, download=True, transform=transform)
+        test_dataset = datasets.CIFAR10(root='./datasets/cifar10/test', train=False, download=True, transform=transform)
+    elif ds == "CIFAR100":
+        train_dataset = datasets.CIFAR100(root='./datasets/cifar100/train', train=True, download=True, transform=transform)
+        test_dataset = datasets.CIFAR100(root='./datasets/cifar100/test', train=False, download=True, transform=transform)
+    elif ds == 'Other':
+        dataset = datasets.ImageFolder('./datasets/upload_dataset/'+ ds_name, transform=transform)
+        len_train = int(0.8*len(dataset))
+        train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len_train, len(dataset)-len_train])
+
+        # check sizes of images
+    img_height = len(train_dataset[0][0][0])
+    img_width = len(train_dataset[0][0][0][0])
+    img_channels = len(train_dataset[0][0])
+
+
+        # check output labels
+    if ds == 'Other':
+        num_labels = len(dataset.class_to_idx)
+    elif ds == "CIFAR10" or ds == "CIFAR100":
+        num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1)
+    else:
+        num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1).item()
+
+
+        # create model
+    if torch.cuda.is_available():
+        device='cuda'
+    else:
+        device='cpu'
+        
+    if IsLeNet == "LeNet":
+        model = LeNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)
+    elif IsLeNet == "EasyNet":
+        model = EasyNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)
+    elif IsLeNet == 'SimpleNet':
+        model = SimpleNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)
+    else:
+        model = pickle.load(open(f'datasets/childnetwork', "rb"))
+
+    return train_dataset, test_dataset, model
\ No newline at end of file
diff --git a/temp_util/wapp_util.py b/temp_util/wapp_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..78be118ae9f3143d907cb8b0940bc6283a3e82ac
--- /dev/null
+++ b/temp_util/wapp_util.py
@@ -0,0 +1,136 @@
+"""
+CONTAINS THE FUNTIONS THAT THE WEBAPP CAN USE TO INTERACT WITH
+THE LIBRARY
+"""
+
+import numpy as np
+import torch
+import torchvision
+import torchvision.datasets as datasets
+
+# # import agents and its functions
+import MetaAugment.autoaugment_learners as aal
+import MetaAugment.controller_networks as cont_n
+import MetaAugment.child_networks as cn
+from MetaAugment.main import create_toy
+
+import pickle
+
+def parse_users_learner_spec(
+            auto_aug_learner, 
+            ds, 
+            ds_name, 
+            exclude_method, 
+            num_funcs, 
+            num_policies, 
+            num_sub_policies, 
+            toy_size, 
+            IsLeNet, 
+            batch_size, 
+            early_stop_num, 
+            iterations, 
+            learning_rate, 
+            max_epochs
+            ):
+    """
+    The website receives user inputs on what they want the aa_learner
+    to be. We take those hyperparameters and return an aa_learner
+
+    """
+    if auto_aug_learner == 'UCB':
+        policies = aal.ucb_learner.generate_policies(num_policies, num_sub_policies)
+        q_values, best_q_values = aal.ucb_learner.run_UCB1(
+                                                policies,
+                                                batch_size, 
+                                                learning_rate, 
+                                                ds, 
+                                                toy_size, 
+                                                max_epochs, 
+                                                early_stop_num, 
+                                                iterations, 
+                                                IsLeNet, 
+                                                ds_name
+                                                )     
+        best_q_values = np.array(best_q_values)
+    elif auto_aug_learner == 'Evolutionary Learner':
+        network = cont_n.evo_controller(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1)
+        child_network = cn.LeNet()
+        learner = aal.evo_learner(
+                                network=network, 
+                                fun_num=num_funcs, 
+                                p_bins=1, 
+                                mag_bins=1, 
+                                sub_num_pol=1, 
+                                ds = ds, 
+                                ds_name=ds_name, 
+                                exclude_method=exclude_method, 
+                                child_network=child_network
+                                )
+        learner.run_instance()
+    elif auto_aug_learner == 'Random Searcher':
+            # As opposed to when ucb==True, `ds` and `IsLenet` are processed outside of the agent
+            # This system makes more sense for the user who is not using the webapp and is instead
+            # using the library within their code
+        download = True
+        if ds == "MNIST":
+            train_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/train', train=True, download=download)
+            test_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/test', train=False,
+                                                download=download, transform=torchvision.transforms.ToTensor())
+        elif ds == "KMNIST":
+            train_dataset = datasets.KMNIST(root='./MetaAugment/datasets/kmnist/train', train=True, download=download)
+            test_dataset = datasets.KMNIST(root='./MetaAugment/datasets/kmnist/test', train=False,
+                                                download=download, transform=torchvision.transforms.ToTensor())
+        elif ds == "FashionMNIST":
+            train_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=download)
+            test_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/test', train=False,
+                                                download=download, transform=torchvision.transforms.ToTensor())
+        elif ds == "CIFAR10":
+            train_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/train', train=True, download=download)
+            test_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/test', train=False,
+                                                download=download, transform=torchvision.transforms.ToTensor())
+        elif ds == "CIFAR100":
+            train_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/train', train=True, download=download)
+            test_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/test', train=False,
+                                                download=download, transform=torchvision.transforms.ToTensor())
+        elif ds == 'Other':
+            dataset = datasets.ImageFolder('./MetaAugment/datasets/'+ ds_name)
+            len_train = int(0.8*len(dataset))
+            train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len_train, len(dataset)-len_train])
+
+            # check sizes of images
+        img_height = len(train_dataset[0][0][0])
+        img_width = len(train_dataset[0][0][0][0])
+        img_channels = len(train_dataset[0][0])
+            # check output labels
+        if ds == 'Other':
+            num_labels = len(dataset.class_to_idx)
+        elif ds == "CIFAR10" or ds == "CIFAR100":
+            num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1)
+        else:
+            num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1).item()
+            # create toy dataset from above uploaded data
+        train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size)
+            # create model
+        if IsLeNet == "LeNet":
+            model = cn.LeNet(img_height, img_width, num_labels, img_channels)
+        elif IsLeNet == "EasyNet":
+            model = cn.EasyNet(img_height, img_width, num_labels, img_channels)
+        elif IsLeNet == 'SimpleNet':
+            model = cn.SimpleNet(img_height, img_width, num_labels, img_channels)
+        else:
+            model = pickle.load(open(f'datasets/childnetwork', "rb"))
+
+            # use an aa_learner. in this case, a rs learner
+        agent = aal.randomsearch_learner(batch_size=batch_size,
+                                            toy_flag=True,
+                                            learning_rate=learning_rate,
+                                            toy_size=toy_size,
+                                            max_epochs=max_epochs,
+                                            early_stop_num=early_stop_num,
+                                            )
+        agent.learn(train_dataset,
+                        test_dataset,
+                        child_network_architecture=model,
+                        iterations=iterations)
+    elif auto_aug_learner == 'Genetic Learner':
+        pass
\ No newline at end of file
diff --git a/test/MetaAugment/test_ucb_learner.py b/test/MetaAugment/test_ucb_learner.py
new file mode 100644
index 0000000000000000000000000000000000000000..514d78307eb553afd16521309e4273127f3fa40e
--- /dev/null
+++ b/test/MetaAugment/test_ucb_learner.py
@@ -0,0 +1,25 @@
+import MetaAugment.autoaugment_learners as aal
+import MetaAugment.child_networks as cn
+import torch
+import torchvision
+import torchvision.datasets as datasets
+
+import random
+
+
+def test_ucb_learner():
+    policies = UCB1_JC.generate_policies(num_policies, num_sub_policies)
+        q_values, best_q_values = UCB1_JC.run_UCB1(
+                                                policies,
+                                                batch_size, 
+                                                learning_rate, 
+                                                ds, 
+                                                toy_size, 
+                                                max_epochs, 
+                                                early_stop_num, 
+                                                iterations, 
+                                                IsLeNet, 
+                                                ds_name
+                                                )     
+        best_q_values = np.array(best_q_values)
+    pass
\ No newline at end of file