From 5c4e5a1818689446bca70729c525ad93e4048dd5 Mon Sep 17 00:00:00 2001
From: Sun Jin Kim <sk2521@ic.ac.uk>
Date: Fri, 22 Apr 2022 21:27:09 +0100
Subject: [PATCH] main.py: add some tensors to device

---
 MetaAugment/main.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/MetaAugment/main.py b/MetaAugment/main.py
index 5796c94c..af1f311d 100644
--- a/MetaAugment/main.py
+++ b/MetaAugment/main.py
@@ -51,8 +51,8 @@ def train_child_network(child_network,
         device = torch.device('cpu')
     child_network = child_network.to(device=device)
     
-    total_val=torch.tensor([0.0])
-    best_acc=torch.tensor([0.0])
+    total_val=torch.tensor([0.0]).to(device=device)
+    best_acc=torch.tensor([0.0]).to(device=device)
     early_stop_cnt = 0
     
     # logging accuracy for plotting
-- 
GitLab