-
Sun Jin Kim authoredSun Jin Kim authored
test_GruLearner.py 1.73 KiB
import autoaug.autoaugment_learners as aal
import autoaug.child_networks as cn
import torch
import torchvision
import torchvision.datasets as datasets
import random
def test__generate_new_policy():
"""
make sure GruLearner._generate_new_policy() is robust
with respect to different values of sp_num, fun_num,
p_bins, and m_bins
"""
for _ in range(40):
sp_num = random.randint(1,20)
p_bins = random.randint(2, 15)
m_bins = random.randint(2, 15)
agent = aal.GruLearner(
sp_num=sp_num,
p_bins=p_bins,
m_bins=m_bins,
cont_mb_size=2
)
for _ in range(4):
new_policy = agent._generate_new_policy()
assert isinstance(new_policy[0], list), new_policy
def test_learn():
"""
tests the GruLearner.learn() method
"""
train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train',
train=True, download=True, transform=None)
test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test',
train=False, download=True,
transform=torchvision.transforms.ToTensor())
child_network_architecture = cn.lenet
# child_network_architecture = cn.lenet()
agent = aal.GruLearner(
sp_num=7,
toy_size=0.001,
batch_size=32,
learning_rate=0.05,
max_epochs=100,
early_stop_num=10,
)
agent.learn(train_dataset,
test_dataset,
child_network_architecture=child_network_architecture,
iterations=2)