From 5e35d5c3b7a73e272674344d383bb1cbd0ed48c0 Mon Sep 17 00:00:00 2001 From: Sun Jin Kim <sk2521@ic.ac.uk> Date: Mon, 25 Apr 2022 10:10:51 +0100 Subject: [PATCH] move /autoaugmentation stuff into /flask_mvp/auto_autoaugmentation --- .../auto_augmentation}/templates/intro.html | 0 .../auto_augmentation}/templates/training.html | 0 .../auto_augmentation}/training.py | 17 +++++++++++++---- 3 files changed, 13 insertions(+), 4 deletions(-) rename {auto_augmentation => flask_mvp/auto_augmentation}/templates/intro.html (100%) rename {auto_augmentation => flask_mvp/auto_augmentation}/templates/training.html (100%) rename {auto_augmentation => flask_mvp/auto_augmentation}/training.py (80%) diff --git a/auto_augmentation/templates/intro.html b/flask_mvp/auto_augmentation/templates/intro.html similarity index 100% rename from auto_augmentation/templates/intro.html rename to flask_mvp/auto_augmentation/templates/intro.html diff --git a/auto_augmentation/templates/training.html b/flask_mvp/auto_augmentation/templates/training.html similarity index 100% rename from auto_augmentation/templates/training.html rename to flask_mvp/auto_augmentation/templates/training.html diff --git a/auto_augmentation/training.py b/flask_mvp/auto_augmentation/training.py similarity index 80% rename from auto_augmentation/training.py rename to flask_mvp/auto_augmentation/training.py index 7c124cd9..5e695b58 100644 --- a/auto_augmentation/training.py +++ b/flask_mvp/auto_augmentation/training.py @@ -18,7 +18,6 @@ from tqdm import trange torch.manual_seed(0) # import agents and its functions -from MetaAugment.autoaugment_learners import ucb_learner as UCB1_JC import MetaAugment.autoaugment_learners as aal import MetaAugment.controller_networks as cont_n import MetaAugment.child_networks as cn @@ -58,8 +57,8 @@ def response(): if auto_aug_learner == 'UCB': - policies = UCB1_JC.generate_policies(num_policies, num_sub_policies) - q_values, best_q_values = UCB1_JC.run_UCB1( + policies = aal.ucb_learner.generate_policies(num_policies, num_sub_policies) + q_values, best_q_values = aal.ucb_learner.run_UCB1( policies, batch_size, learning_rate, @@ -76,7 +75,17 @@ def response(): elif auto_aug_learner == 'Evolutionary Learner': network = cont_n.evo_controller(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1) child_network = cn.LeNet() - learner = aal.evo_learner(network=network, fun_num=num_funcs, p_bins=1, mag_bins=1, sub_num_pol=1, ds = ds, ds_name=ds_name, exclude_method=exclude_method, child_network=child_network) + learner = aal.evo_learner( + network=network, + fun_num=num_funcs, + p_bins=1, + mag_bins=1, + sub_num_pol=1, + ds = ds, + ds_name=ds_name, + exclude_method=exclude_method, + child_network=child_network + ) learner.run_instance() elif auto_aug_learner == 'Random Searcher': pass -- GitLab