From 5f05e0d721a488977a940b16869c289b53cc2487 Mon Sep 17 00:00:00 2001 From: Max Ramsay King <maxramsayking@gmail.com> Date: Mon, 4 Apr 2022 10:31:35 -0700 Subject: [PATCH] Updated the ES learner to be more in line with the random search funcitons --- MetaAugment/CP2_Max.py | 146 +++++++++++++++++++++++++---------------- 1 file changed, 89 insertions(+), 57 deletions(-) diff --git a/MetaAugment/CP2_Max.py b/MetaAugment/CP2_Max.py index 61cccbec..c85e9fa9 100644 --- a/MetaAugment/CP2_Max.py +++ b/MetaAugment/CP2_Max.py @@ -14,36 +14,56 @@ import pygad.torchga as torchga import random import copy -from MetaAugment.main import * - -# import MetaAugment.child_networks as child_networks # from MetaAugment.main import * +# import MetaAugment.child_networks as child_networks np.random.seed(0) random.seed(0) -# augmentation_space = [ -# # (function_name, do_we_need_to_specify_magnitude) -# ("ShearX", True), -# ("ShearY", True), -# ("TranslateX", True), -# ("TranslateY", True), -# ("Rotate", True), -# ("Brightness", True), -# ("Color", True), -# ("Contrast", True), -# ("Sharpness", True), -# ("Posterize", True), -# ("Solarize", True), -# ("AutoContrast", False), -# ("Equalize", False), -# ("Invert", False), -# ] +augmentation_space = [ + # (function_name, do_we_need_to_specify_magnitude) + ("ShearX", True), + ("ShearY", True), + ("TranslateX", True), + ("TranslateY", True), + ("Rotate", True), + ("Brightness", True), + ("Color", True), + ("Contrast", True), + ("Sharpness", True), + ("Posterize", True), + ("Solarize", True), + ("AutoContrast", False), + ("Equalize", False), + ("Invert", False), + ] class Learner(nn.Module): - def __init__(self, num_transforms = 3): + def __init__(self, fun_num=14, p_bins=11, m_bins=10): + self.fun_num = fun_num + self.p_bins = p_bins + self.m_bins = m_bins + + self.augmentation_space = [ + # (function_name, do_we_need_to_specify_magnitude) + ("ShearX", True), + ("ShearY", True), + ("TranslateX", True), + ("TranslateY", True), + ("Rotate", True), + ("Brightness", True), + ("Color", True), + ("Contrast", True), + ("Sharpness", True), + ("Posterize", True), + ("Solarize", True), + ("AutoContrast", False), + ("Equalize", False), + ("Invert", False), + ] + super().__init__() self.conv1 = nn.Conv2d(1, 6, 5) self.relu1 = nn.ReLU() @@ -55,11 +75,9 @@ class Learner(nn.Module): self.relu3 = nn.ReLU() self.fc2 = nn.Linear(120, 84) self.relu4 = nn.ReLU() - self.fc3 = nn.Linear(84, num_transforms + 21) - # self.sig = nn.Sigmoid() -# Currently using discrete outputs for the probabilities - + self.fc3 = nn.Linear(84, 5 * 2 * (self.fun_num + self.p_bins + self.m_bins)) +# Currently using discrete outputs for the probabilities def forward(self, x): y = self.conv1(x) @@ -78,10 +96,22 @@ class Learner(nn.Module): return y def get_idx(self, x): + section = self.fun_num + self.p_bins + self.m_bins y = self.forward(x) - idx_ret = torch.argmax(y[:, 0:3].mean(dim = 0)) - p_ret = 0.1 * torch.argmax(y[:, 3:].mean(dim = 0)) - return (idx_ret, p_ret) + full_policy = [] + for pol in range(5 * 2): + int_pol = [] + idx_ret = torch.argmax(y[:, (pol * section):(pol*section) + self.fun_num].mean(dim = 0)) + + trans, need_mag = self.augmentation_space[idx_ret] + + p_ret = 0.1 * torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0)) + mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0)) if need_mag else 0 + int_pol.append((trans, p_ret, mag)) + if pol % 2 != 0: + full_policy.append(tuple(int_pol)) + + return full_policy class LeNet(nn.Module): @@ -118,44 +148,27 @@ class LeNet(nn.Module): # code from https://github.com/ChawDoe/LeNet5-MNIST-PyTorch/blob/master/train.py -def train_model(transform_idx, p, child_network): +def train_model(full_policy, child_network): """ Takes in the specific transformation index and probability """ - if transform_idx == 0: - transform_train = torchvision.transforms.Compose( - [ - torchvision.transforms.RandomVerticalFlip(p), - torchvision.transforms.ToTensor(), - ] - ) - elif transform_idx == 1: - transform_train = torchvision.transforms.Compose( - [ - torchvision.transforms.RandomHorizontalFlip(p), - torchvision.transforms.ToTensor(), - ] - ) - else: - transform_train = torchvision.transforms.Compose( - [ - torchvision.transforms.RandomGrayscale(p), - torchvision.transforms.ToTensor(), - ] - ) + # transformation = generate_policy(5, ps, mags) + + train_transform = transforms.Compose([ + full_policy, + transforms.ToTensor() + ]) batch_size = 32 n_samples = 0.005 - train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, transform=transform_train) + train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, transform=train_transform) test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False, transform=torchvision.transforms.ToTensor()) train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, 0.01) - # child_network = child_networks.lenet() - sgd = optim.SGD(child_network.parameters(), lr=1e-1) cost = nn.CrossEntropyLoss() epoch = 20 @@ -191,20 +204,37 @@ train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=600 class Evolutionary_learner(): - def __init__(self, network, num_solutions = 30, num_generations = 10, num_parents_mating = 15, train_loader = None, sec_model = None): - self.meta_rl_agent = network + def __init__(self, network, num_solutions = 30, num_generations = 10, num_parents_mating = 15, train_loader = None, sec_model = None, p_bins = 11, mag_bins = 10, fun_num = 14): + self.meta_rl_agent = Learner(fun_num, p_bins=11, m_bins=10) self.torch_ga = torchga.TorchGA(model=network, num_solutions=num_solutions) self.num_generations = num_generations self.num_parents_mating = num_parents_mating self.initial_population = self.torch_ga.population_weights self.train_loader = train_loader self.sec_model = sec_model + self.p_bins = p_bins + self.mag_bins = mag_bins + self.fun_num = fun_num assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!' self.set_up_instance() + def generate_policy(self, sp_num, ps, mags): + policies = [] + for subpol in range(sp_num): + sub = [] + for idx in range(2): + transformation = augmentation_space[(2*subpol) + idx] + p = ps[(2*subpol) + idx] + mag = mags[(2*subpol) + idx] + sub.append((transformation, p, mag)) + policies.append(tuple(sub)) + + return policies + + def run_instance(self, return_weights = False): self.ga_instance.run() solution, solution_fitness, solution_idx = self.ga_instance.best_solution() @@ -213,12 +243,14 @@ class Evolutionary_learner(): else: return solution, solution_fitness, solution_idx + def new_model(self): copy_model = copy.deepcopy(self.sec_model) return copy_model def set_up_instance(self): + def fitness_func(solution, sol_idx): """ Defines fitness function (accuracy of the model) @@ -227,9 +259,9 @@ class Evolutionary_learner(): weights_vector=solution) self.meta_rl_agent.load_state_dict(model_weights_dict) for idx, (test_x, label_x) in enumerate(train_loader): - trans_idx, p = self.meta_rl_agent.get_idx(test_x) + full_policy = self.meta_rl_agent.get_idx(test_x) cop_mod = self.new_model() - fit_val = train_model(trans_idx, p, cop_mod) + fit_val = train_model(full_policy, cop_mod) cop_mod = 0 return fit_val -- GitLab