diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py index cc452550bc4b79242b94f2ab193b4568fdba78fa..abccb0a916dc829570d3368be97352585556f2ed 100644 --- a/MetaAugment/autoaugment_learners/aa_learner.py +++ b/MetaAugment/autoaugment_learners/aa_learner.py @@ -103,6 +103,8 @@ class aa_learner: self.policy_record = {} + + def _translate_operation_tensor(self, operation_tensor, return_log_prob=False, argmax=False): """ takes in a tensor representing an operation and returns an actual operation which @@ -302,6 +304,7 @@ class aa_learner: self.history.append((policy, reward)) """ + def _test_autoaugment_policy(self, diff --git a/MetaAugment/autoaugment_learners/autoaugment.py b/MetaAugment/autoaugment_learners/autoaugment.py index 6baa28cd51c6a9c2a584aea0f3c772a42e55f92c..4ad1c4ebe574c53d3b11e03b39d531efc440ae66 100644 --- a/MetaAugment/autoaugment_learners/autoaugment.py +++ b/MetaAugment/autoaugment_learners/autoaugment.py @@ -238,8 +238,6 @@ class AutoAugment(torch.nn.Module): if probs[i] <= p: op_meta = self._augmentation_space(10, F.get_image_size(img)) magnitudes, signed = op_meta[op_name] - print("magnitude_id: ", magnitude_id) - print("magnitudes[magnitude_id]: ", magnitudes[magnitude_id]) magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0 if signed and signs[i] == 0: magnitude *= -1.0 diff --git a/MetaAugment/autoaugment_learners/evo_learner.py b/MetaAugment/autoaugment_learners/evo_learner.py index c3dd315dbfc1d3236ef753435244669c4cb447e3..2c0eb855de2f5a5209c41aa305bd9855e9ff7a9f 100644 --- a/MetaAugment/autoaugment_learners/evo_learner.py +++ b/MetaAugment/autoaugment_learners/evo_learner.py @@ -57,60 +57,11 @@ class evo_learner(aa_learner): # store our logs self.policy_dict = {} - self.policy_result = [] - - - assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!' + self.running_policy = [] - def get_full_policy(self, x): - """ - Generates the full policy (self.num_sub_pol subpolicies). Network architecture requires - output size 5 * 2 * (self.fun_num + self.p_bins + self.m_bins) - - Parameters - ----------- - x -> PyTorch tensor - Input data for network - - Returns - ---------- - full_policy -> [((String, float, float), (String, float, float)), ...) - Full policy consisting of tuples of subpolicies. Each subpolicy consisting of - two transformations, with a probability and magnitude float for each - """ - section = self.fun_num + self.p_bins + self.m_bins - y = self.controller.forward(x) - full_policy = [] - for pol in range(self.sp_num): - int_pol = [] - for _ in range(2): - idx_ret = torch.argmax(y[:, (pol * section):(pol*section) + self.fun_num].mean(dim = 0)) - - trans, need_mag = self.augmentation_space[idx_ret] - - if self.p_bins == 1: - p_ret = min(1, max(0, (y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0).item()))) - # p_ret = torch.sigmoid(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0)) - else: - p_ret = torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0).item()) * 0.1 - - - if need_mag: - # print("original mag", y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0)) - if self.m_bins == 1: - mag = min(9, max(0, (y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0).item()))) - else: - mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0).item()) - mag = int(mag) - else: - mag = None - int_pol.append((trans, p_ret, mag)) - - full_policy.append(tuple(int_pol)) - - return full_policy + assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!' def get_single_policy_cov(self, x, alpha = 0.5): @@ -172,10 +123,8 @@ class evo_learner(aa_learner): prob1 += torch.sigmoid(y[idx, self.fun_num]).item() prob2 += torch.sigmoid(y[idx, section+self.fun_num]).item() if mag1 is not None: - # mag1 += min(max(0, (y[idx, self.auto_aug_agent.fun_num+1]).item()), 8) mag1 += min(9, 10 * torch.sigmoid(y[idx, self.fun_num+1]).item()) if mag2 is not None: - # mag2 += min(max(0, y[idx, section+self.auto_aug_agent.fun_num+1].item()), 8) mag2 += min(9, 10 * torch.sigmoid(y[idx, self.fun_num+1]).item()) counter += 1 @@ -213,19 +162,14 @@ class evo_learner(aa_learner): Solution_idx -> Int """ self.num_generations = iterations - self.history_best = [0 for i in range(iterations+1)] - print("itations: ", iterations) + self.history_best = [] - self.history_avg = [0 for i in range(iterations+1)] self.gen_count = 0 self.best_model = 0 self.set_up_instance(train_dataset, test_dataset, child_network_architecture) - print("train_dataset: ", train_dataset) self.ga_instance.run() - self.history_avg = [x / self.num_solutions for x in self.history_avg] - print("-----------------------------------------------------------------------------------------------------") solution, solution_fitness, solution_idx = self.ga_instance.best_solution() if return_weights: @@ -278,33 +222,34 @@ class evo_learner(aa_learner): self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size) for idx, (test_x, label_x) in enumerate(self.train_loader): - # if self.sp_num == 1: full_policy = self.get_single_policy_cov(test_x) - # else: - # full_policy = self.get_full_policy(test_x) while self.in_pol_dict(full_policy): full_policy = self.get_single_policy_cov(test_x)[0] - fit_val = self.test_autoaugment_policy(full_policy,child_network_architecture,train_dataset,test_dataset) #) / - # + self.test_autoaugment_policy(full_policy, train_dataset, test_dataset)) / 2 + fit_val = self.test_autoaugment_policy(full_policy,child_network_architecture,train_dataset,test_dataset) - self.policy_result.append([full_policy, fit_val]) + self.history.append((full_policy, fit_val)) + self.running_policy.append((full_policy, fit_val)) - if len(self.policy_result) > self.sp_num: - self.policy_result = sorted(self.policy_result, key=lambda x: x[1], reverse=True) - self.policy_result = self.policy_result[:self.sp_num] - print("appended policy: ", self.policy_result) + if len(self.running_policy) > self.sp_num: + self.running_policy = sorted(self.running_policy, key=lambda x: x[1], reverse=True) + self.running_policy = self.running_policy[:self.sp_num] + print("appended policy: ", self.running_policy) - if fit_val > self.history_best[self.gen_count]: - print("best policy: ", full_policy) - self.history_best[self.gen_count] = fit_val + if len(self.history_best) == 0: + self.history_best.append((fit_val)) + self.best_model = model_weights_dict + elif fit_val > self.history_best[-1]: + self.history_best.append(fit_val) self.best_model = model_weights_dict + else: + self.history_best.append(self.history_best[-1]) - self.history_avg[self.gen_count] += fit_val + return fit_val