-
Sun Jin Kim authoredSun Jin Kim authored
test_aa_learner.py 4.79 KiB
import MetaAugment.autoaugment_learners as aal
import MetaAugment.child_networks as cn
import torch
import torchvision
import torchvision.datasets as datasets
import random
def test__translate_operation_tensor():
"""
See if aa_learner class's _translate_operation_tensor works
by feeding many (valid) inputs in it.
We make a lot of (fun_num+p_bins_m_bins,) size tensors, softmax
them, and feed them through the _translate_operation_tensor method
to see if it doesn't break
"""
# discrete_p_m=True
for i in range(2000):
softmax = torch.nn.Softmax(dim=0)
fun_num=14
p_bins = random.randint(2, 15)
m_bins = random.randint(2, 15)
agent = aal.aa_learner(
sp_num=5,
p_bins=p_bins,
m_bins=m_bins,
discrete_p_m=True
)
alpha = i/1000
vector = torch.rand(fun_num+p_bins+m_bins)
fun_t, prob_t, mag_t = vector.split([fun_num, p_bins, m_bins])
fun_t = softmax(fun_t * alpha)
prob_t = softmax(prob_t * alpha)
mag_t = softmax(mag_t * alpha)
softmaxed_vector = torch.cat((fun_t, prob_t, mag_t))
agent._translate_operation_tensor(softmaxed_vector)
# discrete_p_m=False
softmax = torch.nn.Softmax(dim=0)
sigmoid = torch.nn.Sigmoid()
for i in range(2000):
fun_num = 14
p_bins = random.randint(1, 15)
m_bins = random.randint(1, 15)
agent = aal.aa_learner(
sp_num=5,
p_bins=p_bins,
m_bins=m_bins,
discrete_p_m=False
)
alpha = i/1000
vector = torch.rand(fun_num+2)
fun_t, prob_t, mag_t = vector.split([fun_num, 1, 1])
fun_t = softmax(fun_t * alpha)
prob_t = sigmoid(prob_t)
mag_t = sigmoid(mag_t) * (m_bins-1)
softmaxed_vector = torch.cat((fun_t, prob_t, mag_t))
agent._translate_operation_tensor(softmaxed_vector)
def test__test_autoaugment_policy():
agent = aal.aa_learner(
sp_num=5,
p_bins=11,
m_bins=10,
discrete_p_m=True,
toy_size=0.004,
max_epochs=20,
early_stop_num=10
)
policy = [
(("Invert", 0.8, None), ("Contrast", 0.2, 6)),
(("Rotate", 0.7, 2), ("Invert", 0.8, None)),
(("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
(("ShearY", 0.5, 8), ("Invert", 0.7, None)),
(("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
(("ShearY", 0.8, 4), ("Rotate", 0.5, 6)),
(("TranslateY", 0.7, 4), ("TranslateX", 0.8, 6)),
(("Rotate", 0.5, 3), ("ShearY", 0.8, 5)),
(("ShearX", 0.5, 6), ("TranslateY", 0.7, 3)),
(("Rotate", 0.5, 3), ("TranslateX", 0.5, 5))
]
child_network_architecture = cn.SimpleNet
train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train',
train=True, download=True, transform=None)
test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test',
train=False, download=True,
transform=torchvision.transforms.ToTensor())
acc = agent._test_autoaugment_policy(
policy,
child_network_architecture,
train_dataset,
test_dataset,
logging=False
)
assert isinstance(acc, float)
def test_exclude_method():
"""
we want to see if the exclude_methods
parameter is working properly in aa_learners
"""
exclude_method = [
'ShearX',
'Color',
'Brightness',
'Contrast'
]
agent = aal.gru_learner(
exclude_method=exclude_method
)
for _ in range(200):
new_pol, _ = agent._generate_new_policy()
print(new_pol)
for (op1, op2) in new_pol:
image_function_1 = op1[0]
image_function_2 = op2[0]
assert image_function_1 not in exclude_method
assert image_function_2 not in exclude_method
agent = aal.randomsearch_learner(
exclude_method=exclude_method
)
for _ in range(200):
new_pol= agent._generate_new_policy()
print(new_pol)
for (op1, op2) in new_pol:
image_function_1 = op1[0]
image_function_2 = op2[0]
assert image_function_1 not in exclude_method
assert image_function_2 not in exclude_method