Skip to content
Snippets Groups Projects
Commit 9c4e0173 authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

fix imports

parent 98be9246
No related branches found
No related tags found
No related merge requests found
......@@ -24,10 +24,9 @@ import sys
sys.path.insert(0, os.path.abspath('..'))
# # import agents and its functions
from MetaAugment.autoaugment_learners import ucb_learner as UCB1_JC
from MetaAugment.autoaugment_learners import evo_learner
import MetaAugment.controller_networks as cn
import MetaAugment.autoaugment_learners as aal
import MetaAugment.controller_networks as cont_n
import MetaAugment.child_networks as cn
print('@@@ import successful')
# import agents and its functions
......@@ -194,8 +193,8 @@ def training():
if auto_aug_learner == 'UCB':
policies = UCB1_JC.generate_policies(num_policies, num_sub_policies)
q_values, best_q_values = UCB1_JC.run_UCB1(
policies = aal.ucb_learner.generate_policies(num_policies, num_sub_policies)
q_values, best_q_values = aal.ucb_learner.run_UCB1(
policies,
batch_size,
learning_rate,
......@@ -210,21 +209,19 @@ def training():
best_q_values = np.array(best_q_values)
elif auto_aug_learner == 'Evolutionary Learner':
network = cn.evo_controller.evo_controller(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1)
child_network = aal.evo.LeNet()
learner = aal.evo.evo_learner(
network=network,
fun_num=num_funcs,
p_bins=1,
mag_bins=1,
sub_num_pol=1,
ds = ds,
ds_name=ds_name,
exclude_method=exclude_method,
child_network=child_network
)
network = cont_n.evo_controller(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1)
child_network = cn.LeNet()
learner = aal.evo_learner(
network=network,
fun_num=num_funcs,
p_bins=1,
mag_bins=1,
sub_num_pol=1,
ds = ds,
ds_name=ds_name,
exclude_method=exclude_method,
child_network=child_network
)
learner.run_instance()
elif auto_aug_learner == 'Random Searcher':
pass
......
......@@ -3,7 +3,8 @@
# app.run(host='0.0.0.0',port=port)
from numpy import broadcast
from auto_augmentation import home, progress,result, training
from auto_augmentation import home, progress,result
from flask_mvp.auto_augmentation import training
from flask_socketio import SocketIO, send
from flask import Flask, flash, request, redirect, url_for
......
......@@ -2,7 +2,8 @@ import os
from flask import Flask, render_template, request, flash
from auto_augmentation import home, progress,result, training
from auto_augmentation import home, progress,result
from flask_mvp.auto_augmentation import training
def create_app(test_config=None):
# create and configure the app
......
......@@ -94,9 +94,28 @@ def response():
if auto_aug_learner == 'UCB':
policies = ucb_learner.generate_policies(num_policies, num_sub_policies)
q_values, best_q_values = ucb_learner.run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet, ds_name)
q_values, best_q_values = ucb_learner.run_UCB1(
policies,
batch_size,
learning_rate,
ds,
toy_size,
max_epochs,
early_stop_num,
iterations,
IsLeNet,
ds_name
)
best_q_values = np.array(best_q_values)
elif auto_aug_learner == 'Evolutionary Learner':
learner = Evo.Evolutionary_learner(fun_num=num_funcs, p_bins=1, mag_bins=1, sub_num_pol=1, ds_name=ds_name, exclude_method=exclude_method)
learner = Evo.Evolutionary_learner(
fun_num=num_funcs,
p_bins=1,
mag_bins=1,
sub_num_pol=1,
ds_name=ds_name,
exclude_method=exclude_method
)
learner.run_instance()
elif auto_aug_learner == 'Random Searcher':
# As opposed to when ucb==True, `ds` and `IsLenet` are processed outside of the agent
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment