Skip to content
Snippets Groups Projects
ucb_learner.py 7.08 KiB
import numpy as np

from tqdm import trange

from ..child_networks import *
from .randomsearch_learner import randomsearch_learner


class ucb_learner(randomsearch_learner):
    """
    Tests randomly sampled policies from the search space specified by the AutoAugment
    paper. Acts as a baseline for other aa_learner's.
    """
    def __init__(self,
                # parameters that define the search space
                sp_num=5,
                p_bins=11,
                m_bins=10,
                discrete_p_m=True,
                exclude_method=[],
                # hyperparameters for when training the child_network
                batch_size=8,
                toy_size=1,
                learning_rate=1e-1,
                max_epochs=float('inf'),
                early_stop_num=30,
                # ucb_learner specific hyperparameter
                num_policies=100
                ):
        
        super().__init__(
                        sp_num=sp_num, 
                        p_bins=p_bins, 
                        m_bins=m_bins, 
                        discrete_p_m=discrete_p_m,
                        batch_size=batch_size,
                        toy_size=toy_size,
                        learning_rate=learning_rate,
                        max_epochs=max_epochs,
                        early_stop_num=early_stop_num,
                        exclude_method=exclude_method,
                        )
        

        

        # attributes used in the UCB1 algorithm
        self.num_policies = num_policies

        self.policies = [self._generate_new_policy() for _ in range(num_policies)]

        self.avg_accs = [None]*self.num_policies
        self.best_avg_accs = []

        self.cnts = [0]*self.num_policies
        self.q_plus_cnt = [0]*self.num_policies
        self.total_count = 0




    def make_more_policies(self, n):
        """generates n more random policies and adds it to self.policies

        Args:
            n (int): how many more policies to we want to randomly generate
                    and add to our list of policies
        """

        self.policies += [self._generate_new_policy() for _ in range(n)]

        # all the below need to be lengthened to store information for the 
        # new policies
        self.avg_accs += [None for _ in range(n)]
        self.cnts += [0 for _ in range(n)]
        self.q_plus_cnt += [None for _ in range(n)]
        self.num_policies += n



    def learn(self, 
            train_dataset, 
            test_dataset, 
            child_network_architecture, 
            iterations=15,
            print_every_epoch=False):
        """continue the UCB algorithm for `iterations` number of turns

        """

        for this_iter in trange(iterations):

            # choose which policy we want to test
            if None in self.avg_accs:
                # if there is a policy we haven't tested yet, we 
                # test that one
                this_policy_idx = self.avg_accs.index(None)
                this_policy = self.policies[this_policy_idx]
                acc = self._test_autoaugment_policy(
                                this_policy,
                                child_network_architecture,
                                train_dataset,
                                test_dataset,
                                logging=False,
                                print_every_epoch=print_every_epoch
                                )
                # update q_values (average accuracy)
                self.avg_accs[this_policy_idx] = acc
            else:
                # if we have tested all policies before, we test the
                # one with the best q_plus_cnt value
                this_policy_idx = np.argmax(self.q_plus_cnt)
                this_policy = self.policies[this_policy_idx]
                acc = self._test_autoaugment_policy(
                                this_policy,
                                child_network_architecture,
                                train_dataset,
                                test_dataset,
                                logging=False,
                                print_every_epoch=print_every_epoch
                                )
                # update q_values (average accuracy)
                self.avg_accs[this_policy_idx] = (self.avg_accs[this_policy_idx]*self.cnts[this_policy_idx] + acc) / (self.cnts[this_policy_idx] + 1)
    
            # logging the best avg acc up to now
            best_avg_acc = max([x for x in self.avg_accs if x is not None])
            self.best_avg_accs.append(best_avg_acc)

            # print progress for user
            if (this_iter+1) % 5 == 0:
                print("Iteration: {},\tQ-Values: {}, Best this_iter: {}".format(
                                this_iter+1, 
                                list(np.around(np.array(self.avg_accs),2)), 
                                max(list(np.around(np.array(self.avg_accs),2)))
                                )
                    )

            # update counts
            self.cnts[this_policy_idx] += 1
            self.total_count += 1

            # update q_plus_cnt values every turn after the initial sweep through
            for i in range(self.num_policies):
                if self.avg_accs[i] is not None:
                    self.q_plus_cnt[i] = self.avg_accs[i] + np.sqrt(2*np.log(self.total_count)/self.cnts[i])
            
            print(self.cnts)

            
    def get_mega_policy(self, number_policies=5):
        """
        Produces a mega policy, based on the n best subpolicies (evo learner)/policies
        (other learners)

        
        Args: 
            number_policies -> int: Number of (sub)policies to be included in the mega
            policy

        Returns:
            megapolicy -> [subpolicy, subpolicy, ...]
        """

        temp_avg_accs = [x if x is not None  else 0 for x in self.avg_accs]

        temp_history = list(zip(self.policies, temp_avg_accs))

        number_policies = max(number_policies, len(temp_history))

        inter_pol = sorted(temp_history, key=lambda x: x[1], reverse = True)[:number_policies]

        megapol = []
        for pol in inter_pol:
            megapol += pol[0]

        return megapol

       




if __name__=='__main__':
    batch_size = 32       # size of batch the inner NN is trained with
    learning_rate = 1e-1  # fix learning rate
    ds = "MNIST"          # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)
    toy_size = 0.02       # total propeortion of training and test set we use
    max_epochs = 100      # max number of epochs that is run if early stopping is not hit
    early_stop_num = 10   # max number of worse validation scores before early stopping is triggered
    early_stop_flag = True        # implement early stopping or not
    average_validation = [15,25]  # if not implementing early stopping, what epochs are we averaging over
    num_policies = 5      # fix number of policies
    sp_num = 5  # fix number of sub-policies in a policy
    iterations = 100      # total iterations, should be more than the number of policies
    IsLeNet = "SimpleNet" # using LeNet or EasyNet or SimpleNet