diff --git a/auto_augmentation/templates/intro.html b/flask_mvp/auto_augmentation/templates/intro.html similarity index 100% rename from auto_augmentation/templates/intro.html rename to flask_mvp/auto_augmentation/templates/intro.html diff --git a/auto_augmentation/templates/training.html b/flask_mvp/auto_augmentation/templates/training.html similarity index 100% rename from auto_augmentation/templates/training.html rename to flask_mvp/auto_augmentation/templates/training.html diff --git a/auto_augmentation/training.py b/flask_mvp/auto_augmentation/training.py similarity index 80% rename from auto_augmentation/training.py rename to flask_mvp/auto_augmentation/training.py index 7c124cd97300132deaa360a678d3b8ce4f426765..5e695b58a2994efb1bdc89bb363b3eddf643d9dc 100644 --- a/auto_augmentation/training.py +++ b/flask_mvp/auto_augmentation/training.py @@ -18,7 +18,6 @@ from tqdm import trange torch.manual_seed(0) # import agents and its functions -from MetaAugment.autoaugment_learners import ucb_learner as UCB1_JC import MetaAugment.autoaugment_learners as aal import MetaAugment.controller_networks as cont_n import MetaAugment.child_networks as cn @@ -58,8 +57,8 @@ def response(): if auto_aug_learner == 'UCB': - policies = UCB1_JC.generate_policies(num_policies, num_sub_policies) - q_values, best_q_values = UCB1_JC.run_UCB1( + policies = aal.ucb_learner.generate_policies(num_policies, num_sub_policies) + q_values, best_q_values = aal.ucb_learner.run_UCB1( policies, batch_size, learning_rate, @@ -76,7 +75,17 @@ def response(): elif auto_aug_learner == 'Evolutionary Learner': network = cont_n.evo_controller(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1) child_network = cn.LeNet() - learner = aal.evo_learner(network=network, fun_num=num_funcs, p_bins=1, mag_bins=1, sub_num_pol=1, ds = ds, ds_name=ds_name, exclude_method=exclude_method, child_network=child_network) + learner = aal.evo_learner( + network=network, + fun_num=num_funcs, + p_bins=1, + mag_bins=1, + sub_num_pol=1, + ds = ds, + ds_name=ds_name, + exclude_method=exclude_method, + child_network=child_network + ) learner.run_instance() elif auto_aug_learner == 'Random Searcher': pass