Newer
Older
import torch
import torch.nn as nn
import pygad
import pygad.torchga as torchga
from MetaAugment.autoaugment_learners.aa_learner import aa_learner
# search space settings
sp_num=5,
discrete_p_m=False,
exclude_method=[],
# child network settings
learning_rate=1e-1,
max_epochs=float('inf'),
early_stop_num=20,
batch_size=8,
toy_size=1,
# evolutionary learner specific settings
num_solutions=5,
num_parents_mating=3,
super().__init__(
sp_num=sp_num,
p_bins=p_bins,
m_bins=m_bins,
discrete_p_m=discrete_p_m,
batch_size=batch_size,
toy_size=toy_size,
learning_rate=learning_rate,
max_epochs=max_epochs,
early_stop_num=early_stop_num,
exclude_method=exclude_method
)
# evolutionary algorithm settings
self.controller = controller(
fun_num=self.fun_num,
p_bins=self.p_bins,
m_bins=self.m_bins,
sub_num_pol=self.sp_num
)
self.torch_ga = torchga.TorchGA(model=self.controller, num_solutions=num_solutions)
self.num_parents_mating = num_parents_mating
self.initial_population = self.torch_ga.population_weights
self.policy_dict = {}
assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!'
"""
Selects policy using population and covariance matrices. For this method
we require p_bins = 1, num_sub_pol = 1, m_bins = 1.
Parameters
------------
x -> PyTorch Tensor
Input data for the AutoAugment network
alpha -> Float
Proportion for covariance and population matrices
Returns
-----------
Subpolicy -> [(String, float, float), (String, float, float)]
Subpolicy consisting of two tuples of policies, each with a string associated
to a transformation, a float for a probability, and a float for a magnittude
"""
section = self.fun_num + self.p_bins + self.m_bins
y = self.controller.forward(x)
y_1 = torch.softmax(y[:,:self.fun_num], dim = 1)
y[:,:self.fun_num] = y_1
y_2 = torch.softmax(y[:,section:section+self.fun_num], dim = 1)
y[:,section:section+self.fun_num] = y_2
concat = torch.cat((y_1, y_2), dim = 1)
cov_mat = cov_mat[:self.fun_num, self.fun_num:]
shape_store = cov_mat.shape
counter, prob1, prob2, mag1, mag2 = (0, 0, 0, 0, 0)
prob_mat = torch.zeros(shape_store)
for idx in range(y.shape[0]):
prob_mat[torch.argmax(y_1[idx])][torch.argmax(y_2[idx])] += 1
prob_mat = prob_mat / torch.sum(prob_mat)
cov_mat = (alpha * cov_mat) + ((1 - alpha)*prob_mat)
cov_mat = torch.reshape(cov_mat, (1, -1)).squeeze()
max_idx = torch.argmax(cov_mat)
val = (max_idx//shape_store[0])
max_idx = (val, max_idx - (val * shape_store[0]))
if not self.augmentation_space[max_idx[0]][1]:
mag1 = None
if not self.augmentation_space[max_idx[1]][1]:
mag2 = None
for idx in range(y.shape[0]):
if (torch.argmax(y_1[idx]) == max_idx[0]) and (torch.argmax(y_2[idx]) == max_idx[1]):
prob1 += torch.sigmoid(y[idx, self.fun_num]).item()
prob2 += torch.sigmoid(y[idx, section+self.fun_num]).item()
mag1 += min(9, 10 * torch.sigmoid(y[idx, self.fun_num+1]).item())
mag2 += min(9, 10 * torch.sigmoid(y[idx, self.fun_num+1]).item())
prob1 = round(prob1/counter, 1) if counter != 0 else 0
prob2 = round(prob2/counter, 1) if counter != 0 else 0
mag1 = int(mag1/counter)
mag2 = int(mag2/counter)
return [((self.augmentation_space[max_idx[0]][0], prob1, mag1), (self.augmentation_space[max_idx[1]][0], prob2, mag2))]
def learn(self, train_dataset, test_dataset, child_network_architecture, iterations = 15, return_weights = False):
"""
Runs the GA instance and returns the model weights as a dictionary
Parameters
------------
return_weights -> Bool
Determines if the weight of the GA network should be returned
Returns
------------
If return_weights:
Network weights -> Dictionary
Else:
Solution -> Best GA instance solution
Solution fitness -> Float
Solution_idx -> Int
"""
self._set_up_instance(train_dataset, test_dataset, child_network_architecture)
solution, solution_fitness, solution_idx = self.ga_instance.best_solution()
if return_weights:
return torchga.model_weights_as_dict(model=self.controller, weights_vector=solution)
else:
return solution, solution_fitness, solution_idx
new_policy = new_policy[0]
trans1, trans2 = new_policy[0][0], new_policy[1][0]
new_set = {new_policy[0][1], new_policy[0][2], new_policy[1][1], new_policy[1][2]}
if trans1 in self.policy_dict:
if trans2 in self.policy_dict[trans1]:
for test_pol in self.policy_dict[trans1][trans2]:
if new_set == test_pol:
return True
self.policy_dict[trans1][trans2].append(new_set)
else:
def _set_up_instance(self, train_dataset, test_dataset, child_network_architecture):
"""
Initialises GA instance, as well as fitness and on_generation functions
"""
def fitness_func(solution, sol_idx):
"""
Defines the fitness function for the parent selection
Parameters
--------------
solution -> GA solution instance (parsed automatically)
sol_idx -> GA solution index (parsed automatically)
Returns
--------------
fit_val -> float
"""
model_weights_dict = torchga.model_weights_as_dict(model=self.controller,
self.controller.load_state_dict(model_weights_dict)
train_dataset.transform = torchvision.transforms.ToTensor()
self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size)
for idx, (test_x, label_x) in enumerate(self.train_loader):
while self._in_pol_dict(full_policy):
full_policy = self._get_single_policy_cov(test_x)[0]
fit_val = self._test_autoaugment_policy(full_policy,child_network_architecture,train_dataset,test_dataset)
self.running_policy.append((full_policy, fit_val))
if len(self.running_policy) > self.sp_num:
self.running_policy = sorted(self.running_policy, key=lambda x: x[1], reverse=True)
self.running_policy = self.running_policy[:self.sp_num]
print("appended policy: ", self.running_policy)
if len(self.history_best) == 0:
self.history_best.append((fit_val))
elif fit_val > self.history_best[-1]:
self.history_best.append(fit_val)
else:
self.history_best.append(self.history_best[-1])
return fit_val
def on_generation(ga_instance):
"""
Prints information of generational fitness
Parameters
-------------
ga_instance -> GA instance
Returns
-------------
None
"""
print("Generation = {generation}".format(generation=ga_instance.generations_completed))
print("Fitness = {fitness}".format(fitness=ga_instance.best_solution()[1]))
return
self.ga_instance = pygad.GA(num_generations=self.num_generations,
num_parents_mating=self.num_parents_mating,
initial_population=self.initial_population,
mutation_percent_genes = 0.1,
fitness_func=fitness_func,
on_generation = on_generation)