Skip to content
Snippets Groups Projects
Commit 181681da authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

update benchmark/scripts

parent 14aa36ee
No related branches found
No related tags found
No related merge requests found
......@@ -35,12 +35,23 @@ child_network_architecture = cn.LeNet(
)
# gru
save_dir='./benchmark/pickles/04_22_cf_ln_gru'
# rs
run_benchmark(
save_file='./benchmark/pickles/04_22_cf_ln_gru',
save_file=save_dir+'.pkl',
train_dataset=train_dataset,
test_dataset=test_dataset,
child_network_architecture=child_network_architecture,
agent_arch=aal.gru_learner,
config=config,
)
rerun_best_policy(
agent_pickle=save_dir+'.pkl',
accs_txt=save_dir+'.txt',
train_dataset=train_dataset,
test_dataset=test_dataset,
child_network_architecture=child_network_architecture,
repeat_num=5
)
\ No newline at end of file
......@@ -34,13 +34,23 @@ child_network_architecture = cn.LeNet(
img_channels=3
)
save_dir='./benchmark/pickles/04_22_cf_ln_rs'
# rs
run_benchmark(
save_file='./benchmark/pickles/04_22_cf_ln_rs',
save_file=save_dir+'.pkl',
train_dataset=train_dataset,
test_dataset=test_dataset,
child_network_architecture=child_network_architecture,
agent_arch=aal.randomsearch_learner,
config=config,
)
rerun_best_policy(
agent_pickle=save_dir+'.pkl',
accs_txt=save_dir+'.txt',
train_dataset=train_dataset,
test_dataset=test_dataset,
child_network_architecture=child_network_architecture,
repeat_num=5
)
\ No newline at end of file
......@@ -30,12 +30,23 @@ test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test',
child_network_architecture = cn.SimpleNet
# gru
save_dir='./benchmark/pickles/04_22_fm_sn_gru'
# rs
run_benchmark(
save_file='./benchmark/pickles/04_22_fm_sn_gru.pkl',
save_file=save_dir+'.pkl',
train_dataset=train_dataset,
test_dataset=test_dataset,
child_network_architecture=child_network_architecture,
agent_arch=aal.gru_learner,
config=config,
)
rerun_best_policy(
agent_pickle=save_dir+'.pkl',
accs_txt=save_dir+'.txt',
train_dataset=train_dataset,
test_dataset=test_dataset,
child_network_architecture=child_network_architecture,
repeat_num=5
)
\ No newline at end of file
......@@ -30,12 +30,23 @@ test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test',
child_network_architecture = cn.SimpleNet
save_dir='./benchmark/pickles/04_22_fm_sn_rs'
# rs
run_benchmark(
save_file='./benchmark/pickles/04_22_fm_sn_rs.pkl',
save_file=save_dir+'.pkl',
train_dataset=train_dataset,
test_dataset=test_dataset,
child_network_architecture=child_network_architecture,
agent_arch=aal.randomsearch_learner,
config=config,
)
rerun_best_policy(
agent_pickle=save_dir+'.pkl',
accs_txt=save_dir+'.txt',
train_dataset=train_dataset,
test_dataset=test_dataset,
child_network_architecture=child_network_architecture,
repeat_num=5
)
\ No newline at end of file
from matplotlib.pyplot import get
import torchvision.datasets as datasets
import torchvision
import torch
......@@ -5,7 +6,7 @@ import torch
import MetaAugment.child_networks as cn
import MetaAugment.autoaugment_learners as aal
from pprint import pprint
"""
testing gru_learner and randomsearch_learner on
......@@ -56,4 +57,61 @@ def run_benchmark(
with open(save_file, 'wb+') as f:
torch.save(agent, f)
print('run_benchmark closing')
\ No newline at end of file
print('run_benchmark closing')
def get_mega_policy(history, n):
"""
we get the best n policies from an agent's history,
concatenate them to form our best mega policy
Args:
history (list[tuple])
n (int)
Returns:
list[float]: validation accuracies
"""
assert len(history) >= n
# agent.history is a list of (policy(list), val_accuracy(float)) tuples
sorted_history = sorted(history, key=lambda x:x[1]) # sort wrt acc
best_history = sorted_history[:n]
megapolicy = []
for policy,acc in best_history:
for subpolicy in policy:
megapolicy.append(subpolicy)
return megapolicy
def rerun_best_policy(
agent_pickle,
accs_txt,
train_dataset,
test_dataset,
child_network_architecture,
repeat_num
):
with open(agent_pickle, 'rb') as f:
agent = torch.load(f, map_location=device)
megapol = get_mega_policy(agent.history)
print('mega policy to be tested:')
pprint(megapol)
accs=[]
for _ in range(repeat_num):
print(f'{_}/{repeat_num}')
accs.append(
agent.test_autoaugment_policy(megapol,
child_network_architecture,
train_dataset,
test_dataset,
logging=False)
)
with open(accs_txt, 'w') as f:
f.write(str(accs))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment