import torch import torch.nn as nn import torch.optim as optim from MetaAugment.main import train_child_network, create_toy from MetaAugment.autoaugment_learners.autoaugment import AutoAugment import torchvision.transforms as transforms from pprint import pprint import matplotlib.pyplot as plt # We will use this augmentation_space temporarily. Later on we will need to # make sure we are able to add other image functions if the users want. augmentation_space = [ # (function_name, do_we_need_to_specify_magnitude) ("ShearX", True), ("ShearY", True), ("TranslateX", True), ("TranslateY", True), ("Rotate", True), ("Brightness", True), ("Color", True), ("Contrast", True), ("Sharpness", True), ("Posterize", True), ("Solarize", True), ("AutoContrast", False), ("Equalize", False), ("Invert", False), ] class aa_learner: """ The parent class for all aa_learner's """ def __init__(self, sp_num=5, fun_num=14, p_bins=11, m_bins=10, discrete_p_m=False): """ Args: spdim (int): number of subpolicies per policy fun_num (int): number of image functions in our search space p_bins (int): number of bins we divide the interval [0,1] for probabilities m_bins (int): number of bins we divide the magnitude space discrete_p_m (boolean): Whether or not the agent should represent probability and magnitude as discrete variables as the out put of the controller (A controller can be a neural network, genetic algorithm, etc.) """ self.sp_num = sp_num self.fun_num = fun_num self.p_bins = p_bins self.m_bins = m_bins self.op_tensor_length = fun_num+p_bins+m_bins if discrete_p_m else fun_num+2 # should we repre self.discrete_p_m = discrete_p_m # TODO: We should probably use a different way to store results than self.history self.history = [] def translate_operation_tensor(self, operation_tensor, return_log_prob=False, argmax=False): """ takes in a tensor representing an operation and returns an actual operation which is in the form of: ("Invert", 0.8, None) or ("Contrast", 0.2, 6) Args: operation_tensor (tensor): We expect this tensor to already have been softmaxed. Furthermore, - If self.discrete_p_m is True, we expect to take in a tensor with dimension (self.fun_num + self.p_bins + self.m_bins) - If self.discrete_p_m is False, we expect to take in a tensor with dimension (self.fun_num + 1 + 1) return_log_prob (boolesn): When this is on, we return which indices (of fun, prob, mag) were chosen (either randomly or deterministically, depending on argmax). This is used, for example, in the gru_learner to calculate the probability of the actions were chosen, which is then logged, then differentiated. argmax (boolean): Whether we are taking the argmax of the softmaxed tensors. If this is False, we treat the softmaxed outputs as multinomial pdf's. Returns: operation (list of tuples): An operation in the format that can be directly put into an AutoAugment object. log_prob (float): Used in reinforcement learning updates, such as proximal policy update in the gru_learner. Can only be used when self.discrete_p_m. We add the logged values of the indices of the image_function, probability, and magnitude chosen. This corresponds to multiplying the non-logged values, then logging it. """ if (not self.discrete_p_m) and return_log_prob: raise ValueError("You are not supposed to use return_log_prob=True when the agent's \ self.discrete_p_m is False!") # make sure shape is correct assert operation_tensor.shape==(self.op_tensor_length, ), operation_tensor.shape # if probability and magnitude are represented as discrete variables if self.discrete_p_m: fun_t, prob_t, mag_t = operation_tensor.split([self.fun_num, self.p_bins, self.m_bins]) # make sure they are of right size assert fun_t.shape==(self.fun_num,), f'{fun_t.shape} != {self.fun_num}' assert prob_t.shape==(self.p_bins,), f'{prob_t.shape} != {self.p_bins}' assert mag_t.shape==(self.m_bins,), f'{mag_t.shape} != {self.m_bins}' if argmax==True: fun_idx = torch.argmax(fun_t).item() prob_idx = torch.argmax(prob_t).item() # 0 <= p <= 10 mag = torch.argmax(mag_t).item() # 0 <= m <= 9 elif argmax==False: # we need these to add up to 1 to be valid pdf's of multinomials assert torch.sum(fun_t).isclose(torch.ones(1)), torch.sum(fun_t) assert torch.sum(prob_t).isclose(torch.ones(1)), torch.sum(prob_t) assert torch.sum(mag_t).isclose(torch.ones(1)), torch.sum(mag_t) fun_idx = torch.multinomial(fun_t, 1).item() # 0 <= fun <= self.fun_num-1 prob_idx = torch.multinomial(prob_t, 1).item() # 0 <= p <= 10 mag = torch.multinomial(mag_t, 1).item() # 0 <= m <= 9 function = augmentation_space[fun_idx][0] prob = prob_idx/10 indices = (fun_idx, prob_idx, mag) # log probability is the sum of the log of the softmax values of the indices # (of fun_t, prob_t, mag_t) that we have chosen log_prob = torch.log(fun_t[fun_idx]) + torch.log(prob_t[prob_idx]) + torch.log(mag_t[mag]) # if probability and magnitude are represented as continuous variables else: fun_t, prob, mag = operation_tensor.split([self.fun_num, 1, 1]) prob = prob.item() # 0 =< prob =< 1 mag = mag.item() # 0 =< mag =< 9 # make sure the shape is correct assert fun_t.shape==(self.fun_num,), f'{fun_t.shape} != {self.fun_num}' if argmax==True: fun_idx = torch.argmax(fun_t) elif argmax==False: assert torch.sum(fun_t).isclose(torch.ones(1)) fun_idx = torch.multinomial(fun_t, 1).item() prob = round(prob, 1) # round to nearest first decimal digit mag = round(mag) # round to nearest integer function = augmentation_space[fun_idx][0] assert 0 <= prob <= 1 assert 0 <= mag <= self.m_bins-1 # if the image function does not require a magnitude, we set the magnitude to None if augmentation_space[fun_idx][1] == True: # if the image function has a magnitude operation = (function, prob, mag) else: operation = (function, prob, None) if return_log_prob: return operation, log_prob else: return operation def generate_new_policy(self): """ Generate a new policy which can be fed into an AutoAugment object by calling: AutoAugment.subpolicies = policy Args: none Returns: new_policy (list[tuple]): A new policy generated by the controller. It has the form of: [ (("Invert", 0.8, None), ("Contrast", 0.2, 6)), (("Rotate", 0.7, 2), ("Invert", 0.8, None)), (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)), (("ShearY", 0.5, 8), ("Invert", 0.7, None)), ] This object can be fed into an AutoAUgment object by calling: AutoAugment.subpolicies = policy """ raise NotImplementedError('generate_new_policy not implemented in aa_learner') def learn(self, train_dataset, test_dataset, child_network_architecture, toy_flag): """ Runs the main loop (of finding a good policy for the given child network, training dataset, and test(validation) dataset) Does the loop which is seen in Figure 1 in the AutoAugment paper which is: 1. <generate a random policy> 2. <see how good that policy is> 3. <save how good the policy is in a list/dictionary and (if applicable,) update the controller (e.g. RL agent)> Args: train_dataset (torchvision.dataset.vision.VisionDataset) test_dataset (torchvision.dataset.vision.VisionDataset) child_network_architecture (type): NOTE THAT THIS VARIABLE IS NOT A nn.module object. Therefore, this needs to be, say, `models.LeNet` instead of `models.LeNet()`. toy_flag (boolean): whether we want to obtain a toy version of train_dataset and test_dataset and use those. Returns: none """ # This is dummy code # test out 15 random policies # for _ in range(15): # policy = self.generate_new_policy() # pprint(policy) # child_network = child_network_architecture() # reward = self.test_autoaugment_policy(policy, child_network, train_dataset, # test_dataset, toy_flag) # self.history.append((policy, reward)) def test_autoaugment_policy(self, policy, child_network, train_dataset, test_dataset, toy_flag, logging=False): """ Given a policy (using AutoAugment paper terminology), we train a child network using the policy and return the accuracy (how good the policy is for the dataset and child network). Args: policy (list[tuple]): A list of tuples representing a policy. child_network (nn.module) train_dataset (torchvision.dataset.vision.VisionDataset) test_dataset (torchvision.dataset.vision.VisionDataset) toy_flag (boolean): Whether we want to obtain a toy version of train_dataset and test_dataset and use those. logging (boolean): Whether we want to save logs Returns: accuracy (float): best accuracy reached in any """ # We need to define an object aa_transform which takes in the image and # transforms it with the policy (specified in its .policies attribute) # in its forward pass aa_transform = AutoAugment() aa_transform.subpolicies = policy train_transform = transforms.Compose([ aa_transform, transforms.ToTensor() ]) # We feed the transformation into the Dataset object train_dataset.transform = train_transform # create Dataloader objects out of the Dataset objects train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size=32, n_samples=0.5, seed=100) # train the child network with the dataloaders equipped with our specific policy accuracy = train_child_network(child_network, train_loader, test_loader, sgd = optim.SGD(child_network.parameters(), lr=3e-1), # sgd = optim.Adadelta(child_network.parameters(), lr=1e-2), cost = nn.CrossEntropyLoss(), max_epochs = 3000000, early_stop_num = 15, logging = logging, print_every_epoch=True) # if logging is true, 'accuracy' is actually a tuple: (accuracy, accuracy_log) return accuracy def demo_plot(self, train_dataset, test_dataset, child_network_architecture, toy_flag, n=5): """ I made this to plot a couple of accuracy graphs to help manually tune my gradient optimizer hyperparameters. Saves a plot of `n` training accuracy graphs overlapped. """ acc_lists = [] # This is dummy code # test out `n` random policies for _ in range(n): policy = self.generate_new_policy() pprint(policy) child_network = child_network_architecture() reward, acc_list = self.test_autoaugment_policy(policy, child_network, train_dataset, test_dataset, toy_flag, logging=True) self.history.append((policy, reward)) acc_lists.append(acc_list) for acc_list in acc_lists: plt.plot(acc_list) plt.title('I ran 5 random policies to see if there is any sign of \ catastrophic failure during training. If there are \ any lines which reach significantly lower (>10%) \ accuracies, you might want to tune the hyperparameters') plt.xlabel('epoch') plt.ylabel('accuracy') plt.show() plt.savefig('training_graphs_without_policies')