diff --git a/MetaAugment/main.py b/MetaAugment/main.py index 5796c94cb349808ad92f028e2c9318ac3fb27316..af1f311d9b266da4572fbda550a2978342c0aad8 100644 --- a/MetaAugment/main.py +++ b/MetaAugment/main.py @@ -51,8 +51,8 @@ def train_child_network(child_network, device = torch.device('cpu') child_network = child_network.to(device=device) - total_val=torch.tensor([0.0]) - best_acc=torch.tensor([0.0]) + total_val=torch.tensor([0.0]).to(device=device) + best_acc=torch.tensor([0.0]).to(device=device) early_stop_cnt = 0 # logging accuracy for plotting