From 929f2dd097a44e12016b11729c11a022ddfd1618 Mon Sep 17 00:00:00 2001 From: Sun Jin Kim <sk2521@ic.ac.uk> Date: Fri, 8 Apr 2022 22:33:24 +0900 Subject: [PATCH] mnist 0.01 randomsearch vs gru graph data --- gru_learner.pkl => 0_01pkls/gru_learner.pkl | Bin gru_logs.pkl => 0_01pkls/gru_logs.pkl | Bin .../randomsearch_logs.pkl | Bin MetaAugment/autoaugment_learners/aa_learner.py | 4 ++-- plot_pickles.py | 9 +++++++++ 5 files changed, 11 insertions(+), 2 deletions(-) rename gru_learner.pkl => 0_01pkls/gru_learner.pkl (100%) rename gru_logs.pkl => 0_01pkls/gru_logs.pkl (100%) rename randomsearch_logs.pkl => 0_01pkls/randomsearch_logs.pkl (100%) diff --git a/gru_learner.pkl b/0_01pkls/gru_learner.pkl similarity index 100% rename from gru_learner.pkl rename to 0_01pkls/gru_learner.pkl diff --git a/gru_logs.pkl b/0_01pkls/gru_logs.pkl similarity index 100% rename from gru_logs.pkl rename to 0_01pkls/gru_logs.pkl diff --git a/randomsearch_logs.pkl b/0_01pkls/randomsearch_logs.pkl similarity index 100% rename from randomsearch_logs.pkl rename to 0_01pkls/randomsearch_logs.pkl diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py index 6e7874e9..909b05b9 100644 --- a/MetaAugment/autoaugment_learners/aa_learner.py +++ b/MetaAugment/autoaugment_learners/aa_learner.py @@ -233,7 +233,7 @@ class aa_learner: train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size=64, - n_samples=0.01, + n_samples=0.04, seed=100) # train the child network with the dataloaders equipped with our specific policy @@ -244,7 +244,7 @@ class aa_learner: # sgd = optim.Adadelta(child_network.parameters(), lr=1e-2), cost = nn.CrossEntropyLoss(), max_epochs = 3000000, - early_stop_num = 120, + early_stop_num = 60, logging = logging) # if logging is true, 'accuracy' is actually a tuple: (accuracy, accuracy_log) diff --git a/plot_pickles.py b/plot_pickles.py index 7f0f2d7c..90462cfa 100644 --- a/plot_pickles.py +++ b/plot_pickles.py @@ -1,6 +1,7 @@ import pickle from pprint import pprint import matplotlib.pyplot as plt +from torch import gru def get_maxacc(log): output = [] @@ -25,6 +26,14 @@ plt.xlabel('number of policies tested') plt.legend() plt.show() +plt.plot([acc for pol,acc in rs_list], label='randomsearcher') +plt.plot([acc for pol,acc in gru_list], label='gru learner') +plt.title('Comparing two agents') +plt.ylabel('best accuracy to date') +plt.xlabel('number of policies tested') +plt.legend() +plt.show() + def get_best5(log): l = sorted(log, reverse=True, key=lambda x:x[1]) -- GitLab