Skip to content
Snippets Groups Projects
Commit bcfdeeec authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

add test_exclude_method

parent 24808e5a
No related branches found
No related tags found
No related merge requests found
Pipeline #272071 failed
......@@ -202,8 +202,11 @@ celerybeat.pid
# SageMath parsed files
*.sage.py
# we don't want dataset directories
**/test
**/train
# but there is a unit test folder in main that we DO want to track
!test
# user uplaod
/react_backend/child_networks
......
......@@ -329,7 +329,9 @@ class aa_learner:
accuracy (float): best accuracy reached in any
"""
# we create an instance of the child network that we're going
# to train. The method of creation depends on the type of
# input we got for child_network_architecture
if isinstance(child_network_architecture, types.FunctionType):
child_network = child_network_architecture()
elif isinstance(child_network_architecture, type):
......
......@@ -116,4 +116,41 @@ def test_test_autoaugment_policy():
)
assert isinstance(acc, float)
def test_exclude_method():
"""
we want to see if the exclude_methods
parameter is working properly in aa_learners
"""
exclude_method = [
'ShearX',
'Color',
'Brightness',
'Contrast'
]
agent = aal.gru_learner(
exclude_method=exclude_method
)
for _ in range(200):
new_pol, _ = agent.generate_new_policy()
print(new_pol)
for (op1, op2) in new_pol:
image_function_1 = op1[0]
image_function_2 = op2[0]
assert image_function_1 not in exclude_method
assert image_function_2 not in exclude_method
agent = aal.randomsearch_learner(
exclude_method=exclude_method
)
for _ in range(200):
new_pol, _ = agent.generate_new_policy()
print(new_pol)
for (op1, op2) in new_pol:
image_function_1 = op1[0]
image_function_2 = op2[0]
assert image_function_1 not in exclude_method
assert image_function_2 not in exclude_method
\ No newline at end of file
import MetaAugment.autoaugment_learners as aal
import MetaAugment.child_networks as cn
import torchvision
import torchvision.datasets as datasets
from pprint import pprint
def test_evo_learner():
child_network_architecture = cn.SimpleNet
train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train',
train=True, download=True, transform=None)
test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test',
train=False, download=True,
transform=torchvision.transforms.ToTensor())
learner = aal.evo_learner(
# parameters that define the search space
sp_num=5,
p_bins=11,
m_bins=10,
discrete_p_m=True,
exclude_method=['ShearX'],
# hyperparameters for when training the child_network
batch_size=8,
toy_size=0.0001,
learning_rate=1e-1,
max_epochs=float('inf'),
early_stop_num=30,
# evolutionary learner specific settings
num_solutions=3,
num_parents_mating=2,
)
# learn on the 3 policies we generated
learner.learn(
train_dataset=train_dataset,
test_dataset=test_dataset,
child_network_architecture=child_network_architecture,
iterations=2
)
if __name__=="__main__":
test_evo_learner()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment