From d4968b9a501ab8e264a1317ac689be84e9c44633 Mon Sep 17 00:00:00 2001
From: John Carter <jac202@ic.ac.uk>
Date: Fri, 22 Apr 2022 11:50:51 +0100
Subject: [PATCH] main.py updated in line with UCB1.py file

---
 MetaAugment/UCB1_JC_py.py |  4 ++--
 MetaAugment/main.py       | 17 ++++++++++++++---
 2 files changed, 16 insertions(+), 5 deletions(-)

diff --git a/MetaAugment/UCB1_JC_py.py b/MetaAugment/UCB1_JC_py.py
index e75c4853..27322463 100644
--- a/MetaAugment/UCB1_JC_py.py
+++ b/MetaAugment/UCB1_JC_py.py
@@ -186,8 +186,8 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl
         cost = nn.CrossEntropyLoss()
 
         best_acc = train_child_network(model, train_loader, test_loader, sgd,
-                         cost, max_epochs, early_stop_num, logging=False,
-                         print_every_epoch=False)
+                         cost, max_epochs, early_stop_num, early_stop_flag,
+			 average_validation, logging=False, print_every_epoch=False)
 
         # update q_values
         if policy < num_policies:
diff --git a/MetaAugment/main.py b/MetaAugment/main.py
index b9642879..61dec50c 100644
--- a/MetaAugment/main.py
+++ b/MetaAugment/main.py
@@ -41,6 +41,8 @@ def train_child_network(child_network,
                         cost,
                         max_epochs=2000,
                         early_stop_num=10,
+                        early_stop_flag=True,
+                        average_validation=[15,25],
                         logging=False,
                         print_every_epoch=True):
     if torch.cuda.is_available():
@@ -94,17 +96,26 @@ def train_child_network(child_network,
                 # correct += torch.sum(_.numpy(), axis=-1)
                 _sum += _.shape[0]
         
-        # update best validation accuracy if it was higher, otherwise increase early stop count
+        
         acc = correct / _sum
 
+        if average_validation[0] <= _epoch <= average_validation[1]:
+            total_val += acc
+
+	# update best validation accuracy if it was higher, otherwise increase early stop count
         if acc > best_acc :
             best_acc = acc
             early_stop_cnt = 0
         else:
             early_stop_cnt += 1
 
-        # exit if validation gets worse over 10 runs
-        if early_stop_cnt >= early_stop_num:
+        # exit if validation gets worse over 10 runs and using early stopping
+        if early_stop_cnt >= early_stop_num and early_stop_flag:
+            break
+
+        # exit if using fixed epoch length
+        if _epoch >= average_validation[1] and not early_stop_flag:
+            best_acc = total_val / (average_validation[1] - average_validation[0] + 1)
             break
         
         if print_every_epoch:
-- 
GitLab