diff --git a/MetaAugment/autoaugment_learners/evo_learner.py b/MetaAugment/autoaugment_learners/evo_learner.py index c1ba5ed47f54b86d2a11ce4dac887608733eda7c..270d959862a19f674375329cd7d18c07f352b0a8 100644 --- a/MetaAugment/autoaugment_learners/evo_learner.py +++ b/MetaAugment/autoaugment_learners/evo_learner.py @@ -1,3 +1,4 @@ +from cgi import test import torch import torch.nn as nn import pygad @@ -30,18 +31,15 @@ class evo_learner(aa_learner): controller=cont_n.evo_controller ): - super().__init__( - sp_num=sp_num, - p_bins=p_bins, - m_bins=m_bins, - discrete_p_m=discrete_p_m, - batch_size=batch_size, - toy_size=toy_size, - learning_rate=learning_rate, - max_epochs=max_epochs, - early_stop_num=early_stop_num, - exclude_method=exclude_method - ) + super().__init__(sp_num, + fun_num, + p_bins, + m_bins, + batch_size=batch_size, + toy_size=toy_size, + learning_rate=learning_rate, + max_epochs=max_epochs, + early_stop_num=early_stop_num,) # evolutionary algorithm settings self.controller = controller( @@ -63,6 +61,7 @@ class evo_learner(aa_learner): assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!' + def get_full_policy(self, x): """ Generates the full policy (self.num_sub_pol subpolicies). Network architecture requires @@ -172,10 +171,10 @@ class evo_learner(aa_learner): prob2 += torch.sigmoid(y[idx, section+self.fun_num]).item() if mag1 is not None: # mag1 += min(max(0, (y[idx, self.auto_aug_agent.fun_num+1]).item()), 8) - mag1 += 10 * torch.sigmoid(y[idx, self.fun_num+1]).item() + mag1 += min(9, 10 * torch.sigmoid(y[idx, self.fun_num+1]).item()) if mag2 is not None: # mag2 += min(max(0, y[idx, section+self.auto_aug_agent.fun_num+1].item()), 8) - mag2 += 10 * torch.sigmoid(y[idx, self.fun_num+1]).item() + mag2 += min(9, 10 * torch.sigmoid(y[idx, self.fun_num+1]).item()) counter += 1 @@ -245,7 +244,7 @@ class evo_learner(aa_learner): self.policy_dict[trans1][trans2].append(new_set) return False else: - self.policy_dict[trans1][trans2] = [new_set] + self.policy_dict[trans1] = {trans2: [new_set]} if trans2 in self.policy_dict: if trans1 in self.policy_dict[trans2]: for test_pol in self.policy_dict[trans2][trans1]: @@ -254,7 +253,7 @@ class evo_learner(aa_learner): self.policy_dict[trans2][trans1].append(new_set) return False else: - self.policy_dict[trans2][trans1] = [new_set] + self.policy_dict[trans2] = {trans1: [new_set]} def set_up_instance(self, train_dataset, test_dataset, child_network_architecture): @@ -304,11 +303,11 @@ class evo_learner(aa_learner): if len(self.policy_result) > self.sp_num: self.policy_result = sorted(self.policy_result, key=lambda x: x[1], reverse=True) self.policy_result = self.policy_result[:self.sp_num] - print("Appended policy: ", self.policy_result) + print("appended policy: ", self.policy_result) if fit_val > self.history_best[self.gen_count]: - print("Best policy: ", full_policy) + print("best policy: ", full_policy) self.history_best[self.gen_count] = fit_val self.best_model = model_weights_dict @@ -341,4 +340,3 @@ class evo_learner(aa_learner): mutation_percent_genes = 0.1, fitness_func=fitness_func, on_generation = on_generation) -