diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py index 0eb38d59c1e3156dbc7a620ee1958fb7d1d032bb..e4460cbfdca799022cd2f1d1ff950cd780355fa1 100644 --- a/MetaAugment/autoaugment_learners/aa_learner.py +++ b/MetaAugment/autoaugment_learners/aa_learner.py @@ -46,7 +46,6 @@ class aa_learner: def __init__(self, # parameters that define the search space sp_num=5, - fun_num=14, p_bins=11, m_bins=10, discrete_p_m=False, @@ -57,6 +56,7 @@ class aa_learner: learning_rate=1e-1, max_epochs=float('inf'), early_stop_num=20, + exclude_method = [], ): """ Args: @@ -84,11 +84,9 @@ class aa_learner: """ # related to defining the search space self.sp_num = sp_num - self.fun_num = fun_num self.p_bins = p_bins self.m_bins = m_bins self.discrete_p_m = discrete_p_m - self.op_tensor_length = fun_num+p_bins+m_bins if discrete_p_m else fun_num+2 # related to training of the child_network self.batch_size = batch_size @@ -101,6 +99,9 @@ class aa_learner: # TODO: We should probably use a different way to store results than self.history self.history = [] + self.augmentation_space = [x for x in augmentation_space if x not in exclude_method] + self.fun_num = len(augmentation_space) + self.op_tensor_length = self.fun_num + p_bins + m_bins if discrete_p_m else self.fun_num +2 def translate_operation_tensor(self, operation_tensor, return_log_prob=False, argmax=False): diff --git a/MetaAugment/autoaugment_learners/autoaugment.py b/MetaAugment/autoaugment_learners/autoaugment.py index 8e10c74547f7230a0eeecf11356804413721f7c1..5a8ecbcf6f0b8c6212a8c034a70d61476f4870f6 100644 --- a/MetaAugment/autoaugment_learners/autoaugment.py +++ b/MetaAugment/autoaugment_learners/autoaugment.py @@ -238,6 +238,8 @@ class AutoAugment(torch.nn.Module): if probs[i] <= p: op_meta = self._augmentation_space(10, F.get_image_size(img)) magnitudes, signed = op_meta[op_name] + print("magnitude_id: ", magnitude_id) + print("magnitudes[magnitude_id]: ", magnitudes[magnitude_id]) magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0 if signed and signs[i] == 0: magnitude *= -1.0 diff --git a/MetaAugment/autoaugment_learners/evo_learner.py b/MetaAugment/autoaugment_learners/evo_learner.py index 18ecf751e614585c7db86902eb3cce927dd696f5..8e1d5bc198548c3e24bb3d2bd5ac2d1f39650923 100644 --- a/MetaAugment/autoaugment_learners/evo_learner.py +++ b/MetaAugment/autoaugment_learners/evo_learner.py @@ -6,34 +6,32 @@ import pygad import pygad.torchga as torchga import copy import torch -from MetaAugment.controller_networks.evo_controller import evo_controller +from MetaAugment.controller_networks.evo_controller import Evo_learner + +from MetaAugment.autoaugment_learners.aa_learner import aa_learner, augmentation_space import MetaAugment.child_networks as cn -from .aa_learner import aa_learner, augmentation_space class evo_learner(aa_learner): def __init__(self, sp_num=1, - num_solutions = 10, - num_parents_mating = 5, + num_solutions = 5, + num_parents_mating = 3, learning_rate = 1e-1, max_epochs=float('inf'), early_stop_num=20, - train_loader = None, - child_network = None, p_bins = 1, m_bins = 1, discrete_p_m=False, batch_size=8, toy_flag=False, toy_size=0.1, - fun_num = 14, exclude_method=[], + controller = None ): super().__init__(sp_num, - fun_num, p_bins, m_bins, discrete_p_m=discrete_p_m, @@ -42,27 +40,24 @@ class evo_learner(aa_learner): toy_size=toy_size, learning_rate=learning_rate, max_epochs=max_epochs, - early_stop_num=early_stop_num,) + early_stop_num=early_stop_num, + exclude_method=exclude_method) self.num_solutions = num_solutions - self.auto_aug_agent = evo_controller(fun_num=fun_num, p_bins=p_bins, m_bins=m_bins, sub_num_pol=sp_num) - self.torch_ga = torchga.TorchGA(model=self.auto_aug_agent, num_solutions=num_solutions) + self.controller = controller + self.torch_ga = torchga.TorchGA(model=self.controller, num_solutions=num_solutions) self.num_parents_mating = num_parents_mating self.initial_population = self.torch_ga.population_weights - self.train_loader = train_loader - self.child_network = child_network self.p_bins = p_bins self.sub_num_pol = sp_num self.m_bins = m_bins - self.fun_num = fun_num - self.augmentation_space = [x for x in augmentation_space if x[0] not in exclude_method] - + self.policy_dict = {} + self.policy_result = [] assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!' - def get_full_policy(self, x): """ Generates the full policy (self.num_sub_pol subpolicies). Network architecture requires @@ -79,8 +74,8 @@ class evo_learner(aa_learner): Full policy consisting of tuples of subpolicies. Each subpolicy consisting of two transformations, with a probability and magnitude float for each """ - section = self.auto_aug_agent.fun_num + self.auto_aug_agent.p_bins + self.auto_aug_agent.m_bins - y = self.auto_aug_agent.forward(x) + section = self.fun_num + self.p_bins + self.m_bins + y = self.controller.forward(x) full_policy = [] for pol in range(self.sub_num_pol): int_pol = [] @@ -89,8 +84,22 @@ class evo_learner(aa_learner): trans, need_mag = self.augmentation_space[idx_ret] - p_ret = (1/(self.p_bins-1)) * torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0)) - mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0)) if need_mag else None + if self.p_bins == 1: + p_ret = min(1, max(0, (y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0).item()))) + # p_ret = torch.sigmoid(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0)) + else: + p_ret = torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0).item()) * 0.1 + + + if need_mag: + # print("original mag", y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0)) + if self.m_bins == 1: + mag = min(9, max(0, (y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0).item()))) + else: + mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0).item()) + mag = int(mag) + else: + mag = None int_pol.append((trans, p_ret, mag)) full_policy.append(tuple(int_pol)) @@ -117,18 +126,18 @@ class evo_learner(aa_learner): Subpolicy consisting of two tuples of policies, each with a string associated to a transformation, a float for a probability, and a float for a magnittude """ - section = self.auto_aug_agent.fun_num + self.auto_aug_agent.p_bins + self.auto_aug_agent.m_bins + section = self.fun_num + self.p_bins + self.m_bins - y = self.auto_aug_agent.forward(x) + y = self.controller.forward(x) - y_1 = torch.softmax(y[:,:self.auto_aug_agent.fun_num], dim = 1) - y[:,:self.auto_aug_agent.fun_num] = y_1 - y_2 = torch.softmax(y[:,section:section+self.auto_aug_agent.fun_num], dim = 1) - y[:,section:section+self.auto_aug_agent.fun_num] = y_2 + y_1 = torch.softmax(y[:,:self.fun_num], dim = 1) + y[:,:self.fun_num] = y_1 + y_2 = torch.softmax(y[:,section:section+self.fun_num], dim = 1) + y[:,section:section+self.fun_num] = y_2 concat = torch.cat((y_1, y_2), dim = 1) cov_mat = torch.cov(concat.T) - cov_mat = cov_mat[:self.auto_aug_agent.fun_num, self.auto_aug_agent.fun_num:] + cov_mat = cov_mat[:self.fun_num, self.fun_num:] shape_store = cov_mat.shape counter, prob1, prob2, mag1, mag2 = (0, 0, 0, 0, 0) @@ -154,26 +163,29 @@ class evo_learner(aa_learner): for idx in range(y.shape[0]): if (torch.argmax(y_1[idx]) == max_idx[0]) and (torch.argmax(y_2[idx]) == max_idx[1]): - prob1 += torch.sigmoid(y[idx, self.auto_aug_agent.fun_num]).item() - prob2 += torch.sigmoid(y[idx, section+self.auto_aug_agent.fun_num]).item() + prob1 += torch.sigmoid(y[idx, self.fun_num]).item() + prob2 += torch.sigmoid(y[idx, section+self.fun_num]).item() if mag1 is not None: - mag1 += min(max(0, (y[idx, self.auto_aug_agent.fun_num+1]).item()), 8) + # mag1 += min(max(0, (y[idx, self.auto_aug_agent.fun_num+1]).item()), 8) + mag1 += 10 * torch.sigmoid(y[idx, self.fun_num+1]).item() if mag2 is not None: - mag2 += min(max(0, y[idx, section+self.auto_aug_agent.fun_num+1].item()), 8) + # mag2 += min(max(0, y[idx, section+self.auto_aug_agent.fun_num+1].item()), 8) + mag2 += 10 * torch.sigmoid(y[idx, self.fun_num+1]).item() + counter += 1 - prob1 = prob1/counter if counter != 0 else 0 - prob2 = prob2/counter if counter != 0 else 0 + prob1 = round(prob1/counter, 1) if counter != 0 else 0 + prob2 = round(prob2/counter, 1) if counter != 0 else 0 if mag1 is not None: - mag1 = mag1/counter + mag1 = int(mag1/counter) if mag2 is not None: - mag2 = mag2/counter + mag2 = int(mag2/counter) - return [(self.augmentation_space[max_idx[0]][0], prob1, mag1), (self.augmentation_space[max_idx[1]][0], prob2, mag2)] + return [((self.augmentation_space[max_idx[0]][0], prob1, mag1), (self.augmentation_space[max_idx[1]][0], prob2, mag2))] - def learn(self, iterations = 15, return_weights = False): + def learn(self, train_dataset, test_dataset, child_network_architecture, iterations = 15, return_weights = False): """ Runs the GA instance and returns the model weights as a dictionary @@ -195,24 +207,52 @@ class evo_learner(aa_learner): Solution_idx -> Int """ self.num_generations = iterations - self.history_best = [0 for i in range(iterations)] - self.history_avg = [0 for i in range(iterations)] + self.history_best = [0 for i in range(iterations+1)] + print("itations: ", iterations) + + self.history_avg = [0 for i in range(iterations+1)] self.gen_count = 0 self.best_model = 0 - self.set_up_instance() + self.set_up_instance(train_dataset, test_dataset, child_network_architecture) + print("train_dataset: ", train_dataset) self.ga_instance.run() - self.history_avg = self.history_avg / self.num_solutions + self.history_avg = [x / self.num_solutions for x in self.history_avg] + print("-----------------------------------------------------------------------------------------------------") solution, solution_fitness, solution_idx = self.ga_instance.best_solution() if return_weights: - return torchga.model_weights_as_dict(model=self.auto_aug_agent, weights_vector=solution) + return torchga.model_weights_as_dict(model=self.controller, weights_vector=solution) else: return solution, solution_fitness, solution_idx - def set_up_instance(self, train_dataset, test_dataset): + def in_pol_dict(self, new_policy): + new_policy = new_policy[0] + trans1, trans2 = new_policy[0][0], new_policy[1][0] + new_set = {new_policy[0][1], new_policy[0][2], new_policy[1][1], new_policy[1][2]} + if trans1 in self.policy_dict: + if trans2 in self.policy_dict[trans1]: + for test_pol in self.policy_dict[trans1][trans2]: + if new_set == test_pol: + return True + self.policy_dict[trans1][trans2].append(new_set) + return False + else: + self.policy_dict[trans1][trans2] = [new_set] + if trans2 in self.policy_dict: + if trans1 in self.policy_dict[trans2]: + for test_pol in self.policy_dict[trans2][trans1]: + if new_set == test_pol: + return True + self.policy_dict[trans2][trans1].append(new_set) + return False + else: + self.policy_dict[trans2][trans1] = [new_set] + + + def set_up_instance(self, train_dataset, test_dataset, child_network_architecture): """ Initialises GA instance, as well as fitness and on_generation functions @@ -233,24 +273,36 @@ class evo_learner(aa_learner): fit_val -> float """ - model_weights_dict = torchga.model_weights_as_dict(model=self.auto_aug_agent, + model_weights_dict = torchga.model_weights_as_dict(model=self.controller, weights_vector=solution) - self.auto_aug_agent.load_state_dict(model_weights_dict) + self.controller.load_state_dict(model_weights_dict) self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size) for idx, (test_x, label_x) in enumerate(self.train_loader): - if self.sp_num == 1: - full_policy = self.get_single_policy_cov(test_x) - else: - full_policy = self.get_full_policy(test_x) + # if self.sp_num == 1: + full_policy = self.get_single_policy_cov(test_x) + + + # else: + # full_policy = self.get_full_policy(test_x) + while self.in_pol_dict(full_policy): + full_policy = self.get_single_policy_cov(test_x)[0] -# Checkpoint -> save learner as a pickle - fit_val = ((self.test_autoaugment_policy(full_policy, train_dataset, test_dataset)[0]) / - + self.test_autoaugment_policy(full_policy, train_dataset, test_dataset)[0]) / 2 + fit_val = self.test_autoaugment_policy(full_policy,child_network_architecture,train_dataset,test_dataset) #) / + # + self.test_autoaugment_policy(full_policy, train_dataset, test_dataset)) / 2 + + self.policy_result.append([full_policy, fit_val]) + + if len(self.policy_result) > self.sp_num: + self.policy_result = sorted(self.policy_result, key=lambda x: x[1], reverse=True) + self.policy_result = self.policy_result[:self.sp_num] + print("Appended policy: ", self.policy_result) + if fit_val > self.history_best[self.gen_count]: + print("Best policy: ", full_policy) self.history_best[self.gen_count] = fit_val self.best_model = model_weights_dict @@ -284,6 +336,3 @@ class evo_learner(aa_learner): fitness_func=fitness_func, on_generation = on_generation) - - -