diff --git a/MetaAugment/CP2_Max.py b/MetaAugment/CP2_Max.py index 005676f234792310fa0a11a6a170cd1822fb6e2a..e928b7de2bd152aaa05b56c8fbb8321a8681edaa 100644 --- a/MetaAugment/CP2_Max.py +++ b/MetaAugment/CP2_Max.py @@ -12,9 +12,10 @@ import random import pygad import pygad.torchga as torchga import random +import copy -import MetaAugment.child_networks as child_networks -from MetaAugment.main import * +# import MetaAugment.child_networks as child_networks +# from MetaAugment.main import * np.random.seed(0) @@ -22,7 +23,7 @@ random.seed(0) class Learner(nn.Module): - def __init__(self): + def __init__(self, num_transforms = 3): super().__init__() self.conv1 = nn.Conv2d(1, 6, 5) self.relu1 = nn.ReLU() @@ -51,16 +52,51 @@ class Learner(nn.Module): y = self.relu4(y) y = self.fc3(y) - # y = self.sig(y) - # print("y[3:, :] shape: ", y[:, 3:].shape) + return y + + def get_idx(self, x): + y = self.forward(x) idx_ret = torch.argmax(y[:, 0:3].mean(dim = 0)) p_ret = 0.1 * torch.argmax(y[:, 3:].mean(dim = 0)) return (idx_ret, p_ret) + # return (torch.argmax(y[0:3]), y[torch.argmax(y[3:])]) +class LeNet(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 6, 5) + self.relu1 = nn.ReLU() + self.pool1 = nn.MaxPool2d(2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.relu2 = nn.ReLU() + self.pool2 = nn.MaxPool2d(2) + self.fc1 = nn.Linear(256, 120) + self.relu3 = nn.ReLU() + self.fc2 = nn.Linear(120, 84) + self.relu4 = nn.ReLU() + self.fc3 = nn.Linear(84, 10) + self.relu5 = nn.ReLU() + + def forward(self, x): + y = self.conv1(x) + y = self.relu1(y) + y = self.pool1(y) + y = self.conv2(y) + y = self.relu2(y) + y = self.pool2(y) + y = y.view(y.shape[0], -1) + y = self.fc1(y) + y = self.relu3(y) + y = self.fc2(y) + y = self.relu4(y) + y = self.fc3(y) + return y + + # code from https://github.com/ChawDoe/LeNet5-MNIST-PyTorch/blob/master/train.py -def train_model(transform_idx, p): +def train_model(transform_idx, p, child_network): """ Takes in the specific transformation index and probability """ @@ -93,16 +129,11 @@ def train_model(transform_idx, p): train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, transform=transform_train) test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False, transform=torchvision.transforms.ToTensor()) - # create toy dataset from above uploaded data train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, 0.01) - # train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size) - # test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size) - # print("Size of training dataset:\t", len(reduced_train_dataset)) - # print("Size of testing dataset:\t", len(reduced_test_dataset), "\n") + # child_network = child_networks.lenet() - child_network = child_networks.lenet() sgd = optim.SGD(child_network.parameters(), lr=1e-1) cost = nn.CrossEntropyLoss() epoch = 20 @@ -114,29 +145,84 @@ def train_model(transform_idx, p): -def fitness_func(solution, sol_idx): - """ - Defines fitness function (accuracy of the model) - """ - global train_loader, meta_rl_agent - model_weights_dict = torchga.model_weights_as_dict(model=meta_rl_agent, - weights_vector=solution) - # Use the current solution as the model parameters. - meta_rl_agent.load_state_dict(model_weights_dict) - for idx, (test_x, label_x) in enumerate(train_loader): - trans_idx, p = meta_rl_agent(test_x) - fit_val = train_model(trans_idx, p) - return fit_val +def train_child_network(child_network, train_loader, test_loader, sgd, + cost, max_epochs=100, early_stop_num = 10, logging=False): + best_acc=0 + early_stop_cnt = 0 + + # logging accuracy for plotting + acc_log = [] + + # train child_network and check validation accuracy each epoch + for _epoch in range(max_epochs): + + # train child_network + child_network.train() + for idx, (train_x, train_label) in enumerate(train_loader): + label_np = np.zeros((train_label.shape[0], 10)) + sgd.zero_grad() + predict_y = child_network(train_x.float()) + loss = cost(predict_y, train_label.long()) + loss.backward() + sgd.step() + + # check validation accuracy on validation set + correct = 0 + _sum = 0 + child_network.eval() + with torch.no_grad(): + for idx, (test_x, test_label) in enumerate(test_loader): + predict_y = child_network(test_x.float()).detach() + predict_ys = np.argmax(predict_y, axis=-1) + label_np = test_label.numpy() + _ = predict_ys == test_label + correct += np.sum(_.numpy(), axis=-1) + _sum += _.shape[0] + + # update best validation accuracy if it was higher, otherwise increase early stop count + acc = correct / _sum + + + if acc > best_acc : + best_acc = acc + early_stop_cnt = 0 + else: + early_stop_cnt += 1 + + # exit if validation gets worse over 10 runs + if early_stop_cnt >= early_stop_num: + break + + # print('main.train_child_network best accuracy: ', best_acc) + acc_log.append(acc) + + if logging: + return best_acc, acc_log + return best_acc -def callback_generation(ga_instance): - """ - Just prints stuff while running - """ - print("Generation = {generation}".format(generation=ga_instance.generations_completed)) - print("Fitness = {fitness}".format(fitness=ga_instance.best_solution()[1])) +def create_toy(train_dataset, test_dataset, batch_size, n_samples, seed=100): + # shuffle and take first n_samples %age of training dataset + shuffle_order_train = np.random.RandomState(seed=seed).permutation(len(train_dataset)) + shuffled_train_dataset = torch.utils.data.Subset(train_dataset, shuffle_order_train) + + indices_train = torch.arange(int(n_samples*len(train_dataset))) + reduced_train_dataset = torch.utils.data.Subset(shuffled_train_dataset, indices_train) + + # shuffle and take first n_samples %age of test dataset + shuffle_order_test = np.random.RandomState(seed=seed).permutation(len(test_dataset)) + shuffled_test_dataset = torch.utils.data.Subset(test_dataset, shuffle_order_test) + + indices_test = torch.arange(int(n_samples*len(test_dataset))) + reduced_test_dataset = torch.utils.data.Subset(shuffled_test_dataset, indices_test) + + # push into DataLoader + train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size) + test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size) + + return train_loader, test_loader # ORGANISING DATA @@ -158,32 +244,74 @@ train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=600 -# GENERATING THE GA INSTANCE +class Evolutionary_learner(): -meta_rl_agent = Learner() -torch_ga = torchga.TorchGA(model=meta_rl_agent, - num_solutions=20) + def __init__(self, network, num_solutions = 30, num_generations = 10, num_parents_mating = 15, train_loader = None, sec_model = None): + self.meta_rl_agent = network + self.torch_ga = torchga.TorchGA(model=network, num_solutions=num_solutions) + self.num_generations = num_generations + self.num_parents_mating = num_parents_mating + self.initial_population = self.torch_ga.population_weights + self.train_loader = train_loader + self.backup_model = sec_model -# HYPERPARAMETER FOR THE GA + assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!' -num_generations = 100 # Number of generations. -num_parents_mating = 20 # Number of solutions to be selected as parents in the mating pool. -initial_population = torch_ga.population_weights + self.set_up_instance() + -ga_instance = pygad.GA(num_generations=num_generations, - num_parents_mating=num_parents_mating, - initial_population=initial_population, - fitness_func=fitness_func, - on_generation=callback_generation) -ga_instance.run() + def run_instance(self, return_weights = False): + self.ga_instance.run() + solution, solution_fitness, solution_idx = self.ga_instance.best_solution() + if return_weights: + return torchga.model_weights_as_dict(model=self.meta_rl_agent, weights_vector=solution) + else: + return solution, solution_fitness, solution_idx -solution, solution_fitness, solution_idx = ga_instance.best_solution() -print("Fitness value of the best solution = {solution_fitness}".format(solution_fitness=solution_fitness)) -print("Index of the best solution : {solution_idx}".format(solution_idx=solution_idx)) -# Fetch the parameters of the best solution. -best_solution_weights = torchga.model_weights_as_dict(model=meta_rl_agent, - weights_vector=solution) + def new_model(self): + copy_model = copy.deepcopy(self.backup_model) + return copy_model + + + def set_up_instance(self): + def fitness_func(solution, sol_idx): + """ + Defines fitness function (accuracy of the model) + """ + model_weights_dict = torchga.model_weights_as_dict(model=self.meta_rl_agent, + weights_vector=solution) + self.meta_rl_agent.load_state_dict(model_weights_dict) + for idx, (test_x, label_x) in enumerate(train_loader): + trans_idx, p = self.meta_rl_agent.get_idx(test_x) + cop_mod = self.new_model() + fit_val = train_model(trans_idx, p, cop_mod) + cop_mod = 0 + return fit_val + + def on_generation(ga_instance): + """ + Just prints stuff while running + """ + print("Generation = {generation}".format(generation=self.ga_instance.generations_completed)) + print("Fitness = {fitness}".format(fitness=self.ga_instance.best_solution()[1])) + return + self.ga_instance = pygad.GA(num_generations=self.num_generations, + num_parents_mating=self.num_parents_mating, + initial_population=self.initial_population, + fitness_func=fitness_func, + on_generation = on_generation) +meta_rl_agent = Learner() +ev_learner = Evolutionary_learner(meta_rl_agent, train_loader=train_loader, sec_model=LeNet()) +ev_learner.run_instance() + + +solution, solution_fitness, solution_idx = ev_learner.ga_instance.best_solution() +print("Fitness value of the best solution = {solution_fitness}".format(solution_fitness=solution_fitness)) +print("Index of the best solution : {solution_idx}".format(solution_idx=solution_idx)) +# Fetch the parameters of the best solution. +best_solution_weights = torchga.model_weights_as_dict(model=ev_learner.meta_rl_agent, + weights_vector=solution) \ No newline at end of file diff --git a/MetaAugment/METALEANER.py b/MetaAugment/METALEANER.py new file mode 100644 index 0000000000000000000000000000000000000000..c94246d6898ccf2d316c1dae7644513bf113149e --- /dev/null +++ b/MetaAugment/METALEANER.py @@ -0,0 +1,7 @@ + + + +# Neural network +# Input the dataset (same batch size, have to check if the input sizes are correc i.e. 28x28) +# Output the hyperprameters --> weights of network, kernel size, number of layers, number of kernels +# \ No newline at end of file