Skip to content
Snippets Groups Projects
Commit 839fb15f authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

FINISH REFACTORING UCB_LEARNER

parent d239f59a
No related branches found
No related tags found
No related merge requests found
...@@ -309,7 +309,8 @@ class aa_learner: ...@@ -309,7 +309,8 @@ class aa_learner:
child_network_architecture, child_network_architecture,
train_dataset, train_dataset,
test_dataset, test_dataset,
logging=False): logging=False,
print_every_epoch=True):
""" """
Given a policy (using AutoAugment paper terminology), we train a child network Given a policy (using AutoAugment paper terminology), we train a child network
using the policy and return the accuracy (how good the policy is for the dataset and using the policy and return the accuracy (how good the policy is for the dataset and
...@@ -384,7 +385,7 @@ class aa_learner: ...@@ -384,7 +385,7 @@ class aa_learner:
max_epochs = self.max_epochs, max_epochs = self.max_epochs,
early_stop_num = self.early_stop_num, early_stop_num = self.early_stop_num,
logging = logging, logging = logging,
print_every_epoch=True) print_every_epoch=print_every_epoch)
# if logging is true, 'accuracy' is actually a tuple: (accuracy, accuracy_log) # if logging is true, 'accuracy' is actually a tuple: (accuracy, accuracy_log)
return accuracy return accuracy
......
#!/usr/bin/env python
# coding: utf-8
# In[1]:
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -53,23 +47,24 @@ class ucb_learner(randomsearch_learner): ...@@ -53,23 +47,24 @@ class ucb_learner(randomsearch_learner):
max_epochs=max_epochs, max_epochs=max_epochs,
early_stop_num=early_stop_num,) early_stop_num=early_stop_num,)
self.num_policies = num_policies
# When this learner is initialized we generate `num_policies` number
# of random policies.
# generate_new_policy is inherited from the randomsearch_learner class
self.policies = []
self.make_more_policies()
# attributes used in the UCB1 algorithm # attributes used in the UCB1 algorithm
self.q_values = [0]*self.num_policies self.num_policies = num_policies
self.best_q_values = []
self.policies = [self.generate_new_policy() for _ in range(num_policies)]
self.avg_accs = [None]*self.num_policies
self.best_avg_accs = []
self.cnts = [0]*self.num_policies self.cnts = [0]*self.num_policies
self.q_plus_cnt = [0]*self.num_policies self.q_plus_cnt = [0]*self.num_policies
self.total_count = 0 self.total_count = 0
def make_more_policies(self, n): def make_more_policies(self, n):
"""generates n more random policies and adds it to self.policies """generates n more random policies and adds it to self.policies
...@@ -78,50 +73,71 @@ class ucb_learner(randomsearch_learner): ...@@ -78,50 +73,71 @@ class ucb_learner(randomsearch_learner):
and add to our list of policies and add to our list of policies
""" """
self.policies.append([self.generate_new_policy() for _ in n]) self.policies += [self.generate_new_policy() for _ in range(n)]
# all the below need to be lengthened to store information for the
# new policies
self.avg_accs += [None for _ in range(n)]
self.cnts += [0 for _ in range(n)]
self.q_plus_cnt += [None for _ in range(n)]
self.num_policies += n
def learn(self, def learn(self,
train_dataset, train_dataset,
test_dataset, test_dataset,
child_network_architecture, child_network_architecture,
iterations=15): iterations=15,
print_every_epoch=False):
"""continue the UCB algorithm for `iterations` number of turns
"""
for this_iter in trange(iterations): for this_iter in trange(iterations):
# get the action to try (either initially in order or using best q_plus_cnt value) # choose which policy we want to test
# TODO: change this if statemetn if None in self.avg_accs:
if this_iter >= self.num_policies: # if there is a policy we haven't tested yet, we
this_policy_idx = np.argmax(self.q_plus_cnt) # test that one
this_policy_idx = self.avg_accs.index(None)
this_policy = self.policies[this_policy_idx] this_policy = self.policies[this_policy_idx]
else: acc = self.test_autoaugment_policy(
this_policy = this_iter
best_acc = self.test_autoaugment_policy(
this_policy, this_policy,
child_network_architecture, child_network_architecture,
train_dataset, train_dataset,
test_dataset, test_dataset,
logging=False logging=False,
print_every_epoch=print_every_epoch
) )
# update q_values (average accuracy)
# update q_values self.avg_accs[this_policy_idx] = acc
# TODO: change this if statemetn
if this_iter < self.num_policies:
self.q_values[this_policy_idx] += best_acc
else: else:
self.q_values[this_policy_idx] = (self.q_values[this_policy_idx]*self.cnts[this_policy_idx] + best_acc) / (self.cnts[this_policy_idx] + 1) # if we have tested all policies before, we test the
# one with the best q_plus_cnt value
best_q_value = max(self.q_values) this_policy_idx = np.argmax(self.q_plus_cnt)
self.best_q_values.append(best_q_value) this_policy = self.policies[this_policy_idx]
acc = self.test_autoaugment_policy(
this_policy,
child_network_architecture,
train_dataset,
test_dataset,
logging=False,
print_every_epoch=print_every_epoch
)
# update q_values (average accuracy)
self.avg_accs[this_policy_idx] = (self.avg_accs[this_policy_idx]*self.cnts[this_policy_idx] + acc) / (self.cnts[this_policy_idx] + 1)
# logging the best avg acc up to now
best_avg_acc = max([x for x in self.avg_accs if x is not None])
self.best_avg_accs.append(best_avg_acc)
# print progress for user
if (this_iter+1) % 5 == 0: if (this_iter+1) % 5 == 0:
print("Iteration: {},\tQ-Values: {}, Best this_iter: {}".format( print("Iteration: {},\tQ-Values: {}, Best this_iter: {}".format(
this_iter+1, this_iter+1,
list(np.around(np.array(self.q_values),2)), list(np.around(np.array(self.avg_accs),2)),
max(list(np.around(np.array(self.q_values),2))) max(list(np.around(np.array(self.avg_accs),2)))
) )
) )
...@@ -130,10 +146,11 @@ class ucb_learner(randomsearch_learner): ...@@ -130,10 +146,11 @@ class ucb_learner(randomsearch_learner):
self.total_count += 1 self.total_count += 1
# update q_plus_cnt values every turn after the initial sweep through # update q_plus_cnt values every turn after the initial sweep through
# TODO: change this if statemetn for i in range(self.num_policies):
if this_iter >= self.num_policies - 1: if self.avg_accs[i] is not None:
for i in range(self.num_policies): self.q_plus_cnt[i] = self.avg_accs[i] + np.sqrt(2*np.log(self.total_count)/self.cnts[i])
self.q_plus_cnt[i] = self.q_values[i] + np.sqrt(2*np.log(self.total_count)/self.cnts[i])
print(self.cnts)
......
...@@ -17,13 +17,16 @@ from MetaAugment.main import create_toy ...@@ -17,13 +17,16 @@ from MetaAugment.main import create_toy
import pickle import pickle
def parse_users_learner_spec( def parse_users_learner_spec(
# aalearner type
auto_aug_learner, auto_aug_learner,
# search space settings
ds, ds,
ds_name, ds_name,
exclude_method, exclude_method,
num_funcs, num_funcs,
num_policies, num_policies,
num_sub_policies, num_sub_policies,
# child network settings
toy_size, toy_size,
IsLeNet, IsLeNet,
batch_size, batch_size,
......
import MetaAugment.autoaugment_learners as aal import MetaAugment.autoaugment_learners as aal
import MetaAugment.child_networks as cn
import torchvision
import torchvision.datasets as datasets
from pprint import pprint
def test_ucb_learner(): def test_ucb_learner():
child_network_architecture = cn.SimpleNet
train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train',
train=True, download=True, transform=None)
test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test',
train=False, download=True,
transform=torchvision.transforms.ToTensor())
learner = aal.ucb_learner( learner = aal.ucb_learner(
# parameters that define the search space # parameters that define the search space
sp_num=5, sp_num=5,
...@@ -10,15 +21,37 @@ def test_ucb_learner(): ...@@ -10,15 +21,37 @@ def test_ucb_learner():
discrete_p_m=True, discrete_p_m=True,
# hyperparameters for when training the child_network # hyperparameters for when training the child_network
batch_size=8, batch_size=8,
toy_flag=False, toy_flag=True,
toy_size=0.1, toy_size=0.001,
learning_rate=1e-1, learning_rate=1e-1,
max_epochs=float('inf'), max_epochs=float('inf'),
early_stop_num=30, early_stop_num=30,
# ucb_learner specific hyperparameter # ucb_learner specific hyperparameter
num_policies=100 num_policies=3
) )
print(learner.policies) pprint(learner.policies)
assert len(learner.policies)==len(learner.avg_accs), \
(len(learner.policies), (len(learner.avg_accs)))
# learn on the 3 policies we generated
learner.learn(
train_dataset=train_dataset,
test_dataset=test_dataset,
child_network_architecture=child_network_architecture,
iterations=5
)
# let's say we want to explore more policies:
# we generate more new policies
learner.make_more_policies(n=4)
# and let's explore how good those are as well
learner.learn(
train_dataset=train_dataset,
test_dataset=test_dataset,
child_network_architecture=child_network_architecture,
iterations=7
)
if __name__=="__main__": if __name__=="__main__":
test_ucb_learner() test_ucb_learner()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment