diff --git a/.DS_Store b/.DS_Store index 87b56ad1c0caa0cd8b0aa4497cbd4d095b75bc27..720cf3ab50cbd4bb4f33acbbc3cb3516e7778732 100644 Binary files a/.DS_Store and b/.DS_Store differ diff --git a/MetaAugment/Baseline_JC.ipynb b/MetaAugment/Baseline_JC.ipynb index d0ab8ea0710b9cf9a0cdb2629e80a8036b014d47..d979dc8a67b4c0232a21967b43e340f90b08a844 100644 --- a/MetaAugment/Baseline_JC.ipynb +++ b/MetaAugment/Baseline_JC.ipynb @@ -171,6 +171,183 @@ }, { "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "KVhYheLfBP33", + "outputId": "8009d87f-7e39-40e3-c6ef-8f3a12f9433f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", + "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MetaAugment/train/MNIST/raw/train-images-idx3-ubyte.gz\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "9913344it [00:04, 2462502.04it/s] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Extracting ./MetaAugment/train/MNIST/raw/train-images-idx3-ubyte.gz to ./MetaAugment/train/MNIST/raw\n", + "\n", + "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n", + "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MetaAugment/train/MNIST/raw/train-labels-idx1-ubyte.gz\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "29696it [00:00, 3785722.37it/s] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Extracting ./MetaAugment/train/MNIST/raw/train-labels-idx1-ubyte.gz to ./MetaAugment/train/MNIST/raw\n", + "\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MetaAugment/train/MNIST/raw/t10k-images-idx3-ubyte.gz\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "1649664it [00:00, 3348476.95it/s] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Extracting ./MetaAugment/train/MNIST/raw/t10k-images-idx3-ubyte.gz to ./MetaAugment/train/MNIST/raw\n", + "\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MetaAugment/train/MNIST/raw/t10k-labels-idx1-ubyte.gz\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "5120it [00:00, 2935726.11it/s] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Extracting ./MetaAugment/train/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MetaAugment/train/MNIST/raw\n", + "\n", + "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", + "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MetaAugment/test/MNIST/raw/train-images-idx3-ubyte.gz\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "9913344it [00:04, 2338660.11it/s] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Extracting ./MetaAugment/test/MNIST/raw/train-images-idx3-ubyte.gz to ./MetaAugment/test/MNIST/raw\n", + "\n", + "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n", + "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MetaAugment/test/MNIST/raw/train-labels-idx1-ubyte.gz\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "29696it [00:00, 33554432.00it/s] " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Extracting ./MetaAugment/test/MNIST/raw/train-labels-idx1-ubyte.gz to ./MetaAugment/test/MNIST/raw\n", + "\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MetaAugment/test/MNIST/raw/t10k-images-idx3-ubyte.gz\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "1649664it [00:00, 2786152.46it/s] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Extracting ./MetaAugment/test/MNIST/raw/t10k-images-idx3-ubyte.gz to ./MetaAugment/test/MNIST/raw\n", + "\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MetaAugment/test/MNIST/raw/t10k-labels-idx1-ubyte.gz\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "5120it [00:00, 4789214.20it/s] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Extracting ./MetaAugment/test/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MetaAugment/test/MNIST/raw\n", + "\n", + "0\tBest accuracy: 18.00%\n", + "10\tBest accuracy: 75.50%\n", + "20\tBest accuracy: 78.00%\n", + "30\tBest accuracy: 95.00%\n", + "40\tBest accuracy: 95.50%\n", + "50\tBest accuracy: 94.00%\n", + "60\tBest accuracy: 85.00%\n", + "70\tBest accuracy: 85.50%\n", + "80\tBest accuracy: 62.50%\n", + "90\tBest accuracy: 76.00%\n", + "Average best accuracy: 79.86%\n", + "\n", + "0\tAverage accuracy: 93.50%\n", + "10\tAverage accuracy: 93.45%\n", + "20\tAverage accuracy: 46.95%\n", + "30\tAverage accuracy: 71.41%\n", + "40\tAverage accuracy: 73.68%\n", + "50\tAverage accuracy: 64.50%\n", + "60\tAverage accuracy: 72.50%\n", + "70\tAverage accuracy: 94.36%\n", + "80\tAverage accuracy: 84.77%\n", + "90\tAverage accuracy: 92.14%\n", + "Average average accuracy: 80.92%\n", + "\n" + ] + } + ], "source": [ "batch_size = 32 # size of batch the inner NN is trained with\n", "toy_size = 0.02 # total propeortion of training and test set we use\n", @@ -198,47 +375,14 @@ " if baselines % 10 == 0:\n", " print(\"{}\\tAverage accuracy: {:.2f}%\".format(baselines, best_acc*100))\n", "print(\"Average average accuracy: {:.2f}%\\n\".format(np.mean(best_accuracies)*100))" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "KVhYheLfBP33", - "outputId": "8009d87f-7e39-40e3-c6ef-8f3a12f9433f" - }, - "execution_count": 5, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "0\tBest accuracy: 49.00%\n", - "10\tBest accuracy: 86.50%\n", - "20\tBest accuracy: 95.00%\n", - "30\tBest accuracy: 54.00%\n", - "40\tBest accuracy: 94.00%\n", - "50\tBest accuracy: 93.50%\n", - "60\tBest accuracy: 66.50%\n", - "70\tBest accuracy: 94.50%\n", - "80\tBest accuracy: 74.50%\n", - "90\tBest accuracy: 74.00%\n", - "Average best accuracy: 79.58%\n", - "\n", - "0\tAverage accuracy: 68.95%\n", - "10\tAverage accuracy: 69.95%\n", - "20\tAverage accuracy: 85.00%\n", - "30\tAverage accuracy: 93.32%\n", - "40\tAverage accuracy: 68.00%\n", - "50\tAverage accuracy: 85.36%\n", - "60\tAverage accuracy: 92.36%\n", - "70\tAverage accuracy: 56.95%\n", - "80\tAverage accuracy: 93.59%\n", - "90\tAverage accuracy: 64.91%\n", - "Average average accuracy: 78.90%\n", - "\n" - ] - } ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -262,9 +406,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 0 -} \ No newline at end of file +} diff --git a/MetaAugment/CP2_Max.py b/MetaAugment/CP2_Max.py index c1e91a97eed2634c29e78325556bf20e52e05ca9..792e81e1f85932408755840fbcbc09612137d39e 100644 --- a/MetaAugment/CP2_Max.py +++ b/MetaAugment/CP2_Max.py @@ -1,3 +1,4 @@ +from cgi import test import numpy as np import torch torch.manual_seed(0) @@ -16,15 +17,24 @@ import copy from torchvision.transforms import functional as F, InterpolationMode from typing import List, Tuple, Optional, Dict import heapq +import math +import math +import torch + +from enum import Enum +from torch import Tensor +from typing import List, Tuple, Optional, Dict +from torchvision.transforms import functional as F, InterpolationMode -# from MetaAugment.main import * # import MetaAugment.child_networks as child_networks +# from main import * +# from autoaugment_learners.autoaugment import * -np.random.seed(0) -random.seed(0) +# np.random.seed(0) +# random.seed(0) augmentation_space = [ @@ -172,9 +182,10 @@ train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=600 + class Evolutionary_learner(): - def __init__(self, network, num_solutions = 30, num_generations = 10, num_parents_mating = 15, train_loader = None, child_network = None, p_bins = 11, mag_bins = 10, sub_num_pol=5, fun_num = 14, augmentation_space = None, train_dataset = None, test_dataset = None): + def __init__(self, network, num_solutions = 10, num_generations = 5, num_parents_mating = 5, train_loader = None, child_network = None, p_bins = 11, mag_bins = 10, sub_num_pol=5, fun_num = 14, augmentation_space = None, train_dataset = None, test_dataset = None): self.auto_aug_agent = Learner(fun_num=fun_num, p_bins=p_bins, m_bins=mag_bins, sub_num_pol=sub_num_pol) self.torch_ga = torchga.TorchGA(model=network, num_solutions=num_solutions) self.num_generations = num_generations @@ -211,30 +222,6 @@ class Evolutionary_learner(): return policies -# Every image has specific operation. Policy for every image (2 (trans., prob., mag) output) - - -# RNN -> change the end -/- leave for now, ask Javier - - -# Use mini-batch with current output, get mode transformation -> mean probability and magnitude -# Pass through each image in mini-batch to get one/two (transformation, prob., mag.) tuples -# Average softmax probability (get softmax of the outputs, then average them to get the probability) - - -# For every batch, store all outputs. Pick top operations -# Every image -> output 2 operation tuples e.g. 14 trans + 1 prob + 1 mag. 32 output total. -# 14 neuron output is then prob. of transformations (softmax + average across dim = 0) -# 1000x28 -# Problem 1: have 28, if we pick argmax top 2 - - # For each image have 28 dim output. Calculate covariance of 1000x28 using np.cov(28_dim_vector.T) - # Give 28x28 covariance matrix. Pick top k pairs (corresponds to largest covariance pairs) - # Once we have pairs, go back to 1000x32 output. Find cases where the largest cov. pairs are used and use those probs and mags - - -# Covariance matrix -> prob. of occurance (might be bad pairs) -# Pair criteria -> highest softmax prob and probaility of occurence def get_full_policy(self, x): """ @@ -257,9 +244,9 @@ class Evolutionary_learner(): full_policy.append(tuple(int_pol)) return full_policy - +# - def get_policy_cov(self, x): + def get_policy_cov(self, x, alpha = 0.5): """ Need p_bins = 1, num_sub_pol = 1, mag_bins = 1 """ @@ -268,48 +255,55 @@ class Evolutionary_learner(): y = self.auto_aug_agent.forward(x) # 1000 x 32 y_1 = torch.softmax(y[:,:self.auto_aug_agent.fun_num], dim = 1) # 1000 x 14 + y[:,:self.auto_aug_agent.fun_num] = y_1 y_2 = torch.softmax(y[:,section:section+self.auto_aug_agent.fun_num], dim = 1) + y[:,section:section+self.auto_aug_agent.fun_num] = y_2 concat = torch.cat((y_1, y_2), dim = 1) cov_mat = torch.cov(concat.T)#[:self.auto_aug_agent.fun_num, self.auto_aug_agent.fun_num:] cov_mat = cov_mat[:self.auto_aug_agent.fun_num, self.auto_aug_agent.fun_num:] shape_store = cov_mat.shape + counter, prob1, prob2, mag1, mag2 = (0, 0, 0, 0, 0) + + + prob_mat = torch.zeros(shape_store) + for idx in range(y.shape[0]): + prob_mat[torch.argmax(y_1[idx])][torch.argmax(y_2[idx])] += 1 + prob_mat = prob_mat / torch.sum(prob_mat) + + cov_mat = (alpha * cov_mat) + ((1 - alpha)*prob_mat) + cov_mat = torch.reshape(cov_mat, (1, -1)).squeeze() max_idx = torch.argmax(cov_mat) val = (max_idx//shape_store[0]) max_idx = (val, max_idx - (val * shape_store[0])) - counter, prob1, prob2, mag1, mag2 = (0, 0, 0, 0, 0) - if self.augmentation_space[max_idx[0]]: + if not self.augmentation_space[max_idx[0]][1]: mag1 = None - if self.augmentation_space[max_idx[1]]: + if not self.augmentation_space[max_idx[1]][1]: mag2 = None - + for idx in range(y.shape[0]): - # print("torch.argmax(y_1[idx]): ", torch.argmax(y_1[idx])) - # print("torch.argmax(y_2[idx]): ", torch.argmax(y_2[idx])) - # print("max idx0: ", max_idx[0]) - # print("max idx1: ", max_idx[1]) - if (torch.argmax(y_1[idx]) == max_idx[0]) and (torch.argmax(y_2[idx]) == max_idx[1]): - prob1 += y[idx, self.auto_aug_agent.fun_num+1] - prob2 += y[idx, section+self.auto_aug_agent.fun_num+1] + prob1 += torch.sigmoid(y[idx, self.auto_aug_agent.fun_num]).item() + prob2 += torch.sigmoid(y[idx, section+self.auto_aug_agent.fun_num]).item() if mag1 is not None: - mag1 += y[idx, self.auto_aug_agent.fun_num+2] + mag1 += min(max(0, (y[idx, self.auto_aug_agent.fun_num+1]).item()), 8) if mag2 is not None: - mag2 += y[idx, section+self.auto_aug_agent.fun_num+2] + mag2 += min(max(0, y[idx, section+self.auto_aug_agent.fun_num+1].item()), 8) counter += 1 - + prob1 = prob1/counter if counter != 0 else 0 prob2 = prob2/counter if counter != 0 else 0 if mag1 is not None: mag1 = mag1/counter if mag2 is not None: - mag2 = mag2/counter + mag2 = mag2/counter + - return [(self.augmentation_space[max_idx[0]], prob1, mag1), (self.augmentation_space[max_idx[1]], prob2, mag2)] + return [(self.augmentation_space[max_idx[0]][0], prob1, mag1), (self.augmentation_space[max_idx[1]][0], prob2, mag2)] @@ -342,6 +336,7 @@ class Evolutionary_learner(): """ Defines fitness function (accuracy of the model) """ + print("FITNESS HERE") model_weights_dict = torchga.model_weights_as_dict(model=self.auto_aug_agent, weights_vector=solution) @@ -349,14 +344,13 @@ class Evolutionary_learner(): self.auto_aug_agent.load_state_dict(model_weights_dict) for idx, (test_x, label_x) in enumerate(train_loader): - # full_policy = self.get_full_policy(test_x) full_policy = self.get_policy_cov(test_x) + print("FULL POLICY: ", full_policy) - print("full_policy: ", full_policy) - cop_mod = self.new_model() - fit_val = test_autoaugment_policy(full_policy, self.train_dataset, self.test_dataset)[0] - cop_mod = 0 + fit_val = (test_autoaugment_policy(full_policy, self.train_dataset, self.test_dataset)[0]) #+ test_autoaugment_policy(full_policy, self.train_dataset, self.test_dataset)[0]) / 2 + + print("DONE FITNESS") return fit_val @@ -372,6 +366,7 @@ class Evolutionary_learner(): self.ga_instance = pygad.GA(num_generations=self.num_generations, num_parents_mating=self.num_parents_mating, initial_population=self.initial_population, + mutation_percent_genes = 0.1, fitness_func=fitness_func, on_generation = on_generation) @@ -381,14 +376,566 @@ class Evolutionary_learner(): + + + + +# HEREHEREHERE0 + +def create_toy(train_dataset, test_dataset, batch_size, n_samples, seed=100): + # shuffle and take first n_samples %age of training dataset + shuffle_order_train = np.random.RandomState(seed=seed).permutation(len(train_dataset)) + shuffled_train_dataset = torch.utils.data.Subset(train_dataset, shuffle_order_train) + + indices_train = torch.arange(int(n_samples*len(train_dataset))) + reduced_train_dataset = torch.utils.data.Subset(shuffled_train_dataset, indices_train) + + # shuffle and take first n_samples %age of test dataset + shuffle_order_test = np.random.RandomState(seed=seed).permutation(len(test_dataset)) + shuffled_test_dataset = torch.utils.data.Subset(test_dataset, shuffle_order_test) + + big = 4 # how much bigger is the test set + + indices_test = torch.arange(int(n_samples*len(test_dataset)*big)) + reduced_test_dataset = torch.utils.data.Subset(shuffled_test_dataset, indices_test) + + # push into DataLoader + train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size) + test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size) + + return train_loader, test_loader + + +def train_child_network(child_network, train_loader, test_loader, sgd, + cost, max_epochs=2000, early_stop_num = 5, logging=False, + print_every_epoch=True): + if torch.cuda.is_available(): + device = torch.device('cuda') + else: + device = torch.device('cpu') + child_network = child_network.to(device=device) + + best_acc=0 + early_stop_cnt = 0 + + # logging accuracy for plotting + acc_log = [] + + # train child_network and check validation accuracy each epoch + for _epoch in range(max_epochs): + + # train child_network + child_network.train() + for idx, (train_x, train_label) in enumerate(train_loader): + # onto device + train_x = train_x.to(device=device, dtype=train_x.dtype) + train_label = train_label.to(device=device, dtype=train_label.dtype) + + # label_np = np.zeros((train_label.shape[0], 10)) + + sgd.zero_grad() + predict_y = child_network(train_x.float()) + loss = cost(predict_y, train_label.long()) + loss.backward() + sgd.step() + + # check validation accuracy on validation set + correct = 0 + _sum = 0 + child_network.eval() + with torch.no_grad(): + for idx, (test_x, test_label) in enumerate(test_loader): + # onto device + test_x = test_x.to(device=device, dtype=test_x.dtype) + test_label = test_label.to(device=device, dtype=test_label.dtype) + + predict_y = child_network(test_x.float()).detach() + predict_ys = torch.argmax(predict_y, axis=-1) + + # label_np = test_label.numpy() + + _ = predict_ys == test_label + correct += torch.sum(_, axis=-1) + # correct += torch.sum(_.numpy(), axis=-1) + _sum += _.shape[0] + + # update best validation accuracy if it was higher, otherwise increase early stop count + acc = correct / _sum + + if acc > best_acc : + best_acc = acc + early_stop_cnt = 0 + else: + early_stop_cnt += 1 + + # exit if validation gets worse over 10 runs + if early_stop_cnt >= early_stop_num: + print('main.train_child_network best accuracy: ', best_acc) + break + + # if print_every_epoch: + # print('main.train_child_network best accuracy: ', best_acc) + acc_log.append(acc) + + if logging: + return best_acc.item(), acc_log + return best_acc.item() + +def test_autoaugment_policy(subpolicies, train_dataset, test_dataset): + + aa_transform = AutoAugment() + aa_transform.subpolicies = subpolicies + + train_transform = transforms.Compose([ + aa_transform, + transforms.ToTensor() + ]) + + train_dataset.transform = train_transform + + # create toy dataset from above uploaded data + train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size=32, n_samples=0.1) + + child_network = LeNet() + sgd = optim.SGD(child_network.parameters(), lr=1e-1) + cost = nn.CrossEntropyLoss() + + best_acc, acc_log = train_child_network(child_network, train_loader, test_loader, + sgd, cost, max_epochs=100, logging=True) + + return best_acc, acc_log + + + +__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"] + + +def _apply_op(img: Tensor, op_name: str, magnitude: float, + interpolation: InterpolationMode, fill: Optional[List[float]]): + if op_name == "ShearX": + img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(magnitude), 0.0], + interpolation=interpolation, fill=fill) + elif op_name == "ShearY": + img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)], + interpolation=interpolation, fill=fill) + elif op_name == "TranslateX": + img = F.affine(img, angle=0.0, translate=[int(magnitude), 0], scale=1.0, + interpolation=interpolation, shear=[0.0, 0.0], fill=fill) + elif op_name == "TranslateY": + img = F.affine(img, angle=0.0, translate=[0, int(magnitude)], scale=1.0, + interpolation=interpolation, shear=[0.0, 0.0], fill=fill) + elif op_name == "Rotate": + img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill) + elif op_name == "Brightness": + img = F.adjust_brightness(img, 1.0 + magnitude) + elif op_name == "Color": + img = F.adjust_saturation(img, 1.0 + magnitude) + elif op_name == "Contrast": + img = F.adjust_contrast(img, 1.0 + magnitude) + elif op_name == "Sharpness": + img = F.adjust_sharpness(img, 1.0 + magnitude) + elif op_name == "Posterize": + img = F.posterize(img, int(magnitude)) + elif op_name == "Solarize": + img = F.solarize(img, magnitude) + elif op_name == "AutoContrast": + img = F.autocontrast(img) + elif op_name == "Equalize": + img = F.equalize(img) + elif op_name == "Invert": + img = F.invert(img) + elif op_name == "Identity": + pass + else: + raise ValueError("The provided operator {} is not recognized.".format(op_name)) + return img + + +class AutoAugmentPolicy(Enum): + """AutoAugment policies learned on different datasets. + Available policies are IMAGENET, CIFAR10 and SVHN. + """ + IMAGENET = "imagenet" + CIFAR10 = "cifar10" + SVHN = "svhn" + + +# FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class +class AutoAugment(torch.nn.Module): + r"""AutoAugment data augmentation method based on + `"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_. + If the image is torch Tensor, it should be of type torch.uint8, and it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + policy (AutoAugmentPolicy): Desired policy enum defined by + :class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + fill (sequence or number, optional): Pixel fill value for the area outside the transformed + image. If given a number, the value is used for all bands respectively. + """ + + def __init__( + self, + policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None + ) -> None: + super().__init__() + self.policy = policy + self.interpolation = interpolation + self.fill = fill + self.subpolicies = self._get_subpolicies(policy) + + def _get_subpolicies( + self, + policy: AutoAugmentPolicy + ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]: + if policy == AutoAugmentPolicy.IMAGENET: + return [ + (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)), + (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), + (("Equalize", 0.8, None), ("Equalize", 0.6, None)), + (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)), + (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), + (("Equalize", 0.4, None), ("Rotate", 0.8, 8)), + (("Solarize", 0.6, 3), ("Equalize", 0.6, None)), + (("Posterize", 0.8, 5), ("Equalize", 1.0, None)), + (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)), + (("Equalize", 0.6, None), ("Posterize", 0.4, 6)), + (("Rotate", 0.8, 8), ("Color", 0.4, 0)), + (("Rotate", 0.4, 9), ("Equalize", 0.6, None)), + (("Equalize", 0.0, None), ("Equalize", 0.8, None)), + (("Invert", 0.6, None), ("Equalize", 1.0, None)), + (("Color", 0.6, 4), ("Contrast", 1.0, 8)), + (("Rotate", 0.8, 8), ("Color", 1.0, 2)), + (("Color", 0.8, 8), ("Solarize", 0.8, 7)), + (("Sharpness", 0.4, 7), ("Invert", 0.6, None)), + (("ShearX", 0.6, 5), ("Equalize", 1.0, None)), + (("Color", 0.4, 0), ("Equalize", 0.6, None)), + (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), + (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), + (("Invert", 0.6, None), ("Equalize", 1.0, None)), + (("Color", 0.6, 4), ("Contrast", 1.0, 8)), + (("Equalize", 0.8, None), ("Equalize", 0.6, None)), + ] + elif policy == AutoAugmentPolicy.CIFAR10: + return [ + (("Invert", 0.1, None), ("Contrast", 0.2, 6)), + (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)), + (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)), + (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)), + (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)), + (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)), + (("Color", 0.4, 3), ("Brightness", 0.6, 7)), + (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)), + (("Equalize", 0.6, None), ("Equalize", 0.5, None)), + (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)), + (("Color", 0.7, 7), ("TranslateX", 0.5, 8)), + (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)), + (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)), + (("Brightness", 0.9, 6), ("Color", 0.2, 8)), + (("Solarize", 0.5, 2), ("Invert", 0.0, None)), + (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)), + (("Equalize", 0.2, None), ("Equalize", 0.6, None)), + (("Color", 0.9, 9), ("Equalize", 0.6, None)), + (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)), + (("Brightness", 0.1, 3), ("Color", 0.7, 0)), + (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)), + (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)), + (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)), + (("Equalize", 0.8, None), ("Invert", 0.1, None)), + (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)), + ] + elif policy == AutoAugmentPolicy.SVHN: + return [ + (("ShearX", 0.9, 4), ("Invert", 0.2, None)), + (("ShearY", 0.9, 8), ("Invert", 0.7, None)), + (("Equalize", 0.6, None), ("Solarize", 0.6, 6)), + (("Invert", 0.9, None), ("Equalize", 0.6, None)), + (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), + (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)), + (("ShearY", 0.9, 8), ("Invert", 0.4, None)), + (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)), + (("Invert", 0.9, None), ("AutoContrast", 0.8, None)), + (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), + (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)), + (("ShearY", 0.8, 8), ("Invert", 0.7, None)), + (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)), + (("Invert", 0.9, None), ("Equalize", 0.6, None)), + (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)), + (("Invert", 0.8, None), ("TranslateY", 0.0, 2)), + (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)), + (("Invert", 0.6, None), ("Rotate", 0.8, 4)), + (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)), + (("ShearX", 0.1, 6), ("Invert", 0.6, None)), + (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)), + (("ShearY", 0.8, 4), ("Invert", 0.8, None)), + (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)), + (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)), + (("ShearX", 0.7, 2), ("Invert", 0.1, None)), + ] + else: + raise ValueError("The provided policy {} is not recognized.".format(policy)) + + def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: + return { + # op_name: (magnitudes, signed) + "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), + "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "Rotate": (torch.linspace(0.0, 30.0, num_bins), True), + "Brightness": (torch.linspace(0.0, 0.9, num_bins), True), + "Color": (torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True), + "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False), + "Solarize": (torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (torch.tensor(0.0), False), + "Equalize": (torch.tensor(0.0), False), + "Invert": (torch.tensor(0.0), False), + } + + @staticmethod + def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]: + """Get parameters for autoaugment transformation + + Returns: + params required by the autoaugment transformation + """ + policy_id = int(torch.randint(transform_num, (1,)).item()) + probs = torch.rand((2,)) + signs = torch.randint(2, (2,)) + + return policy_id, probs, signs + + def forward(self, img: Tensor, dis_mag = True) -> Tensor: + """ + img (PIL Image or Tensor): Image to be transformed. + + Returns: + PIL Image or Tensor: AutoAugmented image. + """ + fill = self.fill + if isinstance(img, Tensor): + if isinstance(fill, (int, float)): + fill = [float(fill)] * F.get_image_num_channels(img) + elif fill is not None: + fill = [float(f) for f in fill] + + transform_id, probs, signs = self.get_params(len(self.subpolicies)) + # print("transform_id, probs, signs : ", transform_id, probs, signs ) + + # for i, (op_name, p, magnitude_id) in enumerate(self.subpolicies[transform_id]): + # for i, (op_name, p, magnitude_id) in enumerate(self.subpolicies): + # print("op_name, p, magnitude_id: ", op_name, p, magnitude_id) + # if probs[i] <= p: + # op_meta = self._augmentation_space(10, F.get_image_size(img)) + # magnitudes, signed = op_meta[op_name] + # magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0 + # if signed and signs[i] == 0: + # magnitude *= -1.0 + # img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) + + for i, (op_name, p, magnitude) in enumerate(self.subpolicies): + img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) + + + return img + + def __repr__(self) -> str: + return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill) + + +class RandAugment(torch.nn.Module): + r"""RandAugment data augmentation method based on + `"RandAugment: Practical automated data augmentation with a reduced search space" + <https://arxiv.org/abs/1909.13719>`_. + If the image is torch Tensor, it should be of type torch.uint8, and it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + num_ops (int): Number of augmentation transformations to apply sequentially. + magnitude (int): Magnitude for all the transformations. + num_magnitude_bins (int): The number of different magnitude values. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + fill (sequence or number, optional): Pixel fill value for the area outside the transformed + image. If given a number, the value is used for all bands respectively. + """ + + def __init__(self, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 31, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None) -> None: + super().__init__() + self.num_ops = num_ops + self.magnitude = magnitude + self.num_magnitude_bins = num_magnitude_bins + self.interpolation = interpolation + self.fill = fill + + def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: + return { + # op_name: (magnitudes, signed) + "Identity": (torch.tensor(0.0), False), + "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), + "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "Rotate": (torch.linspace(0.0, 30.0, num_bins), True), + "Brightness": (torch.linspace(0.0, 0.9, num_bins), True), + "Color": (torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True), + "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False), + "Solarize": (torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (torch.tensor(0.0), False), + "Equalize": (torch.tensor(0.0), False), + } + + def forward(self, img: Tensor) -> Tensor: + """ + img (PIL Image or Tensor): Image to be transformed. + + Returns: + PIL Image or Tensor: Transformed image. + """ + fill = self.fill + if isinstance(img, Tensor): + if isinstance(fill, (int, float)): + fill = [float(fill)] * F.get_image_num_channels(img) + elif fill is not None: + fill = [float(f) for f in fill] + + for _ in range(self.num_ops): + op_meta = self._augmentation_space(self.num_magnitude_bins, F.get_image_size(img)) + op_index = int(torch.randint(len(op_meta), (1,)).item()) + op_name = list(op_meta.keys())[op_index] + magnitudes, signed = op_meta[op_name] + magnitude = float(magnitudes[self.magnitude].item()) if magnitudes.ndim > 0 else 0.0 + if signed and torch.randint(2, (1,)): + magnitude *= -1.0 + img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) + + return img + + def __repr__(self) -> str: + s = self.__class__.__name__ + '(' + s += 'num_ops={num_ops}' + s += ', magnitude={magnitude}' + s += ', num_magnitude_bins={num_magnitude_bins}' + s += ', interpolation={interpolation}' + s += ', fill={fill}' + s += ')' + return s.format(**self.__dict__) + + +class TrivialAugmentWide(torch.nn.Module): + r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in + `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`. + If the image is torch Tensor, it should be of type torch.uint8, and it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + num_magnitude_bins (int): The number of different magnitude values. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + fill (sequence or number, optional): Pixel fill value for the area outside the transformed + image. If given a number, the value is used for all bands respectively. + """ + + def __init__(self, num_magnitude_bins: int = 31, interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None) -> None: + super().__init__() + self.num_magnitude_bins = num_magnitude_bins + self.interpolation = interpolation + self.fill = fill + + def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]: + return { + # op_name: (magnitudes, signed) + "Identity": (torch.tensor(0.0), False), + "ShearX": (torch.linspace(0.0, 0.99, num_bins), True), + "ShearY": (torch.linspace(0.0, 0.99, num_bins), True), + "TranslateX": (torch.linspace(0.0, 32.0, num_bins), True), + "TranslateY": (torch.linspace(0.0, 32.0, num_bins), True), + "Rotate": (torch.linspace(0.0, 135.0, num_bins), True), + "Brightness": (torch.linspace(0.0, 0.99, num_bins), True), + "Color": (torch.linspace(0.0, 0.99, num_bins), True), + "Contrast": (torch.linspace(0.0, 0.99, num_bins), True), + "Sharpness": (torch.linspace(0.0, 0.99, num_bins), True), + "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False), + "Solarize": (torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (torch.tensor(0.0), False), + "Equalize": (torch.tensor(0.0), False), + } + + def forward(self, img: Tensor) -> Tensor: + """ + img (PIL Image or Tensor): Image to be transformed. + + Returns: + PIL Image or Tensor: Transformed image. + """ + fill = self.fill + if isinstance(img, Tensor): + if isinstance(fill, (int, float)): + fill = [float(fill)] * F.get_image_num_channels(img) + elif fill is not None: + fill = [float(f) for f in fill] + + op_meta = self._augmentation_space(self.num_magnitude_bins) + op_index = int(torch.randint(len(op_meta), (1,)).item()) + op_name = list(op_meta.keys())[op_index] + magnitudes, signed = op_meta[op_name] + magnitude = float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item()) \ + if magnitudes.ndim > 0 else 0.0 + if signed and torch.randint(2, (1,)): + magnitude *= -1.0 + + return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) + + def __repr__(self) -> str: + s = self.__class__.__name__ + '(' + s += 'num_magnitude_bins={num_magnitude_bins}' + s += ', interpolation={interpolation}' + s += ', fill={fill}' + s += ')' + return s.format(**self.__dict__) + +# HEREHEREHEREHERE1 + + + + + + + + +train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, + transform=None) +test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False, + transform=torchvision.transforms.ToTensor()) + + auto_aug_agent = Learner() -ev_learner = Evolutionary_learner(auto_aug_agent, train_loader=train_loader, child_network=LeNet(), augmentation_space=augmentation_space, p_bins=1, mag_bins=1, sub_num_pol=1) +ev_learner = Evolutionary_learner(auto_aug_agent, train_loader=train_loader, child_network=LeNet(), augmentation_space=augmentation_space, p_bins=1, mag_bins=1, sub_num_pol=1, train_dataset=train_dataset, test_dataset=test_dataset) ev_learner.run_instance() solution, solution_fitness, solution_idx = ev_learner.ga_instance.best_solution() -print("Fitness value of the best solution = {solution_fitness}".format(solution_fitness=solution_fitness)) -print("Index of the best solution : {solution_idx}".format(solution_idx=solution_idx)) + +print(f"Best solution : {solution}") +print(f"Fitness value of the best solution = {solution_fitness}") +print(f"Index of the best solution : {solution_idx}") # Fetch the parameters of the best solution. best_solution_weights = torchga.model_weights_as_dict(model=ev_learner.auto_aug_agent, weights_vector=solution) \ No newline at end of file diff --git a/MetaAugment/GA_results.png b/MetaAugment/GA_results.png new file mode 100644 index 0000000000000000000000000000000000000000..62449415b64500804927328ca677c4c023085436 Binary files /dev/null and b/MetaAugment/GA_results.png differ diff --git a/MetaAugment/MetaAugment/test/MNIST/raw/t10k-labels-idx1-ubyte b/MetaAugment/MetaAugment/test/MNIST/raw/t10k-labels-idx1-ubyte new file mode 100644 index 0000000000000000000000000000000000000000..d1c3a970612bbd2df47a3c0697f82bd394abc450 Binary files /dev/null and b/MetaAugment/MetaAugment/test/MNIST/raw/t10k-labels-idx1-ubyte differ diff --git a/MetaAugment/MetaAugment/test/MNIST/raw/t10k-labels-idx1-ubyte.gz b/MetaAugment/MetaAugment/test/MNIST/raw/t10k-labels-idx1-ubyte.gz new file mode 100644 index 0000000000000000000000000000000000000000..a7e141541c1d08d3f2ed01eae03e644f9e2fd0c5 Binary files /dev/null and b/MetaAugment/MetaAugment/test/MNIST/raw/t10k-labels-idx1-ubyte.gz differ diff --git a/MetaAugment/MetaAugment/test/MNIST/raw/train-labels-idx1-ubyte b/MetaAugment/MetaAugment/test/MNIST/raw/train-labels-idx1-ubyte new file mode 100644 index 0000000000000000000000000000000000000000..d6b4c5db3b52063d543fb397aede09aba0dc5234 Binary files /dev/null and b/MetaAugment/MetaAugment/test/MNIST/raw/train-labels-idx1-ubyte differ diff --git a/MetaAugment/MetaAugment/test/MNIST/raw/train-labels-idx1-ubyte.gz b/MetaAugment/MetaAugment/test/MNIST/raw/train-labels-idx1-ubyte.gz new file mode 100644 index 0000000000000000000000000000000000000000..707a576bb523304d5b674de436c0779d77b7d480 Binary files /dev/null and b/MetaAugment/MetaAugment/test/MNIST/raw/train-labels-idx1-ubyte.gz differ diff --git a/MetaAugment/MetaAugment/train/MNIST/raw/t10k-labels-idx1-ubyte b/MetaAugment/MetaAugment/train/MNIST/raw/t10k-labels-idx1-ubyte new file mode 100644 index 0000000000000000000000000000000000000000..d1c3a970612bbd2df47a3c0697f82bd394abc450 Binary files /dev/null and b/MetaAugment/MetaAugment/train/MNIST/raw/t10k-labels-idx1-ubyte differ diff --git a/MetaAugment/MetaAugment/train/MNIST/raw/t10k-labels-idx1-ubyte.gz b/MetaAugment/MetaAugment/train/MNIST/raw/t10k-labels-idx1-ubyte.gz new file mode 100644 index 0000000000000000000000000000000000000000..a7e141541c1d08d3f2ed01eae03e644f9e2fd0c5 Binary files /dev/null and b/MetaAugment/MetaAugment/train/MNIST/raw/t10k-labels-idx1-ubyte.gz differ diff --git a/MetaAugment/MetaAugment/train/MNIST/raw/train-labels-idx1-ubyte b/MetaAugment/MetaAugment/train/MNIST/raw/train-labels-idx1-ubyte new file mode 100644 index 0000000000000000000000000000000000000000..d6b4c5db3b52063d543fb397aede09aba0dc5234 Binary files /dev/null and b/MetaAugment/MetaAugment/train/MNIST/raw/train-labels-idx1-ubyte differ diff --git a/MetaAugment/MetaAugment/train/MNIST/raw/train-labels-idx1-ubyte.gz b/MetaAugment/MetaAugment/train/MNIST/raw/train-labels-idx1-ubyte.gz new file mode 100644 index 0000000000000000000000000000000000000000..707a576bb523304d5b674de436c0779d77b7d480 Binary files /dev/null and b/MetaAugment/MetaAugment/train/MNIST/raw/train-labels-idx1-ubyte.gz differ diff --git a/MetaAugment/genetic_learner_results.py b/MetaAugment/genetic_learner_results.py new file mode 100644 index 0000000000000000000000000000000000000000..35d9de8df2e17748b34e6879d4a3ae75dca9d9fb --- /dev/null +++ b/MetaAugment/genetic_learner_results.py @@ -0,0 +1,109 @@ +import matplotlib.pyplot as plt +import numpy as np + + +# Fixed seed (same as benchmark) + +# Looking at last generation can make out general trends of which transformations lead to the largest accuracies + + +gen_1_acc = [0.1998, 0.1405, 0.1678, 0.9690, 0.9672, 0.9540, 0.9047, 0.9730, 0.2060, 0.9260, 0.8035, 0.9715, 0.9737, 0.14, 0.9645] + +gen_2_acc = [0.9218, 0.9753, 0.9758, 0.1088, 0.9710, 0.1655, 0.9735, 0.9655, 0.9740, 0.9377] + +gen_3_acc = [0.1445, 0.9740, 0.9643, 0.9750, 0.9492, 0.9693, 0.1262, 0.9660, 0.9760, 0.9697] + +gen_4_acc = [0.9697, 0.1238, 0.9613, 0.9737, 0.9603, 0.8620, 0.9712, 0.9617, 0.9737, 0.1855] + +gen_5_acc = [0.6445, 0.9705, 0.9668, 0.9765, 0.1142, 0.9780, 0.9700, 0.2120, 0.9555, 0.9732] + +gen_6_acc = [0.9710, 0.9665, 0.2077, 0.9535, 0.9765, 0.9712, 0.9697, 0.2145, 0.9523, 0.9718, 0.9718, 0.9718, 0.2180, 0.9622, 0.9785] + +gen_acc = [gen_1_acc, gen_2_acc, gen_3_acc, gen_4_acc, gen_5_acc, gen_6_acc] + +gen_acc_means = [] +gen_acc_stds = [] + +for val in gen_acc: + gen_acc_means.append(np.mean(val)) + gen_acc_stds.append(np.std(val)) + + + +# Vary seed + +gen_1_vary = [0.1998, 0.9707, 0.9715, 0.9657, 0.8347, 0.9655, 0.1870, 0.0983, 0.3750, 0.9765, 0.9712, 0.9705, 0.9635, 0.9718, 0.1170] + +gen_2_vary = [0.9758, 0.9607, 0.9597, 0.9753, 0.1165, 0.1503, 0.9747, 0.1725, 0.9645, 0.2290] + +gen_3_vary = [0.1357, 0.9725, 0.1708, 0.9607, 0.2132, 0.9730, 0.9743, 0.9690, 0.0850, 0.9755] + +gen_4_vary = [0.9722, 0.9760, 0.9697, 0.1155, 0.9715, 0.9688, 0.1785, 0.9745, 0.2362, 0.9765] + +gen_5_vary = [0.9705, 0.2280, 0.9745, 0.1875, 0.9735, 0.9735, 0.9720, 0.9678, 0.9770, 0.1155] + +gen_6_vary = [0.9685, 0.9730, 0.9735, 0.9760, 0.1495, 0.9707, 0.9700, 0.9747, 0.9750, 0.1155, 0.9732, 0.9745, 0.9758, 0.9768, 0.1155] + +gen_vary = [gen_1_vary, gen_2_vary, gen_3_vary, gen_4_vary, gen_5_vary, gen_6_vary] + +gen_vary_means = [] +gen_vary_stds = [] + +for val in gen_vary: + gen_vary_means.append(np.mean(val)) + gen_vary_stds.append(np.std(val)) + + + + + +# Multiple runs + +gen_1_mult = [0.1762, 0.9575, 0.1200, 0.9660, 0.9650, 0.9570, 0.9745, 0.9700, 0.15, 0.23, 0.16, 0.186, 0.9640, 0.9650] + +gen_2_mult = [0.17, 0.1515, 0.1700, 0.9625, 0.9630, 0.9732, 0.9680, 0.9633, 0.9530, 0.9640] + +gen_3_mult = [0.9750, 0.9720, 0.9655, 0.9530, 0.9623, 0.9730, 0.9748, 0.9625, 0.9716, 0.9672] + +gen_4_mult = [0.9724, 0.9755, 0.9657, 0.9718, 0.9690, 0.9735, 0.9715, 0.9300, 0.9725, 0.9695] + +gen_5_mult = [0.9560, 0.9750, 0.8750, 0.9717, 0.9731, 0.9741, 0.9747, 0.9726, 0.9729, 0.9727] + +gen_6_mult = [0.9730, 0.9740, 0.9715, 0.9755, 0.9761, 0.9700, 0.9755, 0.9750, 0.9726, 0.9748, 0.9705, 0.9745, 0.9752, 0.9740, 0.9744] + + + +gen_mult = [gen_1_mult, gen_2_mult, gen_3_mult, gen_4_mult, gen_5_mult, gen_6_mult] + +gen_mult_means = [] +gen_mult_stds = [] + +for val in gen_mult: + gen_mult_means.append(np.mean(val)) + gen_mult_stds.append(np.std(val)) + +num_gen = [i for i in range(len(gen_mult))] + + +# Baseline +baseline = [0.7990 for i in range(len(gen_mult))] + + + +# plt.errorbar(num_gen, gen_acc_means, yerr=gen_acc_stds, linestyle = 'dotted', label = 'Fixed seed GA') +# plt.errorbar(num_gen, gen_vary_means, linestyle = 'dotted', yerr=gen_vary_stds, label = 'Varying seed GA') +# plt.errorbar(num_gen, gen_mult_means, linestyle = 'dotted', yerr=gen_mult_stds, label = 'Varying seed GA 2') + +plt.plot(num_gen, gen_acc_means, linestyle = 'dotted', label = 'Fixed seed GA') +plt.plot(num_gen, gen_vary_means, linestyle = 'dotted', label = 'Varying seed GA') +plt.plot(num_gen, gen_mult_means, linestyle = 'dotted', label = 'Varying seed GA 2') + +plt.plot(num_gen, baseline, label = 'Fixed seed baseline') + + +plt.xlabel('Generation', fontsize = 16) +plt.ylabel('Validation Accuracy', fontsize = 16) + +plt.legend() + +plt.savefig('GA_results.png') \ No newline at end of file