Skip to content
Snippets Groups Projects
Commit 620ddf58 authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

update /test

parent f007b3f3
No related branches found
No related tags found
No related merge requests found
[pytest]
filterwarnings =
error
ignore::UserWarning
ignore:function ham\(\) is deprecated:DeprecationWarning
\ No newline at end of file
...@@ -22,7 +22,8 @@ def test_generate_new_policy(): ...@@ -22,7 +22,8 @@ def test_generate_new_policy():
sp_num=sp_num, sp_num=sp_num,
fun_num=fun_num, fun_num=fun_num,
p_bins=p_bins, p_bins=p_bins,
m_bins=m_bins m_bins=m_bins,
cont_mb_size=2
) )
for _ in range(4): for _ in range(4):
new_policy = agent.generate_new_policy() new_policy = agent.generate_new_policy()
......
import torch
import torchvision
import torchvision.datasets as datasets
import MetaAugment.autoaugment_learners as aal
import MetaAugment.child_networks as cn
import MetaAugment.main as main
def test_create_toy():
train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train',
train=True, download=True, transform=None)
test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test',
train=False, download=True,
transform=torchvision.transforms.ToTensor())
for _ in range(20):
train_loader, test_loader = main.create_toy(train_dataset, test_dataset,
batch_size=32, n_samples=1)
p = torch.rand_like(torch.tensor([.0]))
train_loader, test_loader = main.create_toy(train_dataset, test_dataset,
batch_size=32, n_samples=p)
train_dataset = datasets.CIFAR10(root='./datasets/cifar10/train',
train=True, download=True, transform=None)
test_dataset = datasets.CIFAR10(root='./datasets/cifar10/train',
train=False, download=True,
transform=torchvision.transforms.ToTensor())
for _ in range(20):
train_loader, test_loader = main.create_toy(train_dataset, test_dataset,
batch_size=32, n_samples=1)
p = torch.rand_like(torch.tensor([.0]))
train_loader, test_loader = main.create_toy(train_dataset, test_dataset,
batch_size=32, n_samples=p)
def test_train_cn():
train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train',
train=True, download=True,
transform=torchvision.transforms.ToTensor())
test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test',
train=False, download=True,
transform=torchvision.transforms.ToTensor())
cn_architecture = cn.Bad_LeNet
model = cn_architecture()
train_loader, test_loader = main.create_toy(train_dataset, test_dataset,
batch_size=32, n_samples=0.01)
main.train_child_network(
model,
train_loader,
test_loader,
sgd=torch.optim.SGD(model.parameters(),lr=0.1),
cost=torch.nn.CrossEntropyLoss(),
early_stop_flag=True
)
main.train_child_network(
model,
train_loader,
test_loader,
sgd=torch.optim.SGD(model.parameters(),lr=0.1),
cost=torch.nn.CrossEntropyLoss(),
early_stop_flag=False
)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment