-
Sun Jin Kim authoredSun Jin Kim authored
util_04_22.py 3.30 KiB
import torch
import MetaAugment.autoaugment_learners as aal
import pprint
"""
testing gru_learner and randomsearch_learner on
fashionmnist with simple net
and
cifar10 with lenet
"""
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
def run_benchmark(
save_file,
train_dataset,
test_dataset,
child_network_architecture,
agent_arch,
config,
total_iter=150,
):
try:
# try to load agent
with open(save_file, 'rb') as f:
agent = torch.load(f, map_location=device)
except FileNotFoundError:
# if agent hasn't been saved yet, initialize the agent
agent = agent_arch(**config)
# if history is not length total_iter yet(if total_iter
# different policies haven't been tested yet), keep running
while len(agent.history)<total_iter:
print(f'{len(agent.history)} / {total_iter}')
# run 1 iteration (test one new policy and update the GRU)
agent.learn(
train_dataset=train_dataset,
test_dataset=test_dataset,
child_network_architecture=child_network_architecture,
iterations=1
)
# save agent every iteration
with open(save_file, 'wb+') as f:
torch.save(agent, f)
print('run_benchmark closing')
def get_mega_policy(history, n):
"""
we get the best n policies from an agent's history,
concatenate them to form our best mega policy
Args:
history (list[tuple])
n (int)
Returns:
list[float]: validation accuracies
"""
assert len(history) >= n
# agent.history is a list of (policy(list), val_accuracy(float)) tuples
sorted_history = sorted(history, key=lambda x:x[1], reverse=True) # sort wrt acc
best_history = sorted_history[:n]
megapolicy = []
# we also want to keep track of how good the best policies were
# maybe if we add them all up, they'll become worse! Hopefully better tho
orig_accs = []
for policy,acc in best_history:
for subpolicy in policy:
megapolicy.append(subpolicy)
orig_accs.append(acc)
return megapolicy, orig_accs
def rerun_best_policy(
agent_pickle,
accs_txt,
train_dataset,
test_dataset,
child_network_architecture,
config,
repeat_num
):
with open(agent_pickle, 'rb') as f:
agent = torch.load(f)
megapol, orig_accs = get_mega_policy(agent.history,3)
print('mega policy to be tested:')
pprint.pprint(megapol)
print(orig_accs)
accs=[]
for _ in range(repeat_num):
print(f'{_}/{repeat_num}')
temp_agent = aal.aa_learner(**config)
accs.append(
temp_agent._test_autoaugment_policy(megapol,
child_network_architecture,
train_dataset,
test_dataset,
logging=False)
)
with open(accs_txt, 'w') as f:
f.write(pprint.pformat(megapol))
f.write(str(accs))
f.write(f'original small policys accuracies: {orig_accs}')