Skip to content
Snippets Groups Projects
Commit 5ecdfcec authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

continue refactoring ucb

parent e583f2e5
No related branches found
No related tags found
No related merge requests found
from .aa_learner import * from .aa_learner import *
from .randomsearch_learner import * from .randomsearch_learner import *
from .gru_learner import * from .gru_learner import *
from .evo_learner import * from .evo_learner import *
\ No newline at end of file from .ucb_learner import *
\ No newline at end of file
...@@ -41,47 +41,60 @@ class ucb_learner(randomsearch_learner): ...@@ -41,47 +41,60 @@ class ucb_learner(randomsearch_learner):
num_policies=100 num_policies=100
): ):
super().__init__(sp_num, super().__init__(sp_num=sp_num,
fun_num, fun_num=14,
p_bins, p_bins=p_bins,
m_bins, m_bins=m_bins,
discrete_p_m=discrete_p_m, discrete_p_m=discrete_p_m,
batch_size=batch_size, batch_size=batch_size,
toy_flag=toy_flag, toy_flag=toy_flag,
toy_size=toy_size, toy_size=toy_size,
learning_rate=learning_rate, learning_rate=learning_rate,
max_epochs=max_epochs, max_epochs=max_epochs,
early_stop_num=early_stop_num,) early_stop_num=early_stop_num,)
self.num_policies = num_policies self.num_policies = num_policies
# When this learner is initialized we generate `num_policies` number # When this learner is initialized we generate `num_policies` number
# of random policies. # of random policies.
# generate_new_policy is inherited from the randomsearch_learner class # generate_new_policy is inherited from the randomsearch_learner class
self.policies = [self.generate_new_policy() for _ in self.num_policies] self.policies = []
self.make_more_policies()
# attributes used in the UCB1 algorithm # attributes used in the UCB1 algorithm
self.q_values = [0]*self.num_policies self.q_values = [0]*self.num_policies
self.best_q_values = []
self.cnts = [0]*self.num_policies self.cnts = [0]*self.num_policies
self.q_plus_cnt = [0]*self.num_policies self.q_plus_cnt = [0]*self.num_policies
self.total_count = 0 self.total_count = 0
def make_more_policies(self, n):
"""generates n more random policies and adds it to self.policies
Args:
n (int): how many more policies to we want to randomly generate
and add to our list of policies
"""
self.policies.append([self.generate_new_policy() for _ in n])
def learn(self, def learn(self,
train_dataset, train_dataset,
test_dataset, test_dataset,
child_network_architecture, child_network_architecture,
iterations=15): iterations=15):
#Initialize vector weights, counts and regret
best_q_values = []
for this_iter in trange(iterations): for this_iter in trange(iterations):
# get the action to try (either initially in order or using best q_plus_cnt value) # get the action to try (either initially in order or using best q_plus_cnt value)
# TODO: change this if statemetn
if this_iter >= self.num_policies: if this_iter >= self.num_policies:
this_policy = self.policies[np.argmax(self.q_plus_cnt)] this_policy_idx = np.argmax(self.q_plus_cnt)
this_policy = self.policies[this_policy_idx]
else: else:
this_policy = this_iter this_policy = this_iter
...@@ -95,13 +108,14 @@ class ucb_learner(randomsearch_learner): ...@@ -95,13 +108,14 @@ class ucb_learner(randomsearch_learner):
) )
# update q_values # update q_values
# TODO: change this if statemetn
if this_iter < self.num_policies: if this_iter < self.num_policies:
self.q_values[this_policy] += best_acc self.q_values[this_policy_idx] += best_acc
else: else:
self.q_values[this_policy] = (self.q_values[this_policy]*self.cnts[this_policy] + best_acc) / (self.cnts[this_policy] + 1) self.q_values[this_policy_idx] = (self.q_values[this_policy_idx]*self.cnts[this_policy_idx] + best_acc) / (self.cnts[this_policy_idx] + 1)
best_q_value = max(self.q_values) best_q_value = max(self.q_values)
best_q_values.append(best_q_value) self.best_q_values.append(best_q_value)
if (this_iter+1) % 5 == 0: if (this_iter+1) % 5 == 0:
print("Iteration: {},\tQ-Values: {}, Best this_iter: {}".format( print("Iteration: {},\tQ-Values: {}, Best this_iter: {}".format(
...@@ -112,41 +126,20 @@ class ucb_learner(randomsearch_learner): ...@@ -112,41 +126,20 @@ class ucb_learner(randomsearch_learner):
) )
# update counts # update counts
self.cnts[this_policy] += 1 self.cnts[this_policy_idx] += 1
self.total_count += 1 self.total_count += 1
# update q_plus_cnt values every turn after the initial sweep through # update q_plus_cnt values every turn after the initial sweep through
# TODO: change this if statemetn
if this_iter >= self.num_policies - 1: if this_iter >= self.num_policies - 1:
for i in range(self.num_policies): for i in range(self.num_policies):
self.q_plus_cnt[i] = self.q_values[i] + np.sqrt(2*np.log(self.total_count)/self.cnts[i]) self.q_plus_cnt[i] = self.q_values[i] + np.sqrt(2*np.log(self.total_count)/self.cnts[i])
# yield q_values, best_q_values
return self.q_values, best_q_values
def run_UCB1(
policies,
batch_size,
learning_rate,
ds,
toy_size,
max_epochs,
early_stop_num,
early_stop_flag,
average_validation,
iterations,
IsLeNet
):
pass
def generate_policies(
num_policies,
self.sp_num
):
pass
......
from ..child_networks import * from ..MetaAugment.child_networks import *
from ..main import create_toy, train_child_network from ..MetaAugment.main import create_toy, train_child_network
import torch import torch
import torchvision.datasets as datasets import torchvision.datasets as datasets
import pickle import pickle
def parse_ds_cn_arch(self, ds, ds_name, IsLeNet, transform): def parse_ds_cn_arch(ds, ds_name, IsLeNet, transform):
# open data and apply these transformations # open data and apply these transformations
if ds == "MNIST": if ds == "MNIST":
train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=True, transform=transform) train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=True, transform=transform)
...@@ -41,19 +41,14 @@ def parse_ds_cn_arch(self, ds, ds_name, IsLeNet, transform): ...@@ -41,19 +41,14 @@ def parse_ds_cn_arch(self, ds, ds_name, IsLeNet, transform):
num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1).item() num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1).item()
# create model
if torch.cuda.is_available():
device='cuda'
else:
device='cpu'
if IsLeNet == "LeNet": if IsLeNet == "LeNet":
model = LeNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device) model = LeNet(img_height, img_width, num_labels, img_channels)
elif IsLeNet == "EasyNet": elif IsLeNet == "EasyNet":
model = EasyNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device) model = EasyNet(img_height, img_width, num_labels, img_channels)
elif IsLeNet == 'SimpleNet': elif IsLeNet == 'SimpleNet':
model = SimpleNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device) model = SimpleNet(img_height, img_width, num_labels, img_channels)
else: else:
model = pickle.load(open(f'datasets/childnetwork', "rb")) model = pickle.load(open(f'datasets/childnetwork', "rb"))
return train_dataset, test_dataset, model return train_dataset, test_dataset, model
\ No newline at end of file \ No newline at end of file
import MetaAugment.autoaugment_learners as aal import MetaAugment.autoaugment_learners as aal
import MetaAugment.child_networks as cn
import torch
import torchvision
import torchvision.datasets as datasets
import random
def test_ucb_learner(): def test_ucb_learner():
policies = UCB1_JC.generate_policies(num_policies, num_sub_policies) learner = aal.ucb_learner(
q_values, best_q_values = UCB1_JC.run_UCB1( # parameters that define the search space
policies, sp_num=5,
batch_size, p_bins=11,
learning_rate, m_bins=10,
ds, discrete_p_m=True,
toy_size, # hyperparameters for when training the child_network
max_epochs, batch_size=8,
early_stop_num, toy_flag=False,
iterations, toy_size=0.1,
IsLeNet, learning_rate=1e-1,
ds_name max_epochs=float('inf'),
) early_stop_num=30,
best_q_values = np.array(best_q_values) # ucb_learner specific hyperparameter
pass num_policies=100
\ No newline at end of file )
print(learner.policies)
if __name__=="__main__":
test_ucb_learner()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment