From 929f2dd097a44e12016b11729c11a022ddfd1618 Mon Sep 17 00:00:00 2001
From: Sun Jin Kim <sk2521@ic.ac.uk>
Date: Fri, 8 Apr 2022 22:33:24 +0900
Subject: [PATCH] mnist 0.01 randomsearch vs gru graph data

---
 gru_learner.pkl => 0_01pkls/gru_learner.pkl         | Bin
 gru_logs.pkl => 0_01pkls/gru_logs.pkl               | Bin
 .../randomsearch_logs.pkl                           | Bin
 MetaAugment/autoaugment_learners/aa_learner.py      |   4 ++--
 plot_pickles.py                                     |   9 +++++++++
 5 files changed, 11 insertions(+), 2 deletions(-)
 rename gru_learner.pkl => 0_01pkls/gru_learner.pkl (100%)
 rename gru_logs.pkl => 0_01pkls/gru_logs.pkl (100%)
 rename randomsearch_logs.pkl => 0_01pkls/randomsearch_logs.pkl (100%)

diff --git a/gru_learner.pkl b/0_01pkls/gru_learner.pkl
similarity index 100%
rename from gru_learner.pkl
rename to 0_01pkls/gru_learner.pkl
diff --git a/gru_logs.pkl b/0_01pkls/gru_logs.pkl
similarity index 100%
rename from gru_logs.pkl
rename to 0_01pkls/gru_logs.pkl
diff --git a/randomsearch_logs.pkl b/0_01pkls/randomsearch_logs.pkl
similarity index 100%
rename from randomsearch_logs.pkl
rename to 0_01pkls/randomsearch_logs.pkl
diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py
index 6e7874e9..909b05b9 100644
--- a/MetaAugment/autoaugment_learners/aa_learner.py
+++ b/MetaAugment/autoaugment_learners/aa_learner.py
@@ -233,7 +233,7 @@ class aa_learner:
         train_loader, test_loader = create_toy(train_dataset,
                                                 test_dataset,
                                                 batch_size=64,
-                                                n_samples=0.01,
+                                                n_samples=0.04,
                                                 seed=100)
 
         # train the child network with the dataloaders equipped with our specific policy
@@ -244,7 +244,7 @@ class aa_learner:
                                     # sgd = optim.Adadelta(child_network.parameters(), lr=1e-2),
                                     cost = nn.CrossEntropyLoss(),
                                     max_epochs = 3000000, 
-                                    early_stop_num = 120, 
+                                    early_stop_num = 60, 
                                     logging = logging)
         
         # if logging is true, 'accuracy' is actually a tuple: (accuracy, accuracy_log)
diff --git a/plot_pickles.py b/plot_pickles.py
index 7f0f2d7c..90462cfa 100644
--- a/plot_pickles.py
+++ b/plot_pickles.py
@@ -1,6 +1,7 @@
 import pickle
 from pprint import pprint
 import matplotlib.pyplot as plt
+from torch import gru
 
 def get_maxacc(log):
     output = []
@@ -25,6 +26,14 @@ plt.xlabel('number of policies tested')
 plt.legend()
 plt.show()
 
+plt.plot([acc for pol,acc in rs_list], label='randomsearcher')
+plt.plot([acc for pol,acc in gru_list], label='gru learner')
+plt.title('Comparing two agents')
+plt.ylabel('best accuracy to date')
+plt.xlabel('number of policies tested')
+plt.legend()
+plt.show()
+
 
 def get_best5(log):
     l = sorted(log, reverse=True, key=lambda x:x[1])
-- 
GitLab