-
Sun Jin Kim authoredSun Jin Kim authored
ucb_learner.py 7.08 KiB
import numpy as np
from tqdm import trange
from ..child_networks import *
from .randomsearch_learner import randomsearch_learner
class ucb_learner(randomsearch_learner):
"""
Tests randomly sampled policies from the search space specified by the AutoAugment
paper. Acts as a baseline for other aa_learner's.
"""
def __init__(self,
# parameters that define the search space
sp_num=5,
p_bins=11,
m_bins=10,
discrete_p_m=True,
exclude_method=[],
# hyperparameters for when training the child_network
batch_size=8,
toy_size=1,
learning_rate=1e-1,
max_epochs=float('inf'),
early_stop_num=30,
# ucb_learner specific hyperparameter
num_policies=100
):
super().__init__(
sp_num=sp_num,
p_bins=p_bins,
m_bins=m_bins,
discrete_p_m=discrete_p_m,
batch_size=batch_size,
toy_size=toy_size,
learning_rate=learning_rate,
max_epochs=max_epochs,
early_stop_num=early_stop_num,
exclude_method=exclude_method,
)
# attributes used in the UCB1 algorithm
self.num_policies = num_policies
self.policies = [self._generate_new_policy() for _ in range(num_policies)]
self.avg_accs = [None]*self.num_policies
self.best_avg_accs = []
self.cnts = [0]*self.num_policies
self.q_plus_cnt = [0]*self.num_policies
self.total_count = 0
def make_more_policies(self, n):
"""generates n more random policies and adds it to self.policies
Args:
n (int): how many more policies to we want to randomly generate
and add to our list of policies
"""
self.policies += [self._generate_new_policy() for _ in range(n)]
# all the below need to be lengthened to store information for the
# new policies
self.avg_accs += [None for _ in range(n)]
self.cnts += [0 for _ in range(n)]
self.q_plus_cnt += [None for _ in range(n)]
self.num_policies += n
def learn(self,
train_dataset,
test_dataset,
child_network_architecture,
iterations=15,
print_every_epoch=False):
"""continue the UCB algorithm for `iterations` number of turns
"""
for this_iter in trange(iterations):
# choose which policy we want to test
if None in self.avg_accs:
# if there is a policy we haven't tested yet, we
# test that one
this_policy_idx = self.avg_accs.index(None)
this_policy = self.policies[this_policy_idx]
acc = self._test_autoaugment_policy(
this_policy,
child_network_architecture,
train_dataset,
test_dataset,
logging=False,
print_every_epoch=print_every_epoch
)
# update q_values (average accuracy)
self.avg_accs[this_policy_idx] = acc
else:
# if we have tested all policies before, we test the
# one with the best q_plus_cnt value
this_policy_idx = np.argmax(self.q_plus_cnt)
this_policy = self.policies[this_policy_idx]
acc = self._test_autoaugment_policy(
this_policy,
child_network_architecture,
train_dataset,
test_dataset,
logging=False,
print_every_epoch=print_every_epoch
)
# update q_values (average accuracy)
self.avg_accs[this_policy_idx] = (self.avg_accs[this_policy_idx]*self.cnts[this_policy_idx] + acc) / (self.cnts[this_policy_idx] + 1)
# logging the best avg acc up to now
best_avg_acc = max([x for x in self.avg_accs if x is not None])
self.best_avg_accs.append(best_avg_acc)
# print progress for user
if (this_iter+1) % 5 == 0:
print("Iteration: {},\tQ-Values: {}, Best this_iter: {}".format(
this_iter+1,
list(np.around(np.array(self.avg_accs),2)),
max(list(np.around(np.array(self.avg_accs),2)))
)
)
# update counts
self.cnts[this_policy_idx] += 1
self.total_count += 1
# update q_plus_cnt values every turn after the initial sweep through
for i in range(self.num_policies):
if self.avg_accs[i] is not None:
self.q_plus_cnt[i] = self.avg_accs[i] + np.sqrt(2*np.log(self.total_count)/self.cnts[i])
print(self.cnts)
def get_mega_policy(self, number_policies=5):
"""
Produces a mega policy, based on the n best subpolicies (evo learner)/policies
(other learners)
Args:
number_policies -> int: Number of (sub)policies to be included in the mega
policy
Returns:
megapolicy -> [subpolicy, subpolicy, ...]
"""
temp_avg_accs = [x if x is not None else 0 for x in self.avg_accs]
temp_history = list(zip(self.policies, temp_avg_accs))
number_policies = max(number_policies, len(temp_history))
inter_pol = sorted(temp_history, key=lambda x: x[1], reverse = True)[:number_policies]
megapol = []
for pol in inter_pol:
megapol += pol[0]
return megapol
if __name__=='__main__':
batch_size = 32 # size of batch the inner NN is trained with
learning_rate = 1e-1 # fix learning rate
ds = "MNIST" # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)
toy_size = 0.02 # total propeortion of training and test set we use
max_epochs = 100 # max number of epochs that is run if early stopping is not hit
early_stop_num = 10 # max number of worse validation scores before early stopping is triggered
early_stop_flag = True # implement early stopping or not
average_validation = [15,25] # if not implementing early stopping, what epochs are we averaging over
num_policies = 5 # fix number of policies
sp_num = 5 # fix number of sub-policies in a policy
iterations = 100 # total iterations, should be more than the number of policies
IsLeNet = "SimpleNet" # using LeNet or EasyNet or SimpleNet