Skip to content
Snippets Groups Projects
Commit 9c101635 authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

Add cuda to train_child_network

parent debf9e2d
No related branches found
No related tags found
No related merge requests found
......@@ -39,6 +39,12 @@ def create_toy(train_dataset, test_dataset, batch_size, n_samples, seed=100):
def train_child_network(child_network, train_loader, test_loader, sgd,
cost, max_epochs=2000, early_stop_num = 10, logging=False):
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
child_network = child_network.to(device=device)
best_acc=0
early_stop_cnt = 0
......@@ -51,7 +57,12 @@ def train_child_network(child_network, train_loader, test_loader, sgd,
# train child_network
child_network.train()
for idx, (train_x, train_label) in enumerate(train_loader):
label_np = np.zeros((train_label.shape[0], 10))
# onto device
train_x = train_x.to(device=device, dtype=train_x.dtype)
train_label = train_label.to(device=device, dtype=train_label.dtype)
# label_np = np.zeros((train_label.shape[0], 10))
sgd.zero_grad()
predict_y = child_network(train_x.float())
loss = cost(predict_y, train_label.long())
......@@ -64,11 +75,18 @@ def train_child_network(child_network, train_loader, test_loader, sgd,
child_network.eval()
with torch.no_grad():
for idx, (test_x, test_label) in enumerate(test_loader):
# onto device
test_x = test_x.to(device=device, dtype=test_x.dtype)
test_label = test_label.to(device=device, dtype=test_label.dtype)
predict_y = child_network(test_x.float()).detach()
predict_ys = np.argmax(predict_y, axis=-1)
label_np = test_label.numpy()
predict_ys = torch.argmax(predict_y, axis=-1)
# label_np = test_label.numpy()
_ = predict_ys == test_label
correct += np.sum(_.numpy(), axis=-1)
correct += torch.sum(_, axis=-1)
# correct += torch.sum(_.numpy(), axis=-1)
_sum += _.shape[0]
# update best validation accuracy if it was higher, otherwise increase early stop count
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment