Skip to content
Snippets Groups Projects
evo_learner.py 9.47 KiB
Newer Older
  • Learn to ignore specific revisions
  • import torch
    import torch.nn as nn
    import pygad
    import pygad.torchga as torchga
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    import torchvision
    
    import torch
    
    from MetaAugment.autoaugment_learners.aa_learner import aa_learner
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    import MetaAugment.controller_networks as cont_n
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    class evo_learner(aa_learner):
    
    
        def __init__(self, 
    
                    # search space settings
                    sp_num=5,
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                    p_bins=11, 
    
                    discrete_p_m=False,
                    exclude_method=[],
    
                    # child network settings
                    learning_rate=1e-1, 
    
                    max_epochs=float('inf'),
                    early_stop_num=20,
                    batch_size=8,
    
                    toy_size=1,
                    # evolutionary learner specific settings
                    num_solutions=5,
                    num_parents_mating=3,
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                    controller=cont_n.evo_controller
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
            super().__init__(
                        sp_num=sp_num, 
                        p_bins=p_bins, 
                        m_bins=m_bins, 
                        discrete_p_m=discrete_p_m, 
                        batch_size=batch_size, 
                        toy_size=toy_size, 
                        learning_rate=learning_rate,
                        max_epochs=max_epochs,
                        early_stop_num=early_stop_num,
                        exclude_method=exclude_method
                        )
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
            # evolutionary algorithm settings
            self.controller = controller(
                            fun_num=self.fun_num, 
                            p_bins=self.p_bins, 
                            m_bins=self.m_bins, 
                            sub_num_pol=self.sp_num
                            )
    
    Max Ramsay King's avatar
    Max Ramsay King committed
            self.num_solutions = num_solutions
    
            self.torch_ga = torchga.TorchGA(model=self.controller, num_solutions=num_solutions)
    
            self.num_parents_mating = num_parents_mating
            self.initial_population = self.torch_ga.population_weights
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    
            # store our logs
    
            self.running_policy = []
    
            assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!'
    
    Max Ramsay King's avatar
    Max Ramsay King committed
        def _get_single_policy_cov(self, x, alpha = 0.5):
    
            """
            Selects policy using population and covariance matrices. For this method 
            we require p_bins = 1, num_sub_pol = 1, m_bins = 1. 
    
            Parameters
            ------------
            x -> PyTorch Tensor
                Input data for the AutoAugment network 
    
            alpha -> Float
                Proportion for covariance and population matrices 
    
            Returns
            -----------
            Subpolicy -> [(String, float, float), (String, float, float)]
                Subpolicy consisting of two tuples of policies, each with a string associated 
                to a transformation, a float for a probability, and a float for a magnittude
            """
    
            section = self.fun_num + self.p_bins + self.m_bins
    
            y = self.controller.forward(x)
    
            y_1 = torch.softmax(y[:,:self.fun_num], dim = 1) 
            y[:,:self.fun_num] = y_1
            y_2 = torch.softmax(y[:,section:section+self.fun_num], dim = 1)
            y[:,section:section+self.fun_num] = y_2
    
            concat = torch.cat((y_1, y_2), dim = 1)
    
    
    Max Ramsay King's avatar
    Max Ramsay King committed
            cov_mat = torch.cov(concat.T)
    
            cov_mat = cov_mat[:self.fun_num, self.fun_num:]
    
            shape_store = cov_mat.shape
    
            counter, prob1, prob2, mag1, mag2 = (0, 0, 0, 0, 0)
    
    
            prob_mat = torch.zeros(shape_store)
            for idx in range(y.shape[0]):
                prob_mat[torch.argmax(y_1[idx])][torch.argmax(y_2[idx])] += 1
            prob_mat = prob_mat / torch.sum(prob_mat)
    
            cov_mat = (alpha * cov_mat) + ((1 - alpha)*prob_mat)
    
            cov_mat = torch.reshape(cov_mat, (1, -1)).squeeze()
            max_idx = torch.argmax(cov_mat)
            val = (max_idx//shape_store[0])
            max_idx = (val, max_idx - (val * shape_store[0]))
    
    
            if not self.augmentation_space[max_idx[0]][1]:
                mag1 = None
            if not self.augmentation_space[max_idx[1]][1]:
                mag2 = None
        
            for idx in range(y.shape[0]):
                if (torch.argmax(y_1[idx]) == max_idx[0]) and (torch.argmax(y_2[idx]) == max_idx[1]):
    
                    prob1 += torch.sigmoid(y[idx, self.fun_num]).item()
                    prob2 += torch.sigmoid(y[idx, section+self.fun_num]).item()
    
                    if mag1 is not None:
    
    Max Ramsay King's avatar
    Max Ramsay King committed
                        mag1 += min(9, 10 * torch.sigmoid(y[idx, self.fun_num+1]).item())
    
                    if mag2 is not None:
    
    Max Ramsay King's avatar
    Max Ramsay King committed
                        mag2 += min(9, 10 * torch.sigmoid(y[idx, self.fun_num+1]).item())
    
            prob1 = round(prob1/counter, 1) if counter != 0 else 0
            prob2 = round(prob2/counter, 1) if counter != 0 else 0
    
            if mag1 is not None:
    
            if mag2 is not None:
    
            return [((self.augmentation_space[max_idx[0]][0], prob1, mag1), (self.augmentation_space[max_idx[1]][0], prob2, mag2))]
    
        def learn(self, train_dataset, test_dataset, child_network_architecture, iterations = 15, return_weights = False):
    
            """
            Runs the GA instance and returns the model weights as a dictionary
    
            Parameters
            ------------
            return_weights -> Bool
                Determines if the weight of the GA network should be returned 
            
            Returns
            ------------
            If return_weights:
                Network weights -> Dictionary
            
            Else:
                Solution -> Best GA instance solution
    
                Solution fitness -> Float
    
                Solution_idx -> Int
            """
    
            self.num_generations = iterations
    
            self.history_best = []
    
    Max Ramsay King's avatar
    Max Ramsay King committed
            self.best_model = 0
    
    
    Max Ramsay King's avatar
    Max Ramsay King committed
            self._set_up_instance(train_dataset, test_dataset, child_network_architecture)
    
            self.ga_instance.run()
    
    Max Ramsay King's avatar
    Max Ramsay King committed
    
    
            solution, solution_fitness, solution_idx = self.ga_instance.best_solution()
            if return_weights:
    
                return torchga.model_weights_as_dict(model=self.controller, weights_vector=solution)
    
            else:
                return solution, solution_fitness, solution_idx
    
    
    
    Max Ramsay King's avatar
    Max Ramsay King committed
        def _in_pol_dict(self, new_policy):
    
            new_policy = new_policy[0]
            trans1, trans2 = new_policy[0][0], new_policy[1][0]
            new_set = {new_policy[0][1], new_policy[0][2], new_policy[1][1], new_policy[1][2]}
            if trans1 in self.policy_dict:
                if trans2 in self.policy_dict[trans1]:
                    for test_pol in self.policy_dict[trans1][trans2]:
                        if new_set == test_pol:
                            return True
                    self.policy_dict[trans1][trans2].append(new_set)
                else:
    
    Max Ramsay King's avatar
    Max Ramsay King committed
                    self.policy_dict[trans1] = {trans2: [new_set]}
    
            return False
    
    Max Ramsay King's avatar
    Max Ramsay King committed
        def _set_up_instance(self, train_dataset, test_dataset, child_network_architecture):
    
            """
            Initialises GA instance, as well as fitness and on_generation functions
            
            """
    
            def fitness_func(solution, sol_idx):
                """
                Defines the fitness function for the parent selection
    
                Parameters
                --------------
                solution -> GA solution instance (parsed automatically)
    
                sol_idx -> GA solution index (parsed automatically)
    
                Returns 
                --------------
                fit_val -> float            
                """
    
    
                model_weights_dict = torchga.model_weights_as_dict(model=self.controller,
    
                                                                weights_vector=solution)
    
    
                self.controller.load_state_dict(model_weights_dict)
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                train_dataset.transform = torchvision.transforms.ToTensor()
    
                self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size)
    
                for idx, (test_x, label_x) in enumerate(self.train_loader):
    
    Max Ramsay King's avatar
    Max Ramsay King committed
                    sub_pol = self._get_single_policy_cov(test_x)
    
    Max Ramsay King's avatar
    Max Ramsay King committed
                    while self._in_pol_dict(sub_pol):
                        sub_pol = self._get_single_policy_cov(test_x)[0]
    
    Max Ramsay King's avatar
    Max Ramsay King committed
                fit_val = self._test_autoaugment_policy(sub_pol,child_network_architecture,train_dataset,test_dataset)
    
    Max Ramsay King's avatar
    Max Ramsay King committed
                self.running_policy.append((sub_pol, fit_val))
    
                if len(self.running_policy) > self.sp_num:
                    self.running_policy = sorted(self.running_policy, key=lambda x: x[1], reverse=True)
                    self.running_policy = self.running_policy[:self.sp_num]
    
                if len(self.history_best) == 0:
                    self.history_best.append((fit_val))
    
    Max Ramsay King's avatar
    Max Ramsay King committed
                    self.best_model = model_weights_dict
    
                elif fit_val > self.history_best[-1]:
                    self.history_best.append(fit_val) 
    
    Max Ramsay King's avatar
    Max Ramsay King committed
                    self.best_model = model_weights_dict
    
                else:
                    self.history_best.append(self.history_best[-1])
    
    Max Ramsay King's avatar
    Max Ramsay King committed
                
    
                return fit_val
    
            def on_generation(ga_instance):
                """
                Prints information of generational fitness
    
                Parameters 
                -------------
                ga_instance -> GA instance
    
                Returns
                -------------
                None
                """
                print("Generation = {generation}".format(generation=ga_instance.generations_completed))
                print("Fitness    = {fitness}".format(fitness=ga_instance.best_solution()[1]))
                return
    
    
            self.ga_instance = pygad.GA(num_generations=self.num_generations, 
                num_parents_mating=self.num_parents_mating, 
                initial_population=self.initial_population,
                mutation_percent_genes = 0.1,
                fitness_func=fitness_func,
                on_generation = on_generation)