Skip to content
Snippets Groups Projects
ucb_learner.py 6.45 KiB
Newer Older
  • Learn to ignore specific revisions
  • import numpy as np
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torchvision
    
    from tqdm import trange
    
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    from ..child_networks import *
    
    from ..main import train_child_network
    from .randomsearch_learner import randomsearch_learner
    from .aa_learner import augmentation_space
    
    
    class ucb_learner(randomsearch_learner):
        """
        Tests randomly sampled policies from the search space specified by the AutoAugment
        paper. Acts as a baseline for other aa_learner's.
        """
        def __init__(self,
                    # parameters that define the search space
                    sp_num=5,
                    fun_num=14,
                    p_bins=11,
                    m_bins=10,
                    discrete_p_m=True,
                    # hyperparameters for when training the child_network
                    batch_size=8,
                    toy_flag=False,
                    toy_size=0.1,
                    learning_rate=1e-1,
                    max_epochs=float('inf'),
                    early_stop_num=30,
                    # ucb_learner specific hyperparameter
                    num_policies=100
                    ):
            
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
            super().__init__(sp_num=sp_num, 
                            fun_num=14,
                            p_bins=p_bins, 
                            m_bins=m_bins, 
                            discrete_p_m=discrete_p_m,
                            batch_size=batch_size,
                            toy_flag=toy_flag,
                            toy_size=toy_size,
                            learning_rate=learning_rate,
                            max_epochs=max_epochs,
                            early_stop_num=early_stop_num,)
    
            # attributes used in the UCB1 algorithm
    
            self.num_policies = num_policies
    
            self.policies = [self.generate_new_policy() for _ in range(num_policies)]
    
            self.avg_accs = [None]*self.num_policies
            self.best_avg_accs = []
    
    
            self.cnts = [0]*self.num_policies
            self.q_plus_cnt = [0]*self.num_policies
            self.total_count = 0
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
        def make_more_policies(self, n):
            """generates n more random policies and adds it to self.policies
    
            Args:
                n (int): how many more policies to we want to randomly generate
                        and add to our list of policies
            """
    
    
            self.policies += [self.generate_new_policy() for _ in range(n)]
    
            # all the below need to be lengthened to store information for the 
            # new policies
            self.avg_accs += [None for _ in range(n)]
            self.cnts += [0 for _ in range(n)]
            self.q_plus_cnt += [None for _ in range(n)]
            self.num_policies += n
    
    
        def learn(self, 
                train_dataset, 
                test_dataset, 
                child_network_architecture, 
    
                iterations=15,
                print_every_epoch=False):
            """continue the UCB algorithm for `iterations` number of turns
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    
    
            for this_iter in trange(iterations):
    
                # choose which policy we want to test
                if None in self.avg_accs:
                    # if there is a policy we haven't tested yet, we 
                    # test that one
                    this_policy_idx = self.avg_accs.index(None)
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                    this_policy = self.policies[this_policy_idx]
    
                    acc = self.test_autoaugment_policy(
    
                                    this_policy,
                                    child_network_architecture,
                                    train_dataset,
                                    test_dataset,
    
                                    logging=False,
                                    print_every_epoch=print_every_epoch
    
                    # update q_values (average accuracy)
                    self.avg_accs[this_policy_idx] = acc
    
                    # if we have tested all policies before, we test the
                    # one with the best q_plus_cnt value
                    this_policy_idx = np.argmax(self.q_plus_cnt)
                    this_policy = self.policies[this_policy_idx]
                    acc = self.test_autoaugment_policy(
                                    this_policy,
                                    child_network_architecture,
                                    train_dataset,
                                    test_dataset,
                                    logging=False,
                                    print_every_epoch=print_every_epoch
                                    )
                    # update q_values (average accuracy)
                    self.avg_accs[this_policy_idx] = (self.avg_accs[this_policy_idx]*self.cnts[this_policy_idx] + acc) / (self.cnts[this_policy_idx] + 1)
        
                # logging the best avg acc up to now
                best_avg_acc = max([x for x in self.avg_accs if x is not None])
                self.best_avg_accs.append(best_avg_acc)
    
                # print progress for user
    
                if (this_iter+1) % 5 == 0:
                    print("Iteration: {},\tQ-Values: {}, Best this_iter: {}".format(
                                    this_iter+1, 
    
                                    list(np.around(np.array(self.avg_accs),2)), 
                                    max(list(np.around(np.array(self.avg_accs),2)))
    
                # update counts
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                self.cnts[this_policy_idx] += 1
    
                self.total_count += 1
    
                # update q_plus_cnt values every turn after the initial sweep through
    
                for i in range(self.num_policies):
                    if self.avg_accs[i] is not None:
                        self.q_plus_cnt[i] = self.avg_accs[i] + np.sqrt(2*np.log(self.total_count)/self.cnts[i])
                
                print(self.cnts)
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    if __name__=='__main__':
        batch_size = 32       # size of batch the inner NN is trained with
        learning_rate = 1e-1  # fix learning rate
        ds = "MNIST"          # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)
        toy_size = 0.02       # total propeortion of training and test set we use
        max_epochs = 100      # max number of epochs that is run if early stopping is not hit
        early_stop_num = 10   # max number of worse validation scores before early stopping is triggered
    
    John Carter's avatar
    John Carter committed
        early_stop_flag = True        # implement early stopping or not
        average_validation = [15,25]  # if not implementing early stopping, what epochs are we averaging over
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
        num_policies = 5      # fix number of policies
    
        sp_num = 5  # fix number of sub-policies in a policy
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
        iterations = 100      # total iterations, should be more than the number of policies
    
        IsLeNet = "SimpleNet" # using LeNet or EasyNet or SimpleNet