Skip to content
Snippets Groups Projects
Commit 5e35d5c3 authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

move /autoaugmentation stuff into /flask_mvp/auto_autoaugmentation

parent 6ccb31f8
No related branches found
No related tags found
No related merge requests found
......@@ -18,7 +18,6 @@ from tqdm import trange
torch.manual_seed(0)
# import agents and its functions
from MetaAugment.autoaugment_learners import ucb_learner as UCB1_JC
import MetaAugment.autoaugment_learners as aal
import MetaAugment.controller_networks as cont_n
import MetaAugment.child_networks as cn
......@@ -58,8 +57,8 @@ def response():
if auto_aug_learner == 'UCB':
policies = UCB1_JC.generate_policies(num_policies, num_sub_policies)
q_values, best_q_values = UCB1_JC.run_UCB1(
policies = aal.ucb_learner.generate_policies(num_policies, num_sub_policies)
q_values, best_q_values = aal.ucb_learner.run_UCB1(
policies,
batch_size,
learning_rate,
......@@ -76,7 +75,17 @@ def response():
elif auto_aug_learner == 'Evolutionary Learner':
network = cont_n.evo_controller(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1)
child_network = cn.LeNet()
learner = aal.evo_learner(network=network, fun_num=num_funcs, p_bins=1, mag_bins=1, sub_num_pol=1, ds = ds, ds_name=ds_name, exclude_method=exclude_method, child_network=child_network)
learner = aal.evo_learner(
network=network,
fun_num=num_funcs,
p_bins=1,
mag_bins=1,
sub_num_pol=1,
ds = ds,
ds_name=ds_name,
exclude_method=exclude_method,
child_network=child_network
)
learner.run_instance()
elif auto_aug_learner == 'Random Searcher':
pass
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment