Skip to content
Snippets Groups Projects
Commit 92b2a3b6 authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

edit evo_learner:

controller is init'ed inside the learner
train_dataset.transform = to.tensor()
parent f117547f
No related branches found
No related tags found
No related merge requests found
import torch
torch.manual_seed(0)
import torch.nn as nn
import pygad
import pygad.torchga as torchga
import copy
import torchvision
import torch
from MetaAugment.autoaugment_learners.aa_learner import aa_learner
import MetaAugment.controller_networks as cont_n
class evo_learner(aa_learner):
......@@ -14,7 +14,7 @@ class evo_learner(aa_learner):
def __init__(self,
# search space settings
sp_num=5,
p_bins=10,
p_bins=11,
m_bins=10,
discrete_p_m=False,
exclude_method=[],
......@@ -27,7 +27,7 @@ class evo_learner(aa_learner):
# evolutionary learner specific settings
num_solutions=5,
num_parents_mating=3,
controller=None
controller=cont_n.evo_controller
):
super().__init__(
......@@ -43,14 +43,19 @@ class evo_learner(aa_learner):
exclude_method=exclude_method
)
# evolutionary algorithm settings
self.controller = controller(
fun_num=self.fun_num,
p_bins=self.p_bins,
m_bins=self.m_bins,
sub_num_pol=self.sp_num
)
self.num_solutions = num_solutions
self.controller = controller
self.torch_ga = torchga.TorchGA(model=self.controller, num_solutions=num_solutions)
self.num_parents_mating = num_parents_mating
self.initial_population = self.torch_ga.population_weights
self.p_bins = p_bins
self.sub_num_pol = sp_num
self.m_bins = m_bins
# store our logs
self.policy_dict = {}
self.policy_result = []
......@@ -77,7 +82,7 @@ class evo_learner(aa_learner):
section = self.fun_num + self.p_bins + self.m_bins
y = self.controller.forward(x)
full_policy = []
for pol in range(self.sub_num_pol):
for pol in range(self.sp_num):
int_pol = []
for _ in range(2):
idx_ret = torch.argmax(y[:, (pol * section):(pol*section) + self.fun_num].mean(dim = 0))
......@@ -277,6 +282,7 @@ class evo_learner(aa_learner):
weights_vector=solution)
self.controller.load_state_dict(model_weights_dict)
train_dataset.transform = torchvision.transforms.ToTensor()
self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size)
for idx, (test_x, label_x) in enumerate(self.train_loader):
......
......@@ -54,6 +54,7 @@ def parse_users_learner_spec(
p_bins=11,
m_bins=10,
discrete_p_m=True,
exclude_method=exclude_method,
# hyperparameters for when training the child_network
batch_size=batch_size,
toy_size=toy_size,
......@@ -63,52 +64,55 @@ def parse_users_learner_spec(
# ucb_learner specific hyperparameter
num_policies=num_policies
)
pprint(learner.policies)
learner.learn(
train_dataset=train_dataset,
test_dataset=test_dataset,
child_network_architecture=child_archi,
iterations=5
)
elif auto_aug_learner == 'Evolutionary Learner':
network = cont_n.evo_controller(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1)
child_network = cn.LeNet()
learner = aal.evo_learner(
network=network,
fun_num=num_funcs,
p_bins=1,
mag_bins=1,
sub_num_pol=1,
ds = ds,
ds_name=ds_name,
exclude_method=exclude_method,
child_network=child_network
)
# parameters that define the search space
sp_num=num_sub_policies,
p_bins=11,
m_bins=10,
discrete_p_m=True,
exclude_method=exclude_method,
# hyperparameters for when training the child_network
batch_size=batch_size,
toy_size=toy_size,
learning_rate=learning_rate,
max_epochs=max_epochs,
early_stop_num=early_stop_num,
)
learner.run_instance()
elif auto_aug_learner == 'Random Searcher':
agent = aal.randomsearch_learner(
sp_num=num_sub_policies,
batch_size=batch_size,
learning_rate=learning_rate,
toy_size=toy_size,
max_epochs=max_epochs,
early_stop_num=early_stop_num,
)
agent.learn(train_dataset,
test_dataset,
child_network_architecture=child_archi,
iterations=iterations)
# parameters that define the search space
sp_num=num_sub_policies,
p_bins=11,
m_bins=10,
discrete_p_m=True,
exclude_method=exclude_method,
# hyperparameters for when training the child_network
batch_size=batch_size,
toy_size=toy_size,
learning_rate=learning_rate,
max_epochs=max_epochs,
early_stop_num=early_stop_num,
)
elif auto_aug_learner == 'GRU Learner':
agent = aal.gru_learner(
sp_num=num_sub_policies,
batch_size=batch_size,
learning_rate=learning_rate,
toy_size=toy_size,
max_epochs=max_epochs,
early_stop_num=early_stop_num,
)
agent.learn(train_dataset,
test_dataset,
child_network_architecture=child_archi,
iterations=iterations)
\ No newline at end of file
# parameters that define the search space
sp_num=num_sub_policies,
p_bins=11,
m_bins=10,
discrete_p_m=True,
exclude_method=exclude_method,
# hyperparameters for when training the child_network
batch_size=batch_size,
toy_size=toy_size,
learning_rate=learning_rate,
max_epochs=max_epochs,
early_stop_num=early_stop_num,
)
agent.learn(train_dataset,
test_dataset,
child_network_architecture=child_archi,
iterations=iterations)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment