Skip to content
Snippets Groups Projects
Commit d4968b9a authored by John Carter's avatar John Carter
Browse files

main.py updated in line with UCB1.py file

parent 1a815f18
No related branches found
No related tags found
No related merge requests found
...@@ -186,8 +186,8 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl ...@@ -186,8 +186,8 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl
cost = nn.CrossEntropyLoss() cost = nn.CrossEntropyLoss()
best_acc = train_child_network(model, train_loader, test_loader, sgd, best_acc = train_child_network(model, train_loader, test_loader, sgd,
cost, max_epochs, early_stop_num, logging=False, cost, max_epochs, early_stop_num, early_stop_flag,
print_every_epoch=False) average_validation, logging=False, print_every_epoch=False)
# update q_values # update q_values
if policy < num_policies: if policy < num_policies:
......
...@@ -41,6 +41,8 @@ def train_child_network(child_network, ...@@ -41,6 +41,8 @@ def train_child_network(child_network,
cost, cost,
max_epochs=2000, max_epochs=2000,
early_stop_num=10, early_stop_num=10,
early_stop_flag=True,
average_validation=[15,25],
logging=False, logging=False,
print_every_epoch=True): print_every_epoch=True):
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -94,17 +96,26 @@ def train_child_network(child_network, ...@@ -94,17 +96,26 @@ def train_child_network(child_network,
# correct += torch.sum(_.numpy(), axis=-1) # correct += torch.sum(_.numpy(), axis=-1)
_sum += _.shape[0] _sum += _.shape[0]
# update best validation accuracy if it was higher, otherwise increase early stop count
acc = correct / _sum acc = correct / _sum
if average_validation[0] <= _epoch <= average_validation[1]:
total_val += acc
# update best validation accuracy if it was higher, otherwise increase early stop count
if acc > best_acc : if acc > best_acc :
best_acc = acc best_acc = acc
early_stop_cnt = 0 early_stop_cnt = 0
else: else:
early_stop_cnt += 1 early_stop_cnt += 1
# exit if validation gets worse over 10 runs # exit if validation gets worse over 10 runs and using early stopping
if early_stop_cnt >= early_stop_num: if early_stop_cnt >= early_stop_num and early_stop_flag:
break
# exit if using fixed epoch length
if _epoch >= average_validation[1] and not early_stop_flag:
best_acc = total_val / (average_validation[1] - average_validation[0] + 1)
break break
if print_every_epoch: if print_every_epoch:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment