Skip to content
Snippets Groups Projects
Commit ddca7ae2 authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

fix training.py evo_learner call

parent 8b36ecef
No related branches found
No related tags found
No related merge requests found
...@@ -19,7 +19,9 @@ torch.manual_seed(0) ...@@ -19,7 +19,9 @@ torch.manual_seed(0)
# import agents and its functions # import agents and its functions
from MetaAugment.autoaugment_learners import ucb_learner as UCB1_JC from MetaAugment.autoaugment_learners import ucb_learner as UCB1_JC
from MetaAugment import Evo_learner as Evo import MetaAugment.autoaugment_learners as aal
import MetaAugment.controller_networks as cont_n
import MetaAugment.child_networks as cn
...@@ -57,13 +59,24 @@ def response(): ...@@ -57,13 +59,24 @@ def response():
if auto_aug_learner == 'UCB': if auto_aug_learner == 'UCB':
policies = UCB1_JC.generate_policies(num_policies, num_sub_policies) policies = UCB1_JC.generate_policies(num_policies, num_sub_policies)
q_values, best_q_values = UCB1_JC.run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet, ds_name) q_values, best_q_values = UCB1_JC.run_UCB1(
policies,
batch_size,
learning_rate,
ds,
toy_size,
max_epochs,
early_stop_num,
iterations,
IsLeNet,
ds_name
)
best_q_values = np.array(best_q_values) best_q_values = np.array(best_q_values)
elif auto_aug_learner == 'Evolutionary Learner': elif auto_aug_learner == 'Evolutionary Learner':
network = Evo.Learner(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1) network = cont_n.evo_controller(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1)
child_network = Evo.LeNet() child_network = cn.LeNet()
learner = Evo.Evolutionary_learner(network=network, fun_num=num_funcs, p_bins=1, mag_bins=1, sub_num_pol=1, ds = ds, ds_name=ds_name, exclude_method=exclude_method, child_network=child_network) learner = aal.evo_learner(network=network, fun_num=num_funcs, p_bins=1, mag_bins=1, sub_num_pol=1, ds = ds, ds_name=ds_name, exclude_method=exclude_method, child_network=child_network)
learner.run_instance() learner.run_instance()
elif auto_aug_learner == 'Random Searcher': elif auto_aug_learner == 'Random Searcher':
pass pass
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment