Skip to content
Snippets Groups Projects
Commit e2340569 authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

Merge branch 'master' of gitlab.doc.ic.ac.uk:yw21218/metarl

parents fd35bf0d 32aeeacd
No related branches found
No related tags found
No related merge requests found
Pipeline #272265 failed
...@@ -99,6 +99,10 @@ class aa_learner: ...@@ -99,6 +99,10 @@ class aa_learner:
self.fun_num = len(self.augmentation_space) self.fun_num = len(self.augmentation_space)
self.op_tensor_length = self.fun_num + p_bins + m_bins if discrete_p_m else self.fun_num +2 self.op_tensor_length = self.fun_num + p_bins + m_bins if discrete_p_m else self.fun_num +2
self.num_pols_tested = 0
self.policy_record = {}
def _translate_operation_tensor(self, operation_tensor, return_log_prob=False, argmax=False): def _translate_operation_tensor(self, operation_tensor, return_log_prob=False, argmax=False):
...@@ -300,6 +304,7 @@ class aa_learner: ...@@ -300,6 +304,7 @@ class aa_learner:
self.history.append((policy, reward)) self.history.append((policy, reward))
""" """
def _test_autoaugment_policy(self, def _test_autoaugment_policy(self,
...@@ -329,6 +334,8 @@ class aa_learner: ...@@ -329,6 +334,8 @@ class aa_learner:
accuracy (float): best accuracy reached in any accuracy (float): best accuracy reached in any
""" """
# we create an instance of the child network that we're going # we create an instance of the child network that we're going
# to train. The method of creation depends on the type of # to train. The method of creation depends on the type of
# input we got for child_network_architecture # input we got for child_network_architecture
...@@ -378,8 +385,24 @@ class aa_learner: ...@@ -378,8 +385,24 @@ class aa_learner:
early_stop_num = self.early_stop_num, early_stop_num = self.early_stop_num,
logging = logging, logging = logging,
print_every_epoch=print_every_epoch) print_every_epoch=print_every_epoch)
curr_pol = f'pol{self.num_pols_tested}'
pol_dict = {}
for subpol in policy:
subpol = subpol[0]
first_trans, first_prob, first_mag = subpol[0]
second_trans, second_prob, second_mag = subpol[1]
components = (first_prob, first_mag, second_prob, second_mag)
if second_trans in pol_dict[first_trans]:
pol_dict[first_trans][second_trans].append(components)
else:
pol_dict[first_trans]= {second_trans: [components]}
self.policy_record[curr_pol] = (pol_dict, accuracy)
self.num_pols_tested += 1
# if logging is true, 'accuracy' is actually a tuple: (accuracy, accuracy_log) # if logging is true, 'accuracy' is actually a tuple: (accuracy, accuracy_log)
return accuracy return accuracy
......
...@@ -57,60 +57,11 @@ class evo_learner(aa_learner): ...@@ -57,60 +57,11 @@ class evo_learner(aa_learner):
# store our logs # store our logs
self.policy_dict = {} self.policy_dict = {}
self.policy_result = []
self.running_policy = []
assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!'
def get_full_policy(self, x):
"""
Generates the full policy (self.num_sub_pol subpolicies). Network architecture requires
output size 5 * 2 * (self.fun_num + self.p_bins + self.m_bins)
Parameters
-----------
x -> PyTorch tensor
Input data for network
Returns
----------
full_policy -> [((String, float, float), (String, float, float)), ...)
Full policy consisting of tuples of subpolicies. Each subpolicy consisting of
two transformations, with a probability and magnitude float for each
"""
section = self.fun_num + self.p_bins + self.m_bins
y = self.controller.forward(x)
full_policy = []
for pol in range(self.sp_num):
int_pol = []
for _ in range(2):
idx_ret = torch.argmax(y[:, (pol * section):(pol*section) + self.fun_num].mean(dim = 0))
trans, need_mag = self.augmentation_space[idx_ret] assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!'
if self.p_bins == 1:
p_ret = min(1, max(0, (y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0).item())))
# p_ret = torch.sigmoid(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0))
else:
p_ret = torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0).item()) * 0.1
if need_mag:
# print("original mag", y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0))
if self.m_bins == 1:
mag = min(9, max(0, (y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0).item())))
else:
mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0).item())
mag = int(mag)
else:
mag = None
int_pol.append((trans, p_ret, mag))
full_policy.append(tuple(int_pol))
return full_policy
def get_single_policy_cov(self, x, alpha = 0.5): def get_single_policy_cov(self, x, alpha = 0.5):
...@@ -172,10 +123,8 @@ class evo_learner(aa_learner): ...@@ -172,10 +123,8 @@ class evo_learner(aa_learner):
prob1 += torch.sigmoid(y[idx, self.fun_num]).item() prob1 += torch.sigmoid(y[idx, self.fun_num]).item()
prob2 += torch.sigmoid(y[idx, section+self.fun_num]).item() prob2 += torch.sigmoid(y[idx, section+self.fun_num]).item()
if mag1 is not None: if mag1 is not None:
# mag1 += min(max(0, (y[idx, self.auto_aug_agent.fun_num+1]).item()), 8)
mag1 += min(9, 10 * torch.sigmoid(y[idx, self.fun_num+1]).item()) mag1 += min(9, 10 * torch.sigmoid(y[idx, self.fun_num+1]).item())
if mag2 is not None: if mag2 is not None:
# mag2 += min(max(0, y[idx, section+self.auto_aug_agent.fun_num+1].item()), 8)
mag2 += min(9, 10 * torch.sigmoid(y[idx, self.fun_num+1]).item()) mag2 += min(9, 10 * torch.sigmoid(y[idx, self.fun_num+1]).item())
counter += 1 counter += 1
...@@ -213,19 +162,14 @@ class evo_learner(aa_learner): ...@@ -213,19 +162,14 @@ class evo_learner(aa_learner):
Solution_idx -> Int Solution_idx -> Int
""" """
self.num_generations = iterations self.num_generations = iterations
self.history_best = [0 for i in range(iterations+1)] self.history_best = []
print("itations: ", iterations)
self.history_avg = [0 for i in range(iterations+1)]
self.gen_count = 0 self.gen_count = 0
self.best_model = 0 self.best_model = 0
self.set_up_instance(train_dataset, test_dataset, child_network_architecture) self.set_up_instance(train_dataset, test_dataset, child_network_architecture)
print("train_dataset: ", train_dataset)
self.ga_instance.run() self.ga_instance.run()
self.history_avg = [x / self.num_solutions for x in self.history_avg]
print("-----------------------------------------------------------------------------------------------------")
solution, solution_fitness, solution_idx = self.ga_instance.best_solution() solution, solution_fitness, solution_idx = self.ga_instance.best_solution()
if return_weights: if return_weights:
...@@ -244,18 +188,9 @@ class evo_learner(aa_learner): ...@@ -244,18 +188,9 @@ class evo_learner(aa_learner):
if new_set == test_pol: if new_set == test_pol:
return True return True
self.policy_dict[trans1][trans2].append(new_set) self.policy_dict[trans1][trans2].append(new_set)
return False
else: else:
self.policy_dict[trans1] = {trans2: [new_set]} self.policy_dict[trans1] = {trans2: [new_set]}
if trans2 in self.policy_dict: return False
if trans1 in self.policy_dict[trans2]:
for test_pol in self.policy_dict[trans2][trans1]:
if new_set == test_pol:
return True
self.policy_dict[trans2][trans1].append(new_set)
return False
else:
self.policy_dict[trans2] = {trans1: [new_set]}
def set_up_instance(self, train_dataset, test_dataset, child_network_architecture): def set_up_instance(self, train_dataset, test_dataset, child_network_architecture):
...@@ -287,33 +222,34 @@ class evo_learner(aa_learner): ...@@ -287,33 +222,34 @@ class evo_learner(aa_learner):
self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size) self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size)
for idx, (test_x, label_x) in enumerate(self.train_loader): for idx, (test_x, label_x) in enumerate(self.train_loader):
# if self.sp_num == 1:
full_policy = self.get_single_policy_cov(test_x) full_policy = self.get_single_policy_cov(test_x)
# else:
# full_policy = self.get_full_policy(test_x)
while self.in_pol_dict(full_policy): while self.in_pol_dict(full_policy):
full_policy = self.get_single_policy_cov(test_x)[0] full_policy = self.get_single_policy_cov(test_x)[0]
fit_val = self._test_autoaugment_policy(full_policy,child_network_architecture,train_dataset,test_dataset) #) / fit_val = self._test_autoaugment_policy(full_policy,child_network_architecture,train_dataset,test_dataset)
# + self.test_autoaugment_policy(full_policy, train_dataset, test_dataset)) / 2
self.policy_result.append([full_policy, fit_val]) self.history.append((full_policy, fit_val))
self.running_policy.append((full_policy, fit_val))
if len(self.policy_result) > self.sp_num: if len(self.running_policy) > self.sp_num:
self.policy_result = sorted(self.policy_result, key=lambda x: x[1], reverse=True) self.running_policy = sorted(self.running_policy, key=lambda x: x[1], reverse=True)
self.policy_result = self.policy_result[:self.sp_num] self.running_policy = self.running_policy[:self.sp_num]
print("appended policy: ", self.policy_result) print("appended policy: ", self.running_policy)
if fit_val > self.history_best[self.gen_count]: if len(self.history_best) == 0:
print("best policy: ", full_policy) self.history_best.append((fit_val))
self.history_best[self.gen_count] = fit_val
self.best_model = model_weights_dict self.best_model = model_weights_dict
elif fit_val > self.history_best[-1]:
self.history_best.append(fit_val)
self.best_model = model_weights_dict
else:
self.history_best.append(self.history_best[-1])
self.history_avg[self.gen_count] += fit_val
return fit_val return fit_val
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment