From 5e35d5c3b7a73e272674344d383bb1cbd0ed48c0 Mon Sep 17 00:00:00 2001
From: Sun Jin Kim <sk2521@ic.ac.uk>
Date: Mon, 25 Apr 2022 10:10:51 +0100
Subject: [PATCH] move /autoaugmentation stuff into
 /flask_mvp/auto_autoaugmentation

---
 .../auto_augmentation}/templates/intro.html     |  0
 .../auto_augmentation}/templates/training.html  |  0
 .../auto_augmentation}/training.py              | 17 +++++++++++++----
 3 files changed, 13 insertions(+), 4 deletions(-)
 rename {auto_augmentation => flask_mvp/auto_augmentation}/templates/intro.html (100%)
 rename {auto_augmentation => flask_mvp/auto_augmentation}/templates/training.html (100%)
 rename {auto_augmentation => flask_mvp/auto_augmentation}/training.py (80%)

diff --git a/auto_augmentation/templates/intro.html b/flask_mvp/auto_augmentation/templates/intro.html
similarity index 100%
rename from auto_augmentation/templates/intro.html
rename to flask_mvp/auto_augmentation/templates/intro.html
diff --git a/auto_augmentation/templates/training.html b/flask_mvp/auto_augmentation/templates/training.html
similarity index 100%
rename from auto_augmentation/templates/training.html
rename to flask_mvp/auto_augmentation/templates/training.html
diff --git a/auto_augmentation/training.py b/flask_mvp/auto_augmentation/training.py
similarity index 80%
rename from auto_augmentation/training.py
rename to flask_mvp/auto_augmentation/training.py
index 7c124cd9..5e695b58 100644
--- a/auto_augmentation/training.py
+++ b/flask_mvp/auto_augmentation/training.py
@@ -18,7 +18,6 @@ from tqdm import trange
 torch.manual_seed(0)
 # import agents and its functions
 
-from MetaAugment.autoaugment_learners import ucb_learner as UCB1_JC
 import MetaAugment.autoaugment_learners as aal
 import MetaAugment.controller_networks as cont_n
 import MetaAugment.child_networks as cn
@@ -58,8 +57,8 @@ def response():
 
 
     if auto_aug_learner == 'UCB':
-        policies = UCB1_JC.generate_policies(num_policies, num_sub_policies)
-        q_values, best_q_values = UCB1_JC.run_UCB1(
+        policies = aal.ucb_learner.generate_policies(num_policies, num_sub_policies)
+        q_values, best_q_values = aal.ucb_learner.run_UCB1(
                                                 policies, 
                                                 batch_size, 
                                                 learning_rate, 
@@ -76,7 +75,17 @@ def response():
     elif auto_aug_learner == 'Evolutionary Learner':
         network = cont_n.evo_controller(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1)
         child_network = cn.LeNet()
-        learner = aal.evo_learner(network=network, fun_num=num_funcs, p_bins=1, mag_bins=1, sub_num_pol=1, ds = ds, ds_name=ds_name, exclude_method=exclude_method, child_network=child_network)
+        learner = aal.evo_learner(
+                                network=network, 
+                                fun_num=num_funcs, 
+                                p_bins=1, 
+                                mag_bins=1, 
+                                sub_num_pol=1, 
+                                ds = ds, 
+                                ds_name=ds_name, 
+                                exclude_method=exclude_method, 
+                                child_network=child_network
+                                )
         learner.run_instance()
     elif auto_aug_learner == 'Random Searcher':
         pass 
-- 
GitLab