diff --git a/benchmark/scripts/04_22_ci_gru.py b/benchmark/scripts/04_22_ci_gru.py index 5c4db6bd5ce8f713f8dbc6ea829176454e4ce28a..1a5b0fadca16fbaabe583b3a58abf0ae9d87c4db 100644 --- a/benchmark/scripts/04_22_ci_gru.py +++ b/benchmark/scripts/04_22_ci_gru.py @@ -35,12 +35,23 @@ child_network_architecture = cn.LeNet( ) -# gru +save_dir='./benchmark/pickles/04_22_cf_ln_gru' + +# rs run_benchmark( - save_file='./benchmark/pickles/04_22_cf_ln_gru', + save_file=save_dir+'.pkl', train_dataset=train_dataset, test_dataset=test_dataset, child_network_architecture=child_network_architecture, agent_arch=aal.gru_learner, config=config, + ) + +rerun_best_policy( + agent_pickle=save_dir+'.pkl', + accs_txt=save_dir+'.txt', + train_dataset=train_dataset, + test_dataset=test_dataset, + child_network_architecture=child_network_architecture, + repeat_num=5 ) \ No newline at end of file diff --git a/benchmark/scripts/04_22_ci_rs.py b/benchmark/scripts/04_22_ci_rs.py index b98c25fb7826918cfa4f1e6cdb5dc1484cd9c662..21f3a9a3e65eb8e7126ec32641c48726b2f4172c 100644 --- a/benchmark/scripts/04_22_ci_rs.py +++ b/benchmark/scripts/04_22_ci_rs.py @@ -34,13 +34,23 @@ child_network_architecture = cn.LeNet( img_channels=3 ) +save_dir='./benchmark/pickles/04_22_cf_ln_rs' # rs run_benchmark( - save_file='./benchmark/pickles/04_22_cf_ln_rs', + save_file=save_dir+'.pkl', train_dataset=train_dataset, test_dataset=test_dataset, child_network_architecture=child_network_architecture, agent_arch=aal.randomsearch_learner, config=config, ) + +rerun_best_policy( + agent_pickle=save_dir+'.pkl', + accs_txt=save_dir+'.txt', + train_dataset=train_dataset, + test_dataset=test_dataset, + child_network_architecture=child_network_architecture, + repeat_num=5 + ) \ No newline at end of file diff --git a/benchmark/scripts/04_22_fm_gru.py b/benchmark/scripts/04_22_fm_gru.py index 227918517fef504b8e5b27aaab354c3c2366e1c8..b3a951c0afd3eeb0cd8911a30143f08a61c6e5e4 100644 --- a/benchmark/scripts/04_22_fm_gru.py +++ b/benchmark/scripts/04_22_fm_gru.py @@ -30,12 +30,23 @@ test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test', child_network_architecture = cn.SimpleNet -# gru +save_dir='./benchmark/pickles/04_22_fm_sn_gru' + +# rs run_benchmark( - save_file='./benchmark/pickles/04_22_fm_sn_gru.pkl', + save_file=save_dir+'.pkl', train_dataset=train_dataset, test_dataset=test_dataset, child_network_architecture=child_network_architecture, agent_arch=aal.gru_learner, config=config, + ) + +rerun_best_policy( + agent_pickle=save_dir+'.pkl', + accs_txt=save_dir+'.txt', + train_dataset=train_dataset, + test_dataset=test_dataset, + child_network_architecture=child_network_architecture, + repeat_num=5 ) \ No newline at end of file diff --git a/benchmark/scripts/04_22_fm_rs.py b/benchmark/scripts/04_22_fm_rs.py index 33a4b7b26cd4211bdc7b241c0370fc5b5f1abe59..0589630f3906fedfeca72326b1d77fdf9332d5b9 100644 --- a/benchmark/scripts/04_22_fm_rs.py +++ b/benchmark/scripts/04_22_fm_rs.py @@ -30,12 +30,23 @@ test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test', child_network_architecture = cn.SimpleNet +save_dir='./benchmark/pickles/04_22_fm_sn_rs' + # rs run_benchmark( - save_file='./benchmark/pickles/04_22_fm_sn_rs.pkl', + save_file=save_dir+'.pkl', train_dataset=train_dataset, test_dataset=test_dataset, child_network_architecture=child_network_architecture, agent_arch=aal.randomsearch_learner, config=config, + ) + +rerun_best_policy( + agent_pickle=save_dir+'.pkl', + accs_txt=save_dir+'.txt', + train_dataset=train_dataset, + test_dataset=test_dataset, + child_network_architecture=child_network_architecture, + repeat_num=5 ) \ No newline at end of file diff --git a/benchmark/scripts/util_04_22.py b/benchmark/scripts/util_04_22.py index 8e39fdd69cef06f52e48d92136e8f617b85dfaf8..86b033ef65efa96782e809136f2793ebaad6b044 100644 --- a/benchmark/scripts/util_04_22.py +++ b/benchmark/scripts/util_04_22.py @@ -1,3 +1,4 @@ +from matplotlib.pyplot import get import torchvision.datasets as datasets import torchvision import torch @@ -5,7 +6,7 @@ import torch import MetaAugment.child_networks as cn import MetaAugment.autoaugment_learners as aal - +from pprint import pprint """ testing gru_learner and randomsearch_learner on @@ -56,4 +57,61 @@ def run_benchmark( with open(save_file, 'wb+') as f: torch.save(agent, f) - print('run_benchmark closing') \ No newline at end of file + print('run_benchmark closing') + + +def get_mega_policy(history, n): + """ + we get the best n policies from an agent's history, + concatenate them to form our best mega policy + + Args: + history (list[tuple]) + n (int) + + Returns: + list[float]: validation accuracies + """ + assert len(history) >= n + + # agent.history is a list of (policy(list), val_accuracy(float)) tuples + sorted_history = sorted(history, key=lambda x:x[1]) # sort wrt acc + + best_history = sorted_history[:n] + + megapolicy = [] + for policy,acc in best_history: + for subpolicy in policy: + megapolicy.append(subpolicy) + + return megapolicy + + +def rerun_best_policy( + agent_pickle, + accs_txt, + train_dataset, + test_dataset, + child_network_architecture, + repeat_num + ): + + with open(agent_pickle, 'rb') as f: + agent = torch.load(f, map_location=device) + + megapol = get_mega_policy(agent.history) + print('mega policy to be tested:') + pprint(megapol) + + accs=[] + for _ in range(repeat_num): + print(f'{_}/{repeat_num}') + accs.append( + agent.test_autoaugment_policy(megapol, + child_network_architecture, + train_dataset, + test_dataset, + logging=False) + ) + with open(accs_txt, 'w') as f: + f.write(str(accs))