Skip to content
Snippets Groups Projects
Commit 86b2bf58 authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

train_child_network in main.py now logs accuracy

parent 124d6915
No related branches found
No related tags found
No related merge requests found
......@@ -411,6 +411,7 @@ class TrivialAugmentWide(torch.nn.Module):
if __name__=='__main__':
import matplotlib.pyplot as plt
from MetaAugment.main import *
import MetaAugment.child_networks as cn
import torchvision.transforms as transforms
......@@ -465,10 +466,14 @@ if __name__=='__main__':
child_network = cn.lenet()
sgd = optim.SGD(child_network.parameters(), lr=1e-1)
best_acc = train_child_network(child_network, train_loader, test_loader, sgd, cost, max_epochs=100)
train_dataset
test_autoaugment_policy(subpolicies1)
test_autoaugment_policy(subpolicies2)
\ No newline at end of file
best_acc, acc_log = train_child_network(child_network, train_loader, test_loader, sgd, cost, max_epochs=100)
return best_acc, acc_log
_, acc_log1 = test_autoaugment_policy(subpolicies1)
_, acc_log2 = test_autoaugment_policy(subpolicies2)
plt.plot(acc_log1, label='subpolicies1')
plt.plot(acc_log2, label='subpolicies2')
plt.xlabel('epochs')
plt.ylabel('accuracy')
plt.legend()
plt.show()
\ No newline at end of file
......@@ -35,6 +35,9 @@ def train_child_network(child_network, train_loader, test_loader, sgd,
best_acc=0
early_stop_cnt = 0
# logging accuracy for plotting
acc_log = []
# train child_network and check validation accuracy each epoch
for _epoch in range(max_epochs):
......@@ -74,8 +77,9 @@ def train_child_network(child_network, train_loader, test_loader, sgd,
break
print('main.train_child_network best accuracy: ', best_acc)
return best_acc
acc_log.append(acc)
return best_acc, acc_log
if __name__=='__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment