From 6f3f3748c926e7d910dd4dedcd63bd1004c93f74 Mon Sep 17 00:00:00 2001
From: Sun Jin Kim <sk2521@ic.ac.uk>
Date: Fri, 22 Apr 2022 17:46:34 +0100
Subject: [PATCH] fix main.train_cn

---
 MetaAugment/autoaugment_learners/gru_learner.py | 2 +-
 MetaAugment/main.py                             | 1 +
 2 files changed, 2 insertions(+), 1 deletion(-)

diff --git a/MetaAugment/autoaugment_learners/gru_learner.py b/MetaAugment/autoaugment_learners/gru_learner.py
index 5f8e19be..c06edec3 100644
--- a/MetaAugment/autoaugment_learners/gru_learner.py
+++ b/MetaAugment/autoaugment_learners/gru_learner.py
@@ -61,7 +61,7 @@ class gru_learner(aa_learner):
                 # GRU-specific attributes that aren't in all other aa_learners's
                 alpha=0.2,
                 cont_mb_size=4,
-                cont_lr=0.1):
+                cont_lr=0.03):
         """
         Args:
             alpha (float, optional): Exploration parameter. It is multiplied to 
diff --git a/MetaAugment/main.py b/MetaAugment/main.py
index 61dec50c..5796c94c 100644
--- a/MetaAugment/main.py
+++ b/MetaAugment/main.py
@@ -51,6 +51,7 @@ def train_child_network(child_network,
         device = torch.device('cpu')
     child_network = child_network.to(device=device)
     
+    total_val=torch.tensor([0.0])
     best_acc=torch.tensor([0.0])
     early_stop_cnt = 0
     
-- 
GitLab