diff --git a/MetaAugment/autoaugment_learners/evo_learner.py b/MetaAugment/autoaugment_learners/evo_learner.py index 34cc2d44555423475914a1ba2528cfddb71aad57..a1f21a1a8e1a6b6827ff16c40da53e0f4f6a504b 100644 --- a/MetaAugment/autoaugment_learners/evo_learner.py +++ b/MetaAugment/autoaugment_learners/evo_learner.py @@ -1,3 +1,4 @@ +from cgi import test import torch torch.manual_seed(0) import torch.nn as nn @@ -5,43 +6,39 @@ import pygad import pygad.torchga as torchga import copy import torch +from meta_augment.controller_networks.evo_controller import evo_controller -from MetaAugment.autoaugment_learners.aa_learner import aa_learner +from meta_augment.autoaugment_learners.aa_learner import aa_learner +import meta_augment.child_networks as cn class evo_learner(aa_learner): def __init__(self, - # search space settings - sp_num=5, - p_bins=10, - m_bins=10, - discrete_p_m=False, - exclude_method=[], - # child network settings - learning_rate=1e-1, + sp_num=1, + num_solutions = 5, + num_parents_mating = 3, + learning_rate = 1e-1, max_epochs=float('inf'), early_stop_num=20, + p_bins = 1, + m_bins = 1, batch_size=8, - toy_size=1, - # evolutionary learner specific settings - num_solutions=5, - num_parents_mating=3, - controller=None + toy_size=0.1, + fun_num = 14, + exclude_method=[], + controller = None ): - super().__init__( - sp_num=sp_num, - p_bins=p_bins, - m_bins=m_bins, - discrete_p_m=discrete_p_m, - batch_size=batch_size, - toy_size=toy_size, - learning_rate=learning_rate, - max_epochs=max_epochs, - early_stop_num=early_stop_num, - exclude_method=exclude_method - ) + super().__init__(sp_num, + fun_num, + p_bins, + m_bins, + batch_size=batch_size, + toy_size=toy_size, + learning_rate=learning_rate, + max_epochs=max_epochs, + early_stop_num=early_stop_num,) self.num_solutions = num_solutions self.controller = controller @@ -51,6 +48,8 @@ class evo_learner(aa_learner): self.p_bins = p_bins self.sub_num_pol = sp_num self.m_bins = m_bins + self.fun_num = fun_num + self.augmentation_space = [x for x in self.augmentation_space if x[0] not in exclude_method] self.policy_dict = {} self.policy_result = [] @@ -58,6 +57,7 @@ class evo_learner(aa_learner): assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!' + def get_full_policy(self, x): """ Generates the full policy (self.num_sub_pol subpolicies). Network architecture requires @@ -167,10 +167,10 @@ class evo_learner(aa_learner): prob2 += torch.sigmoid(y[idx, section+self.fun_num]).item() if mag1 is not None: # mag1 += min(max(0, (y[idx, self.auto_aug_agent.fun_num+1]).item()), 8) - mag1 += 10 * torch.sigmoid(y[idx, self.fun_num+1]).item() + mag1 += min(9, 10 * torch.sigmoid(y[idx, self.fun_num+1]).item()) if mag2 is not None: # mag2 += min(max(0, y[idx, section+self.auto_aug_agent.fun_num+1].item()), 8) - mag2 += 10 * torch.sigmoid(y[idx, self.fun_num+1]).item() + mag2 += min(9, 10 * torch.sigmoid(y[idx, self.fun_num+1]).item()) counter += 1 @@ -240,7 +240,7 @@ class evo_learner(aa_learner): self.policy_dict[trans1][trans2].append(new_set) return False else: - self.policy_dict[trans1][trans2] = [new_set] + self.policy_dict[trans1] = {trans2: [new_set]} if trans2 in self.policy_dict: if trans1 in self.policy_dict[trans2]: for test_pol in self.policy_dict[trans2][trans1]: @@ -249,7 +249,7 @@ class evo_learner(aa_learner): self.policy_dict[trans2][trans1].append(new_set) return False else: - self.policy_dict[trans2][trans1] = [new_set] + self.policy_dict[trans2] = {trans1: [new_set]} def set_up_instance(self, train_dataset, test_dataset, child_network_architecture): @@ -298,11 +298,11 @@ class evo_learner(aa_learner): if len(self.policy_result) > self.sp_num: self.policy_result = sorted(self.policy_result, key=lambda x: x[1], reverse=True) self.policy_result = self.policy_result[:self.sp_num] - print("Appended policy: ", self.policy_result) + print("appended policy: ", self.policy_result) if fit_val > self.history_best[self.gen_count]: - print("Best policy: ", full_policy) + print("best policy: ", full_policy) self.history_best[self.gen_count] = fit_val self.best_model = model_weights_dict @@ -335,4 +335,3 @@ class evo_learner(aa_learner): mutation_percent_genes = 0.1, fitness_func=fitness_func, on_generation = on_generation) -