Skip to content
Snippets Groups Projects
Commit 7d4fe4cc authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

Merge branch 'master' of gitlab.doc.ic.ac.uk:yw21218/metarl

parents 922f9cf6 bd183477
No related branches found
No related tags found
No related merge requests found
[(('ShearY', 0.2, 5), ('Rotate', 0.6, 6)),
(('TranslateX', 0.8, 3), ('Posterize', 0.1, 3)),
(('TranslateY', 0.0, 8), ('Equalize', 0.7, None)),
(('Equalize', 0.3, None), ('Contrast', 0.2, 0)),
(('ShearX', 0.4, 5), ('Contrast', 0.2, 8)),
(('TranslateX', 0.9, 3), ('Solarize', 0.4, 5)),
(('Color', 0.2, 4), ('Solarize', 0.6, 8)),
(('ShearX', 0.1, 8), ('Equalize', 0.4, None)),
(('Posterize', 0.7, 5), ('Solarize', 1.0, 4))][0.6056999564170837, 0.6329999566078186, 0.6171000003814697, 0.62909996509552]original small policys accuracies: [0.6236000061035156, 0.6187999844551086, 0.617900013923645]
\ No newline at end of file
No preview for this file type
[(('Color', 0.9, 3), ('Contrast', 0.8, 3)),
(('Sharpness', 0.9, 0), ('Solarize', 0.3, 7)),
(('Color', 0.0, 6), ('Solarize', 0.4, 3)),
(('Brightness', 0.1, 3), ('Brightness', 0.5, 9)),
(('Solarize', 0.9, 6), ('Rotate', 0.6, 1)),
(('Contrast', 0.7, 3), ('Posterize', 0.9, 4)),
(('Solarize', 0.6, 2), ('Contrast', 0.5, 6)),
(('TranslateX', 0.0, 4), ('AutoContrast', 0.3, None)),
(('Equalize', 0.0, None), ('Brightness', 0.8, 1))][0.7490999698638916, 0.8359999656677246, 0.8394999504089355]original small policys accuracies: [0.8380999565124512, 0.8376999497413635, 0.8376999497413635]
\ No newline at end of file
No preview for this file type
[(('ShearX', 1.0, 0), ('Color', 0.3, 2)),
(('AutoContrast', 0.0, None), ('Brightness', 0.7, 2)),
(('Invert', 0.1, None), ('Contrast', 0.1, 6)),
(('Solarize', 0.4, 2), ('Contrast', 0.9, 2)),
(('Equalize', 0.0, None), ('Contrast', 0.0, 2)),
(('Rotate', 0.4, 0), ('Posterize', 0.5, 9)),
(('Posterize', 0.7, 3), ('Invert', 0.1, None)),
(('Solarize', 0.6, 1), ('Contrast', 0.0, 0)),
(('Color', 0.2, 6), ('Posterize', 0.4, 7))][0.6222999691963196, 0.6868000030517578, 0.8374999761581421, 0.8370999693870544, 0.6934999823570251]original small policys accuracies: [0.8431999683380127, 0.8393999934196472, 0.8377999663352966]
\ No newline at end of file
...@@ -40,6 +40,7 @@ run_benchmark( ...@@ -40,6 +40,7 @@ run_benchmark(
child_network_architecture=child_network_architecture, child_network_architecture=child_network_architecture,
agent_arch=aal.gru_learner, agent_arch=aal.gru_learner,
config=config, config=config,
total_iter=144
) )
rerun_best_policy( rerun_best_policy(
...@@ -48,5 +49,6 @@ rerun_best_policy( ...@@ -48,5 +49,6 @@ rerun_best_policy(
train_dataset=train_dataset, train_dataset=train_dataset,
test_dataset=test_dataset, test_dataset=test_dataset,
child_network_architecture=child_network_architecture, child_network_architecture=child_network_architecture,
config=config,
repeat_num=5 repeat_num=5
) )
\ No newline at end of file
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
import MetaAugment.child_networks as cn import MetaAugment.child_networks as cn
import MetaAugment.autoaugment_learners as aal import MetaAugment.autoaugment_learners as aal
from pprint import pprint import pprint
""" """
testing gru_learner and randomsearch_learner on testing gru_learner and randomsearch_learner on
...@@ -75,16 +75,21 @@ def get_mega_policy(history, n): ...@@ -75,16 +75,21 @@ def get_mega_policy(history, n):
assert len(history) >= n assert len(history) >= n
# agent.history is a list of (policy(list), val_accuracy(float)) tuples # agent.history is a list of (policy(list), val_accuracy(float)) tuples
sorted_history = sorted(history, key=lambda x:x[1]) # sort wrt acc sorted_history = sorted(history, key=lambda x:x[1], reverse=True) # sort wrt acc
best_history = sorted_history[:n] best_history = sorted_history[:n]
megapolicy = [] megapolicy = []
# we also want to keep track of how good the best policies were
# maybe if we add them all up, they'll become worse! Hopefully better tho
orig_accs = []
for policy,acc in best_history: for policy,acc in best_history:
for subpolicy in policy: for subpolicy in policy:
megapolicy.append(subpolicy) megapolicy.append(subpolicy)
orig_accs.append(acc)
return megapolicy return megapolicy, orig_accs
def rerun_best_policy( def rerun_best_policy(
...@@ -93,25 +98,30 @@ def rerun_best_policy( ...@@ -93,25 +98,30 @@ def rerun_best_policy(
train_dataset, train_dataset,
test_dataset, test_dataset,
child_network_architecture, child_network_architecture,
config,
repeat_num repeat_num
): ):
with open(agent_pickle, 'rb') as f: with open(agent_pickle, 'rb') as f:
agent = torch.load(f, map_location=device) agent = torch.load(f)
megapol = get_mega_policy(agent.history) megapol, orig_accs = get_mega_policy(agent.history,3)
print('mega policy to be tested:') print('mega policy to be tested:')
pprint(megapol) pprint.pprint(megapol)
print(orig_accs)
accs=[] accs=[]
for _ in range(repeat_num): for _ in range(repeat_num):
print(f'{_}/{repeat_num}') print(f'{_}/{repeat_num}')
temp_agent = aal.aa_learner(**config)
accs.append( accs.append(
agent.test_autoaugment_policy(megapol, temp_agent.test_autoaugment_policy(megapol,
child_network_architecture, child_network_architecture,
train_dataset, train_dataset,
test_dataset, test_dataset,
logging=False) logging=False)
) )
with open(accs_txt, 'w') as f: with open(accs_txt, 'w') as f:
f.write(pprint.pformat(megapol))
f.write(str(accs)) f.write(str(accs))
f.write(f'original small policys accuracies: {orig_accs}')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment