-
Sun Jin Kim authoredSun Jin Kim authored
04_22_ci_gru.py 1.55 KiB
import torchvision.datasets as datasets
import torchvision
import torch
import MetaAugment.child_networks as cn
import MetaAugment.autoaugment_learners as aal
from .util_04_22 import *
# aa_learner config
config = {
'sp_num' : 3,
'learning_rate' : 1e-1,
'toy_flag' : False,
# 'toy_flag' : True,
# 'toy_size' : 0.001,
'batch_size' : 32,
'max_epochs' : 100,
'early_stop_num' : 10,
}
# CIFAR10 with LeNet
train_dataset = datasets.CIFAR10(root='./datasets/cifar10/train',
train=True, download=True, transform=None)
test_dataset = datasets.CIFAR10(root='./datasets/cifar10/train',
train=False, download=True,
transform=torchvision.transforms.ToTensor())
child_network_architecture = cn.LeNet(
img_height=32,
img_width=32,
num_labels=10,
img_channels=3
)
save_dir='./benchmark/pickles/04_22_cf_ln_gru'
# rs
run_benchmark(
save_file=save_dir+'.pkl',
train_dataset=train_dataset,
test_dataset=test_dataset,
child_network_architecture=child_network_architecture,
agent_arch=aal.gru_learner,
config=config,
)
rerun_best_policy(
agent_pickle=save_dir+'.pkl',
accs_txt=save_dir+'.txt',
train_dataset=train_dataset,
test_dataset=test_dataset,
child_network_architecture=child_network_architecture,
repeat_num=5
)