From d4968b9a501ab8e264a1317ac689be84e9c44633 Mon Sep 17 00:00:00 2001 From: John Carter <jac202@ic.ac.uk> Date: Fri, 22 Apr 2022 11:50:51 +0100 Subject: [PATCH] main.py updated in line with UCB1.py file --- MetaAugment/UCB1_JC_py.py | 4 ++-- MetaAugment/main.py | 17 ++++++++++++++--- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/MetaAugment/UCB1_JC_py.py b/MetaAugment/UCB1_JC_py.py index e75c4853..27322463 100644 --- a/MetaAugment/UCB1_JC_py.py +++ b/MetaAugment/UCB1_JC_py.py @@ -186,8 +186,8 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl cost = nn.CrossEntropyLoss() best_acc = train_child_network(model, train_loader, test_loader, sgd, - cost, max_epochs, early_stop_num, logging=False, - print_every_epoch=False) + cost, max_epochs, early_stop_num, early_stop_flag, + average_validation, logging=False, print_every_epoch=False) # update q_values if policy < num_policies: diff --git a/MetaAugment/main.py b/MetaAugment/main.py index b9642879..61dec50c 100644 --- a/MetaAugment/main.py +++ b/MetaAugment/main.py @@ -41,6 +41,8 @@ def train_child_network(child_network, cost, max_epochs=2000, early_stop_num=10, + early_stop_flag=True, + average_validation=[15,25], logging=False, print_every_epoch=True): if torch.cuda.is_available(): @@ -94,17 +96,26 @@ def train_child_network(child_network, # correct += torch.sum(_.numpy(), axis=-1) _sum += _.shape[0] - # update best validation accuracy if it was higher, otherwise increase early stop count + acc = correct / _sum + if average_validation[0] <= _epoch <= average_validation[1]: + total_val += acc + + # update best validation accuracy if it was higher, otherwise increase early stop count if acc > best_acc : best_acc = acc early_stop_cnt = 0 else: early_stop_cnt += 1 - # exit if validation gets worse over 10 runs - if early_stop_cnt >= early_stop_num: + # exit if validation gets worse over 10 runs and using early stopping + if early_stop_cnt >= early_stop_num and early_stop_flag: + break + + # exit if using fixed epoch length + if _epoch >= average_validation[1] and not early_stop_flag: + best_acc = total_val / (average_validation[1] - average_validation[0] + 1) break if print_every_epoch: -- GitLab