Newer
Older
import MetaAugment.autoaugment_learners as aal
import MetaAugment.child_networks as cn
import torchvision
import torchvision.datasets as datasets
from pprint import pprint
def test_ucb_learner():
child_network_architecture = cn.SimpleNet
train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train',
train=True, download=True, transform=None)
test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test',
train=False, download=True,
transform=torchvision.transforms.ToTensor())
learner = aal.ucb_learner(
# parameters that define the search space
sp_num=5,
p_bins=11,
m_bins=10,
discrete_p_m=True,
# hyperparameters for when training the child_network
batch_size=8,
learning_rate=1e-1,
max_epochs=float('inf'),
early_stop_num=30,
# ucb_learner specific hyperparameter
pprint(learner.policies)
assert len(learner.policies)==len(learner.avg_accs), \
(len(learner.policies), (len(learner.avg_accs)))
# learn on the 3 policies we generated
learner.learn(
train_dataset=train_dataset,
test_dataset=test_dataset,
child_network_architecture=child_network_architecture,
iterations=5
)
# let's say we want to explore more policies:
# we generate more new policies
learner.make_more_policies(n=4)
# and let's explore how good those are as well
learner.learn(
train_dataset=train_dataset,
test_dataset=test_dataset,
child_network_architecture=child_network_architecture,
iterations=7
)