Skip to content
Snippets Groups Projects
Commit 38498217 authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

Add aa_learner unit tests

parent 6f53474a
No related branches found
No related tags found
No related merge requests found
# MetaRL
Documentation:
https://metaaugment.readthedocs.io/en/latest/
\ No newline at end of file
See
"Sphinx Tutorial" on https://www.sphinx-doc.org/en/master/contents.html
and
https://docs.readthedocs.io/en/stable/tutorial/
if you want to contribute to the docs.
\ No newline at end of file
import MetaAugment.autoaugment_learners as aal
import MetaAugment.child_networks as cn
import torch
import torchvision
import torchvision.datasets as datasets
import random
from tqdm import trange
def test_translate_operation_tensor():
"""
See if aa_learner class's translate_operation_tensor works
by feeding many (valid) inputs in it.
We make a lot of (fun_num+p_bins_m_bins,) size tensors, softmax
them, and feed them through the translate_operation_tensor method
to see if it doesn't break
"""
# discrete_p_m=True
for i in trange(2000):
softmax = torch.nn.Softmax(dim=0)
fun_num = random.randint(1, 14)
p_bins = random.randint(1, 15)
m_bins = random.randint(1, 15)
agent = aal.aa_learner(
sp_num=5,
fun_num=fun_num,
p_bins=p_bins,
m_bins=m_bins,
discrete_p_m=True
)
alpha = i/1000
vector = torch.rand(fun_num+p_bins+m_bins)
fun_t, prob_t, mag_t = vector.split([fun_num, p_bins, m_bins])
fun_t = softmax(fun_t * alpha)
prob_t = softmax(prob_t * alpha)
mag_t = softmax(mag_t * alpha)
softmaxed_vector = torch.cat((fun_t, prob_t, mag_t))
agent.translate_operation_tensor(softmaxed_vector)
# discrete_p_m=False
softmax = torch.nn.Softmax(dim=0)
sigmoid = torch.nn.Sigmoid()
for i in trange(2000):
fun_num = random.randint(1, 14)
p_bins = random.randint(1, 15)
m_bins = random.randint(1, 15)
agent = aal.aa_learner(
sp_num=5,
fun_num=fun_num,
p_bins=p_bins,
m_bins=m_bins,
discrete_p_m=False
)
alpha = i/1000
vector = torch.rand(fun_num+2)
fun_t, prob_t, mag_t = vector.split([fun_num, 1, 1])
fun_t = softmax(fun_t * alpha)
prob_t = sigmoid(prob_t)
mag_t = sigmoid(mag_t) * (m_bins-1)
softmaxed_vector = torch.cat((fun_t, prob_t, mag_t))
agent.translate_operation_tensor(softmaxed_vector)
def test_test_autoaugment_policy():
agent = aal.aa_learner(
sp_num=5,
fun_num=14,
p_bins=11,
m_bins=10,
discrete_p_m=True,
toy_flag=True,
toy_size=0.004,
max_epochs=20,
early_stop_num=10
)
policy = [
(("Invert", 0.8, None), ("Contrast", 0.2, 6)),
(("Rotate", 0.7, 2), ("Invert", 0.8, None)),
(("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
(("ShearY", 0.5, 8), ("Invert", 0.7, None)),
(("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
(("ShearY", 0.8, 4), ("Rotate", 0.5, 6)),
(("TranslateY", 0.7, 4), ("TranslateX", 0.8, 6)),
(("Rotate", 0.5, 3), ("ShearY", 0.8, 5)),
(("ShearX", 0.5, 6), ("TranslateY", 0.7, 3)),
(("Rotate", 0.5, 3), ("TranslateX", 0.5, 5))
]
child_network_architecture = cn.SimpleNet
train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train',
train=True, download=True, transform=None)
test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test',
train=False, download=True,
transform=torchvision.transforms.ToTensor())
acc = agent.test_autoaugment_policy(
policy,
child_network_architecture,
train_dataset,
test_dataset,
logging=False
)
assert isinstance(acc, float)
\ No newline at end of file
Type
`pytest`
in main directory to run all tests in this directory.
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment