diff --git a/auto_augmentation/training.py b/auto_augmentation/training.py index 7eaac2f322cdbb5c3444bbc552b1b0340b23ab78..b8b3521a52c8806aca3830da0d838c8266a8f801 100644 --- a/auto_augmentation/training.py +++ b/auto_augmentation/training.py @@ -29,25 +29,30 @@ bp = Blueprint("training", __name__) @bp.route("/start_training", methods=["GET", "POST"]) def response(): - # hyperparameters to change - # auto_aug_learner = session + # aa learner auto_aug_learner = current_app.config.get('AAL') + # auto_aug_learner = session + # search space & problem setting + ds = current_app.config.get('ds') + ds_name = current_app.config.get('DSN') + exclude_method = current_app.config.get('exc_meth') + num_funcs = current_app.config.get('NUMFUN') num_policies = current_app.config.get('NP') num_sub_policies = current_app.config.get('NSP') - batch_size = current_app.config.get('BS') - learning_rate = current_app.config.get('LR') toy_size = current_app.config.get('TS') - max_epochs = current_app.config.get('ME') + + # child network + IsLeNet = current_app.config.get('ISLENET') + + # child network training hyperparameters + batch_size = current_app.config.get('BS') early_stop_num = current_app.config.get('ESN') iterations = current_app.config.get('IT') - IsLeNet = current_app.config.get('ISLENET') - ds_name = current_app.config.get('DSN') - num_funcs = current_app.config.get('NUMFUN') - ds = current_app.config.get('ds') - exclude_method = current_app.config.get('exc_meth') + learning_rate = current_app.config.get('LR') + max_epochs = current_app.config.get('ME') if auto_aug_learner == 'UCB': diff --git a/backend_react/react_app.py b/backend_react/react_app.py index 121760dd5dd807f5fa3190d6dc254dd7f5649937..21f5e8a2a9ae99d4b058931f2e912aa116d7ee0b 100644 --- a/backend_react/react_app.py +++ b/backend_react/react_app.py @@ -138,20 +138,28 @@ def get_form_data(): @app.route('/confirm', methods=['POST', 'GET']) def confirm(): print('inside confirm') + + # aa learner auto_aug_learner = current_app.config.get('AAL') + + # search space & problem setting + ds = current_app.config.get('ds') + ds_name = current_app.config.get('DSN') + exclude_method = current_app.config.get('exc_meth') num_policies = current_app.config.get('NP') num_sub_policies = current_app.config.get('NSP') - batch_size = current_app.config.get('BS') - learning_rate = current_app.config.get('LR') + num_funcs = current_app.config.get('NUMFUN') toy_size = current_app.config.get('TS') - max_epochs = current_app.config.get('ME') + + # child network + IsLeNet = current_app.config.get('ISLENET') + + # child network training hyperparameters + batch_size = current_app.config.get('BS') early_stop_num = current_app.config.get('ESN') iterations = current_app.config.get('IT') - IsLeNet = current_app.config.get('ISLENET') - ds_name = current_app.config.get('DSN') - num_funcs = current_app.config.get('NUMFUN') - ds = current_app.config.get('ds') - exclude_method = current_app.config.get('exc_meth') + learning_rate = current_app.config.get('LR') + max_epochs = current_app.config.get('ME') data = {'ds': ds, 'ds_name': ds_name, 'IsLeNet': IsLeNet, 'ds_folder.filename': ds_name, 'auto_aug_learner':auto_aug_learner, 'batch_size': batch_size, 'learning_rate': learning_rate, @@ -161,20 +169,28 @@ def confirm(): # ======================================================================== @app.route('/training', methods=['POST', 'GET']) def training(): + + # aa learner auto_aug_learner = current_app.config.get('AAL') + + # search space & problem setting + ds = current_app.config.get('ds') + ds_name = current_app.config.get('DSN') + exclude_method = current_app.config.get('exc_meth') + num_funcs = current_app.config.get('NUMFUN') num_policies = current_app.config.get('NP') num_sub_policies = current_app.config.get('NSP') - batch_size = current_app.config.get('BS') - learning_rate = current_app.config.get('LR') toy_size = current_app.config.get('TS') - max_epochs = current_app.config.get('ME') + + # child network + IsLeNet = current_app.config.get('ISLENET') + + # child network training hyperparameters + batch_size = current_app.config.get('BS') early_stop_num = current_app.config.get('ESN') iterations = current_app.config.get('IT') - IsLeNet = current_app.config.get('ISLENET') - ds_name = current_app.config.get('DSN') - num_funcs = current_app.config.get('NUMFUN') - ds = current_app.config.get('ds') - exclude_method = current_app.config.get('exc_meth') + learning_rate = current_app.config.get('LR') + max_epochs = current_app.config.get('ME') if auto_aug_learner == 'UCB': diff --git a/flask_mvp/auto_augmentation/progress.py b/flask_mvp/auto_augmentation/progress.py index abe15fe35fb226fe94c30d513e5634b22a72c30a..ac51b8d807cb2ad9c7c37020a7073a408020dac5 100644 --- a/flask_mvp/auto_augmentation/progress.py +++ b/flask_mvp/auto_augmentation/progress.py @@ -18,7 +18,7 @@ from tqdm import trange torch.manual_seed(0) # import agents and its functions -from MetaAugment import UCB1_JC_py as UCB1_JC +from MetaAugment.autoaugment_learners import ucb_learner from MetaAugment import Evo_learner as Evo @@ -38,21 +38,28 @@ def response(): if request.method == 'POST': + # generate random policies at start + auto_aug_learner = request.form.get("auto_aug_selection") + + # search space & problem setting + ds = request.form.get("dataset_selection") # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100) + ds_up = request.files['dataset_upload'] exclude_method = request.form.getlist("action_space") num_funcs = 14 - len(exclude_method) + num_policies = 5 # fix number of policies + num_sub_policies = 5 # fix number of sub-policies in a policy + toy_size = 1 # total propeortion of training and test set we use - batch_size = 1 # size of batch the inner NN is trained with - learning_rate = 1e-1 # fix learning rate - ds = request.form.get("dataset_selection") # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100) - ds_up = request.files['dataset_upload'] + # child network + IsLeNet = request.form.get("network_selection") # using LeNet or EasyNet or SimpleNet ->> default nw_up = childnetwork = request.files['network_upload'] - toy_size = 1 # total propeortion of training and test set we use - max_epochs = 10 # max number of epochs that is run if early stopping is not hit + + # child network training hyperparameters + batch_size = 1 # size of batch the inner NN is trained with early_stop_num = 10 # max number of worse validation scores before early stopping is triggered - num_policies = 5 # fix number of policies - num_sub_policies = 5 # fix number of sub-policies in a policy iterations = 5 # total iterations, should be more than the number of policies - IsLeNet = request.form.get("network_selection") # using LeNet or EasyNet or SimpleNet ->> default + learning_rate = 1e-1 # fix learning rate + max_epochs = 10 # max number of epochs that is run if early stopping is not hit # if user upload datasets and networks, save them in the database @@ -83,16 +90,15 @@ def response(): childnetwork = request.files['network_upload'] childnetwork.save('./MetaAugment/child_networks/'+childnetwork.filename) - # generate random policies at start - auto_aug_leanrer = request.form.get("auto_aug_selection") - if auto_aug_leanrer == 'UCB': - policies = UCB1_JC.generate_policies(num_policies, num_sub_policies) - q_values, best_q_values = UCB1_JC.run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet, ds_name) - elif auto_aug_leanrer == 'Evolutionary Learner': + + if auto_aug_learner == 'UCB': + policies = ucb_learner.generate_policies(num_policies, num_sub_policies) + q_values, best_q_values = ucb_learner.run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet, ds_name) + elif auto_aug_learner == 'Evolutionary Learner': learner = Evo.Evolutionary_learner(fun_num=num_funcs, p_bins=1, mag_bins=1, sub_num_pol=1, ds_name=ds_name, exclude_method=exclude_method) learner.run_instance() - elif auto_aug_leanrer == 'Random Searcher': + elif auto_aug_learner == 'Random Searcher': # As opposed to when ucb==True, `ds` and `IsLenet` are processed outside of the agent # This system makes more sense for the user who is not using the webapp and is instead # using the library within their code @@ -157,7 +163,7 @@ def response(): test_dataset, child_network_architecture=model, iterations=iterations) - elif auto_aug_leanrer == 'Genetic Learner': + elif auto_aug_learner == 'Genetic Learner': pass plt.figure() @@ -165,8 +171,8 @@ def response(): # if auto_aug_learner == 'UCB': - # policies = UCB1_JC.generate_policies(num_policies, num_sub_policies) - # q_values, best_q_values = UCB1_JC.run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet, ds_name) + # policies = ucb_learner.generate_policies(num_policies, num_sub_policies) + # q_values, best_q_values = ucb_learner.run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet, ds_name) # # plt.figure() # # plt.plot(q_values) # best_q_values = np.array(best_q_values)