diff --git a/MetaAugment/autoaugment_learners/evo_learner.py b/MetaAugment/autoaugment_learners/evo_learner.py index 34cc2d44555423475914a1ba2528cfddb71aad57..c1ba5ed47f54b86d2a11ce4dac887608733eda7c 100644 --- a/MetaAugment/autoaugment_learners/evo_learner.py +++ b/MetaAugment/autoaugment_learners/evo_learner.py @@ -1,12 +1,12 @@ import torch -torch.manual_seed(0) import torch.nn as nn import pygad import pygad.torchga as torchga -import copy +import torchvision import torch from MetaAugment.autoaugment_learners.aa_learner import aa_learner +import MetaAugment.controller_networks as cont_n class evo_learner(aa_learner): @@ -14,7 +14,7 @@ class evo_learner(aa_learner): def __init__(self, # search space settings sp_num=5, - p_bins=10, + p_bins=11, m_bins=10, discrete_p_m=False, exclude_method=[], @@ -27,7 +27,7 @@ class evo_learner(aa_learner): # evolutionary learner specific settings num_solutions=5, num_parents_mating=3, - controller=None + controller=cont_n.evo_controller ): super().__init__( @@ -43,14 +43,19 @@ class evo_learner(aa_learner): exclude_method=exclude_method ) + # evolutionary algorithm settings + self.controller = controller( + fun_num=self.fun_num, + p_bins=self.p_bins, + m_bins=self.m_bins, + sub_num_pol=self.sp_num + ) self.num_solutions = num_solutions - self.controller = controller self.torch_ga = torchga.TorchGA(model=self.controller, num_solutions=num_solutions) self.num_parents_mating = num_parents_mating self.initial_population = self.torch_ga.population_weights - self.p_bins = p_bins - self.sub_num_pol = sp_num - self.m_bins = m_bins + + # store our logs self.policy_dict = {} self.policy_result = [] @@ -77,7 +82,7 @@ class evo_learner(aa_learner): section = self.fun_num + self.p_bins + self.m_bins y = self.controller.forward(x) full_policy = [] - for pol in range(self.sub_num_pol): + for pol in range(self.sp_num): int_pol = [] for _ in range(2): idx_ret = torch.argmax(y[:, (pol * section):(pol*section) + self.fun_num].mean(dim = 0)) @@ -277,6 +282,7 @@ class evo_learner(aa_learner): weights_vector=solution) self.controller.load_state_dict(model_weights_dict) + train_dataset.transform = torchvision.transforms.ToTensor() self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size) for idx, (test_x, label_x) in enumerate(self.train_loader): diff --git a/temp_util/wapp_util.py b/temp_util/wapp_util.py index cde572f7f1e0ba1590fb5685b7212f4d0b3b173a..bb10113311f77d32f3f0b996e76d71f16fa0d184 100644 --- a/temp_util/wapp_util.py +++ b/temp_util/wapp_util.py @@ -54,6 +54,7 @@ def parse_users_learner_spec( p_bins=11, m_bins=10, discrete_p_m=True, + exclude_method=exclude_method, # hyperparameters for when training the child_network batch_size=batch_size, toy_size=toy_size, @@ -63,52 +64,55 @@ def parse_users_learner_spec( # ucb_learner specific hyperparameter num_policies=num_policies ) - pprint(learner.policies) - - learner.learn( - train_dataset=train_dataset, - test_dataset=test_dataset, - child_network_architecture=child_archi, - iterations=5 - ) elif auto_aug_learner == 'Evolutionary Learner': - network = cont_n.evo_controller(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1) - child_network = cn.LeNet() learner = aal.evo_learner( - network=network, - fun_num=num_funcs, - p_bins=1, - mag_bins=1, - sub_num_pol=1, - ds = ds, - ds_name=ds_name, - exclude_method=exclude_method, - child_network=child_network - ) + # parameters that define the search space + sp_num=num_sub_policies, + p_bins=11, + m_bins=10, + discrete_p_m=True, + exclude_method=exclude_method, + # hyperparameters for when training the child_network + batch_size=batch_size, + toy_size=toy_size, + learning_rate=learning_rate, + max_epochs=max_epochs, + early_stop_num=early_stop_num, + ) learner.run_instance() elif auto_aug_learner == 'Random Searcher': agent = aal.randomsearch_learner( - sp_num=num_sub_policies, - batch_size=batch_size, - learning_rate=learning_rate, - toy_size=toy_size, - max_epochs=max_epochs, - early_stop_num=early_stop_num, - ) - agent.learn(train_dataset, - test_dataset, - child_network_architecture=child_archi, - iterations=iterations) + # parameters that define the search space + sp_num=num_sub_policies, + p_bins=11, + m_bins=10, + discrete_p_m=True, + exclude_method=exclude_method, + # hyperparameters for when training the child_network + batch_size=batch_size, + toy_size=toy_size, + learning_rate=learning_rate, + max_epochs=max_epochs, + early_stop_num=early_stop_num, + ) elif auto_aug_learner == 'GRU Learner': agent = aal.gru_learner( - sp_num=num_sub_policies, - batch_size=batch_size, - learning_rate=learning_rate, - toy_size=toy_size, - max_epochs=max_epochs, - early_stop_num=early_stop_num, - ) - agent.learn(train_dataset, - test_dataset, - child_network_architecture=child_archi, - iterations=iterations) \ No newline at end of file + # parameters that define the search space + sp_num=num_sub_policies, + p_bins=11, + m_bins=10, + discrete_p_m=True, + exclude_method=exclude_method, + # hyperparameters for when training the child_network + batch_size=batch_size, + toy_size=toy_size, + learning_rate=learning_rate, + max_epochs=max_epochs, + early_stop_num=early_stop_num, + ) + + + agent.learn(train_dataset, + test_dataset, + child_network_architecture=child_archi, + iterations=iterations) \ No newline at end of file