Skip to content
Snippets Groups Projects
Commit 929f2dd0 authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

mnist 0.01 randomsearch vs gru graph data

parent a26ed543
No related branches found
No related tags found
No related merge requests found
File moved
File moved
File moved
......@@ -233,7 +233,7 @@ class aa_learner:
train_loader, test_loader = create_toy(train_dataset,
test_dataset,
batch_size=64,
n_samples=0.01,
n_samples=0.04,
seed=100)
# train the child network with the dataloaders equipped with our specific policy
......@@ -244,7 +244,7 @@ class aa_learner:
# sgd = optim.Adadelta(child_network.parameters(), lr=1e-2),
cost = nn.CrossEntropyLoss(),
max_epochs = 3000000,
early_stop_num = 120,
early_stop_num = 60,
logging = logging)
# if logging is true, 'accuracy' is actually a tuple: (accuracy, accuracy_log)
......
import pickle
from pprint import pprint
import matplotlib.pyplot as plt
from torch import gru
def get_maxacc(log):
output = []
......@@ -25,6 +26,14 @@ plt.xlabel('number of policies tested')
plt.legend()
plt.show()
plt.plot([acc for pol,acc in rs_list], label='randomsearcher')
plt.plot([acc for pol,acc in gru_list], label='gru learner')
plt.title('Comparing two agents')
plt.ylabel('best accuracy to date')
plt.xlabel('number of policies tested')
plt.legend()
plt.show()
def get_best5(log):
l = sorted(log, reverse=True, key=lambda x:x[1])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment