From 0dcb24dfc84ed08ee43ca03adebfd3775d76dc66 Mon Sep 17 00:00:00 2001 From: Sun Jin Kim <sk2521@ic.ac.uk> Date: Sat, 16 Apr 2022 17:16:48 +0100 Subject: [PATCH] add __name__ to UCB1_JC_py --- MetaAugment/UCB1_JC_py.py | 44 +++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/MetaAugment/UCB1_JC_py.py b/MetaAugment/UCB1_JC_py.py index 48a8573d..8ba8c93d 100644 --- a/MetaAugment/UCB1_JC_py.py +++ b/MetaAugment/UCB1_JC_py.py @@ -215,26 +215,26 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl # # In[9]: - -batch_size = 32 # size of batch the inner NN is trained with -learning_rate = 1e-1 # fix learning rate -ds = "MNIST" # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100) -toy_size = 0.02 # total propeortion of training and test set we use -max_epochs = 100 # max number of epochs that is run if early stopping is not hit -early_stop_num = 10 # max number of worse validation scores before early stopping is triggered -num_policies = 5 # fix number of policies -num_sub_policies = 5 # fix number of sub-policies in a policy -iterations = 100 # total iterations, should be more than the number of policies -IsLeNet = "SimpleNet" # using LeNet or EasyNet or SimpleNet - -# generate random policies at start -policies = generate_policies(num_policies, num_sub_policies) - -q_values, best_q_values = run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet) - -plt.plot(best_q_values) - -best_q_values = np.array(best_q_values) -save('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), best_q_values) -#best_q_values = load('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), allow_pickle=True) +if __name__=='__main__': + batch_size = 32 # size of batch the inner NN is trained with + learning_rate = 1e-1 # fix learning rate + ds = "MNIST" # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100) + toy_size = 0.02 # total propeortion of training and test set we use + max_epochs = 100 # max number of epochs that is run if early stopping is not hit + early_stop_num = 10 # max number of worse validation scores before early stopping is triggered + num_policies = 5 # fix number of policies + num_sub_policies = 5 # fix number of sub-policies in a policy + iterations = 100 # total iterations, should be more than the number of policies + IsLeNet = "SimpleNet" # using LeNet or EasyNet or SimpleNet + + # generate random policies at start + policies = generate_policies(num_policies, num_sub_policies) + + q_values, best_q_values = run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet) + + plt.plot(best_q_values) + + best_q_values = np.array(best_q_values) + save('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), best_q_values) + #best_q_values = load('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), allow_pickle=True) -- GitLab