diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py index 48c05b95d5bef8de9c0405e949f1e6663e66e8ae..abccb0a916dc829570d3368be97352585556f2ed 100644 --- a/MetaAugment/autoaugment_learners/aa_learner.py +++ b/MetaAugment/autoaugment_learners/aa_learner.py @@ -99,6 +99,10 @@ class aa_learner: self.fun_num = len(self.augmentation_space) self.op_tensor_length = self.fun_num + p_bins + m_bins if discrete_p_m else self.fun_num +2 + self.num_pols_tested = 0 + self.policy_record = {} + + def _translate_operation_tensor(self, operation_tensor, return_log_prob=False, argmax=False): @@ -300,6 +304,7 @@ class aa_learner: self.history.append((policy, reward)) """ + def _test_autoaugment_policy(self, @@ -329,6 +334,8 @@ class aa_learner: accuracy (float): best accuracy reached in any """ + + # we create an instance of the child network that we're going # to train. The method of creation depends on the type of # input we got for child_network_architecture @@ -378,8 +385,24 @@ class aa_learner: early_stop_num = self.early_stop_num, logging = logging, print_every_epoch=print_every_epoch) + + curr_pol = f'pol{self.num_pols_tested}' + pol_dict = {} + for subpol in policy: + subpol = subpol[0] + first_trans, first_prob, first_mag = subpol[0] + second_trans, second_prob, second_mag = subpol[1] + components = (first_prob, first_mag, second_prob, second_mag) + if second_trans in pol_dict[first_trans]: + pol_dict[first_trans][second_trans].append(components) + else: + pol_dict[first_trans]= {second_trans: [components]} + self.policy_record[curr_pol] = (pol_dict, accuracy) + + self.num_pols_tested += 1 # if logging is true, 'accuracy' is actually a tuple: (accuracy, accuracy_log) + return accuracy diff --git a/MetaAugment/autoaugment_learners/evo_learner.py b/MetaAugment/autoaugment_learners/evo_learner.py index d5a076b5bb9bad66a76b5346a8f79e5e36c48e2f..9a6b291bac4f07a817295dcf63edbe0a06a204e3 100644 --- a/MetaAugment/autoaugment_learners/evo_learner.py +++ b/MetaAugment/autoaugment_learners/evo_learner.py @@ -57,60 +57,11 @@ class evo_learner(aa_learner): # store our logs self.policy_dict = {} - self.policy_result = [] + self.running_policy = [] - assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!' - - - - def get_full_policy(self, x): - """ - Generates the full policy (self.num_sub_pol subpolicies). Network architecture requires - output size 5 * 2 * (self.fun_num + self.p_bins + self.m_bins) - - Parameters - ----------- - x -> PyTorch tensor - Input data for network - - Returns - ---------- - full_policy -> [((String, float, float), (String, float, float)), ...) - Full policy consisting of tuples of subpolicies. Each subpolicy consisting of - two transformations, with a probability and magnitude float for each - """ - section = self.fun_num + self.p_bins + self.m_bins - y = self.controller.forward(x) - full_policy = [] - for pol in range(self.sp_num): - int_pol = [] - for _ in range(2): - idx_ret = torch.argmax(y[:, (pol * section):(pol*section) + self.fun_num].mean(dim = 0)) - trans, need_mag = self.augmentation_space[idx_ret] - - if self.p_bins == 1: - p_ret = min(1, max(0, (y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0).item()))) - # p_ret = torch.sigmoid(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0)) - else: - p_ret = torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0).item()) * 0.1 - - - if need_mag: - # print("original mag", y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0)) - if self.m_bins == 1: - mag = min(9, max(0, (y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0).item()))) - else: - mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0).item()) - mag = int(mag) - else: - mag = None - int_pol.append((trans, p_ret, mag)) - - full_policy.append(tuple(int_pol)) - - return full_policy + assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!' def get_single_policy_cov(self, x, alpha = 0.5): @@ -172,10 +123,8 @@ class evo_learner(aa_learner): prob1 += torch.sigmoid(y[idx, self.fun_num]).item() prob2 += torch.sigmoid(y[idx, section+self.fun_num]).item() if mag1 is not None: - # mag1 += min(max(0, (y[idx, self.auto_aug_agent.fun_num+1]).item()), 8) mag1 += min(9, 10 * torch.sigmoid(y[idx, self.fun_num+1]).item()) if mag2 is not None: - # mag2 += min(max(0, y[idx, section+self.auto_aug_agent.fun_num+1].item()), 8) mag2 += min(9, 10 * torch.sigmoid(y[idx, self.fun_num+1]).item()) counter += 1 @@ -213,19 +162,14 @@ class evo_learner(aa_learner): Solution_idx -> Int """ self.num_generations = iterations - self.history_best = [0 for i in range(iterations+1)] - print("itations: ", iterations) + self.history_best = [] - self.history_avg = [0 for i in range(iterations+1)] self.gen_count = 0 self.best_model = 0 self.set_up_instance(train_dataset, test_dataset, child_network_architecture) - print("train_dataset: ", train_dataset) self.ga_instance.run() - self.history_avg = [x / self.num_solutions for x in self.history_avg] - print("-----------------------------------------------------------------------------------------------------") solution, solution_fitness, solution_idx = self.ga_instance.best_solution() if return_weights: @@ -244,18 +188,9 @@ class evo_learner(aa_learner): if new_set == test_pol: return True self.policy_dict[trans1][trans2].append(new_set) - return False else: self.policy_dict[trans1] = {trans2: [new_set]} - if trans2 in self.policy_dict: - if trans1 in self.policy_dict[trans2]: - for test_pol in self.policy_dict[trans2][trans1]: - if new_set == test_pol: - return True - self.policy_dict[trans2][trans1].append(new_set) - return False - else: - self.policy_dict[trans2] = {trans1: [new_set]} + return False def set_up_instance(self, train_dataset, test_dataset, child_network_architecture): @@ -287,33 +222,34 @@ class evo_learner(aa_learner): self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size) for idx, (test_x, label_x) in enumerate(self.train_loader): - # if self.sp_num == 1: full_policy = self.get_single_policy_cov(test_x) - # else: - # full_policy = self.get_full_policy(test_x) while self.in_pol_dict(full_policy): full_policy = self.get_single_policy_cov(test_x)[0] - fit_val = self._test_autoaugment_policy(full_policy,child_network_architecture,train_dataset,test_dataset) #) / - # + self.test_autoaugment_policy(full_policy, train_dataset, test_dataset)) / 2 + fit_val = self._test_autoaugment_policy(full_policy,child_network_architecture,train_dataset,test_dataset) - self.policy_result.append([full_policy, fit_val]) + self.history.append((full_policy, fit_val)) + self.running_policy.append((full_policy, fit_val)) - if len(self.policy_result) > self.sp_num: - self.policy_result = sorted(self.policy_result, key=lambda x: x[1], reverse=True) - self.policy_result = self.policy_result[:self.sp_num] - print("appended policy: ", self.policy_result) + if len(self.running_policy) > self.sp_num: + self.running_policy = sorted(self.running_policy, key=lambda x: x[1], reverse=True) + self.running_policy = self.running_policy[:self.sp_num] + print("appended policy: ", self.running_policy) - if fit_val > self.history_best[self.gen_count]: - print("best policy: ", full_policy) - self.history_best[self.gen_count] = fit_val + if len(self.history_best) == 0: + self.history_best.append((fit_val)) self.best_model = model_weights_dict + elif fit_val > self.history_best[-1]: + self.history_best.append(fit_val) + self.best_model = model_weights_dict + else: + self.history_best.append(self.history_best[-1]) - self.history_avg[self.gen_count] += fit_val + return fit_val