Skip to content
Snippets Groups Projects
ucb_learner.py 5.54 KiB
Newer Older
  • Learn to ignore specific revisions
  • #!/usr/bin/env python
    # coding: utf-8
    
    # In[1]:
    
    
    import numpy as np
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torchvision
    
    from tqdm import trange
    
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    from ..child_networks import *
    
    from ..main import train_child_network
    from .randomsearch_learner import randomsearch_learner
    from .aa_learner import augmentation_space
    
    
    class ucb_learner(randomsearch_learner):
        """
        Tests randomly sampled policies from the search space specified by the AutoAugment
        paper. Acts as a baseline for other aa_learner's.
        """
        def __init__(self,
                    # parameters that define the search space
                    sp_num=5,
                    fun_num=14,
                    p_bins=11,
                    m_bins=10,
                    discrete_p_m=True,
                    # hyperparameters for when training the child_network
                    batch_size=8,
                    toy_flag=False,
                    toy_size=0.1,
                    learning_rate=1e-1,
                    max_epochs=float('inf'),
                    early_stop_num=30,
                    # ucb_learner specific hyperparameter
                    num_policies=100
                    ):
            
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
            super().__init__(sp_num=sp_num, 
                            fun_num=14,
                            p_bins=p_bins, 
                            m_bins=m_bins, 
                            discrete_p_m=discrete_p_m,
                            batch_size=batch_size,
                            toy_flag=toy_flag,
                            toy_size=toy_size,
                            learning_rate=learning_rate,
                            max_epochs=max_epochs,
                            early_stop_num=early_stop_num,)
    
            
            self.num_policies = num_policies
    
            # When this learner is initialized we generate `num_policies` number
            # of random policies. 
            # generate_new_policy is inherited from the randomsearch_learner class
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
            self.policies = []
            self.make_more_policies()
    
            # attributes used in the UCB1 algorithm
            self.q_values = [0]*self.num_policies
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
            self.best_q_values = []
    
            self.cnts = [0]*self.num_policies
            self.q_plus_cnt = [0]*self.num_policies
            self.total_count = 0
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    
    
        def make_more_policies(self, n):
            """generates n more random policies and adds it to self.policies
    
            Args:
                n (int): how many more policies to we want to randomly generate
                        and add to our list of policies
            """
    
            self.policies.append([self.generate_new_policy() for _ in n])
    
    
    
        def learn(self, 
                train_dataset, 
                test_dataset, 
                child_network_architecture, 
                iterations=15):
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    
    
            for this_iter in trange(iterations):
    
                # get the action to try (either initially in order or using best q_plus_cnt value)
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                # TODO: change this if statemetn
    
                if this_iter >= self.num_policies:
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                    this_policy_idx = np.argmax(self.q_plus_cnt)
                    this_policy = self.policies[this_policy_idx]
    
                else:
                    this_policy = this_iter
    
                best_acc = self.test_autoaugment_policy(
                                    this_policy,
                                    child_network_architecture,
                                    train_dataset,
                                    test_dataset,
                                    logging=False
                                    )
    
                # update q_values
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                # TODO: change this if statemetn
    
                if this_iter < self.num_policies:
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                    self.q_values[this_policy_idx] += best_acc
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                    self.q_values[this_policy_idx] = (self.q_values[this_policy_idx]*self.cnts[this_policy_idx] + best_acc) / (self.cnts[this_policy_idx] + 1)
    
                best_q_value = max(self.q_values)
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                self.best_q_values.append(best_q_value)
    
                if (this_iter+1) % 5 == 0:
                    print("Iteration: {},\tQ-Values: {}, Best this_iter: {}".format(
                                    this_iter+1, 
                                    list(np.around(np.array(self.q_values),2)), 
                                    max(list(np.around(np.array(self.q_values),2)))
                                    )
                        )
    
                # update counts
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                self.cnts[this_policy_idx] += 1
    
                self.total_count += 1
    
                # update q_plus_cnt values every turn after the initial sweep through
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                # TODO: change this if statemetn
    
                if this_iter >= self.num_policies - 1:
                    for i in range(self.num_policies):
                        self.q_plus_cnt[i] = self.q_values[i] + np.sqrt(2*np.log(self.total_count)/self.cnts[i])
    
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    if __name__=='__main__':
        batch_size = 32       # size of batch the inner NN is trained with
        learning_rate = 1e-1  # fix learning rate
        ds = "MNIST"          # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)
        toy_size = 0.02       # total propeortion of training and test set we use
        max_epochs = 100      # max number of epochs that is run if early stopping is not hit
        early_stop_num = 10   # max number of worse validation scores before early stopping is triggered
    
    John Carter's avatar
    John Carter committed
        early_stop_flag = True        # implement early stopping or not
        average_validation = [15,25]  # if not implementing early stopping, what epochs are we averaging over
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
        num_policies = 5      # fix number of policies
    
        sp_num = 5  # fix number of sub-policies in a policy
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
        iterations = 100      # total iterations, should be more than the number of policies
    
        IsLeNet = "SimpleNet" # using LeNet or EasyNet or SimpleNet