diff --git a/benchmark/pickles/04_22_cf_ln_gru.txt b/benchmark/pickles/04_22_cf_ln_gru.txt new file mode 100644 index 0000000000000000000000000000000000000000..b77d909a632ea76aa98e70af053eb2e09409df4d --- /dev/null +++ b/benchmark/pickles/04_22_cf_ln_gru.txt @@ -0,0 +1,9 @@ +[(('ShearY', 0.2, 5), ('Rotate', 0.6, 6)), + (('TranslateX', 0.8, 3), ('Posterize', 0.1, 3)), + (('TranslateY', 0.0, 8), ('Equalize', 0.7, None)), + (('Equalize', 0.3, None), ('Contrast', 0.2, 0)), + (('ShearX', 0.4, 5), ('Contrast', 0.2, 8)), + (('TranslateX', 0.9, 3), ('Solarize', 0.4, 5)), + (('Color', 0.2, 4), ('Solarize', 0.6, 8)), + (('ShearX', 0.1, 8), ('Equalize', 0.4, None)), + (('Posterize', 0.7, 5), ('Solarize', 1.0, 4))][0.6056999564170837, 0.6329999566078186, 0.6171000003814697, 0.62909996509552]original small policys accuracies: [0.6236000061035156, 0.6187999844551086, 0.617900013923645] \ No newline at end of file diff --git a/benchmark/pickles/04_22_cf_ln_rs.pkl b/benchmark/pickles/04_22_cf_ln_rs.pkl index c25ff45ad7003d1fe589fc22525cb16b7df1b644..ae6e66cfd5f3c4e4c6940293df987eba7aaaf6d3 100644 Binary files a/benchmark/pickles/04_22_cf_ln_rs.pkl and b/benchmark/pickles/04_22_cf_ln_rs.pkl differ diff --git a/benchmark/pickles/04_22_fm_sn_gru.txt b/benchmark/pickles/04_22_fm_sn_gru.txt new file mode 100644 index 0000000000000000000000000000000000000000..9c00a102009ae63e25ae1437b8bab3b004a890f0 --- /dev/null +++ b/benchmark/pickles/04_22_fm_sn_gru.txt @@ -0,0 +1,9 @@ +[(('Color', 0.9, 3), ('Contrast', 0.8, 3)), + (('Sharpness', 0.9, 0), ('Solarize', 0.3, 7)), + (('Color', 0.0, 6), ('Solarize', 0.4, 3)), + (('Brightness', 0.1, 3), ('Brightness', 0.5, 9)), + (('Solarize', 0.9, 6), ('Rotate', 0.6, 1)), + (('Contrast', 0.7, 3), ('Posterize', 0.9, 4)), + (('Solarize', 0.6, 2), ('Contrast', 0.5, 6)), + (('TranslateX', 0.0, 4), ('AutoContrast', 0.3, None)), + (('Equalize', 0.0, None), ('Brightness', 0.8, 1))][0.7490999698638916, 0.8359999656677246, 0.8394999504089355]original small policys accuracies: [0.8380999565124512, 0.8376999497413635, 0.8376999497413635] \ No newline at end of file diff --git a/benchmark/pickles/04_22_fm_sn_rs.pkl b/benchmark/pickles/04_22_fm_sn_rs.pkl index 452e2fb5458f750f6cc475e2f20c40ff138ad742..ada525c3a479b7652cc5c4551ea7bb637063a6f5 100644 Binary files a/benchmark/pickles/04_22_fm_sn_rs.pkl and b/benchmark/pickles/04_22_fm_sn_rs.pkl differ diff --git a/benchmark/pickles/04_22_fm_sn_rs.txt b/benchmark/pickles/04_22_fm_sn_rs.txt new file mode 100644 index 0000000000000000000000000000000000000000..90b89260b6f0db6578f7310d0f672fd6748f6197 --- /dev/null +++ b/benchmark/pickles/04_22_fm_sn_rs.txt @@ -0,0 +1,9 @@ +[(('ShearX', 1.0, 0), ('Color', 0.3, 2)), + (('AutoContrast', 0.0, None), ('Brightness', 0.7, 2)), + (('Invert', 0.1, None), ('Contrast', 0.1, 6)), + (('Solarize', 0.4, 2), ('Contrast', 0.9, 2)), + (('Equalize', 0.0, None), ('Contrast', 0.0, 2)), + (('Rotate', 0.4, 0), ('Posterize', 0.5, 9)), + (('Posterize', 0.7, 3), ('Invert', 0.1, None)), + (('Solarize', 0.6, 1), ('Contrast', 0.0, 0)), + (('Color', 0.2, 6), ('Posterize', 0.4, 7))][0.6222999691963196, 0.6868000030517578, 0.8374999761581421, 0.8370999693870544, 0.6934999823570251]original small policys accuracies: [0.8431999683380127, 0.8393999934196472, 0.8377999663352966] \ No newline at end of file diff --git a/benchmark/scripts/04_22_fm_gru.py b/benchmark/scripts/04_22_fm_gru.py index b3a951c0afd3eeb0cd8911a30143f08a61c6e5e4..799e439ef22f51cef57b42e807905648800a4710 100644 --- a/benchmark/scripts/04_22_fm_gru.py +++ b/benchmark/scripts/04_22_fm_gru.py @@ -40,6 +40,7 @@ run_benchmark( child_network_architecture=child_network_architecture, agent_arch=aal.gru_learner, config=config, + total_iter=144 ) rerun_best_policy( @@ -48,5 +49,6 @@ rerun_best_policy( train_dataset=train_dataset, test_dataset=test_dataset, child_network_architecture=child_network_architecture, + config=config, repeat_num=5 ) \ No newline at end of file diff --git a/benchmark/scripts/util_04_22.py b/benchmark/scripts/util_04_22.py index 86b033ef65efa96782e809136f2793ebaad6b044..344feef26b011f3fa350db420a240b038dbeb317 100644 --- a/benchmark/scripts/util_04_22.py +++ b/benchmark/scripts/util_04_22.py @@ -6,7 +6,7 @@ import torch import MetaAugment.child_networks as cn import MetaAugment.autoaugment_learners as aal -from pprint import pprint +import pprint """ testing gru_learner and randomsearch_learner on @@ -75,16 +75,21 @@ def get_mega_policy(history, n): assert len(history) >= n # agent.history is a list of (policy(list), val_accuracy(float)) tuples - sorted_history = sorted(history, key=lambda x:x[1]) # sort wrt acc + sorted_history = sorted(history, key=lambda x:x[1], reverse=True) # sort wrt acc best_history = sorted_history[:n] megapolicy = [] + # we also want to keep track of how good the best policies were + # maybe if we add them all up, they'll become worse! Hopefully better tho + orig_accs = [] + for policy,acc in best_history: for subpolicy in policy: megapolicy.append(subpolicy) + orig_accs.append(acc) - return megapolicy + return megapolicy, orig_accs def rerun_best_policy( @@ -93,25 +98,30 @@ def rerun_best_policy( train_dataset, test_dataset, child_network_architecture, + config, repeat_num ): with open(agent_pickle, 'rb') as f: - agent = torch.load(f, map_location=device) + agent = torch.load(f) - megapol = get_mega_policy(agent.history) + megapol, orig_accs = get_mega_policy(agent.history,3) print('mega policy to be tested:') - pprint(megapol) - + pprint.pprint(megapol) + print(orig_accs) + accs=[] for _ in range(repeat_num): print(f'{_}/{repeat_num}') + temp_agent = aal.aa_learner(**config) accs.append( - agent.test_autoaugment_policy(megapol, + temp_agent.test_autoaugment_policy(megapol, child_network_architecture, train_dataset, test_dataset, logging=False) ) with open(accs_txt, 'w') as f: + f.write(pprint.pformat(megapol)) f.write(str(accs)) + f.write(f'original small policys accuracies: {orig_accs}')