Skip to content
Snippets Groups Projects
Commit f2e6f159 authored by Max Ramsay King's avatar Max Ramsay King
Browse files

added the policy-accuracy record

parent 3bf4ae16
No related branches found
No related tags found
No related merge requests found
Pipeline #272213 failed
......@@ -99,6 +99,8 @@ class aa_learner:
self.fun_num = len(self.augmentation_space)
self.op_tensor_length = self.fun_num + p_bins + m_bins if discrete_p_m else self.fun_num +2
self.num_pols_tested = 0
self.policy_record = {}
def _translate_operation_tensor(self, operation_tensor, return_log_prob=False, argmax=False):
......@@ -329,6 +331,8 @@ class aa_learner:
accuracy (float): best accuracy reached in any
"""
# we create an instance of the child network that we're going
# to train. The method of creation depends on the type of
# input we got for child_network_architecture
......@@ -378,8 +382,24 @@ class aa_learner:
early_stop_num = self.early_stop_num,
logging = logging,
print_every_epoch=print_every_epoch)
curr_pol = f'pol{self.num_pols_tested}'
pol_dict = {}
for subpol in policy:
subpol = subpol[0]
first_trans, first_prob, first_mag = subpol[0]
second_trans, second_prob, second_mag = subpol[1]
components = (first_prob, first_mag, second_prob, second_mag)
if second_trans in pol_dict[first_trans]:
pol_dict[first_trans][second_trans].append(components)
else:
pol_dict[first_trans]= {second_trans: [components]}
self.policy_record[curr_pol] = (pol_dict, accuracy)
self.num_pols_tested += 1
# if logging is true, 'accuracy' is actually a tuple: (accuracy, accuracy_log)
return accuracy
......
......@@ -244,18 +244,9 @@ class evo_learner(aa_learner):
if new_set == test_pol:
return True
self.policy_dict[trans1][trans2].append(new_set)
return False
else:
self.policy_dict[trans1] = {trans2: [new_set]}
if trans2 in self.policy_dict:
if trans1 in self.policy_dict[trans2]:
for test_pol in self.policy_dict[trans2][trans1]:
if new_set == test_pol:
return True
self.policy_dict[trans2][trans1].append(new_set)
return False
else:
self.policy_dict[trans2] = {trans1: [new_set]}
return False
def set_up_instance(self, train_dataset, test_dataset, child_network_architecture):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment