diff --git a/MetaAugment/CP2_Max.py b/MetaAugment/CP2_Max.py index c85e9fa95087839d259a18123c2da4d8bd77087f..04c8ffcfe871f3fdc5c570c522e92e81e52228aa 100644 --- a/MetaAugment/CP2_Max.py +++ b/MetaAugment/CP2_Max.py @@ -46,24 +46,6 @@ class Learner(nn.Module): self.p_bins = p_bins self.m_bins = m_bins - self.augmentation_space = [ - # (function_name, do_we_need_to_specify_magnitude) - ("ShearX", True), - ("ShearY", True), - ("TranslateX", True), - ("TranslateY", True), - ("Rotate", True), - ("Brightness", True), - ("Color", True), - ("Contrast", True), - ("Sharpness", True), - ("Posterize", True), - ("Solarize", True), - ("AutoContrast", False), - ("Equalize", False), - ("Invert", False), - ] - super().__init__() self.conv1 = nn.Conv2d(1, 6, 5) self.relu1 = nn.ReLU() @@ -95,24 +77,6 @@ class Learner(nn.Module): return y - def get_idx(self, x): - section = self.fun_num + self.p_bins + self.m_bins - y = self.forward(x) - full_policy = [] - for pol in range(5 * 2): - int_pol = [] - idx_ret = torch.argmax(y[:, (pol * section):(pol*section) + self.fun_num].mean(dim = 0)) - - trans, need_mag = self.augmentation_space[idx_ret] - - p_ret = 0.1 * torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0)) - mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0)) if need_mag else 0 - int_pol.append((trans, p_ret, mag)) - if pol % 2 != 0: - full_policy.append(tuple(int_pol)) - - return full_policy - class LeNet(nn.Module): def __init__(self): @@ -204,7 +168,7 @@ train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=600 class Evolutionary_learner(): - def __init__(self, network, num_solutions = 30, num_generations = 10, num_parents_mating = 15, train_loader = None, sec_model = None, p_bins = 11, mag_bins = 10, fun_num = 14): + def __init__(self, network, num_solutions = 30, num_generations = 10, num_parents_mating = 15, train_loader = None, sec_model = None, p_bins = 11, mag_bins = 10, fun_num = 14, augmentation_space = None): self.meta_rl_agent = Learner(fun_num, p_bins=11, m_bins=10) self.torch_ga = torchga.TorchGA(model=network, num_solutions=num_solutions) self.num_generations = num_generations @@ -215,6 +179,7 @@ class Evolutionary_learner(): self.p_bins = p_bins self.mag_bins = mag_bins self.fun_num = fun_num + self.augmentation_space = augmentation_space assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!' @@ -222,6 +187,9 @@ class Evolutionary_learner(): def generate_policy(self, sp_num, ps, mags): + """ + + """ policies = [] for subpol in range(sp_num): sub = [] @@ -235,7 +203,33 @@ class Evolutionary_learner(): return policies + def get_full_policy(self, x): + """ + Generates the full policy (5 x 2 subpolicies) + """ + section = self.meta_rl_agent.fun_num + self.meta_rl_agent.p_bins + self.meta_rl_agent.m_bins + y = self.meta_rl_agent.forward(x) + full_policy = [] + for pol in range(5): + int_pol = [] + for _ in range(2): + idx_ret = torch.argmax(y[:, (pol * section):(pol*section) + self.fun_num].mean(dim = 0)) + + trans, need_mag = self.augmentation_space[idx_ret] + + p_ret = 0.1 * torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0)) + mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0)) if need_mag else 0 + int_pol.append((trans, p_ret, mag)) + + full_policy.append(tuple(int_pol)) + + return full_policy + + def run_instance(self, return_weights = False): + """ + Runs the GA instance and returns the model weights as a dictionary + """ self.ga_instance.run() solution, solution_fitness, solution_idx = self.ga_instance.best_solution() if return_weights: @@ -245,6 +239,9 @@ class Evolutionary_learner(): def new_model(self): + """ + Simple function to create a copy of the secondary model (used for classification) + """ copy_model = copy.deepcopy(self.sec_model) return copy_model @@ -259,7 +256,7 @@ class Evolutionary_learner(): weights_vector=solution) self.meta_rl_agent.load_state_dict(model_weights_dict) for idx, (test_x, label_x) in enumerate(train_loader): - full_policy = self.meta_rl_agent.get_idx(test_x) + full_policy = self.meta_rl_agent.get_full_policy(test_x) cop_mod = self.new_model() fit_val = train_model(full_policy, cop_mod) cop_mod = 0