From 0dcb24dfc84ed08ee43ca03adebfd3775d76dc66 Mon Sep 17 00:00:00 2001
From: Sun Jin Kim <sk2521@ic.ac.uk>
Date: Sat, 16 Apr 2022 17:16:48 +0100
Subject: [PATCH] add __name__ to UCB1_JC_py

---
 MetaAugment/UCB1_JC_py.py | 44 +++++++++++++++++++--------------------
 1 file changed, 22 insertions(+), 22 deletions(-)

diff --git a/MetaAugment/UCB1_JC_py.py b/MetaAugment/UCB1_JC_py.py
index 48a8573d..8ba8c93d 100644
--- a/MetaAugment/UCB1_JC_py.py
+++ b/MetaAugment/UCB1_JC_py.py
@@ -215,26 +215,26 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl
 
 # # In[9]:
 
-
-batch_size = 32       # size of batch the inner NN is trained with
-learning_rate = 1e-1  # fix learning rate
-ds = "MNIST"          # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)
-toy_size = 0.02       # total propeortion of training and test set we use
-max_epochs = 100      # max number of epochs that is run if early stopping is not hit
-early_stop_num = 10   # max number of worse validation scores before early stopping is triggered
-num_policies = 5      # fix number of policies
-num_sub_policies = 5  # fix number of sub-policies in a policy
-iterations = 100      # total iterations, should be more than the number of policies
-IsLeNet = "SimpleNet" # using LeNet or EasyNet or SimpleNet
-
-# generate random policies at start
-policies = generate_policies(num_policies, num_sub_policies)
-
-q_values, best_q_values = run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet)
-
-plt.plot(best_q_values)
-
-best_q_values = np.array(best_q_values)
-save('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), best_q_values)
-#best_q_values = load('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), allow_pickle=True)
+if __name__=='__main__':
+    batch_size = 32       # size of batch the inner NN is trained with
+    learning_rate = 1e-1  # fix learning rate
+    ds = "MNIST"          # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)
+    toy_size = 0.02       # total propeortion of training and test set we use
+    max_epochs = 100      # max number of epochs that is run if early stopping is not hit
+    early_stop_num = 10   # max number of worse validation scores before early stopping is triggered
+    num_policies = 5      # fix number of policies
+    num_sub_policies = 5  # fix number of sub-policies in a policy
+    iterations = 100      # total iterations, should be more than the number of policies
+    IsLeNet = "SimpleNet" # using LeNet or EasyNet or SimpleNet
+
+    # generate random policies at start
+    policies = generate_policies(num_policies, num_sub_policies)
+
+    q_values, best_q_values = run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet)
+
+    plt.plot(best_q_values)
+
+    best_q_values = np.array(best_q_values)
+    save('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), best_q_values)
+    #best_q_values = load('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), allow_pickle=True)
 
-- 
GitLab