Skip to content
Snippets Groups Projects
Commit 246b6352 authored by Max Ramsay King's avatar Max Ramsay King
Browse files

updated evo_learner and added exclude_method to aa_learner

parent d239f59a
No related branches found
No related tags found
No related merge requests found
......@@ -46,7 +46,6 @@ class aa_learner:
def __init__(self,
# parameters that define the search space
sp_num=5,
fun_num=14,
p_bins=11,
m_bins=10,
discrete_p_m=False,
......@@ -57,6 +56,7 @@ class aa_learner:
learning_rate=1e-1,
max_epochs=float('inf'),
early_stop_num=20,
exclude_method = [],
):
"""
Args:
......@@ -84,11 +84,9 @@ class aa_learner:
"""
# related to defining the search space
self.sp_num = sp_num
self.fun_num = fun_num
self.p_bins = p_bins
self.m_bins = m_bins
self.discrete_p_m = discrete_p_m
self.op_tensor_length = fun_num+p_bins+m_bins if discrete_p_m else fun_num+2
# related to training of the child_network
self.batch_size = batch_size
......@@ -101,6 +99,9 @@ class aa_learner:
# TODO: We should probably use a different way to store results than self.history
self.history = []
self.augmentation_space = [x for x in augmentation_space if x not in exclude_method]
self.fun_num = len(augmentation_space)
self.op_tensor_length = self.fun_num +p_bins+m_bins if discrete_p_m else self.fun_num +2
def translate_operation_tensor(self, operation_tensor, return_log_prob=False, argmax=False):
......
......@@ -6,22 +6,21 @@ import pygad
import pygad.torchga as torchga
import copy
import torch
from MetaAugment.controller_networks.evo_controller import evo_controller
from MetaAugment.controller_networks.evo_controller import Evo_learner
from MetaAugment.autoaugment_learners.aa_learner import aa_learner, augmentation_space
import MetaAugment.child_networks as cn
from .aa_learner import aa_learner, augmentation_space
class evo_learner(aa_learner):
def __init__(self,
sp_num=1,
num_solutions = 10,
num_parents_mating = 5,
num_solutions = 5,
num_parents_mating = 3,
learning_rate = 1e-1,
max_epochs=float('inf'),
early_stop_num=20,
train_loader = None,
child_network = None,
p_bins = 1,
m_bins = 1,
discrete_p_m=False,
......@@ -30,6 +29,7 @@ class evo_learner(aa_learner):
toy_size=0.1,
fun_num = 14,
exclude_method=[],
controller = None
):
super().__init__(sp_num,
......@@ -45,24 +45,22 @@ class evo_learner(aa_learner):
early_stop_num=early_stop_num,)
self.num_solutions = num_solutions
self.auto_aug_agent = evo_controller(fun_num=fun_num, p_bins=p_bins, m_bins=m_bins, sub_num_pol=sp_num)
self.torch_ga = torchga.TorchGA(model=self.auto_aug_agent, num_solutions=num_solutions)
self.controller = controller
self.torch_ga = torchga.TorchGA(model=self.controller, num_solutions=num_solutions)
self.num_parents_mating = num_parents_mating
self.initial_population = self.torch_ga.population_weights
self.train_loader = train_loader
self.child_network = child_network
self.p_bins = p_bins
self.sub_num_pol = sp_num
self.m_bins = m_bins
self.fun_num = fun_num
self.augmentation_space = [x for x in augmentation_space if x[0] not in exclude_method]
self.policy_dict = {}
self.policy_result = []
assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!'
def get_full_policy(self, x):
"""
Generates the full policy (self.num_sub_pol subpolicies). Network architecture requires
......@@ -79,8 +77,8 @@ class evo_learner(aa_learner):
Full policy consisting of tuples of subpolicies. Each subpolicy consisting of
two transformations, with a probability and magnitude float for each
"""
section = self.auto_aug_agent.fun_num + self.auto_aug_agent.p_bins + self.auto_aug_agent.m_bins
y = self.auto_aug_agent.forward(x)
section = self.fun_num + self.p_bins + self.m_bins
y = self.controller.forward(x)
full_policy = []
for pol in range(self.sub_num_pol):
int_pol = []
......@@ -89,8 +87,22 @@ class evo_learner(aa_learner):
trans, need_mag = self.augmentation_space[idx_ret]
p_ret = (1/(self.p_bins-1)) * torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0))
mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0)) if need_mag else None
if self.p_bins == 1:
p_ret = min(1, max(0, (y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0).item())))
# p_ret = torch.sigmoid(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0))
else:
p_ret = torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0).item()) * 0.1
if need_mag:
# print("original mag", y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0))
if self.m_bins == 1:
mag = min(9, max(0, (y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0).item())))
else:
mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0).item())
mag = int(mag)
else:
mag = None
int_pol.append((trans, p_ret, mag))
full_policy.append(tuple(int_pol))
......@@ -117,18 +129,18 @@ class evo_learner(aa_learner):
Subpolicy consisting of two tuples of policies, each with a string associated
to a transformation, a float for a probability, and a float for a magnittude
"""
section = self.auto_aug_agent.fun_num + self.auto_aug_agent.p_bins + self.auto_aug_agent.m_bins
section = self.fun_num + self.p_bins + self.m_bins
y = self.auto_aug_agent.forward(x)
y = self.controller.forward(x)
y_1 = torch.softmax(y[:,:self.auto_aug_agent.fun_num], dim = 1)
y[:,:self.auto_aug_agent.fun_num] = y_1
y_2 = torch.softmax(y[:,section:section+self.auto_aug_agent.fun_num], dim = 1)
y[:,section:section+self.auto_aug_agent.fun_num] = y_2
y_1 = torch.softmax(y[:,:self.fun_num], dim = 1)
y[:,:self.fun_num] = y_1
y_2 = torch.softmax(y[:,section:section+self.fun_num], dim = 1)
y[:,section:section+self.fun_num] = y_2
concat = torch.cat((y_1, y_2), dim = 1)
cov_mat = torch.cov(concat.T)
cov_mat = cov_mat[:self.auto_aug_agent.fun_num, self.auto_aug_agent.fun_num:]
cov_mat = cov_mat[:self.fun_num, self.fun_num:]
shape_store = cov_mat.shape
counter, prob1, prob2, mag1, mag2 = (0, 0, 0, 0, 0)
......@@ -154,26 +166,29 @@ class evo_learner(aa_learner):
for idx in range(y.shape[0]):
if (torch.argmax(y_1[idx]) == max_idx[0]) and (torch.argmax(y_2[idx]) == max_idx[1]):
prob1 += torch.sigmoid(y[idx, self.auto_aug_agent.fun_num]).item()
prob2 += torch.sigmoid(y[idx, section+self.auto_aug_agent.fun_num]).item()
prob1 += torch.sigmoid(y[idx, self.fun_num]).item()
prob2 += torch.sigmoid(y[idx, section+self.fun_num]).item()
if mag1 is not None:
mag1 += min(max(0, (y[idx, self.auto_aug_agent.fun_num+1]).item()), 8)
# mag1 += min(max(0, (y[idx, self.auto_aug_agent.fun_num+1]).item()), 8)
mag1 += 10 * torch.sigmoid(y[idx, self.fun_num+1]).item()
if mag2 is not None:
mag2 += min(max(0, y[idx, section+self.auto_aug_agent.fun_num+1].item()), 8)
# mag2 += min(max(0, y[idx, section+self.auto_aug_agent.fun_num+1].item()), 8)
mag2 += 10 * torch.sigmoid(y[idx, self.fun_num+1]).item()
counter += 1
prob1 = prob1/counter if counter != 0 else 0
prob2 = prob2/counter if counter != 0 else 0
prob1 = round(prob1/counter, 1) if counter != 0 else 0
prob2 = round(prob2/counter, 1) if counter != 0 else 0
if mag1 is not None:
mag1 = mag1/counter
mag1 = int(mag1/counter)
if mag2 is not None:
mag2 = mag2/counter
mag2 = int(mag2/counter)
return [(self.augmentation_space[max_idx[0]][0], prob1, mag1), (self.augmentation_space[max_idx[1]][0], prob2, mag2)]
return [((self.augmentation_space[max_idx[0]][0], prob1, mag1), (self.augmentation_space[max_idx[1]][0], prob2, mag2))]
def learn(self, iterations = 15, return_weights = False):
def learn(self, train_dataset, test_dataset, child_network_architecture, iterations = 15, return_weights = False):
"""
Runs the GA instance and returns the model weights as a dictionary
......@@ -195,24 +210,52 @@ class evo_learner(aa_learner):
Solution_idx -> Int
"""
self.num_generations = iterations
self.history_best = [0 for i in range(iterations)]
self.history_avg = [0 for i in range(iterations)]
self.history_best = [0 for i in range(iterations+1)]
print("itations: ", iterations)
self.history_avg = [0 for i in range(iterations+1)]
self.gen_count = 0
self.best_model = 0
self.set_up_instance()
self.set_up_instance(train_dataset, test_dataset, child_network_architecture)
print("train_dataset: ", train_dataset)
self.ga_instance.run()
self.history_avg = self.history_avg / self.num_solutions
self.history_avg = [x / self.num_solutions for x in self.history_avg]
print("-----------------------------------------------------------------------------------------------------")
solution, solution_fitness, solution_idx = self.ga_instance.best_solution()
if return_weights:
return torchga.model_weights_as_dict(model=self.auto_aug_agent, weights_vector=solution)
return torchga.model_weights_as_dict(model=self.controller, weights_vector=solution)
else:
return solution, solution_fitness, solution_idx
def set_up_instance(self, train_dataset, test_dataset):
def in_pol_dict(self, new_policy):
new_policy = new_policy[0]
trans1, trans2 = new_policy[0][0], new_policy[1][0]
new_set = {new_policy[0][1], new_policy[0][2], new_policy[1][1], new_policy[1][2]}
if trans1 in self.policy_dict:
if trans2 in self.policy_dict[trans1]:
for test_pol in self.policy_dict[trans1][trans2]:
if new_set == test_pol:
return True
self.policy_dict[trans1][trans2].append(new_set)
return False
else:
self.policy_dict[trans1][trans2] = [new_set]
if trans2 in self.policy_dict:
if trans1 in self.policy_dict[trans2]:
for test_pol in self.policy_dict[trans2][trans1]:
if new_set == test_pol:
return True
self.policy_dict[trans2][trans1].append(new_set)
return False
else:
self.policy_dict[trans2][trans1] = [new_set]
def set_up_instance(self, train_dataset, test_dataset, child_network_architecture):
"""
Initialises GA instance, as well as fitness and on_generation functions
......@@ -233,24 +276,36 @@ class evo_learner(aa_learner):
fit_val -> float
"""
model_weights_dict = torchga.model_weights_as_dict(model=self.auto_aug_agent,
model_weights_dict = torchga.model_weights_as_dict(model=self.controller,
weights_vector=solution)
self.auto_aug_agent.load_state_dict(model_weights_dict)
self.controller.load_state_dict(model_weights_dict)
self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size)
for idx, (test_x, label_x) in enumerate(self.train_loader):
if self.sp_num == 1:
full_policy = self.get_single_policy_cov(test_x)
else:
full_policy = self.get_full_policy(test_x)
# if self.sp_num == 1:
full_policy = self.get_single_policy_cov(test_x)
# else:
# full_policy = self.get_full_policy(test_x)
while self.in_pol_dict(full_policy):
full_policy = self.get_single_policy_cov(test_x)[0]
# Checkpoint -> save learner as a pickle
fit_val = ((self.test_autoaugment_policy(full_policy, train_dataset, test_dataset)[0]) /
+ self.test_autoaugment_policy(full_policy, train_dataset, test_dataset)[0]) / 2
fit_val = self.test_autoaugment_policy(full_policy,child_network_architecture,train_dataset,test_dataset) #) /
# + self.test_autoaugment_policy(full_policy, train_dataset, test_dataset)) / 2
self.policy_result.append([full_policy, fit_val])
if len(self.policy_result) > self.sp_num:
self.policy_result = sorted(self.policy_result, key=lambda x: x[1], reverse=True)
self.policy_result = self.policy_result[:self.sp_num]
print("Appended policy: ", self.policy_result)
if fit_val > self.history_best[self.gen_count]:
print("Best policy: ", full_policy)
self.history_best[self.gen_count] = fit_val
self.best_model = model_weights_dict
......@@ -284,6 +339,3 @@ class evo_learner(aa_learner):
fitness_func=fitness_func,
on_generation = on_generation)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment