Skip to content
Snippets Groups Projects
Commit 839fb15f authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

FINISH REFACTORING UCB_LEARNER

parent d239f59a
Branches
No related tags found
No related merge requests found
......@@ -309,7 +309,8 @@ class aa_learner:
child_network_architecture,
train_dataset,
test_dataset,
logging=False):
logging=False,
print_every_epoch=True):
"""
Given a policy (using AutoAugment paper terminology), we train a child network
using the policy and return the accuracy (how good the policy is for the dataset and
......@@ -384,7 +385,7 @@ class aa_learner:
max_epochs = self.max_epochs,
early_stop_num = self.early_stop_num,
logging = logging,
print_every_epoch=True)
print_every_epoch=print_every_epoch)
# if logging is true, 'accuracy' is actually a tuple: (accuracy, accuracy_log)
return accuracy
......
#!/usr/bin/env python
# coding: utf-8
# In[1]:
import numpy as np
import torch
import torch.nn as nn
......@@ -53,23 +47,24 @@ class ucb_learner(randomsearch_learner):
max_epochs=max_epochs,
early_stop_num=early_stop_num,)
self.num_policies = num_policies
# When this learner is initialized we generate `num_policies` number
# of random policies.
# generate_new_policy is inherited from the randomsearch_learner class
self.policies = []
self.make_more_policies()
# attributes used in the UCB1 algorithm
self.q_values = [0]*self.num_policies
self.best_q_values = []
self.num_policies = num_policies
self.policies = [self.generate_new_policy() for _ in range(num_policies)]
self.avg_accs = [None]*self.num_policies
self.best_avg_accs = []
self.cnts = [0]*self.num_policies
self.q_plus_cnt = [0]*self.num_policies
self.total_count = 0
def make_more_policies(self, n):
"""generates n more random policies and adds it to self.policies
......@@ -78,50 +73,71 @@ class ucb_learner(randomsearch_learner):
and add to our list of policies
"""
self.policies.append([self.generate_new_policy() for _ in n])
self.policies += [self.generate_new_policy() for _ in range(n)]
# all the below need to be lengthened to store information for the
# new policies
self.avg_accs += [None for _ in range(n)]
self.cnts += [0 for _ in range(n)]
self.q_plus_cnt += [None for _ in range(n)]
self.num_policies += n
def learn(self,
train_dataset,
test_dataset,
child_network_architecture,
iterations=15):
iterations=15,
print_every_epoch=False):
"""continue the UCB algorithm for `iterations` number of turns
"""
for this_iter in trange(iterations):
# get the action to try (either initially in order or using best q_plus_cnt value)
# TODO: change this if statemetn
if this_iter >= self.num_policies:
this_policy_idx = np.argmax(self.q_plus_cnt)
# choose which policy we want to test
if None in self.avg_accs:
# if there is a policy we haven't tested yet, we
# test that one
this_policy_idx = self.avg_accs.index(None)
this_policy = self.policies[this_policy_idx]
else:
this_policy = this_iter
best_acc = self.test_autoaugment_policy(
acc = self.test_autoaugment_policy(
this_policy,
child_network_architecture,
train_dataset,
test_dataset,
logging=False
logging=False,
print_every_epoch=print_every_epoch
)
# update q_values
# TODO: change this if statemetn
if this_iter < self.num_policies:
self.q_values[this_policy_idx] += best_acc
# update q_values (average accuracy)
self.avg_accs[this_policy_idx] = acc
else:
self.q_values[this_policy_idx] = (self.q_values[this_policy_idx]*self.cnts[this_policy_idx] + best_acc) / (self.cnts[this_policy_idx] + 1)
best_q_value = max(self.q_values)
self.best_q_values.append(best_q_value)
# if we have tested all policies before, we test the
# one with the best q_plus_cnt value
this_policy_idx = np.argmax(self.q_plus_cnt)
this_policy = self.policies[this_policy_idx]
acc = self.test_autoaugment_policy(
this_policy,
child_network_architecture,
train_dataset,
test_dataset,
logging=False,
print_every_epoch=print_every_epoch
)
# update q_values (average accuracy)
self.avg_accs[this_policy_idx] = (self.avg_accs[this_policy_idx]*self.cnts[this_policy_idx] + acc) / (self.cnts[this_policy_idx] + 1)
# logging the best avg acc up to now
best_avg_acc = max([x for x in self.avg_accs if x is not None])
self.best_avg_accs.append(best_avg_acc)
# print progress for user
if (this_iter+1) % 5 == 0:
print("Iteration: {},\tQ-Values: {}, Best this_iter: {}".format(
this_iter+1,
list(np.around(np.array(self.q_values),2)),
max(list(np.around(np.array(self.q_values),2)))
list(np.around(np.array(self.avg_accs),2)),
max(list(np.around(np.array(self.avg_accs),2)))
)
)
......@@ -130,10 +146,11 @@ class ucb_learner(randomsearch_learner):
self.total_count += 1
# update q_plus_cnt values every turn after the initial sweep through
# TODO: change this if statemetn
if this_iter >= self.num_policies - 1:
for i in range(self.num_policies):
self.q_plus_cnt[i] = self.q_values[i] + np.sqrt(2*np.log(self.total_count)/self.cnts[i])
for i in range(self.num_policies):
if self.avg_accs[i] is not None:
self.q_plus_cnt[i] = self.avg_accs[i] + np.sqrt(2*np.log(self.total_count)/self.cnts[i])
print(self.cnts)
......
......@@ -17,13 +17,16 @@ from MetaAugment.main import create_toy
import pickle
def parse_users_learner_spec(
# aalearner type
auto_aug_learner,
# search space settings
ds,
ds_name,
exclude_method,
num_funcs,
num_policies,
num_sub_policies,
# child network settings
toy_size,
IsLeNet,
batch_size,
......
import MetaAugment.autoaugment_learners as aal
import MetaAugment.child_networks as cn
import torchvision
import torchvision.datasets as datasets
from pprint import pprint
def test_ucb_learner():
child_network_architecture = cn.SimpleNet
train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train',
train=True, download=True, transform=None)
test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test',
train=False, download=True,
transform=torchvision.transforms.ToTensor())
learner = aal.ucb_learner(
# parameters that define the search space
sp_num=5,
......@@ -10,15 +21,37 @@ def test_ucb_learner():
discrete_p_m=True,
# hyperparameters for when training the child_network
batch_size=8,
toy_flag=False,
toy_size=0.1,
toy_flag=True,
toy_size=0.001,
learning_rate=1e-1,
max_epochs=float('inf'),
early_stop_num=30,
# ucb_learner specific hyperparameter
num_policies=100
num_policies=3
)
print(learner.policies)
pprint(learner.policies)
assert len(learner.policies)==len(learner.avg_accs), \
(len(learner.policies), (len(learner.avg_accs)))
# learn on the 3 policies we generated
learner.learn(
train_dataset=train_dataset,
test_dataset=test_dataset,
child_network_architecture=child_network_architecture,
iterations=5
)
# let's say we want to explore more policies:
# we generate more new policies
learner.make_more_policies(n=4)
# and let's explore how good those are as well
learner.learn(
train_dataset=train_dataset,
test_dataset=test_dataset,
child_network_architecture=child_network_architecture,
iterations=7
)
if __name__=="__main__":
test_ucb_learner()
\ No newline at end of file
test_ucb_learner()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment