Skip to content
Snippets Groups Projects
Commit cda1f7fb authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

refactor further UCB1_JC_py

parent d646941f
No related branches found
No related tags found
No related merge requests found
...@@ -21,7 +21,7 @@ from numpy import save, load ...@@ -21,7 +21,7 @@ from numpy import save, load
from tqdm import trange from tqdm import trange
from MetaAugment.child_networks import * from MetaAugment.child_networks import *
from MetaAugment.main import create_toy from MetaAugment.main import create_toy, train_child_network
# In[6]: # In[6]:
...@@ -184,46 +184,9 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl ...@@ -184,46 +184,9 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl
sgd = optim.SGD(model.parameters(), lr=1e-1) sgd = optim.SGD(model.parameters(), lr=1e-1)
cost = nn.CrossEntropyLoss() cost = nn.CrossEntropyLoss()
# set variables for best validation accuracy and early stop count best_acc = train_child_network(model, train_loader, test_loader, sgd,
best_acc = 0 cost, max_epochs, early_stop_num, logging=False,
early_stop_cnt = 0 print_every_epoch=False)
# train model and check validation accuracy each epoch
for _epoch in range(max_epochs):
# train model
model.train()
for idx, (train_x, train_label) in enumerate(train_loader):
label_np = np.zeros((train_label.shape[0], num_labels))
sgd.zero_grad()
predict_y = model(train_x.float())
loss = cost(predict_y, train_label.long())
loss.backward()
sgd.step()
# check validation accuracy on validation set
correct = 0
_sum = 0
model.eval()
for idx, (test_x, test_label) in enumerate(test_loader):
predict_y = model(test_x.float()).detach()
predict_ys = np.argmax(predict_y, axis=-1)
label_np = test_label.numpy()
_ = predict_ys == test_label
correct += np.sum(_.numpy(), axis=-1)
_sum += _.shape[0]
# update best validation accuracy if it was higher, otherwise increase early stop count
acc = correct / _sum
if acc > best_acc :
best_acc = acc
early_stop_cnt = 0
else:
early_stop_cnt += 1
# exit if validation gets worse over 10 runs
if early_stop_cnt >= early_stop_num:
break
# update q_values # update q_values
if policy < num_policies: if policy < num_policies:
...@@ -253,25 +216,25 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl ...@@ -253,25 +216,25 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl
# # In[9]: # # In[9]:
# batch_size = 32 # size of batch the inner NN is trained with batch_size = 32 # size of batch the inner NN is trained with
# learning_rate = 1e-1 # fix learning rate learning_rate = 1e-1 # fix learning rate
# ds = "MNIST" # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100) ds = "MNIST" # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)
# toy_size = 0.02 # total propeortion of training and test set we use toy_size = 0.02 # total propeortion of training and test set we use
# max_epochs = 100 # max number of epochs that is run if early stopping is not hit max_epochs = 100 # max number of epochs that is run if early stopping is not hit
# early_stop_num = 10 # max number of worse validation scores before early stopping is triggered early_stop_num = 10 # max number of worse validation scores before early stopping is triggered
# num_policies = 5 # fix number of policies num_policies = 5 # fix number of policies
# num_sub_policies = 5 # fix number of sub-policies in a policy num_sub_policies = 5 # fix number of sub-policies in a policy
# iterations = 100 # total iterations, should be more than the number of policies iterations = 100 # total iterations, should be more than the number of policies
# IsLeNet = "SimpleNet" # using LeNet or EasyNet or SimpleNet IsLeNet = "SimpleNet" # using LeNet or EasyNet or SimpleNet
# # generate random policies at start # generate random policies at start
# policies = generate_policies(num_policies, num_sub_policies) policies = generate_policies(num_policies, num_sub_policies)
# q_values, best_q_values = run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet) q_values, best_q_values = run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet)
# plt.plot(best_q_values) plt.plot(best_q_values)
# best_q_values = np.array(best_q_values) best_q_values = np.array(best_q_values)
# save('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), best_q_values) save('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), best_q_values)
# #best_q_values = load('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), allow_pickle=True) #best_q_values = load('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), allow_pickle=True)
from .lenet import * from .lenet import *
from .bad_lenet import * from .bad_lenet import *
\ No newline at end of file
class LeNet(nn.Module):
def __init__(self, img_height=28, img_width=28, num_labels=10, img_channels=1):
super().__init__()
self.conv1 = nn.Conv2d(img_channels, 6, 5)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(int((((img_height-4)/2-4)/2)*(((img_width-4)/2-4)/2)*16), 120)
self.relu3 = nn.ReLU()
self.fc2 = nn.Linear(120, 84)
self.relu4 = nn.ReLU()
self.fc3 = nn.Linear(84, num_labels)
self.relu5 = nn.ReLU()
def forward(self, x):
y = self.conv1(x)
y = self.relu1(y)
y = self.pool1(y)
y = self.conv2(y)
y = self.relu2(y)
y = self.pool2(y)
y = y.view(y.shape[0], -1)
y = self.fc1(y)
y = self.relu3(y)
y = self.fc2(y)
y = self.relu4(y)
y = self.fc3(y)
y = self.relu5(y)
return y
"""Define internal NN module that trains on the dataset"""
class EasyNet(nn.Module):
def __init__(self, img_height=28, img_width=28, num_labels=10, img_channels=1):
super().__init__()
self.fc1 = nn.Linear(img_height*img_width*img_channels, 2048)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(2048, num_labels)
self.relu2 = nn.ReLU()
def forward(self, x):
y = x.view(x.shape[0], -1)
y = self.fc1(y)
y = self.relu1(y)
y = self.fc2(y)
y = self.relu2(y)
return y
"""Define internal NN module that trains on the dataset"""
class SimpleNet(nn.Module):
def __init__(self, img_height=28, img_width=28, num_labels=10, img_channels=1):
super().__init__()
self.fc1 = nn.Linear(img_height*img_width*img_channels, num_labels)
self.relu1 = nn.ReLU()
def forward(self, x):
y = x.view(x.shape[0], -1)
y = self.fc1(y)
y = self.relu1(y)
return y
...@@ -106,9 +106,8 @@ def train_child_network(child_network, train_loader, test_loader, sgd, ...@@ -106,9 +106,8 @@ def train_child_network(child_network, train_loader, test_loader, sgd,
if logging: if logging:
return best_acc.item(), acc_log return best_acc.item(), acc_log
else:
print('main.train_child_network best accuracy: ', best_acc) return best_acc.item()
return best_acc.item()
if __name__=='__main__': if __name__=='__main__':
import MetaAugment.child_networks as cn import MetaAugment.child_networks as cn
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment