Skip to content
Snippets Groups Projects
Commit e1804cb5 authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

rearrange hyperparameters in webapp .py's

parent a3ebc162
No related branches found
No related tags found
No related merge requests found
Pipeline #271994 passed
...@@ -29,25 +29,30 @@ bp = Blueprint("training", __name__) ...@@ -29,25 +29,30 @@ bp = Blueprint("training", __name__)
@bp.route("/start_training", methods=["GET", "POST"]) @bp.route("/start_training", methods=["GET", "POST"])
def response(): def response():
# hyperparameters to change
# auto_aug_learner = session
# aa learner
auto_aug_learner = current_app.config.get('AAL') auto_aug_learner = current_app.config.get('AAL')
# auto_aug_learner = session
# search space & problem setting
ds = current_app.config.get('ds')
ds_name = current_app.config.get('DSN')
exclude_method = current_app.config.get('exc_meth')
num_funcs = current_app.config.get('NUMFUN')
num_policies = current_app.config.get('NP') num_policies = current_app.config.get('NP')
num_sub_policies = current_app.config.get('NSP') num_sub_policies = current_app.config.get('NSP')
batch_size = current_app.config.get('BS')
learning_rate = current_app.config.get('LR')
toy_size = current_app.config.get('TS') toy_size = current_app.config.get('TS')
max_epochs = current_app.config.get('ME')
# child network
IsLeNet = current_app.config.get('ISLENET')
# child network training hyperparameters
batch_size = current_app.config.get('BS')
early_stop_num = current_app.config.get('ESN') early_stop_num = current_app.config.get('ESN')
iterations = current_app.config.get('IT') iterations = current_app.config.get('IT')
IsLeNet = current_app.config.get('ISLENET') learning_rate = current_app.config.get('LR')
ds_name = current_app.config.get('DSN') max_epochs = current_app.config.get('ME')
num_funcs = current_app.config.get('NUMFUN')
ds = current_app.config.get('ds')
exclude_method = current_app.config.get('exc_meth')
if auto_aug_learner == 'UCB': if auto_aug_learner == 'UCB':
......
...@@ -138,20 +138,28 @@ def get_form_data(): ...@@ -138,20 +138,28 @@ def get_form_data():
@app.route('/confirm', methods=['POST', 'GET']) @app.route('/confirm', methods=['POST', 'GET'])
def confirm(): def confirm():
print('inside confirm') print('inside confirm')
# aa learner
auto_aug_learner = current_app.config.get('AAL') auto_aug_learner = current_app.config.get('AAL')
# search space & problem setting
ds = current_app.config.get('ds')
ds_name = current_app.config.get('DSN')
exclude_method = current_app.config.get('exc_meth')
num_policies = current_app.config.get('NP') num_policies = current_app.config.get('NP')
num_sub_policies = current_app.config.get('NSP') num_sub_policies = current_app.config.get('NSP')
batch_size = current_app.config.get('BS') num_funcs = current_app.config.get('NUMFUN')
learning_rate = current_app.config.get('LR')
toy_size = current_app.config.get('TS') toy_size = current_app.config.get('TS')
max_epochs = current_app.config.get('ME')
# child network
IsLeNet = current_app.config.get('ISLENET')
# child network training hyperparameters
batch_size = current_app.config.get('BS')
early_stop_num = current_app.config.get('ESN') early_stop_num = current_app.config.get('ESN')
iterations = current_app.config.get('IT') iterations = current_app.config.get('IT')
IsLeNet = current_app.config.get('ISLENET') learning_rate = current_app.config.get('LR')
ds_name = current_app.config.get('DSN') max_epochs = current_app.config.get('ME')
num_funcs = current_app.config.get('NUMFUN')
ds = current_app.config.get('ds')
exclude_method = current_app.config.get('exc_meth')
data = {'ds': ds, 'ds_name': ds_name, 'IsLeNet': IsLeNet, 'ds_folder.filename': ds_name, data = {'ds': ds, 'ds_name': ds_name, 'IsLeNet': IsLeNet, 'ds_folder.filename': ds_name,
'auto_aug_learner':auto_aug_learner, 'batch_size': batch_size, 'learning_rate': learning_rate, 'auto_aug_learner':auto_aug_learner, 'batch_size': batch_size, 'learning_rate': learning_rate,
...@@ -161,20 +169,28 @@ def confirm(): ...@@ -161,20 +169,28 @@ def confirm():
# ======================================================================== # ========================================================================
@app.route('/training', methods=['POST', 'GET']) @app.route('/training', methods=['POST', 'GET'])
def training(): def training():
# aa learner
auto_aug_learner = current_app.config.get('AAL') auto_aug_learner = current_app.config.get('AAL')
# search space & problem setting
ds = current_app.config.get('ds')
ds_name = current_app.config.get('DSN')
exclude_method = current_app.config.get('exc_meth')
num_funcs = current_app.config.get('NUMFUN')
num_policies = current_app.config.get('NP') num_policies = current_app.config.get('NP')
num_sub_policies = current_app.config.get('NSP') num_sub_policies = current_app.config.get('NSP')
batch_size = current_app.config.get('BS')
learning_rate = current_app.config.get('LR')
toy_size = current_app.config.get('TS') toy_size = current_app.config.get('TS')
max_epochs = current_app.config.get('ME')
# child network
IsLeNet = current_app.config.get('ISLENET')
# child network training hyperparameters
batch_size = current_app.config.get('BS')
early_stop_num = current_app.config.get('ESN') early_stop_num = current_app.config.get('ESN')
iterations = current_app.config.get('IT') iterations = current_app.config.get('IT')
IsLeNet = current_app.config.get('ISLENET') learning_rate = current_app.config.get('LR')
ds_name = current_app.config.get('DSN') max_epochs = current_app.config.get('ME')
num_funcs = current_app.config.get('NUMFUN')
ds = current_app.config.get('ds')
exclude_method = current_app.config.get('exc_meth')
if auto_aug_learner == 'UCB': if auto_aug_learner == 'UCB':
......
...@@ -18,7 +18,7 @@ from tqdm import trange ...@@ -18,7 +18,7 @@ from tqdm import trange
torch.manual_seed(0) torch.manual_seed(0)
# import agents and its functions # import agents and its functions
from MetaAugment import UCB1_JC_py as UCB1_JC from MetaAugment.autoaugment_learners import ucb_learner
from MetaAugment import Evo_learner as Evo from MetaAugment import Evo_learner as Evo
...@@ -38,21 +38,28 @@ def response(): ...@@ -38,21 +38,28 @@ def response():
if request.method == 'POST': if request.method == 'POST':
# generate random policies at start
auto_aug_learner = request.form.get("auto_aug_selection")
# search space & problem setting
ds = request.form.get("dataset_selection") # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)
ds_up = request.files['dataset_upload']
exclude_method = request.form.getlist("action_space") exclude_method = request.form.getlist("action_space")
num_funcs = 14 - len(exclude_method) num_funcs = 14 - len(exclude_method)
num_policies = 5 # fix number of policies
num_sub_policies = 5 # fix number of sub-policies in a policy
toy_size = 1 # total propeortion of training and test set we use
batch_size = 1 # size of batch the inner NN is trained with # child network
learning_rate = 1e-1 # fix learning rate IsLeNet = request.form.get("network_selection") # using LeNet or EasyNet or SimpleNet ->> default
ds = request.form.get("dataset_selection") # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)
ds_up = request.files['dataset_upload']
nw_up = childnetwork = request.files['network_upload'] nw_up = childnetwork = request.files['network_upload']
toy_size = 1 # total propeortion of training and test set we use
max_epochs = 10 # max number of epochs that is run if early stopping is not hit # child network training hyperparameters
batch_size = 1 # size of batch the inner NN is trained with
early_stop_num = 10 # max number of worse validation scores before early stopping is triggered early_stop_num = 10 # max number of worse validation scores before early stopping is triggered
num_policies = 5 # fix number of policies
num_sub_policies = 5 # fix number of sub-policies in a policy
iterations = 5 # total iterations, should be more than the number of policies iterations = 5 # total iterations, should be more than the number of policies
IsLeNet = request.form.get("network_selection") # using LeNet or EasyNet or SimpleNet ->> default learning_rate = 1e-1 # fix learning rate
max_epochs = 10 # max number of epochs that is run if early stopping is not hit
# if user upload datasets and networks, save them in the database # if user upload datasets and networks, save them in the database
...@@ -83,16 +90,15 @@ def response(): ...@@ -83,16 +90,15 @@ def response():
childnetwork = request.files['network_upload'] childnetwork = request.files['network_upload']
childnetwork.save('./MetaAugment/child_networks/'+childnetwork.filename) childnetwork.save('./MetaAugment/child_networks/'+childnetwork.filename)
# generate random policies at start
auto_aug_leanrer = request.form.get("auto_aug_selection")
if auto_aug_leanrer == 'UCB':
policies = UCB1_JC.generate_policies(num_policies, num_sub_policies) if auto_aug_learner == 'UCB':
q_values, best_q_values = UCB1_JC.run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet, ds_name) policies = ucb_learner.generate_policies(num_policies, num_sub_policies)
elif auto_aug_leanrer == 'Evolutionary Learner': q_values, best_q_values = ucb_learner.run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet, ds_name)
elif auto_aug_learner == 'Evolutionary Learner':
learner = Evo.Evolutionary_learner(fun_num=num_funcs, p_bins=1, mag_bins=1, sub_num_pol=1, ds_name=ds_name, exclude_method=exclude_method) learner = Evo.Evolutionary_learner(fun_num=num_funcs, p_bins=1, mag_bins=1, sub_num_pol=1, ds_name=ds_name, exclude_method=exclude_method)
learner.run_instance() learner.run_instance()
elif auto_aug_leanrer == 'Random Searcher': elif auto_aug_learner == 'Random Searcher':
# As opposed to when ucb==True, `ds` and `IsLenet` are processed outside of the agent # As opposed to when ucb==True, `ds` and `IsLenet` are processed outside of the agent
# This system makes more sense for the user who is not using the webapp and is instead # This system makes more sense for the user who is not using the webapp and is instead
# using the library within their code # using the library within their code
...@@ -157,7 +163,7 @@ def response(): ...@@ -157,7 +163,7 @@ def response():
test_dataset, test_dataset,
child_network_architecture=model, child_network_architecture=model,
iterations=iterations) iterations=iterations)
elif auto_aug_leanrer == 'Genetic Learner': elif auto_aug_learner == 'Genetic Learner':
pass pass
plt.figure() plt.figure()
...@@ -165,8 +171,8 @@ def response(): ...@@ -165,8 +171,8 @@ def response():
# if auto_aug_learner == 'UCB': # if auto_aug_learner == 'UCB':
# policies = UCB1_JC.generate_policies(num_policies, num_sub_policies) # policies = ucb_learner.generate_policies(num_policies, num_sub_policies)
# q_values, best_q_values = UCB1_JC.run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet, ds_name) # q_values, best_q_values = ucb_learner.run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet, ds_name)
# # plt.figure() # # plt.figure()
# # plt.plot(q_values) # # plt.plot(q_values)
# best_q_values = np.array(best_q_values) # best_q_values = np.array(best_q_values)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment