From 5221e2a0c2591f90dafe84f759eab7c027243cc9 Mon Sep 17 00:00:00 2001
From: Sun Jin Kim <sk2521@ic.ac.uk>
Date: Fri, 8 Apr 2022 21:58:04 +0900
Subject: [PATCH] Add plot_pickles, for the presentation

---
 plot_pickles.py | 38 ++++++++++++++++++++++++++++++++++++++
 1 file changed, 38 insertions(+)
 create mode 100644 plot_pickles.py

diff --git a/plot_pickles.py b/plot_pickles.py
new file mode 100644
index 00000000..7f0f2d7c
--- /dev/null
+++ b/plot_pickles.py
@@ -0,0 +1,38 @@
+import pickle
+from pprint import pprint
+import matplotlib.pyplot as plt
+
+def get_maxacc(log):
+    output = []
+    maxacc = 0
+    for policy, acc in log:
+        maxacc = max(maxacc, acc)
+        output.append(maxacc)
+    return output
+
+with open('randomsearch_logs.pkl', 'rb') as file:
+    rs_list = pickle.load(file)
+
+with open('gru_logs.pkl', 'rb') as file:
+    gru_list = pickle.load(file)
+
+
+plt.plot(get_maxacc(rs_list), label='randomsearcher')
+plt.plot(get_maxacc(gru_list), label='gru learner')
+plt.title('Comparing two agents')
+plt.ylabel('best accuracy to date')
+plt.xlabel('number of policies tested')
+plt.legend()
+plt.show()
+
+
+def get_best5(log):
+    l = sorted(log, reverse=True, key=lambda x:x[1])
+    return (l[:5])
+
+def get_worst5(log):
+    l = sorted(log, key=lambda x:x[1])
+    return (l[:5])
+
+pprint(get_best5(rs_list))
+pprint(get_best5(gru_list))
\ No newline at end of file
-- 
GitLab