Skip to content
Snippets Groups Projects
Commit 5ecdfcec authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

continue refactoring ucb

parent e583f2e5
No related branches found
No related tags found
No related merge requests found
from .aa_learner import *
from .randomsearch_learner import *
from .gru_learner import *
from .evo_learner import *
\ No newline at end of file
from .evo_learner import *
from .ucb_learner import *
\ No newline at end of file
......@@ -41,47 +41,60 @@ class ucb_learner(randomsearch_learner):
num_policies=100
):
super().__init__(sp_num,
fun_num,
p_bins,
m_bins,
discrete_p_m=discrete_p_m,
batch_size=batch_size,
toy_flag=toy_flag,
toy_size=toy_size,
learning_rate=learning_rate,
max_epochs=max_epochs,
early_stop_num=early_stop_num,)
super().__init__(sp_num=sp_num,
fun_num=14,
p_bins=p_bins,
m_bins=m_bins,
discrete_p_m=discrete_p_m,
batch_size=batch_size,
toy_flag=toy_flag,
toy_size=toy_size,
learning_rate=learning_rate,
max_epochs=max_epochs,
early_stop_num=early_stop_num,)
self.num_policies = num_policies
# When this learner is initialized we generate `num_policies` number
# of random policies.
# generate_new_policy is inherited from the randomsearch_learner class
self.policies = [self.generate_new_policy() for _ in self.num_policies]
self.policies = []
self.make_more_policies()
# attributes used in the UCB1 algorithm
self.q_values = [0]*self.num_policies
self.best_q_values = []
self.cnts = [0]*self.num_policies
self.q_plus_cnt = [0]*self.num_policies
self.total_count = 0
def make_more_policies(self, n):
"""generates n more random policies and adds it to self.policies
Args:
n (int): how many more policies to we want to randomly generate
and add to our list of policies
"""
self.policies.append([self.generate_new_policy() for _ in n])
def learn(self,
train_dataset,
test_dataset,
child_network_architecture,
iterations=15):
#Initialize vector weights, counts and regret
best_q_values = []
for this_iter in trange(iterations):
# get the action to try (either initially in order or using best q_plus_cnt value)
# TODO: change this if statemetn
if this_iter >= self.num_policies:
this_policy = self.policies[np.argmax(self.q_plus_cnt)]
this_policy_idx = np.argmax(self.q_plus_cnt)
this_policy = self.policies[this_policy_idx]
else:
this_policy = this_iter
......@@ -95,13 +108,14 @@ class ucb_learner(randomsearch_learner):
)
# update q_values
# TODO: change this if statemetn
if this_iter < self.num_policies:
self.q_values[this_policy] += best_acc
self.q_values[this_policy_idx] += best_acc
else:
self.q_values[this_policy] = (self.q_values[this_policy]*self.cnts[this_policy] + best_acc) / (self.cnts[this_policy] + 1)
self.q_values[this_policy_idx] = (self.q_values[this_policy_idx]*self.cnts[this_policy_idx] + best_acc) / (self.cnts[this_policy_idx] + 1)
best_q_value = max(self.q_values)
best_q_values.append(best_q_value)
self.best_q_values.append(best_q_value)
if (this_iter+1) % 5 == 0:
print("Iteration: {},\tQ-Values: {}, Best this_iter: {}".format(
......@@ -112,41 +126,20 @@ class ucb_learner(randomsearch_learner):
)
# update counts
self.cnts[this_policy] += 1
self.cnts[this_policy_idx] += 1
self.total_count += 1
# update q_plus_cnt values every turn after the initial sweep through
# TODO: change this if statemetn
if this_iter >= self.num_policies - 1:
for i in range(self.num_policies):
self.q_plus_cnt[i] = self.q_values[i] + np.sqrt(2*np.log(self.total_count)/self.cnts[i])
# yield q_values, best_q_values
return self.q_values, best_q_values
def run_UCB1(
policies,
batch_size,
learning_rate,
ds,
toy_size,
max_epochs,
early_stop_num,
early_stop_flag,
average_validation,
iterations,
IsLeNet
):
pass
def generate_policies(
num_policies,
self.sp_num
):
pass
......
from ..child_networks import *
from ..main import create_toy, train_child_network
from ..MetaAugment.child_networks import *
from ..MetaAugment.main import create_toy, train_child_network
import torch
import torchvision.datasets as datasets
import pickle
def parse_ds_cn_arch(self, ds, ds_name, IsLeNet, transform):
def parse_ds_cn_arch(ds, ds_name, IsLeNet, transform):
# open data and apply these transformations
if ds == "MNIST":
train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=True, transform=transform)
......@@ -41,19 +41,14 @@ def parse_ds_cn_arch(self, ds, ds_name, IsLeNet, transform):
num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1).item()
# create model
if torch.cuda.is_available():
device='cuda'
else:
device='cpu'
if IsLeNet == "LeNet":
model = LeNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)
model = LeNet(img_height, img_width, num_labels, img_channels)
elif IsLeNet == "EasyNet":
model = EasyNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)
model = EasyNet(img_height, img_width, num_labels, img_channels)
elif IsLeNet == 'SimpleNet':
model = SimpleNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)
model = SimpleNet(img_height, img_width, num_labels, img_channels)
else:
model = pickle.load(open(f'datasets/childnetwork', "rb"))
return train_dataset, test_dataset, model
\ No newline at end of file
return train_dataset, test_dataset, model
\ No newline at end of file
import MetaAugment.autoaugment_learners as aal
import MetaAugment.child_networks as cn
import torch
import torchvision
import torchvision.datasets as datasets
import random
def test_ucb_learner():
policies = UCB1_JC.generate_policies(num_policies, num_sub_policies)
q_values, best_q_values = UCB1_JC.run_UCB1(
policies,
batch_size,
learning_rate,
ds,
toy_size,
max_epochs,
early_stop_num,
iterations,
IsLeNet,
ds_name
)
best_q_values = np.array(best_q_values)
pass
\ No newline at end of file
learner = aal.ucb_learner(
# parameters that define the search space
sp_num=5,
p_bins=11,
m_bins=10,
discrete_p_m=True,
# hyperparameters for when training the child_network
batch_size=8,
toy_flag=False,
toy_size=0.1,
learning_rate=1e-1,
max_epochs=float('inf'),
early_stop_num=30,
# ucb_learner specific hyperparameter
num_policies=100
)
print(learner.policies)
if __name__=="__main__":
test_ucb_learner()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment