diff --git a/MetaAugment/main.py b/MetaAugment/main.py
index 5796c94cb349808ad92f028e2c9318ac3fb27316..af1f311d9b266da4572fbda550a2978342c0aad8 100644
--- a/MetaAugment/main.py
+++ b/MetaAugment/main.py
@@ -51,8 +51,8 @@ def train_child_network(child_network,
         device = torch.device('cpu')
     child_network = child_network.to(device=device)
     
-    total_val=torch.tensor([0.0])
-    best_acc=torch.tensor([0.0])
+    total_val=torch.tensor([0.0]).to(device=device)
+    best_acc=torch.tensor([0.0]).to(device=device)
     early_stop_cnt = 0
     
     # logging accuracy for plotting