diff --git a/benchmark/scripts/util_04_22.py b/benchmark/scripts/util_04_22.py index 344feef26b011f3fa350db420a240b038dbeb317..62c0456af78549bfaa0599a1a61c40a7eb78e806 100644 --- a/benchmark/scripts/util_04_22.py +++ b/benchmark/scripts/util_04_22.py @@ -1,9 +1,5 @@ -from matplotlib.pyplot import get -import torchvision.datasets as datasets -import torchvision import torch -import MetaAugment.child_networks as cn import MetaAugment.autoaugment_learners as aal import pprint diff --git a/benchmark/scripts/util_04_26.py b/benchmark/scripts/util_04_26.py new file mode 100644 index 0000000000000000000000000000000000000000..a8e05f593c67bfff0002a7b2862a344b63645a89 --- /dev/null +++ b/benchmark/scripts/util_04_26.py @@ -0,0 +1,52 @@ +from cProfile import label +import torch +import matplotlib.pyplot as plt + +def get_best_acc( + save_file, + ): + """ + Use this to get the best accuracy history of the pickles + + Args: + save_file (str): pickle directory + + Returns: + list[floats]: best_accuracy_list + """ + # try to load agent + with open(save_file, 'rb') as f: + agent = torch.load(f) + history = agent.history + + best_acc_list = [] + best_acc = 0.0 + + for policy, acc in history: + best_acc = max(best_acc, acc) + best_acc_list.append(best_acc) + + return best_acc_list + +plt.plot(get_best_acc('benchmark/pickles/04_22_cf_ln_gru.pkl'), + label='GRU') +print('1 done') +plt.plot(get_best_acc('benchmark/pickles/04_22_cf_ln_rs.pkl'), + label='RandomSearch') +print('2 done') +plt.xlabel('no. of child networks trained') +plt.ylabel('highest accuracy obtained until now') +plt.legend() +plt.show() + + +plt.plot(get_best_acc('benchmark/pickles/04_22_fm_sn_gru.pkl'), + label='GRU') +print('3 done') +plt.plot(get_best_acc('benchmark/pickles/04_22_fm_sn_rs.pkl'), + label='RandomSearch') +print('4 done') +plt.xlabel('no. of child networks trained') +plt.ylabel('highest accuracy obtained until now') +plt.legend() +plt.show() \ No newline at end of file