Skip to content
Snippets Groups Projects
Commit 06384f33 authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

add rs_learner and gru_learner unit tests

remove gru pickling

make main.train_child_netowork robust wrt 0

revise aa_learner tutorial
parent 363e8919
No related branches found
No related tags found
No related merge requests found
......@@ -135,7 +135,7 @@ class gru_learner(aa_learner):
This list can then be input into an AutoAugment object
as is done in self.learn()
We return the list and the sum of the log probs
We return a tuple of the list and the sum of the log probs
"""
log_prob = 0
......@@ -218,11 +218,11 @@ class gru_learner(aa_learner):
# minimize it.
self.cont_optim.step()
# save the history every 1 epochs as a pickle
with open('gru_logs.pkl', 'wb') as file:
pickle.dump(self.history, file)
with open('gru_learner.pkl', 'wb') as file:
pickle.dump(self, file)
# # save the history every 1 epochs as a pickle
# with open('gru_logs.pkl', 'wb') as file:
# pickle.dump(self.history, file)
# with open('gru_learner.pkl', 'wb') as file:
# pickle.dump(self, file)
......
......@@ -49,7 +49,7 @@ def train_child_network(child_network,
device = torch.device('cpu')
child_network = child_network.to(device=device)
best_acc=0
best_acc=torch.tensor([0.0])
early_stop_cnt = 0
# logging accuracy for plotting
......
......@@ -30,16 +30,20 @@ Defining the problem setting:
train=True, download=True, transform=None)
test_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/test',
train=False, download=True, transform=torchvision.transforms.ToTensor())
child_network = cn.lenet
child_network_architecture = cn.lenet
.. warning::
In earlier versions, we had to write ``child_network=cn.LeNet``
and not ``child_network=cn.LeNet()``. But now we can do both.
In earlier versions, we had to write ``child_network_architecture=cn.LeNet``
and not ``child_network_architecture=cn.LeNet()``. But now we can do both.
Both types of objects can be input into ``aa_learner.learn()``.
More precisely, the ``child_network_architecture`` parameter has to be either
as ``nn.Module``, a ``function`` which returns a ``nn.Module``, or a ``type``
which inherits ``nn.Module``.
A downside (or maybe the upside??) of doing the latter is that
A downside (or maybe the upside??) of doing one of the latter two is that
the same randomly initialized weights are used for every policy.
Using the random search learner to evaluate randomly generated policies: (You
......@@ -62,7 +66,7 @@ can use any other learner in place of random search learner as well)
)
aa_agent.learn(train_dataset,
test_dataset,
child_network_architecture=child_network,
child_network_architecture=child_network_architecture,
iterations=15000)
You can set further hyperparameters when defining a aa_learner.
......
......@@ -5,7 +5,6 @@ import torchvision
import torchvision.datasets as datasets
import random
from tqdm import trange
def test_translate_operation_tensor():
......@@ -22,7 +21,7 @@ def test_translate_operation_tensor():
# discrete_p_m=True
for i in trange(2000):
for i in range(2000):
softmax = torch.nn.Softmax(dim=0)
......@@ -52,7 +51,7 @@ def test_translate_operation_tensor():
# discrete_p_m=False
softmax = torch.nn.Softmax(dim=0)
sigmoid = torch.nn.Sigmoid()
for i in trange(2000):
for i in range(2000):
fun_num = random.randint(1, 14)
......
import MetaAugment.autoaugment_learners as aal
import MetaAugment.child_networks as cn
import torch
import torchvision
import torchvision.datasets as datasets
import random
def test_generate_new_policy():
"""
make sure gru_learner.generate_new_policy() is robust
with respect to different values of sp_num, fun_num,
p_bins, and m_bins
"""
for _ in range(40):
sp_num = random.randint(1,20)
fun_num = random.randint(1, 14)
p_bins = random.randint(1, 15)
m_bins = random.randint(1, 15)
agent = aal.gru_learner(
sp_num=sp_num,
fun_num=fun_num,
p_bins=p_bins,
m_bins=m_bins
)
for _ in range(10):
new_policy = agent.generate_new_policy()
assert isinstance(new_policy[0], list), new_policy
def test_learn():
"""
tests the gru_learner.learn() method
"""
train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train',
train=True, download=True, transform=None)
test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test',
train=False, download=True,
transform=torchvision.transforms.ToTensor())
child_network_architecture = cn.lenet
# child_network_architecture = cn.lenet()
agent = aal.gru_learner(
sp_num=7,
toy_flag=True,
toy_size=0.001,
batch_size=32,
learning_rate=0.05,
max_epochs=100,
early_stop_num=10,
)
agent.learn(train_dataset,
test_dataset,
child_network_architecture=child_network_architecture,
iterations=2)
import MetaAugment.autoaugment_learners as aal
import MetaAugment.child_networks as cn
import torch
import torchvision
import torchvision.datasets as datasets
import random
def test_generate_new_policy():
"""
make sure randomsearch_learner.generate_new_policy() is robust
with respect to different values of sp_num, fun_num,
p_bins, and m_bins
"""
def my_test(discrete_p_m):
for _ in range(40):
sp_num = random.randint(1,20)
fun_num = random.randint(1, 14)
p_bins = random.randint(1, 15)
m_bins = random.randint(1, 15)
agent = aal.randomsearch_learner(
sp_num=sp_num,
fun_num=fun_num,
p_bins=p_bins,
m_bins=m_bins,
discrete_p_m=discrete_p_m
)
for _ in range(10):
new_policy = agent.generate_new_policy()
assert isinstance(new_policy, list), new_policy
discrete_p_m = True
my_test(discrete_p_m)
discrete_p_m = False
my_test(discrete_p_m)
def test_learn():
"""
tests the randomsearch_learner.learn() method
"""
train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train',
train=True, download=True, transform=None)
test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test',
train=False, download=True,
transform=torchvision.transforms.ToTensor())
child_network_architecture = cn.lenet
# child_network_architecture = cn.lenet()
agent = aal.randomsearch_learner(
sp_num=7,
toy_flag=True,
toy_size=0.001,
batch_size=32,
learning_rate=0.05,
max_epochs=100,
early_stop_num=10,
)
agent.learn(train_dataset,
test_dataset,
child_network_architecture=child_network_architecture,
iterations=2)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment