Skip to content
Snippets Groups Projects
aa_learner.py 14.3 KiB
Newer Older
  • Learn to ignore specific revisions
  • import torch.nn as nn
    import torch.optim as optim
    from MetaAugment.main import train_child_network, create_toy
    from MetaAugment.autoaugment_learners.autoaugment import AutoAugment
    
    import torchvision.transforms as transforms
    
    
    from pprint import pprint
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    import matplotlib.pyplot as plt
    
    
    
    # We will use this augmentation_space temporarily. Later on we will need to 
    # make sure we are able to add other image functions if the users want.
    augmentation_space = [
                # (function_name, do_we_need_to_specify_magnitude)
                ("ShearX", True),
                ("ShearY", True),
                ("TranslateX", True),
                ("TranslateY", True),
                ("Rotate", True),
                ("Brightness", True),
                ("Color", True),
                ("Contrast", True),
                ("Sharpness", True),
                ("Posterize", True),
                ("Solarize", True),
                ("AutoContrast", False),
                ("Equalize", False),
                ("Invert", False),
            ]
    
    
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    class aa_learner:
    
        """
        The parent class for all aa_learner's
        """
    
        def __init__(self, sp_num=5, fun_num=14, p_bins=11, m_bins=10, discrete_p_m=False):
    
                spdim (int): number of subpolicies per policy
                fun_num (int): number of image functions in our search space
                p_bins (int): number of bins we divide the interval [0,1] for probabilities
                m_bins (int): number of bins we divide the magnitude space
    
                discrete_p_m (boolean): Whether or not the agent should represent probability and 
                                        magnitude as discrete variables as the out put of the 
                                        controller (A controller can be a neural network, genetic
                                        algorithm, etc.)
    
            self.sp_num = sp_num
    
            self.fun_num = fun_num
            self.p_bins = p_bins
            self.m_bins = m_bins
    
            self.op_tensor_length = fun_num+p_bins+m_bins if discrete_p_m else fun_num+2
    
    
            # should we repre
            self.discrete_p_m = discrete_p_m
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    
    
            # TODO: We should probably use a different way to store results than self.history
            self.history = []
    
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
        def translate_operation_tensor(self, operation_tensor, return_log_prob=False, argmax=False):
    
            takes in a tensor representing an operation and returns an actual operation which
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
            is in the form of:
                ("Invert", 0.8, None)
                or
                ("Contrast", 0.2, 6)
    
            Args:
    
                operation_tensor (tensor): 
    
                                    We expect this tensor to already have been softmaxed.
                                    Furthermore,
                                    - If self.discrete_p_m is True, we expect to take in a tensor with
    
                                    dimension (self.fun_num + self.p_bins + self.m_bins)
    
                                    - If self.discrete_p_m is False, we expect to take in a tensor with
    
                                    dimension (self.fun_num + 1 + 1)
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                return_log_prob (boolesn): 
                                    When this is on, we return which indices (of fun, prob, mag) were
                                    chosen (either randomly or deterministically, depending on argmax).
                                    This is used, for example, in the gru_learner to calculate the
                                    probability of the actions were chosen, which is then logged, then
                                    differentiated.
    
    
                argmax (boolean): 
                                Whether we are taking the argmax of the softmaxed tensors. 
                                If this is False, we treat the softmaxed outputs as multinomial pdf's.
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    
            Returns:
                operation (list of tuples):
                                    An operation in the format that can be directly put into an
                                    AutoAugment object.
    
                log_prob (float):
                                Used in reinforcement learning updates, such as proximal policy update
                                in the gru_learner.
                                Can only be used when self.discrete_p_m.
                                We add the logged values of the indices of the image_function,
                                probability, and magnitude chosen.
                                This corresponds to multiplying the non-logged values, then logging
                                it.                  
            """
    
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
            if (not self.discrete_p_m) and return_log_prob:
                raise ValueError("You are not supposed to use return_log_prob=True when the agent's \
                                self.discrete_p_m is False!")
    
    
            # make sure shape is correct
            assert operation_tensor.shape==(self.op_tensor_length, ), operation_tensor.shape
    
    
            # if probability and magnitude are represented as discrete variables
            if self.discrete_p_m:
    
                fun_t, prob_t, mag_t = operation_tensor.split([self.fun_num, self.p_bins, self.m_bins])
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    
    
                # make sure they are of right size
                assert fun_t.shape==(self.fun_num,), f'{fun_t.shape} != {self.fun_num}'
                assert prob_t.shape==(self.p_bins,), f'{prob_t.shape} != {self.p_bins}'
                assert mag_t.shape==(self.m_bins,), f'{mag_t.shape} != {self.m_bins}'
    
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                    fun_idx = torch.argmax(fun_t).item()
                    prob_idx = torch.argmax(prob_t).item() # 0 <= p <= 10
    
                    mag = torch.argmax(mag_t).item() # 0 <= m <= 9
    
                elif argmax==False:
                    # we need these to add up to 1 to be valid pdf's of multinomials
    
                    assert torch.sum(fun_t).isclose(torch.ones(1)), torch.sum(fun_t)
                    assert torch.sum(prob_t).isclose(torch.ones(1)), torch.sum(prob_t)
                    assert torch.sum(mag_t).isclose(torch.ones(1)), torch.sum(mag_t)
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    
                    fun_idx = torch.multinomial(fun_t, 1).item() # 0 <= fun <= self.fun_num-1
                    prob_idx = torch.multinomial(prob_t, 1).item() # 0 <= p <= 10
    
                    mag = torch.multinomial(mag_t, 1).item() # 0 <= m <= 9
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                function = augmentation_space[fun_idx][0]
                prob = prob_idx/10
    
                indices = (fun_idx, prob_idx, mag)
    
                # log probability is the sum of the log of the softmax values of the indices 
                # (of fun_t, prob_t, mag_t) that we have chosen
                log_prob = torch.log(fun_t[fun_idx]) + torch.log(prob_t[prob_idx]) + torch.log(mag_t[mag])
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    
    
            # if probability and magnitude are represented as continuous variables
            else:
    
                fun_t, prob, mag = operation_tensor.split([self.fun_num, 1, 1])
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                prob = prob.item()
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                mag = mag.item()
    
                # make sure the shape is correct
                assert fun_t.shape==(self.fun_num,), f'{fun_t.shape} != {self.fun_num}'
                
                if argmax==True:
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                    fun_idx = torch.argmax(fun_t)
    
                    assert torch.sum(fun_t).isclose(torch.ones(1))
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                    fun_idx = torch.multinomial(fun_t, 1).item()
                prob = round(prob, 1) # round to nearest first decimal digit
                mag = round(mag) # round to nearest integer
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
            function = augmentation_space[fun_idx][0]
    
            # if the image function does not require a magnitude, we set the magnitude to None
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
            if augmentation_space[fun_idx][1] == True: # if the image function has a magnitude
                operation = (function, prob, mag)
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                operation =  (function, prob, None)
            
            if return_log_prob:
                return operation, log_prob
            else:
                return operation
            
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    
        def generate_new_policy(self):
    
            """
            Generate a new policy which can be fed into an AutoAugment object 
            by calling:
                AutoAugment.subpolicies = policy
            
            Args:
                none
            
            Returns:
                new_policy (list[tuple]):
                            A new policy generated by the controller. It
                            has the form of:
                                [
                                (("Invert", 0.8, None), ("Contrast", 0.2, 6)),
                                (("Rotate", 0.7, 2), ("Invert", 0.8, None)),
                                (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
                                (("ShearY", 0.5, 8), ("Invert", 0.7, None)),
                                ]
                            This object can be fed into an AutoAUgment object
                            by calling: AutoAugment.subpolicies = policy
            """
    
    
            raise NotImplementedError('generate_new_policy not implemented in aa_learner')
    
    
    
        def learn(self, train_dataset, test_dataset, child_network_architecture, toy_flag):
    
            """
            Runs the main loop (of finding a good policy for the given child network,
            training dataset, and test(validation) dataset)
    
            Does the loop which is seen in Figure 1 in the AutoAugment paper
            which is:
    
                1. <generate a random policy>
                2. <see how good that policy is>
                3. <save how good the policy is in a list/dictionary>
    
            
            Args:
                train_dataset (torchvision.dataset.vision.VisionDataset)
                test_dataset (torchvision.dataset.vision.VisionDataset)
                child_network_architecture (type): NOTE THAT THIS VARIABLE IS NOT
                                        A nn.module object. Therefore, this needs
                                        to be, say, `models.LeNet` instead of 
                                        `models.LeNet()`.
                toy_flag (boolean): whether we want to obtain a toy version of 
                                train_dataset and test_dataset and use those.
    
            Returns:
                none
            """
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    
    
            # This is dummy code
    
            # test out 15 random policies
            for _ in range(15):
                policy = self.generate_new_policy()
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    
    
                pprint(policy)
                child_network = child_network_architecture()
                reward = self.test_autoaugment_policy(policy, child_network, train_dataset,
                                                    test_dataset, toy_flag)
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    
    
                self.history.append((policy, reward))
        
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
        def test_autoaugment_policy(self, policy, child_network, train_dataset, test_dataset, 
                                    toy_flag, logging=False):
    
            Given a policy (using AutoAugment paper terminology), we train a child network
    
            using the policy and return the accuracy (how good the policy is for the dataset and 
            child network).
    
    
            Args: 
                policy (list[tuple]): A list of tuples representing a policy.
                child_network (nn.module)
                train_dataset (torchvision.dataset.vision.VisionDataset)
                test_dataset (torchvision.dataset.vision.VisionDataset)
                toy_flag (boolean): Whether we want to obtain a toy version of 
                                train_dataset and test_dataset and use those.
                logging (boolean): Whether we want to save logs
            
            Returns:
                accuracy (float): best accuracy reached in any
            """
    
    
            # We need to define an object aa_transform which takes in the image and 
            # transforms it with the policy (specified in its .policies attribute)
            # in its forward pass
            aa_transform = AutoAugment()
    
            aa_transform.subpolicies = policy
    
            train_transform = transforms.Compose([
                                                    aa_transform,
                                                    transforms.ToTensor()
                                                ])
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
            
    
            # We feed the transformation into the Dataset object
            train_dataset.transform = train_transform
    
            # create Dataloader objects out of the Dataset objects
            train_loader, test_loader = create_toy(train_dataset,
                                                    test_dataset,
    
            # train the child network with the dataloaders equipped with our specific policy
            accuracy = train_child_network(child_network, 
                                        train_loader, 
                                        test_loader, 
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                                        sgd = optim.SGD(child_network.parameters(), lr=3e-1),
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                                        # sgd = optim.Adadelta(child_network.parameters(), lr=1e-2),
    
                                        cost = nn.CrossEntropyLoss(),
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                                        max_epochs = 3000000, 
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                                        early_stop_num = 15, 
    
                                        logging = logging,
                                        print_every_epoch=True)
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
            
            # if logging is true, 'accuracy' is actually a tuple: (accuracy, accuracy_log)
    
            return accuracy
        
    
        def demo_plot(self, train_dataset, test_dataset, child_network_architecture, toy_flag, n=5):
            """
            I made this to plot a couple of accuracy graphs to help manually tune my gradient 
            optimizer hyperparameters.
    
            Saves a plot of `n` training accuracy graphs overlapped.
            """
            
            acc_lists = []
    
            # This is dummy code
            # test out `n` random policies
            for _ in range(n):
                policy = self.generate_new_policy()
    
                pprint(policy)
                child_network = child_network_architecture()
                reward, acc_list = self.test_autoaugment_policy(policy, child_network, train_dataset,
                                                    test_dataset, toy_flag, logging=True)
    
                self.history.append((policy, reward))
                acc_lists.append(acc_list)
    
            for acc_list in acc_lists:
                plt.plot(acc_list)
            plt.title('I ran 5 random policies to see if there is any sign of \
                        catastrophic failure during training. If there are \
                        any lines which reach significantly lower (>10%) \
                        accuracies, you might want to tune the hyperparameters')
            plt.xlabel('epoch')
            plt.ylabel('accuracy')
            plt.show()
            plt.savefig('training_graphs_without_policies')