diff --git a/plot_pickles.py b/plot_pickles.py new file mode 100644 index 0000000000000000000000000000000000000000..7f0f2d7ca3bcd1c2e16a4a8ccfcd1fe21ee4bd69 --- /dev/null +++ b/plot_pickles.py @@ -0,0 +1,38 @@ +import pickle +from pprint import pprint +import matplotlib.pyplot as plt + +def get_maxacc(log): + output = [] + maxacc = 0 + for policy, acc in log: + maxacc = max(maxacc, acc) + output.append(maxacc) + return output + +with open('randomsearch_logs.pkl', 'rb') as file: + rs_list = pickle.load(file) + +with open('gru_logs.pkl', 'rb') as file: + gru_list = pickle.load(file) + + +plt.plot(get_maxacc(rs_list), label='randomsearcher') +plt.plot(get_maxacc(gru_list), label='gru learner') +plt.title('Comparing two agents') +plt.ylabel('best accuracy to date') +plt.xlabel('number of policies tested') +plt.legend() +plt.show() + + +def get_best5(log): + l = sorted(log, reverse=True, key=lambda x:x[1]) + return (l[:5]) + +def get_worst5(log): + l = sorted(log, key=lambda x:x[1]) + return (l[:5]) + +pprint(get_best5(rs_list)) +pprint(get_best5(gru_list)) \ No newline at end of file