Skip to content
Snippets Groups Projects
Commit e1804cb5 authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

rearrange hyperparameters in webapp .py's

parent a3ebc162
No related branches found
No related tags found
No related merge requests found
Pipeline #271994 passed
......@@ -29,25 +29,30 @@ bp = Blueprint("training", __name__)
@bp.route("/start_training", methods=["GET", "POST"])
def response():
# hyperparameters to change
# auto_aug_learner = session
# aa learner
auto_aug_learner = current_app.config.get('AAL')
# auto_aug_learner = session
# search space & problem setting
ds = current_app.config.get('ds')
ds_name = current_app.config.get('DSN')
exclude_method = current_app.config.get('exc_meth')
num_funcs = current_app.config.get('NUMFUN')
num_policies = current_app.config.get('NP')
num_sub_policies = current_app.config.get('NSP')
batch_size = current_app.config.get('BS')
learning_rate = current_app.config.get('LR')
toy_size = current_app.config.get('TS')
max_epochs = current_app.config.get('ME')
# child network
IsLeNet = current_app.config.get('ISLENET')
# child network training hyperparameters
batch_size = current_app.config.get('BS')
early_stop_num = current_app.config.get('ESN')
iterations = current_app.config.get('IT')
IsLeNet = current_app.config.get('ISLENET')
ds_name = current_app.config.get('DSN')
num_funcs = current_app.config.get('NUMFUN')
ds = current_app.config.get('ds')
exclude_method = current_app.config.get('exc_meth')
learning_rate = current_app.config.get('LR')
max_epochs = current_app.config.get('ME')
if auto_aug_learner == 'UCB':
......
......@@ -138,20 +138,28 @@ def get_form_data():
@app.route('/confirm', methods=['POST', 'GET'])
def confirm():
print('inside confirm')
# aa learner
auto_aug_learner = current_app.config.get('AAL')
# search space & problem setting
ds = current_app.config.get('ds')
ds_name = current_app.config.get('DSN')
exclude_method = current_app.config.get('exc_meth')
num_policies = current_app.config.get('NP')
num_sub_policies = current_app.config.get('NSP')
batch_size = current_app.config.get('BS')
learning_rate = current_app.config.get('LR')
num_funcs = current_app.config.get('NUMFUN')
toy_size = current_app.config.get('TS')
max_epochs = current_app.config.get('ME')
# child network
IsLeNet = current_app.config.get('ISLENET')
# child network training hyperparameters
batch_size = current_app.config.get('BS')
early_stop_num = current_app.config.get('ESN')
iterations = current_app.config.get('IT')
IsLeNet = current_app.config.get('ISLENET')
ds_name = current_app.config.get('DSN')
num_funcs = current_app.config.get('NUMFUN')
ds = current_app.config.get('ds')
exclude_method = current_app.config.get('exc_meth')
learning_rate = current_app.config.get('LR')
max_epochs = current_app.config.get('ME')
data = {'ds': ds, 'ds_name': ds_name, 'IsLeNet': IsLeNet, 'ds_folder.filename': ds_name,
'auto_aug_learner':auto_aug_learner, 'batch_size': batch_size, 'learning_rate': learning_rate,
......@@ -161,20 +169,28 @@ def confirm():
# ========================================================================
@app.route('/training', methods=['POST', 'GET'])
def training():
# aa learner
auto_aug_learner = current_app.config.get('AAL')
# search space & problem setting
ds = current_app.config.get('ds')
ds_name = current_app.config.get('DSN')
exclude_method = current_app.config.get('exc_meth')
num_funcs = current_app.config.get('NUMFUN')
num_policies = current_app.config.get('NP')
num_sub_policies = current_app.config.get('NSP')
batch_size = current_app.config.get('BS')
learning_rate = current_app.config.get('LR')
toy_size = current_app.config.get('TS')
max_epochs = current_app.config.get('ME')
# child network
IsLeNet = current_app.config.get('ISLENET')
# child network training hyperparameters
batch_size = current_app.config.get('BS')
early_stop_num = current_app.config.get('ESN')
iterations = current_app.config.get('IT')
IsLeNet = current_app.config.get('ISLENET')
ds_name = current_app.config.get('DSN')
num_funcs = current_app.config.get('NUMFUN')
ds = current_app.config.get('ds')
exclude_method = current_app.config.get('exc_meth')
learning_rate = current_app.config.get('LR')
max_epochs = current_app.config.get('ME')
if auto_aug_learner == 'UCB':
......
......@@ -18,7 +18,7 @@ from tqdm import trange
torch.manual_seed(0)
# import agents and its functions
from MetaAugment import UCB1_JC_py as UCB1_JC
from MetaAugment.autoaugment_learners import ucb_learner
from MetaAugment import Evo_learner as Evo
......@@ -38,21 +38,28 @@ def response():
if request.method == 'POST':
# generate random policies at start
auto_aug_learner = request.form.get("auto_aug_selection")
# search space & problem setting
ds = request.form.get("dataset_selection") # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)
ds_up = request.files['dataset_upload']
exclude_method = request.form.getlist("action_space")
num_funcs = 14 - len(exclude_method)
num_policies = 5 # fix number of policies
num_sub_policies = 5 # fix number of sub-policies in a policy
toy_size = 1 # total propeortion of training and test set we use
batch_size = 1 # size of batch the inner NN is trained with
learning_rate = 1e-1 # fix learning rate
ds = request.form.get("dataset_selection") # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)
ds_up = request.files['dataset_upload']
# child network
IsLeNet = request.form.get("network_selection") # using LeNet or EasyNet or SimpleNet ->> default
nw_up = childnetwork = request.files['network_upload']
toy_size = 1 # total propeortion of training and test set we use
max_epochs = 10 # max number of epochs that is run if early stopping is not hit
# child network training hyperparameters
batch_size = 1 # size of batch the inner NN is trained with
early_stop_num = 10 # max number of worse validation scores before early stopping is triggered
num_policies = 5 # fix number of policies
num_sub_policies = 5 # fix number of sub-policies in a policy
iterations = 5 # total iterations, should be more than the number of policies
IsLeNet = request.form.get("network_selection") # using LeNet or EasyNet or SimpleNet ->> default
learning_rate = 1e-1 # fix learning rate
max_epochs = 10 # max number of epochs that is run if early stopping is not hit
# if user upload datasets and networks, save them in the database
......@@ -83,16 +90,15 @@ def response():
childnetwork = request.files['network_upload']
childnetwork.save('./MetaAugment/child_networks/'+childnetwork.filename)
# generate random policies at start
auto_aug_leanrer = request.form.get("auto_aug_selection")
if auto_aug_leanrer == 'UCB':
policies = UCB1_JC.generate_policies(num_policies, num_sub_policies)
q_values, best_q_values = UCB1_JC.run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet, ds_name)
elif auto_aug_leanrer == 'Evolutionary Learner':
if auto_aug_learner == 'UCB':
policies = ucb_learner.generate_policies(num_policies, num_sub_policies)
q_values, best_q_values = ucb_learner.run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet, ds_name)
elif auto_aug_learner == 'Evolutionary Learner':
learner = Evo.Evolutionary_learner(fun_num=num_funcs, p_bins=1, mag_bins=1, sub_num_pol=1, ds_name=ds_name, exclude_method=exclude_method)
learner.run_instance()
elif auto_aug_leanrer == 'Random Searcher':
elif auto_aug_learner == 'Random Searcher':
# As opposed to when ucb==True, `ds` and `IsLenet` are processed outside of the agent
# This system makes more sense for the user who is not using the webapp and is instead
# using the library within their code
......@@ -157,7 +163,7 @@ def response():
test_dataset,
child_network_architecture=model,
iterations=iterations)
elif auto_aug_leanrer == 'Genetic Learner':
elif auto_aug_learner == 'Genetic Learner':
pass
plt.figure()
......@@ -165,8 +171,8 @@ def response():
# if auto_aug_learner == 'UCB':
# policies = UCB1_JC.generate_policies(num_policies, num_sub_policies)
# q_values, best_q_values = UCB1_JC.run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet, ds_name)
# policies = ucb_learner.generate_policies(num_policies, num_sub_policies)
# q_values, best_q_values = ucb_learner.run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet, ds_name)
# # plt.figure()
# # plt.plot(q_values)
# best_q_values = np.array(best_q_values)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment