diff --git a/MetaAugment/CP2_Max.py b/MetaAugment/CP2_Max.py deleted file mode 100644 index 3e60d56adb23c23bd353e3c8d118c253f4c845f4..0000000000000000000000000000000000000000 --- a/MetaAugment/CP2_Max.py +++ /dev/null @@ -1,945 +0,0 @@ -from cgi import test -import numpy as np -import torch -torch.manual_seed(0) -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -import torchvision -import torchvision.datasets as datasets -from torchvision import transforms -import torchvision.transforms.autoaugment as autoaugment -import random -import pygad -import pygad.torchga as torchga -import random -import copy -from torchvision.transforms import functional as F, InterpolationMode -from typing import List, Tuple, Optional, Dict -import heapq -import math - -import math -import torch - -from enum import Enum -from torch import Tensor -from typing import List, Tuple, Optional, Dict - -from torchvision.transforms import functional as F, InterpolationMode - -# import MetaAugment.child_networks as child_networks -# from main import * -# from autoaugment_learners.autoaugment import * - - -# np.random.seed(0) -# random.seed(0) - - -augmentation_space = [ - # (function_name, do_we_need_to_specify_magnitude) - ("ShearX", True), - ("ShearY", True), - ("TranslateX", True), - ("TranslateY", True), - ("Rotate", True), - ("Brightness", True), - ("Color", True), - ("Contrast", True), - ("Sharpness", True), - ("Posterize", True), - ("Solarize", True), - ("AutoContrast", False), - ("Equalize", False), - ("Invert", False), - ] - -class Learner(nn.Module): - def __init__(self, fun_num=14, p_bins=11, m_bins=10, sub_num_pol=5): - self.fun_num = fun_num - self.p_bins = p_bins - self.m_bins = m_bins - self.sub_num_pol = sub_num_pol - - super().__init__() - self.conv1 = nn.Conv2d(1, 6, 5) - self.relu1 = nn.ReLU() - self.pool1 = nn.MaxPool2d(2) - self.conv2 = nn.Conv2d(6, 16, 5) - self.relu2 = nn.ReLU() - self.pool2 = nn.MaxPool2d(2) - self.fc1 = nn.Linear(256, 120) - self.relu3 = nn.ReLU() - self.fc2 = nn.Linear(120, 84) - self.relu4 = nn.ReLU() - self.fc3 = nn.Linear(84, self.sub_num_pol * 2 * (self.fun_num + self.p_bins + self.m_bins)) - - def forward(self, x): - y = self.conv1(x) - y = self.relu1(y) - y = self.pool1(y) - y = self.conv2(y) - y = self.relu2(y) - y = self.pool2(y) - y = y.view(y.shape[0], -1) - y = self.fc1(y) - y = self.relu3(y) - y = self.fc2(y) - y = self.relu4(y) - y = self.fc3(y) - - return y - - -# class LeNet(nn.Module): -# def __init__(self): -# super().__init__() -# self.conv1 = nn.Conv2d(1, 6, 5) -# self.relu1 = nn.ReLU() -# self.pool1 = nn.MaxPool2d(2) -# self.conv2 = nn.Conv2d(6, 16, 5) -# self.relu2 = nn.ReLU() -# self.pool2 = nn.MaxPool2d(2) -# self.fc1 = nn.Linear(256, 120) -# self.relu3 = nn.ReLU() -# self.fc2 = nn.Linear(120, 84) -# self.relu4 = nn.ReLU() -# self.fc3 = nn.Linear(84, 10) -# self.relu5 = nn.ReLU() - -# def forward(self, x): -# y = self.conv1(x) -# y = self.relu1(y) -# y = self.pool1(y) -# y = self.conv2(y) -# y = self.relu2(y) -# y = self.pool2(y) -# y = y.view(y.shape[0], -1) -# y = self.fc1(y) -# y = self.relu3(y) -# y = self.fc2(y) -# y = self.relu4(y) -# y = self.fc3(y) -# return y - - -class LeNet(nn.Module): - def __init__(self): - super().__init__() - self.fc1 = nn.Linear(784, 2048) - self.relu1 = nn.ReLU() - self.fc2 = nn.Linear(2048, 10) - self.relu2 = nn.ReLU() - - def forward(self, x): - x = x.reshape((-1, 784)) - y = self.fc1(x) - y = self.relu1(y) - y = self.fc2(y) - y = self.relu2(y) - return y - - - -# ORGANISING DATA - -# transforms = ['RandomResizedCrop', 'RandomHorizontalFlip', 'RandomVerticalCrop', 'RandomRotation'] -train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=True, transform=torchvision.transforms.ToTensor()) -test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=True, transform=torchvision.transforms.ToTensor()) -n_samples = 0.02 -# shuffle and take first n_samples %age of training dataset -shuffled_train_dataset = torch.utils.data.Subset(train_dataset, torch.randperm(len(train_dataset)).tolist()) -indices_train = torch.arange(int(n_samples*len(train_dataset))) -reduced_train_dataset = torch.utils.data.Subset(shuffled_train_dataset, indices_train) -# shuffle and take first n_samples %age of test dataset -shuffled_test_dataset = torch.utils.data.Subset(test_dataset, torch.randperm(len(test_dataset)).tolist()) -indices_test = torch.arange(int(n_samples*len(test_dataset))) -reduced_test_dataset = torch.utils.data.Subset(shuffled_test_dataset, indices_test) - -train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=60000) - - - - - -class Evolutionary_learner(): - - def __init__(self, network, num_solutions = 10, num_generations = 5, num_parents_mating = 5, train_loader = None, child_network = None, p_bins = 11, mag_bins = 10, sub_num_pol=5, fun_num = 14, augmentation_space = None, train_dataset = None, test_dataset = None): - self.auto_aug_agent = Learner(fun_num=fun_num, p_bins=p_bins, m_bins=mag_bins, sub_num_pol=sub_num_pol) - self.torch_ga = torchga.TorchGA(model=network, num_solutions=num_solutions) - self.num_generations = num_generations - self.num_parents_mating = num_parents_mating - self.initial_population = self.torch_ga.population_weights - self.train_loader = train_loader - self.child_network = child_network - self.p_bins = p_bins - self.sub_num_pol = sub_num_pol - self.mag_bins = mag_bins - self.fun_num = fun_num - self.augmentation_space = augmentation_space - self.train_dataset = train_dataset - self.test_dataset = test_dataset - - assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!' - - self.set_up_instance() - - - def get_full_policy(self, x): - """ - Generates the full policy (self.num_sub_pol subpolicies). Network architecture requires - output size 5 * 2 * (self.fun_num + self.p_bins + self.mag_bins) - - Parameters - ----------- - x -> PyTorch tensor - Input data for network - - Returns - ---------- - full_policy -> [((String, float, float), (String, float, float)), ...) - Full policy consisting of tuples of subpolicies. Each subpolicy consisting of - two transformations, with a probability and magnitude float for each - """ - section = self.auto_aug_agent.fun_num + self.auto_aug_agent.p_bins + self.auto_aug_agent.m_bins - y = self.auto_aug_agent.forward(x) - full_policy = [] - for pol in range(self.sub_num_pol): - int_pol = [] - for _ in range(2): - idx_ret = torch.argmax(y[:, (pol * section):(pol*section) + self.fun_num].mean(dim = 0)) - - trans, need_mag = self.augmentation_space[idx_ret] - - p_ret = (1/(self.p_bins-1)) * torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0)) - mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0)) if need_mag else None - int_pol.append((trans, p_ret, mag)) - - full_policy.append(tuple(int_pol)) - - return full_policy - - - def get_policy_cov(self, x, alpha = 0.5): - """ - Selects policy using population and covariance matrices. For this method - we require p_bins = 1, num_sub_pol = 1, mag_bins = 1. - - Parameters - ------------ - x -> PyTorch Tensor - Input data for the AutoAugment network - - alpha -> Float - Proportion for covariance and population matrices - - Returns - ----------- - Subpolicy -> [(String, float, float), (String, float, float)] - Subpolicy consisting of two tuples of policies, each with a string associated - to a transformation, a float for a probability, and a float for a magnittude - """ - section = self.auto_aug_agent.fun_num + self.auto_aug_agent.p_bins + self.auto_aug_agent.m_bins - - y = self.auto_aug_agent.forward(x) # 1000 x 32 - - y_1 = torch.softmax(y[:,:self.auto_aug_agent.fun_num], dim = 1) # 1000 x 14 - y[:,:self.auto_aug_agent.fun_num] = y_1 - y_2 = torch.softmax(y[:,section:section+self.auto_aug_agent.fun_num], dim = 1) - y[:,section:section+self.auto_aug_agent.fun_num] = y_2 - concat = torch.cat((y_1, y_2), dim = 1) - - cov_mat = torch.cov(concat.T)#[:self.auto_aug_agent.fun_num, self.auto_aug_agent.fun_num:] - cov_mat = cov_mat[:self.auto_aug_agent.fun_num, self.auto_aug_agent.fun_num:] - shape_store = cov_mat.shape - - counter, prob1, prob2, mag1, mag2 = (0, 0, 0, 0, 0) - - - prob_mat = torch.zeros(shape_store) - for idx in range(y.shape[0]): - prob_mat[torch.argmax(y_1[idx])][torch.argmax(y_2[idx])] += 1 - prob_mat = prob_mat / torch.sum(prob_mat) - - cov_mat = (alpha * cov_mat) + ((1 - alpha)*prob_mat) - - cov_mat = torch.reshape(cov_mat, (1, -1)).squeeze() - max_idx = torch.argmax(cov_mat) - val = (max_idx//shape_store[0]) - max_idx = (val, max_idx - (val * shape_store[0])) - - - if not self.augmentation_space[max_idx[0]][1]: - mag1 = None - if not self.augmentation_space[max_idx[1]][1]: - mag2 = None - - for idx in range(y.shape[0]): - if (torch.argmax(y_1[idx]) == max_idx[0]) and (torch.argmax(y_2[idx]) == max_idx[1]): - prob1 += torch.sigmoid(y[idx, self.auto_aug_agent.fun_num]).item() - prob2 += torch.sigmoid(y[idx, section+self.auto_aug_agent.fun_num]).item() - if mag1 is not None: - mag1 += min(max(0, (y[idx, self.auto_aug_agent.fun_num+1]).item()), 8) - if mag2 is not None: - mag2 += min(max(0, y[idx, section+self.auto_aug_agent.fun_num+1].item()), 8) - counter += 1 - - prob1 = prob1/counter if counter != 0 else 0 - prob2 = prob2/counter if counter != 0 else 0 - if mag1 is not None: - mag1 = mag1/counter - if mag2 is not None: - mag2 = mag2/counter - - - return [(self.augmentation_space[max_idx[0]][0], prob1, mag1), (self.augmentation_space[max_idx[1]][0], prob2, mag2)] - - - - - - - def run_instance(self, return_weights = False): - """ - Runs the GA instance and returns the model weights as a dictionary - - Parameters - ------------ - return_weights -> Bool - Determines if the weight of the GA network should be returned - - Returns - ------------ - If return_weights: - Network weights -> Dictionary - - Else: - Solution -> Best GA instance solution - - Solution fitness -> Float - - Solution_idx -> Int - """ - self.ga_instance.run() - solution, solution_fitness, solution_idx = self.ga_instance.best_solution() - if return_weights: - return torchga.model_weights_as_dict(model=self.auto_aug_agent, weights_vector=solution) - else: - return solution, solution_fitness, solution_idx - - - def new_model(self): - """ - Simple function to create a copy of the secondary model (used for classification) - """ - copy_model = copy.deepcopy(self.child_network) - return copy_model - - - def set_up_instance(self): - """ - Initialises GA instance, as well as fitness and on_generation functions - - """ - - def fitness_func(solution, sol_idx): - """ - Defines the fitness function for the parent selection - - Parameters - -------------- - solution -> GA solution instance (parsed automatically) - - sol_idx -> GA solution index (parsed automatically) - - Returns - -------------- - fit_val -> float - """ - - model_weights_dict = torchga.model_weights_as_dict(model=self.auto_aug_agent, - weights_vector=solution) - - self.auto_aug_agent.load_state_dict(model_weights_dict) - - for idx, (test_x, label_x) in enumerate(train_loader): - full_policy = self.get_policy_cov(test_x) - - fit_val = ((test_autoaugment_policy(full_policy, self.train_dataset, self.test_dataset)[0])/ - + test_autoaugment_policy(full_policy, self.train_dataset, self.test_dataset)[0]) / 2 - - return fit_val - - def on_generation(ga_instance): - """ - Prints information of generational fitness - - Parameters - ------------- - ga_instance -> GA instance - - Returns - ------------- - None - """ - print("Generation = {generation}".format(generation=ga_instance.generations_completed)) - print("Fitness = {fitness}".format(fitness=ga_instance.best_solution()[1])) - return - - - self.ga_instance = pygad.GA(num_generations=self.num_generations, - num_parents_mating=self.num_parents_mating, - initial_population=self.initial_population, - mutation_percent_genes = 0.1, - fitness_func=fitness_func, - on_generation = on_generation) - - - - -# HEREHEREHERE0 - -def create_toy(train_dataset, test_dataset, batch_size, n_samples, seed=100): - # shuffle and take first n_samples %age of training dataset - shuffle_order_train = np.random.RandomState(seed=seed).permutation(len(train_dataset)) - shuffled_train_dataset = torch.utils.data.Subset(train_dataset, shuffle_order_train) - - indices_train = torch.arange(int(n_samples*len(train_dataset))) - reduced_train_dataset = torch.utils.data.Subset(shuffled_train_dataset, indices_train) - - # shuffle and take first n_samples %age of test dataset - shuffle_order_test = np.random.RandomState(seed=seed).permutation(len(test_dataset)) - shuffled_test_dataset = torch.utils.data.Subset(test_dataset, shuffle_order_test) - - big = 4 # how much bigger is the test set - - indices_test = torch.arange(int(n_samples*len(test_dataset)*big)) - reduced_test_dataset = torch.utils.data.Subset(shuffled_test_dataset, indices_test) - - # push into DataLoader - train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size) - test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size) - - return train_loader, test_loader - - -def train_child_network(child_network, train_loader, test_loader, sgd, - cost, max_epochs=2000, early_stop_num = 5, logging=False, - print_every_epoch=True): - if torch.cuda.is_available(): - device = torch.device('cuda') - else: - device = torch.device('cpu') - child_network = child_network.to(device=device) - - best_acc=0 - early_stop_cnt = 0 - - # logging accuracy for plotting - acc_log = [] - - # train child_network and check validation accuracy each epoch - for _epoch in range(max_epochs): - - # train child_network - child_network.train() - for idx, (train_x, train_label) in enumerate(train_loader): - # onto device - train_x = train_x.to(device=device, dtype=train_x.dtype) - train_label = train_label.to(device=device, dtype=train_label.dtype) - - # label_np = np.zeros((train_label.shape[0], 10)) - - sgd.zero_grad() - predict_y = child_network(train_x.float()) - loss = cost(predict_y, train_label.long()) - loss.backward() - sgd.step() - - # check validation accuracy on validation set - correct = 0 - _sum = 0 - child_network.eval() - with torch.no_grad(): - for idx, (test_x, test_label) in enumerate(test_loader): - # onto device - test_x = test_x.to(device=device, dtype=test_x.dtype) - test_label = test_label.to(device=device, dtype=test_label.dtype) - - predict_y = child_network(test_x.float()).detach() - predict_ys = torch.argmax(predict_y, axis=-1) - - _ = predict_ys == test_label - correct += torch.sum(_, axis=-1) - - _sum += _.shape[0] - - # update best validation accuracy if it was higher, otherwise increase early stop count - acc = correct / _sum - - if acc > best_acc : - best_acc = acc - early_stop_cnt = 0 - else: - early_stop_cnt += 1 - - # exit if validation gets worse over 10 runs - if early_stop_cnt >= early_stop_num: - print('main.train_child_network best accuracy: ', best_acc) - break - - # if print_every_epoch: - # print('main.train_child_network best accuracy: ', best_acc) - acc_log.append(acc) - - if logging: - return best_acc.item(), acc_log - return best_acc.item() - -def test_autoaugment_policy(subpolicies, train_dataset, test_dataset): - - aa_transform = AutoAugment() - aa_transform.subpolicies = subpolicies - - train_transform = transforms.Compose([ - aa_transform, - transforms.ToTensor() - ]) - - train_dataset.transform = train_transform - - # create toy dataset from above uploaded data - train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size=32, n_samples=0.1) - - child_network = LeNet() - sgd = optim.SGD(child_network.parameters(), lr=1e-1) - cost = nn.CrossEntropyLoss() - - best_acc, acc_log = train_child_network(child_network, train_loader, test_loader, - sgd, cost, max_epochs=100, logging=True) - - return best_acc, acc_log - - - -__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"] - - -def _apply_op(img: Tensor, op_name: str, magnitude: float, - interpolation: InterpolationMode, fill: Optional[List[float]]): - if op_name == "ShearX": - img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(magnitude), 0.0], - interpolation=interpolation, fill=fill) - elif op_name == "ShearY": - img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)], - interpolation=interpolation, fill=fill) - elif op_name == "TranslateX": - img = F.affine(img, angle=0.0, translate=[int(magnitude), 0], scale=1.0, - interpolation=interpolation, shear=[0.0, 0.0], fill=fill) - elif op_name == "TranslateY": - img = F.affine(img, angle=0.0, translate=[0, int(magnitude)], scale=1.0, - interpolation=interpolation, shear=[0.0, 0.0], fill=fill) - elif op_name == "Rotate": - img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill) - elif op_name == "Brightness": - img = F.adjust_brightness(img, 1.0 + magnitude) - elif op_name == "Color": - img = F.adjust_saturation(img, 1.0 + magnitude) - elif op_name == "Contrast": - img = F.adjust_contrast(img, 1.0 + magnitude) - elif op_name == "Sharpness": - img = F.adjust_sharpness(img, 1.0 + magnitude) - elif op_name == "Posterize": - img = F.posterize(img, int(magnitude)) - elif op_name == "Solarize": - img = F.solarize(img, magnitude) - elif op_name == "AutoContrast": - img = F.autocontrast(img) - elif op_name == "Equalize": - img = F.equalize(img) - elif op_name == "Invert": - img = F.invert(img) - elif op_name == "Identity": - pass - else: - raise ValueError("The provided operator {} is not recognized.".format(op_name)) - return img - - -class AutoAugmentPolicy(Enum): - """AutoAugment policies learned on different datasets. - Available policies are IMAGENET, CIFAR10 and SVHN. - """ - IMAGENET = "imagenet" - CIFAR10 = "cifar10" - SVHN = "svhn" - - -# FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class -class AutoAugment(torch.nn.Module): - r"""AutoAugment data augmentation method based on - `"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_. - If the image is torch Tensor, it should be of type torch.uint8, and it is expected - to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. - If img is PIL Image, it is expected to be in mode "L" or "RGB". - - Args: - policy (AutoAugmentPolicy): Desired policy enum defined by - :class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``. - interpolation (InterpolationMode): Desired interpolation enum defined by - :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. - If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. - fill (sequence or number, optional): Pixel fill value for the area outside the transformed - image. If given a number, the value is used for all bands respectively. - """ - - def __init__( - self, - policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, - interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[List[float]] = None - ) -> None: - super().__init__() - self.policy = policy - self.interpolation = interpolation - self.fill = fill - self.subpolicies = self._get_subpolicies(policy) - - def _get_subpolicies( - self, - policy: AutoAugmentPolicy - ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]: - if policy == AutoAugmentPolicy.IMAGENET: - return [ - (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)), - (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), - (("Equalize", 0.8, None), ("Equalize", 0.6, None)), - (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)), - (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), - (("Equalize", 0.4, None), ("Rotate", 0.8, 8)), - (("Solarize", 0.6, 3), ("Equalize", 0.6, None)), - (("Posterize", 0.8, 5), ("Equalize", 1.0, None)), - (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)), - (("Equalize", 0.6, None), ("Posterize", 0.4, 6)), - (("Rotate", 0.8, 8), ("Color", 0.4, 0)), - (("Rotate", 0.4, 9), ("Equalize", 0.6, None)), - (("Equalize", 0.0, None), ("Equalize", 0.8, None)), - (("Invert", 0.6, None), ("Equalize", 1.0, None)), - (("Color", 0.6, 4), ("Contrast", 1.0, 8)), - (("Rotate", 0.8, 8), ("Color", 1.0, 2)), - (("Color", 0.8, 8), ("Solarize", 0.8, 7)), - (("Sharpness", 0.4, 7), ("Invert", 0.6, None)), - (("ShearX", 0.6, 5), ("Equalize", 1.0, None)), - (("Color", 0.4, 0), ("Equalize", 0.6, None)), - (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), - (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), - (("Invert", 0.6, None), ("Equalize", 1.0, None)), - (("Color", 0.6, 4), ("Contrast", 1.0, 8)), - (("Equalize", 0.8, None), ("Equalize", 0.6, None)), - ] - elif policy == AutoAugmentPolicy.CIFAR10: - return [ - (("Invert", 0.1, None), ("Contrast", 0.2, 6)), - (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)), - (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)), - (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)), - (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)), - (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)), - (("Color", 0.4, 3), ("Brightness", 0.6, 7)), - (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)), - (("Equalize", 0.6, None), ("Equalize", 0.5, None)), - (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)), - (("Color", 0.7, 7), ("TranslateX", 0.5, 8)), - (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)), - (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)), - (("Brightness", 0.9, 6), ("Color", 0.2, 8)), - (("Solarize", 0.5, 2), ("Invert", 0.0, None)), - (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)), - (("Equalize", 0.2, None), ("Equalize", 0.6, None)), - (("Color", 0.9, 9), ("Equalize", 0.6, None)), - (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)), - (("Brightness", 0.1, 3), ("Color", 0.7, 0)), - (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)), - (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)), - (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)), - (("Equalize", 0.8, None), ("Invert", 0.1, None)), - (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)), - ] - elif policy == AutoAugmentPolicy.SVHN: - return [ - (("ShearX", 0.9, 4), ("Invert", 0.2, None)), - (("ShearY", 0.9, 8), ("Invert", 0.7, None)), - (("Equalize", 0.6, None), ("Solarize", 0.6, 6)), - (("Invert", 0.9, None), ("Equalize", 0.6, None)), - (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), - (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)), - (("ShearY", 0.9, 8), ("Invert", 0.4, None)), - (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)), - (("Invert", 0.9, None), ("AutoContrast", 0.8, None)), - (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), - (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)), - (("ShearY", 0.8, 8), ("Invert", 0.7, None)), - (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)), - (("Invert", 0.9, None), ("Equalize", 0.6, None)), - (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)), - (("Invert", 0.8, None), ("TranslateY", 0.0, 2)), - (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)), - (("Invert", 0.6, None), ("Rotate", 0.8, 4)), - (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)), - (("ShearX", 0.1, 6), ("Invert", 0.6, None)), - (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)), - (("ShearY", 0.8, 4), ("Invert", 0.8, None)), - (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)), - (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)), - (("ShearX", 0.7, 2), ("Invert", 0.1, None)), - ] - else: - raise ValueError("The provided policy {} is not recognized.".format(policy)) - - def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: - return { - # op_name: (magnitudes, signed) - "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), - "ShearY": (torch.linspace(0.0, 0.3, num_bins), True), - "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), - "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), - "Rotate": (torch.linspace(0.0, 30.0, num_bins), True), - "Brightness": (torch.linspace(0.0, 0.9, num_bins), True), - "Color": (torch.linspace(0.0, 0.9, num_bins), True), - "Contrast": (torch.linspace(0.0, 0.9, num_bins), True), - "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True), - "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False), - "Solarize": (torch.linspace(255.0, 0.0, num_bins), False), - "AutoContrast": (torch.tensor(0.0), False), - "Equalize": (torch.tensor(0.0), False), - "Invert": (torch.tensor(0.0), False), - } - - @staticmethod - def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]: - """Get parameters for autoaugment transformation - - Returns: - params required by the autoaugment transformation - """ - policy_id = int(torch.randint(transform_num, (1,)).item()) - probs = torch.rand((2,)) - signs = torch.randint(2, (2,)) - - return policy_id, probs, signs - - def forward(self, img: Tensor, dis_mag = True) -> Tensor: - """ - img (PIL Image or Tensor): Image to be transformed. - - Returns: - PIL Image or Tensor: AutoAugmented image. - """ - fill = self.fill - if isinstance(img, Tensor): - if isinstance(fill, (int, float)): - fill = [float(fill)] * F.get_image_num_channels(img) - elif fill is not None: - fill = [float(f) for f in fill] - - transform_id, probs, signs = self.get_params(len(self.subpolicies)) - - for i, (op_name, p, magnitude) in enumerate(self.subpolicies): - img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) - - - return img - - def __repr__(self) -> str: - return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill) - - -class RandAugment(torch.nn.Module): - r"""RandAugment data augmentation method based on - `"RandAugment: Practical automated data augmentation with a reduced search space" - <https://arxiv.org/abs/1909.13719>`_. - If the image is torch Tensor, it should be of type torch.uint8, and it is expected - to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. - If img is PIL Image, it is expected to be in mode "L" or "RGB". - - Args: - num_ops (int): Number of augmentation transformations to apply sequentially. - magnitude (int): Magnitude for all the transformations. - num_magnitude_bins (int): The number of different magnitude values. - interpolation (InterpolationMode): Desired interpolation enum defined by - :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. - If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. - fill (sequence or number, optional): Pixel fill value for the area outside the transformed - image. If given a number, the value is used for all bands respectively. - """ - - def __init__(self, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 31, - interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[List[float]] = None) -> None: - super().__init__() - self.num_ops = num_ops - self.magnitude = magnitude - self.num_magnitude_bins = num_magnitude_bins - self.interpolation = interpolation - self.fill = fill - - def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: - return { - # op_name: (magnitudes, signed) - "Identity": (torch.tensor(0.0), False), - "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), - "ShearY": (torch.linspace(0.0, 0.3, num_bins), True), - "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), - "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), - "Rotate": (torch.linspace(0.0, 30.0, num_bins), True), - "Brightness": (torch.linspace(0.0, 0.9, num_bins), True), - "Color": (torch.linspace(0.0, 0.9, num_bins), True), - "Contrast": (torch.linspace(0.0, 0.9, num_bins), True), - "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True), - "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False), - "Solarize": (torch.linspace(255.0, 0.0, num_bins), False), - "AutoContrast": (torch.tensor(0.0), False), - "Equalize": (torch.tensor(0.0), False), - } - - def forward(self, img: Tensor) -> Tensor: - """ - img (PIL Image or Tensor): Image to be transformed. - - Returns: - PIL Image or Tensor: Transformed image. - """ - fill = self.fill - if isinstance(img, Tensor): - if isinstance(fill, (int, float)): - fill = [float(fill)] * F.get_image_num_channels(img) - elif fill is not None: - fill = [float(f) for f in fill] - - for _ in range(self.num_ops): - op_meta = self._augmentation_space(self.num_magnitude_bins, F.get_image_size(img)) - op_index = int(torch.randint(len(op_meta), (1,)).item()) - op_name = list(op_meta.keys())[op_index] - magnitudes, signed = op_meta[op_name] - magnitude = float(magnitudes[self.magnitude].item()) if magnitudes.ndim > 0 else 0.0 - if signed and torch.randint(2, (1,)): - magnitude *= -1.0 - img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) - - return img - - def __repr__(self) -> str: - s = self.__class__.__name__ + '(' - s += 'num_ops={num_ops}' - s += ', magnitude={magnitude}' - s += ', num_magnitude_bins={num_magnitude_bins}' - s += ', interpolation={interpolation}' - s += ', fill={fill}' - s += ')' - return s.format(**self.__dict__) - - -class TrivialAugmentWide(torch.nn.Module): - r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in - `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`. - If the image is torch Tensor, it should be of type torch.uint8, and it is expected - to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. - If img is PIL Image, it is expected to be in mode "L" or "RGB". - - Args: - num_magnitude_bins (int): The number of different magnitude values. - interpolation (InterpolationMode): Desired interpolation enum defined by - :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. - If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. - fill (sequence or number, optional): Pixel fill value for the area outside the transformed - image. If given a number, the value is used for all bands respectively. - """ - - def __init__(self, num_magnitude_bins: int = 31, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[List[float]] = None) -> None: - super().__init__() - self.num_magnitude_bins = num_magnitude_bins - self.interpolation = interpolation - self.fill = fill - - def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]: - return { - # op_name: (magnitudes, signed) - "Identity": (torch.tensor(0.0), False), - "ShearX": (torch.linspace(0.0, 0.99, num_bins), True), - "ShearY": (torch.linspace(0.0, 0.99, num_bins), True), - "TranslateX": (torch.linspace(0.0, 32.0, num_bins), True), - "TranslateY": (torch.linspace(0.0, 32.0, num_bins), True), - "Rotate": (torch.linspace(0.0, 135.0, num_bins), True), - "Brightness": (torch.linspace(0.0, 0.99, num_bins), True), - "Color": (torch.linspace(0.0, 0.99, num_bins), True), - "Contrast": (torch.linspace(0.0, 0.99, num_bins), True), - "Sharpness": (torch.linspace(0.0, 0.99, num_bins), True), - "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False), - "Solarize": (torch.linspace(255.0, 0.0, num_bins), False), - "AutoContrast": (torch.tensor(0.0), False), - "Equalize": (torch.tensor(0.0), False), - } - - def forward(self, img: Tensor) -> Tensor: - """ - img (PIL Image or Tensor): Image to be transformed. - - Returns: - PIL Image or Tensor: Transformed image. - """ - fill = self.fill - if isinstance(img, Tensor): - if isinstance(fill, (int, float)): - fill = [float(fill)] * F.get_image_num_channels(img) - elif fill is not None: - fill = [float(f) for f in fill] - - op_meta = self._augmentation_space(self.num_magnitude_bins) - op_index = int(torch.randint(len(op_meta), (1,)).item()) - op_name = list(op_meta.keys())[op_index] - magnitudes, signed = op_meta[op_name] - magnitude = float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item()) \ - if magnitudes.ndim > 0 else 0.0 - if signed and torch.randint(2, (1,)): - magnitude *= -1.0 - - return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) - - def __repr__(self) -> str: - s = self.__class__.__name__ + '(' - s += 'num_magnitude_bins={num_magnitude_bins}' - s += ', interpolation={interpolation}' - s += ', fill={fill}' - s += ')' - return s.format(**self.__dict__) - -# HEREHEREHEREHERE1 - - - - - - - - -# train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, -# transform=None) -# test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False, -# transform=torchvision.transforms.ToTensor()) - - -# auto_aug_agent = Learner() -# ev_learner = Evolutionary_learner(auto_aug_agent, train_loader=train_loader, child_network=LeNet(), augmentation_space=augmentation_space, p_bins=1, mag_bins=1, sub_num_pol=1, train_dataset=train_dataset, test_dataset=test_dataset) -# ev_learner.run_instance() - - -# solution, solution_fitness, solution_idx = ev_learner.ga_instance.best_solution() - -# print(f"Best solution : {solution}") -# print(f"Fitness value of the best solution = {solution_fitness}") -# print(f"Index of the best solution : {solution_idx}") -# # Fetch the parameters of the best solution. -# best_solution_weights = torchga.model_weights_as_dict(model=ev_learner.auto_aug_agent, -# weights_vector=solution) diff --git a/MetaAugment/Evo_learner.py b/MetaAugment/Evo_learner.py deleted file mode 100644 index d111a9a8b5e726734fe9fb971402ed5cc8f28eef..0000000000000000000000000000000000000000 --- a/MetaAugment/Evo_learner.py +++ /dev/null @@ -1,868 +0,0 @@ -from cgi import test -import numpy as np -import torch -torch.manual_seed(0) -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -import torchvision -import torchvision.datasets as datasets -from torchvision import transforms -import torchvision.transforms.autoaugment as autoaugment -import random -import pygad -import pygad.torchga as torchga -import random -import copy -from torchvision.transforms import functional as F, InterpolationMode -from typing import List, Tuple, Optional, Dict -import heapq -import math -import torch -from flask import current_app - -from enum import Enum -from torch import Tensor - - - - -class Learner(nn.Module): - def __init__(self, fun_num=14, p_bins=11, m_bins=10, sub_num_pol=5): - self.fun_num = fun_num - self.p_bins = p_bins - self.m_bins = m_bins - self.sub_num_pol = sub_num_pol - - super().__init__() - self.conv1 = nn.Conv2d(1, 6, 5) - self.relu1 = nn.ReLU() - self.pool1 = nn.MaxPool2d(2) - self.conv2 = nn.Conv2d(6, 16, 5) - self.relu2 = nn.ReLU() - self.pool2 = nn.MaxPool2d(2) - self.fc1 = nn.Linear(256, 120) - self.relu3 = nn.ReLU() - self.fc2 = nn.Linear(120, 84) - self.relu4 = nn.ReLU() - self.fc3 = nn.Linear(84, self.sub_num_pol * 2 * (self.fun_num + self.p_bins + self.m_bins)) - - def forward(self, x): - x = x[:, 0:1, :, :] - y = self.conv1(x) - y = self.relu1(y) - y = self.pool1(y) - y = self.conv2(y) - y = self.relu2(y) - y = self.pool2(y) - y = y.view(y.shape[0], -1) - y = self.fc1(y) - y = self.relu3(y) - y = self.fc2(y) - y = self.relu4(y) - y = self.fc3(y) - - return y - -class LeNet(nn.Module): - def __init__(self): - super().__init__() - self.fc1 = nn.Linear(784, 2048) - self.relu1 = nn.ReLU() - self.fc2 = nn.Linear(2048, 10) - self.relu2 = nn.ReLU() - - def forward(self, x): - x = x.reshape((-1, 784)) - y = self.fc1(x) - y = self.relu1(y) - y = self.fc2(y) - y = self.relu2(y) - return y - - -class Evolutionary_learner(): - - def __init__(self, network, num_solutions = 10, num_generations = 5, num_parents_mating = 5, batch_size=32, child_network = None, p_bins = 11, mag_bins = 10, sub_num_pol=5, fun_num = 14, exclude_method=[], augmentation_space = None, ds=None, ds_name=None): - self.auto_aug_agent = network - self.torch_ga = torchga.TorchGA(model=self.auto_aug_agent, num_solutions=num_solutions) - - self.num_generations = num_generations - self.num_parents_mating = num_parents_mating - self.initial_population = self.torch_ga.population_weights - self.child_network = child_network - self.p_bins = p_bins - self.sub_num_pol = sub_num_pol - self.mag_bins = mag_bins - self.fun_num = fun_num - self.iter_count = 0 - - full_augmentation_space = [ - # (function_name, do_we_need_to_specify_magnitude) - ("ShearX", True), - ("ShearY", True), - ("TranslateX", True), - ("TranslateY", True), - ("Rotate", True), - ("Brightness", True), - ("Color", True), - ("Contrast", True), - ("Sharpness", True), - ("Posterize", True), - ("Solarize", True), - ("AutoContrast", False), - ("Equalize", False), - ("Invert", False), - ] - self.augmentation_space = [x for x in full_augmentation_space if x[0] not in exclude_method] - assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!' - - - transform = torchvision.transforms.Compose([ - torchvision.transforms.CenterCrop(28), - torchvision.transforms.ToTensor()]) - - if ds == "MNIST": - self.train_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/train', train=True, download=True, transform=transform) - self.test_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/test', train=False, download=True, transform=transform) - elif ds == "KMNIST": - self.train_dataset = datasets.KMNIST(root='./MetaAugment/datasets/kmnist/train', train=True, download=True, transform=transform) - self.test_dataset = datasets.KMNIST(root='./MetaAugment/datasets/kmnist/test', train=False, download=True, transform=transform) - elif ds == "FashionMNIST": - self.train_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=True, transform=transform) - self.test_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/test', train=False, download=True, transform=transform) - elif ds == "CIFAR10": - self.train_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=True, transform=transform) - self.test_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/fashionmnist/test', train=False, download=True, transform=transform) - elif ds == "CIFAR100": - self.train_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=True, transform=transform) - self.test_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/fashionmnist/test', train=False, download=True, transform=transform) - elif ds == 'Other': - dataset = datasets.ImageFolder('./MetaAugment/datasets/upload_dataset/'+ ds_name, transform=transform) - len_train = int(0.8*len(dataset)) - self.train_dataset, self.test_dataset = torch.utils.data.random_split(dataset, [len_train, len(dataset)-len_train]) - - self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=batch_size) - - - self.set_up_instance() - - - def get_full_policy(self, x): - """ - Generates the full policy (self.num_sub_pol subpolicies). Network architecture requires - output size 5 * 2 * (self.fun_num + self.p_bins + self.mag_bins) - - Parameters - ----------- - x -> PyTorch tensor - Input data for network - - Returns - ---------- - full_policy -> [((String, float, float), (String, float, float)), ...) - Full policy consisting of tuples of subpolicies. Each subpolicy consisting of - two transformations, with a probability and magnitude float for each - """ - section = self.auto_aug_agent.fun_num + self.auto_aug_agent.p_bins + self.auto_aug_agent.m_bins - y = self.auto_aug_agent.forward(x) - full_policy = [] - for pol in range(self.sub_num_pol): - int_pol = [] - for _ in range(2): - idx_ret = torch.argmax(y[:, (pol * section):(pol*section) + self.fun_num].mean(dim = 0)) - - trans, need_mag = self.augmentation_space[idx_ret] - - p_ret = (1/(self.p_bins-1)) * torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0)) - mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0)) if need_mag else None - int_pol.append((trans, p_ret, mag)) - - full_policy.append(tuple(int_pol)) - - return full_policy - - - def get_policy_cov(self, x, alpha = 0.5): - """ - Selects policy using population and covariance matrices. For this method - we require p_bins = 1, num_sub_pol = 1, mag_bins = 1. - - Parameters - ------------ - x -> PyTorch Tensor - Input data for the AutoAugment network - - alpha -> Float - Proportion for covariance and population matrices - - Returns - ----------- - Subpolicy -> [(String, float, float), (String, float, float)] - Subpolicy consisting of two tuples of policies, each with a string associated - to a transformation, a float for a probability, and a float for a magnittude - """ - section = self.auto_aug_agent.fun_num + self.auto_aug_agent.p_bins + self.auto_aug_agent.m_bins - - y = self.auto_aug_agent.forward(x) # 1000 x 32 - - y_1 = torch.softmax(y[:,:self.auto_aug_agent.fun_num], dim = 1) # 1000 x 14 - y[:,:self.auto_aug_agent.fun_num] = y_1 - y_2 = torch.softmax(y[:,section:section+self.auto_aug_agent.fun_num], dim = 1) - y[:,section:section+self.auto_aug_agent.fun_num] = y_2 - concat = torch.cat((y_1, y_2), dim = 1) - - cov_mat = torch.cov(concat.T)#[:self.auto_aug_agent.fun_num, self.auto_aug_agent.fun_num:] - cov_mat = cov_mat[:self.auto_aug_agent.fun_num, self.auto_aug_agent.fun_num:] - shape_store = cov_mat.shape - - counter, prob1, prob2, mag1, mag2 = (0, 0, 0, 0, 0) - - - prob_mat = torch.zeros(shape_store) - for idx in range(y.shape[0]): - prob_mat[torch.argmax(y_1[idx])][torch.argmax(y_2[idx])] += 1 - prob_mat = prob_mat / torch.sum(prob_mat) - - cov_mat = (alpha * cov_mat) + ((1 - alpha)*prob_mat) - - cov_mat = torch.reshape(cov_mat, (1, -1)).squeeze() - max_idx = torch.argmax(cov_mat) - val = (max_idx//shape_store[0]) - max_idx = (val, max_idx - (val * shape_store[0])) - - - if not self.augmentation_space[max_idx[0]][1]: - mag1 = None - if not self.augmentation_space[max_idx[1]][1]: - mag2 = None - - for idx in range(y.shape[0]): - if (torch.argmax(y_1[idx]) == max_idx[0]) and (torch.argmax(y_2[idx]) == max_idx[1]): - prob1 += torch.sigmoid(y[idx, self.auto_aug_agent.fun_num]).item() - prob2 += torch.sigmoid(y[idx, section+self.auto_aug_agent.fun_num]).item() - if mag1 is not None: - mag1 += min(max(0, (y[idx, self.auto_aug_agent.fun_num+1]).item()), 8) - if mag2 is not None: - mag2 += min(max(0, y[idx, section+self.auto_aug_agent.fun_num+1].item()), 8) - counter += 1 - - prob1 = prob1/counter if counter != 0 else 0 - prob2 = prob2/counter if counter != 0 else 0 - if mag1 is not None: - mag1 = mag1/counter - if mag2 is not None: - mag2 = mag2/counter - - - return [(self.augmentation_space[max_idx[0]][0], prob1, mag1), (self.augmentation_space[max_idx[1]][0], prob2, mag2)] - - - def run_instance(self, return_weights = False): - """ - Runs the GA instance and returns the model weights as a dictionary - - Parameters - ------------ - return_weights -> Bool - Determines if the weight of the GA network should be returned - - Returns - ------------ - If return_weights: - Network weights -> Dictionary - - Else: - Solution -> Best GA instance solution - - Solution fitness -> Float - - Solution_idx -> Int - """ - self.ga_instance.run() - solution, solution_fitness, solution_idx = self.ga_instance.best_solution() - if return_weights: - return torchga.model_weights_as_dict(model=self.auto_aug_agent, weights_vector=solution) - else: - return solution, solution_fitness, solution_idx - - - def set_up_instance(self): - """ - Initialises GA instance, as well as fitness and on_generation functions - - """ - - def fitness_func(solution, sol_idx): - """ - Defines the fitness function for the parent selection - - Parameters - -------------- - solution -> GA solution instance (parsed automatically) - - sol_idx -> GA solution index (parsed automatically) - - Returns - -------------- - fit_val -> float - """ - - model_weights_dict = torchga.model_weights_as_dict(model=self.auto_aug_agent, - weights_vector=solution) - - self.auto_aug_agent.load_state_dict(model_weights_dict) - - for idx, (test_x, label_x) in enumerate(self.train_loader): - full_policy = self.get_policy_cov(test_x) - - fit_val = ((test_autoaugment_policy(full_policy, self.train_dataset, self.test_dataset, self.child_network)[0])/ - + test_autoaugment_policy(full_policy, self.train_dataset, self.test_dataset, self.child_network)[0]) / 2 - - self.iter_count += 1 - current_app.config['iteration'] = self.iter_count - - return fit_val - - def on_generation(ga_instance): - """ - Prints information of generational fitness - - Parameters - ------------- - ga_instance -> GA instance - - Returns - ------------- - None - """ - print("Generation = {generation}".format(generation=ga_instance.generations_completed)) - print("Fitness = {fitness}".format(fitness=ga_instance.best_solution()[1])) - return - - - self.ga_instance = pygad.GA(num_generations=self.num_generations, - num_parents_mating=self.num_parents_mating, - initial_population=self.initial_population, - mutation_percent_genes = 0.1, - fitness_func=fitness_func, - on_generation = on_generation) - - - - -# HEREHEREHERE0 - -def create_toy(train_dataset, test_dataset, batch_size, n_samples, seed=100): - # shuffle and take first n_samples %age of training dataset - shuffle_order_train = np.random.RandomState(seed=seed).permutation(len(train_dataset)) - shuffled_train_dataset = torch.utils.data.Subset(train_dataset, shuffle_order_train) - - indices_train = torch.arange(int(n_samples*len(train_dataset))) - reduced_train_dataset = torch.utils.data.Subset(shuffled_train_dataset, indices_train) - - # shuffle and take first n_samples %age of test dataset - shuffle_order_test = np.random.RandomState(seed=seed).permutation(len(test_dataset)) - shuffled_test_dataset = torch.utils.data.Subset(test_dataset, shuffle_order_test) - - big = 4 # how much bigger is the test set - - indices_test = torch.arange(int(n_samples*len(test_dataset)*big)) - reduced_test_dataset = torch.utils.data.Subset(shuffled_test_dataset, indices_test) - - # push into DataLoader - train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size) - test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size) - - return train_loader, test_loader - - -def train_child_network(child_network, train_loader, test_loader, sgd, - cost, max_epochs=2000, early_stop_num = 5, logging=False, - print_every_epoch=True): - if torch.cuda.is_available(): - device = torch.device('cuda') - else: - device = torch.device('cpu') - child_network = child_network.to(device=device) - - best_acc=0 - early_stop_cnt = 0 - - acc_log = [] - - for _epoch in range(max_epochs): - - child_network.train() - for idx, (train_x, train_label) in enumerate(train_loader): - train_x = train_x.to(device=device, dtype=train_x.dtype) - train_label = train_label.to(device=device, dtype=train_label.dtype) - sgd.zero_grad() - predict_y = child_network(train_x.float()) - loss = cost(predict_y, train_label.long()) - loss.backward() - sgd.step() - - correct = 0 - _sum = 0 - child_network.eval() - print("here0") - with torch.no_grad(): - print("here1") - print("len test_loader: ", len(test_loader)) - for idx, (test_x, test_label) in enumerate(test_loader): - print("here2") - test_x = test_x.to(device=device, dtype=test_x.dtype) - test_label = test_label.to(device=device, dtype=test_label.dtype) - - predict_y = child_network(test_x.float()).detach() - predict_ys = torch.argmax(predict_y, axis=-1) - - _ = predict_ys == test_label - correct += torch.sum(_, axis=-1) - - _sum += _.shape[0] - print("SUM: ", _sum) - - acc = correct / _sum - - if acc > best_acc : - best_acc = acc - early_stop_cnt = 0 - else: - early_stop_cnt += 1 - - if early_stop_cnt >= early_stop_num: - print('main.train_child_network best accuracy: ', best_acc) - break - - acc_log.append(acc) - - if logging: - return best_acc.item(), acc_log - return best_acc.item() - - -def test_autoaugment_policy(subpolicies, train_dataset, test_dataset, train_network): - - aa_transform = AutoAugment() - aa_transform.subpolicies = subpolicies - - train_transform = transforms.Compose([ - aa_transform, - transforms.ToTensor() - ]) - - train_dataset.transform = train_transform - - # create toy dataset from above uploaded data - train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size=32, n_samples=0.1) - - child_network = train_network - sgd = optim.SGD(child_network.parameters(), lr=1e-1) - cost = nn.CrossEntropyLoss() - - best_acc, acc_log = train_child_network(child_network, train_loader, test_loader, - sgd, cost, max_epochs=100, logging=True) - - return best_acc, acc_log - - - -__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"] - - -def _apply_op(img: Tensor, op_name: str, magnitude: float, - interpolation: InterpolationMode, fill: Optional[List[float]]): - if op_name == "ShearX": - img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(magnitude), 0.0], - interpolation=interpolation, fill=fill) - elif op_name == "ShearY": - img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)], - interpolation=interpolation, fill=fill) - elif op_name == "TranslateX": - img = F.affine(img, angle=0.0, translate=[int(magnitude), 0], scale=1.0, - interpolation=interpolation, shear=[0.0, 0.0], fill=fill) - elif op_name == "TranslateY": - img = F.affine(img, angle=0.0, translate=[0, int(magnitude)], scale=1.0, - interpolation=interpolation, shear=[0.0, 0.0], fill=fill) - elif op_name == "Rotate": - img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill) - elif op_name == "Brightness": - img = F.adjust_brightness(img, 1.0 + magnitude) - elif op_name == "Color": - img = F.adjust_saturation(img, 1.0 + magnitude) - elif op_name == "Contrast": - img = F.adjust_contrast(img, 1.0 + magnitude) - elif op_name == "Sharpness": - img = F.adjust_sharpness(img, 1.0 + magnitude) - elif op_name == "Posterize": - img = F.posterize(img, int(magnitude)) - elif op_name == "Solarize": - img = F.solarize(img, magnitude) - elif op_name == "AutoContrast": - img = F.autocontrast(img) - elif op_name == "Equalize": - img = F.equalize(img) - elif op_name == "Invert": - img = F.invert(img) - elif op_name == "Identity": - pass - else: - raise ValueError("The provided operator {} is not recognized.".format(op_name)) - return img - - -class AutoAugmentPolicy(Enum): - """AutoAugment policies learned on different datasets. - Available policies are IMAGENET, CIFAR10 and SVHN. - """ - IMAGENET = "imagenet" - CIFAR10 = "cifar10" - SVHN = "svhn" - - -# FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class -class AutoAugment(torch.nn.Module): - r"""AutoAugment data augmentation method based on - `"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_. - If the image is torch Tensor, it should be of type torch.uint8, and it is expected - to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. - If img is PIL Image, it is expected to be in mode "L" or "RGB". - - Args: - policy (AutoAugmentPolicy): Desired policy enum defined by - :class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``. - interpolation (InterpolationMode): Desired interpolation enum defined by - :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. - If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. - fill (sequence or number, optional): Pixel fill value for the area outside the transformed - image. If given a number, the value is used for all bands respectively. - """ - - def __init__( - self, - policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, - interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[List[float]] = None - ) -> None: - super().__init__() - self.policy = policy - self.interpolation = interpolation - self.fill = fill - self.subpolicies = self._get_subpolicies(policy) - - def _get_subpolicies( - self, - policy: AutoAugmentPolicy - ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]: - if policy == AutoAugmentPolicy.IMAGENET: - return [ - (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)), - (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), - (("Equalize", 0.8, None), ("Equalize", 0.6, None)), - (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)), - (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), - (("Equalize", 0.4, None), ("Rotate", 0.8, 8)), - (("Solarize", 0.6, 3), ("Equalize", 0.6, None)), - (("Posterize", 0.8, 5), ("Equalize", 1.0, None)), - (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)), - (("Equalize", 0.6, None), ("Posterize", 0.4, 6)), - (("Rotate", 0.8, 8), ("Color", 0.4, 0)), - (("Rotate", 0.4, 9), ("Equalize", 0.6, None)), - (("Equalize", 0.0, None), ("Equalize", 0.8, None)), - (("Invert", 0.6, None), ("Equalize", 1.0, None)), - (("Color", 0.6, 4), ("Contrast", 1.0, 8)), - (("Rotate", 0.8, 8), ("Color", 1.0, 2)), - (("Color", 0.8, 8), ("Solarize", 0.8, 7)), - (("Sharpness", 0.4, 7), ("Invert", 0.6, None)), - (("ShearX", 0.6, 5), ("Equalize", 1.0, None)), - (("Color", 0.4, 0), ("Equalize", 0.6, None)), - (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), - (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), - (("Invert", 0.6, None), ("Equalize", 1.0, None)), - (("Color", 0.6, 4), ("Contrast", 1.0, 8)), - (("Equalize", 0.8, None), ("Equalize", 0.6, None)), - ] - elif policy == AutoAugmentPolicy.CIFAR10: - return [ - (("Invert", 0.1, None), ("Contrast", 0.2, 6)), - (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)), - (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)), - (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)), - (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)), - (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)), - (("Color", 0.4, 3), ("Brightness", 0.6, 7)), - (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)), - (("Equalize", 0.6, None), ("Equalize", 0.5, None)), - (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)), - (("Color", 0.7, 7), ("TranslateX", 0.5, 8)), - (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)), - (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)), - (("Brightness", 0.9, 6), ("Color", 0.2, 8)), - (("Solarize", 0.5, 2), ("Invert", 0.0, None)), - (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)), - (("Equalize", 0.2, None), ("Equalize", 0.6, None)), - (("Color", 0.9, 9), ("Equalize", 0.6, None)), - (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)), - (("Brightness", 0.1, 3), ("Color", 0.7, 0)), - (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)), - (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)), - (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)), - (("Equalize", 0.8, None), ("Invert", 0.1, None)), - (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)), - ] - elif policy == AutoAugmentPolicy.SVHN: - return [ - (("ShearX", 0.9, 4), ("Invert", 0.2, None)), - (("ShearY", 0.9, 8), ("Invert", 0.7, None)), - (("Equalize", 0.6, None), ("Solarize", 0.6, 6)), - (("Invert", 0.9, None), ("Equalize", 0.6, None)), - (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), - (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)), - (("ShearY", 0.9, 8), ("Invert", 0.4, None)), - (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)), - (("Invert", 0.9, None), ("AutoContrast", 0.8, None)), - (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), - (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)), - (("ShearY", 0.8, 8), ("Invert", 0.7, None)), - (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)), - (("Invert", 0.9, None), ("Equalize", 0.6, None)), - (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)), - (("Invert", 0.8, None), ("TranslateY", 0.0, 2)), - (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)), - (("Invert", 0.6, None), ("Rotate", 0.8, 4)), - (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)), - (("ShearX", 0.1, 6), ("Invert", 0.6, None)), - (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)), - (("ShearY", 0.8, 4), ("Invert", 0.8, None)), - (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)), - (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)), - (("ShearX", 0.7, 2), ("Invert", 0.1, None)), - ] - else: - raise ValueError("The provided policy {} is not recognized.".format(policy)) - - def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: - return { - # op_name: (magnitudes, signed) - "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), - "ShearY": (torch.linspace(0.0, 0.3, num_bins), True), - "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), - "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), - "Rotate": (torch.linspace(0.0, 30.0, num_bins), True), - "Brightness": (torch.linspace(0.0, 0.9, num_bins), True), - "Color": (torch.linspace(0.0, 0.9, num_bins), True), - "Contrast": (torch.linspace(0.0, 0.9, num_bins), True), - "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True), - "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False), - "Solarize": (torch.linspace(255.0, 0.0, num_bins), False), - "AutoContrast": (torch.tensor(0.0), False), - "Equalize": (torch.tensor(0.0), False), - "Invert": (torch.tensor(0.0), False), - } - - @staticmethod - def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]: - """Get parameters for autoaugment transformation - - Returns: - params required by the autoaugment transformation - """ - policy_id = int(torch.randint(transform_num, (1,)).item()) - probs = torch.rand((2,)) - signs = torch.randint(2, (2,)) - - return policy_id, probs, signs - - def forward(self, img: Tensor, dis_mag = True) -> Tensor: - """ - img (PIL Image or Tensor): Image to be transformed. - - Returns: - PIL Image or Tensor: AutoAugmented image. - """ - fill = self.fill - if isinstance(img, Tensor): - if isinstance(fill, (int, float)): - fill = [float(fill)] * F.get_image_num_channels(img) - elif fill is not None: - fill = [float(f) for f in fill] - - transform_id, probs, signs = self.get_params(len(self.subpolicies)) - - for i, (op_name, p, magnitude) in enumerate(self.subpolicies): - img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) - - return img - - def __repr__(self) -> str: - return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill) - - -class RandAugment(torch.nn.Module): - r"""RandAugment data augmentation method based on - `"RandAugment: Practical automated data augmentation with a reduced search space" - <https://arxiv.org/abs/1909.13719>`_. - If the image is torch Tensor, it should be of type torch.uint8, and it is expected - to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. - If img is PIL Image, it is expected to be in mode "L" or "RGB". - - Args: - num_ops (int): Number of augmentation transformations to apply sequentially. - magnitude (int): Magnitude for all the transformations. - num_magnitude_bins (int): The number of different magnitude values. - interpolation (InterpolationMode): Desired interpolation enum defined by - :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. - If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. - fill (sequence or number, optional): Pixel fill value for the area outside the transformed - image. If given a number, the value is used for all bands respectively. - """ - - def __init__(self, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 31, - interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[List[float]] = None) -> None: - super().__init__() - self.num_ops = num_ops - self.magnitude = magnitude - self.num_magnitude_bins = num_magnitude_bins - self.interpolation = interpolation - self.fill = fill - - def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: - return { - # op_name: (magnitudes, signed) - "Identity": (torch.tensor(0.0), False), - "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), - "ShearY": (torch.linspace(0.0, 0.3, num_bins), True), - "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), - "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), - "Rotate": (torch.linspace(0.0, 30.0, num_bins), True), - "Brightness": (torch.linspace(0.0, 0.9, num_bins), True), - "Color": (torch.linspace(0.0, 0.9, num_bins), True), - "Contrast": (torch.linspace(0.0, 0.9, num_bins), True), - "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True), - "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False), - "Solarize": (torch.linspace(255.0, 0.0, num_bins), False), - "AutoContrast": (torch.tensor(0.0), False), - "Equalize": (torch.tensor(0.0), False), - } - - def forward(self, img: Tensor) -> Tensor: - """ - img (PIL Image or Tensor): Image to be transformed. - - Returns: - PIL Image or Tensor: Transformed image. - """ - fill = self.fill - if isinstance(img, Tensor): - if isinstance(fill, (int, float)): - fill = [float(fill)] * F.get_image_num_channels(img) - elif fill is not None: - fill = [float(f) for f in fill] - - for _ in range(self.num_ops): - op_meta = self._augmentation_space(self.num_magnitude_bins, F.get_image_size(img)) - op_index = int(torch.randint(len(op_meta), (1,)).item()) - op_name = list(op_meta.keys())[op_index] - magnitudes, signed = op_meta[op_name] - magnitude = float(magnitudes[self.magnitude].item()) if magnitudes.ndim > 0 else 0.0 - if signed and torch.randint(2, (1,)): - magnitude *= -1.0 - img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) - - return img - - def __repr__(self) -> str: - s = self.__class__.__name__ + '(' - s += 'num_ops={num_ops}' - s += ', magnitude={magnitude}' - s += ', num_magnitude_bins={num_magnitude_bins}' - s += ', interpolation={interpolation}' - s += ', fill={fill}' - s += ')' - return s.format(**self.__dict__) - - -class TrivialAugmentWide(torch.nn.Module): - r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in - `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`. - If the image is torch Tensor, it should be of type torch.uint8, and it is expected - to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. - If img is PIL Image, it is expected to be in mode "L" or "RGB". - - Args: - num_magnitude_bins (int): The number of different magnitude values. - interpolation (InterpolationMode): Desired interpolation enum defined by - :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. - If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. - fill (sequence or number, optional): Pixel fill value for the area outside the transformed - image. If given a number, the value is used for all bands respectively. - """ - - def __init__(self, num_magnitude_bins: int = 31, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[List[float]] = None) -> None: - super().__init__() - self.num_magnitude_bins = num_magnitude_bins - self.interpolation = interpolation - self.fill = fill - - def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]: - return { - # op_name: (magnitudes, signed) - "Identity": (torch.tensor(0.0), False), - "ShearX": (torch.linspace(0.0, 0.99, num_bins), True), - "ShearY": (torch.linspace(0.0, 0.99, num_bins), True), - "TranslateX": (torch.linspace(0.0, 32.0, num_bins), True), - "TranslateY": (torch.linspace(0.0, 32.0, num_bins), True), - "Rotate": (torch.linspace(0.0, 135.0, num_bins), True), - "Brightness": (torch.linspace(0.0, 0.99, num_bins), True), - "Color": (torch.linspace(0.0, 0.99, num_bins), True), - "Contrast": (torch.linspace(0.0, 0.99, num_bins), True), - "Sharpness": (torch.linspace(0.0, 0.99, num_bins), True), - "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False), - "Solarize": (torch.linspace(255.0, 0.0, num_bins), False), - "AutoContrast": (torch.tensor(0.0), False), - "Equalize": (torch.tensor(0.0), False), - } - - def forward(self, img: Tensor) -> Tensor: - """ - img (PIL Image or Tensor): Image to be transformed. - - Returns: - PIL Image or Tensor: Transformed image. - """ - fill = self.fill - if isinstance(img, Tensor): - if isinstance(fill, (int, float)): - fill = [float(fill)] * F.get_image_num_channels(img) - elif fill is not None: - fill = [float(f) for f in fill] - - op_meta = self._augmentation_space(self.num_magnitude_bins) - op_index = int(torch.randint(len(op_meta), (1,)).item()) - op_name = list(op_meta.keys())[op_index] - magnitudes, signed = op_meta[op_name] - magnitude = float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item()) \ - if magnitudes.ndim > 0 else 0.0 - if signed and torch.randint(2, (1,)): - magnitude *= -1.0 - - return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) - - def __repr__(self) -> str: - s = self.__class__.__name__ + '(' - s += 'num_magnitude_bins={num_magnitude_bins}' - s += ', interpolation={interpolation}' - s += ', fill={fill}' - s += ')' - return s.format(**self.__dict__) - - - - - - - diff --git a/MetaAugment/autoaugment_learners/evo_learner.py b/MetaAugment/autoaugment_learners/evo_learner.py index 92347c767098d033eb0afeecd2f640a1f016e142..1ff576b28a4d39f367550bb2fa15168e0a9b99c8 100644 --- a/MetaAugment/autoaugment_learners/evo_learner.py +++ b/MetaAugment/autoaugment_learners/evo_learner.py @@ -6,7 +6,7 @@ import pygad import pygad.torchga as torchga import copy import torch -from MetaAugment.controller_networks.evo_controller import Evo_learner +from MetaAugment.controller_networks.evo_controller import evo_controller from MetaAugment.autoaugment_learners.aa_learner import aa_learner, augmentation_space import MetaAugment.child_networks as cn @@ -46,7 +46,7 @@ class evo_learner(): early_stop_num=early_stop_num,) self.num_solutions = num_solutions - self.auto_aug_agent = Evo_learner(fun_num=fun_num, p_bins=p_bins, m_bins=m_bins, sub_num_pol=sp_num) + self.auto_aug_agent = evo_controller(fun_num=fun_num, p_bins=p_bins, m_bins=m_bins, sub_num_pol=sp_num) self.torch_ga = torchga.TorchGA(model=self.auto_aug_agent, num_solutions=num_solutions) self.num_parents_mating = num_parents_mating self.initial_population = self.torch_ga.population_weights @@ -246,6 +246,7 @@ class evo_learner(): else: full_policy = self.get_full_policy(test_x) +# Checkpoint -> save learner as a pickle fit_val = ((self.test_autoaugment_policy(full_policy, train_dataset, test_dataset)[0]) / + self.test_autoaugment_policy(full_policy, train_dataset, test_dataset)[0]) / 2 diff --git a/MetaAugment/controller_networks/evo_controller.py b/MetaAugment/controller_networks/evo_controller.py index aa7a1da2b6bf0bd3789118825c5c798b3aa32387..55dafc05d76819f148fb0e158d1be36bcffe311d 100644 --- a/MetaAugment/controller_networks/evo_controller.py +++ b/MetaAugment/controller_networks/evo_controller.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn import math -class Evo_learner(nn.Module): +class evo_controller(nn.Module): def __init__(self, fun_num=14, p_bins=11, m_bins=10, sub_num_pol=5): self.fun_num = fun_num self.p_bins = p_bins diff --git a/backend_react/react_app.py b/backend_react/react_app.py index ab6be331273b4664629d6842a5f32e2c30de0f27..0cb3d89fc767984580cba0fe2ef47cc43daf3aaa 100644 --- a/backend_react/react_app.py +++ b/backend_react/react_app.py @@ -24,8 +24,10 @@ import sys sys.path.insert(0, os.path.abspath('..')) # import agents and its functions -from MetaAugment import UCB1_JC_py as UCB1_JC -from MetaAugment import Evo_learner as Evo +from ..MetaAugment import UCB1_JC_py as UCB1_JC +from ..MetaAugment.autoaugment_learners import evo_learner +import MetaAugment.controller_networks as cn +import MetaAugment.autoaugment_learners as aal print('@@@ import successful') # import agents and its functions @@ -181,9 +183,9 @@ def training(): best_q_values = np.array(best_q_values) elif auto_aug_learner == 'Evolutionary Learner': - network = Evo.Learner(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1) - child_network = Evo.LeNet() - learner = Evo.Evolutionary_learner(network=network, fun_num=num_funcs, p_bins=1, mag_bins=1, sub_num_pol=1, ds = ds, ds_name=ds_name, exclude_method=exclude_method, child_network=child_network) + network = cn.evo_controller.evo_controller(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1) + child_network = aal.evo.LeNet() + learner = aal.evo.evo_learner(network=network, fun_num=num_funcs, p_bins=1, mag_bins=1, sub_num_pol=1, ds = ds, ds_name=ds_name, exclude_method=exclude_method, child_network=child_network) learner.run_instance() elif auto_aug_learner == 'Random Searcher': pass diff --git a/benchmark/pickles/04_22_cf_ln_gru.pkl b/benchmark/pickles/04_22_cf_ln_gru.pkl new file mode 100644 index 0000000000000000000000000000000000000000..ad21c6df0caf681f54acac11cfd51e58a70acab1 Binary files /dev/null and b/benchmark/pickles/04_22_cf_ln_gru.pkl differ diff --git a/benchmark/pickles/04_22_cf_ln_rs.pkl b/benchmark/pickles/04_22_cf_ln_rs.pkl new file mode 100644 index 0000000000000000000000000000000000000000..c25ff45ad7003d1fe589fc22525cb16b7df1b644 Binary files /dev/null and b/benchmark/pickles/04_22_cf_ln_rs.pkl differ diff --git a/benchmark/pickles/04_22_fm_sn_gru.pkl b/benchmark/pickles/04_22_fm_sn_gru.pkl new file mode 100644 index 0000000000000000000000000000000000000000..d64164edb915c0b611f402d56ce482d01fa82c9b Binary files /dev/null and b/benchmark/pickles/04_22_fm_sn_gru.pkl differ diff --git a/benchmark/pickles/04_22_fm_sn_rs.pkl b/benchmark/pickles/04_22_fm_sn_rs.pkl new file mode 100644 index 0000000000000000000000000000000000000000..452e2fb5458f750f6cc475e2f20c40ff138ad742 Binary files /dev/null and b/benchmark/pickles/04_22_fm_sn_rs.pkl differ diff --git a/benchmark/scripts/04_22_ci_gru.py b/benchmark/scripts/04_22_ci_gru.py index 5c4db6bd5ce8f713f8dbc6ea829176454e4ce28a..1a5b0fadca16fbaabe583b3a58abf0ae9d87c4db 100644 --- a/benchmark/scripts/04_22_ci_gru.py +++ b/benchmark/scripts/04_22_ci_gru.py @@ -35,12 +35,23 @@ child_network_architecture = cn.LeNet( ) -# gru +save_dir='./benchmark/pickles/04_22_cf_ln_gru' + +# rs run_benchmark( - save_file='./benchmark/pickles/04_22_cf_ln_gru', + save_file=save_dir+'.pkl', train_dataset=train_dataset, test_dataset=test_dataset, child_network_architecture=child_network_architecture, agent_arch=aal.gru_learner, config=config, + ) + +rerun_best_policy( + agent_pickle=save_dir+'.pkl', + accs_txt=save_dir+'.txt', + train_dataset=train_dataset, + test_dataset=test_dataset, + child_network_architecture=child_network_architecture, + repeat_num=5 ) \ No newline at end of file diff --git a/benchmark/scripts/04_22_ci_rs.py b/benchmark/scripts/04_22_ci_rs.py index b98c25fb7826918cfa4f1e6cdb5dc1484cd9c662..21f3a9a3e65eb8e7126ec32641c48726b2f4172c 100644 --- a/benchmark/scripts/04_22_ci_rs.py +++ b/benchmark/scripts/04_22_ci_rs.py @@ -34,13 +34,23 @@ child_network_architecture = cn.LeNet( img_channels=3 ) +save_dir='./benchmark/pickles/04_22_cf_ln_rs' # rs run_benchmark( - save_file='./benchmark/pickles/04_22_cf_ln_rs', + save_file=save_dir+'.pkl', train_dataset=train_dataset, test_dataset=test_dataset, child_network_architecture=child_network_architecture, agent_arch=aal.randomsearch_learner, config=config, ) + +rerun_best_policy( + agent_pickle=save_dir+'.pkl', + accs_txt=save_dir+'.txt', + train_dataset=train_dataset, + test_dataset=test_dataset, + child_network_architecture=child_network_architecture, + repeat_num=5 + ) \ No newline at end of file diff --git a/benchmark/scripts/04_22_fm_gru.py b/benchmark/scripts/04_22_fm_gru.py index 227918517fef504b8e5b27aaab354c3c2366e1c8..b3a951c0afd3eeb0cd8911a30143f08a61c6e5e4 100644 --- a/benchmark/scripts/04_22_fm_gru.py +++ b/benchmark/scripts/04_22_fm_gru.py @@ -30,12 +30,23 @@ test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test', child_network_architecture = cn.SimpleNet -# gru +save_dir='./benchmark/pickles/04_22_fm_sn_gru' + +# rs run_benchmark( - save_file='./benchmark/pickles/04_22_fm_sn_gru.pkl', + save_file=save_dir+'.pkl', train_dataset=train_dataset, test_dataset=test_dataset, child_network_architecture=child_network_architecture, agent_arch=aal.gru_learner, config=config, + ) + +rerun_best_policy( + agent_pickle=save_dir+'.pkl', + accs_txt=save_dir+'.txt', + train_dataset=train_dataset, + test_dataset=test_dataset, + child_network_architecture=child_network_architecture, + repeat_num=5 ) \ No newline at end of file diff --git a/benchmark/scripts/04_22_fm_rs.py b/benchmark/scripts/04_22_fm_rs.py index 33a4b7b26cd4211bdc7b241c0370fc5b5f1abe59..0589630f3906fedfeca72326b1d77fdf9332d5b9 100644 --- a/benchmark/scripts/04_22_fm_rs.py +++ b/benchmark/scripts/04_22_fm_rs.py @@ -30,12 +30,23 @@ test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test', child_network_architecture = cn.SimpleNet +save_dir='./benchmark/pickles/04_22_fm_sn_rs' + # rs run_benchmark( - save_file='./benchmark/pickles/04_22_fm_sn_rs.pkl', + save_file=save_dir+'.pkl', train_dataset=train_dataset, test_dataset=test_dataset, child_network_architecture=child_network_architecture, agent_arch=aal.randomsearch_learner, config=config, + ) + +rerun_best_policy( + agent_pickle=save_dir+'.pkl', + accs_txt=save_dir+'.txt', + train_dataset=train_dataset, + test_dataset=test_dataset, + child_network_architecture=child_network_architecture, + repeat_num=5 ) \ No newline at end of file diff --git a/benchmark/scripts/util_04_22.py b/benchmark/scripts/util_04_22.py index 8e39fdd69cef06f52e48d92136e8f617b85dfaf8..86b033ef65efa96782e809136f2793ebaad6b044 100644 --- a/benchmark/scripts/util_04_22.py +++ b/benchmark/scripts/util_04_22.py @@ -1,3 +1,4 @@ +from matplotlib.pyplot import get import torchvision.datasets as datasets import torchvision import torch @@ -5,7 +6,7 @@ import torch import MetaAugment.child_networks as cn import MetaAugment.autoaugment_learners as aal - +from pprint import pprint """ testing gru_learner and randomsearch_learner on @@ -56,4 +57,61 @@ def run_benchmark( with open(save_file, 'wb+') as f: torch.save(agent, f) - print('run_benchmark closing') \ No newline at end of file + print('run_benchmark closing') + + +def get_mega_policy(history, n): + """ + we get the best n policies from an agent's history, + concatenate them to form our best mega policy + + Args: + history (list[tuple]) + n (int) + + Returns: + list[float]: validation accuracies + """ + assert len(history) >= n + + # agent.history is a list of (policy(list), val_accuracy(float)) tuples + sorted_history = sorted(history, key=lambda x:x[1]) # sort wrt acc + + best_history = sorted_history[:n] + + megapolicy = [] + for policy,acc in best_history: + for subpolicy in policy: + megapolicy.append(subpolicy) + + return megapolicy + + +def rerun_best_policy( + agent_pickle, + accs_txt, + train_dataset, + test_dataset, + child_network_architecture, + repeat_num + ): + + with open(agent_pickle, 'rb') as f: + agent = torch.load(f, map_location=device) + + megapol = get_mega_policy(agent.history) + print('mega policy to be tested:') + pprint(megapol) + + accs=[] + for _ in range(repeat_num): + print(f'{_}/{repeat_num}') + accs.append( + agent.test_autoaugment_policy(megapol, + child_network_architecture, + train_dataset, + test_dataset, + logging=False) + ) + with open(accs_txt, 'w') as f: + f.write(str(accs)) diff --git a/setupProxy.js b/setupProxy.js new file mode 100644 index 0000000000000000000000000000000000000000..0b021257aca377b78503c40a9ccfb0a95197b16b --- /dev/null +++ b/setupProxy.js @@ -0,0 +1,11 @@ +const { createProxyMiddleware } = require('http-proxy-middleware'); + +module.exports = function(app) { + app.use( + '/api', + createProxyMiddleware({ + target: 'http://localhost:3000', + changeOrigin: true, + }) + ); +}; \ No newline at end of file