diff --git a/MetaAugment/UCB1_JC_py.py b/MetaAugment/UCB1_JC_py.py index 48a8573dc4431ab07ecb318aa945a10e1ef2d38d..8ba8c93d7cfd6fb5fad91499e5a46faad7c1d91a 100644 --- a/MetaAugment/UCB1_JC_py.py +++ b/MetaAugment/UCB1_JC_py.py @@ -215,26 +215,26 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl # # In[9]: - -batch_size = 32 # size of batch the inner NN is trained with -learning_rate = 1e-1 # fix learning rate -ds = "MNIST" # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100) -toy_size = 0.02 # total propeortion of training and test set we use -max_epochs = 100 # max number of epochs that is run if early stopping is not hit -early_stop_num = 10 # max number of worse validation scores before early stopping is triggered -num_policies = 5 # fix number of policies -num_sub_policies = 5 # fix number of sub-policies in a policy -iterations = 100 # total iterations, should be more than the number of policies -IsLeNet = "SimpleNet" # using LeNet or EasyNet or SimpleNet - -# generate random policies at start -policies = generate_policies(num_policies, num_sub_policies) - -q_values, best_q_values = run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet) - -plt.plot(best_q_values) - -best_q_values = np.array(best_q_values) -save('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), best_q_values) -#best_q_values = load('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), allow_pickle=True) +if __name__=='__main__': + batch_size = 32 # size of batch the inner NN is trained with + learning_rate = 1e-1 # fix learning rate + ds = "MNIST" # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100) + toy_size = 0.02 # total propeortion of training and test set we use + max_epochs = 100 # max number of epochs that is run if early stopping is not hit + early_stop_num = 10 # max number of worse validation scores before early stopping is triggered + num_policies = 5 # fix number of policies + num_sub_policies = 5 # fix number of sub-policies in a policy + iterations = 100 # total iterations, should be more than the number of policies + IsLeNet = "SimpleNet" # using LeNet or EasyNet or SimpleNet + + # generate random policies at start + policies = generate_policies(num_policies, num_sub_policies) + + q_values, best_q_values = run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet) + + plt.plot(best_q_values) + + best_q_values = np.array(best_q_values) + save('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), best_q_values) + #best_q_values = load('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), allow_pickle=True)