Skip to content
Snippets Groups Projects
Commit 5c322005 authored by Max Ramsay King's avatar Max Ramsay King
Browse files

fixed evo learner

parent f117547f
No related branches found
No related tags found
No related merge requests found
Checking pipeline status
from cgi import test
import torch import torch
torch.manual_seed(0) torch.manual_seed(0)
import torch.nn as nn import torch.nn as nn
...@@ -5,43 +6,39 @@ import pygad ...@@ -5,43 +6,39 @@ import pygad
import pygad.torchga as torchga import pygad.torchga as torchga
import copy import copy
import torch import torch
from meta_augment.controller_networks.evo_controller import evo_controller
from MetaAugment.autoaugment_learners.aa_learner import aa_learner from meta_augment.autoaugment_learners.aa_learner import aa_learner
import meta_augment.child_networks as cn
class evo_learner(aa_learner): class evo_learner(aa_learner):
def __init__(self, def __init__(self,
# search space settings sp_num=1,
sp_num=5, num_solutions = 5,
p_bins=10, num_parents_mating = 3,
m_bins=10, learning_rate = 1e-1,
discrete_p_m=False,
exclude_method=[],
# child network settings
learning_rate=1e-1,
max_epochs=float('inf'), max_epochs=float('inf'),
early_stop_num=20, early_stop_num=20,
p_bins = 1,
m_bins = 1,
batch_size=8, batch_size=8,
toy_size=1, toy_size=0.1,
# evolutionary learner specific settings fun_num = 14,
num_solutions=5, exclude_method=[],
num_parents_mating=3, controller = None
controller=None
): ):
super().__init__( super().__init__(sp_num,
sp_num=sp_num, fun_num,
p_bins=p_bins, p_bins,
m_bins=m_bins, m_bins,
discrete_p_m=discrete_p_m, batch_size=batch_size,
batch_size=batch_size, toy_size=toy_size,
toy_size=toy_size, learning_rate=learning_rate,
learning_rate=learning_rate, max_epochs=max_epochs,
max_epochs=max_epochs, early_stop_num=early_stop_num,)
early_stop_num=early_stop_num,
exclude_method=exclude_method
)
self.num_solutions = num_solutions self.num_solutions = num_solutions
self.controller = controller self.controller = controller
...@@ -51,6 +48,8 @@ class evo_learner(aa_learner): ...@@ -51,6 +48,8 @@ class evo_learner(aa_learner):
self.p_bins = p_bins self.p_bins = p_bins
self.sub_num_pol = sp_num self.sub_num_pol = sp_num
self.m_bins = m_bins self.m_bins = m_bins
self.fun_num = fun_num
self.augmentation_space = [x for x in self.augmentation_space if x[0] not in exclude_method]
self.policy_dict = {} self.policy_dict = {}
self.policy_result = [] self.policy_result = []
...@@ -58,6 +57,7 @@ class evo_learner(aa_learner): ...@@ -58,6 +57,7 @@ class evo_learner(aa_learner):
assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!' assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!'
def get_full_policy(self, x): def get_full_policy(self, x):
""" """
Generates the full policy (self.num_sub_pol subpolicies). Network architecture requires Generates the full policy (self.num_sub_pol subpolicies). Network architecture requires
...@@ -167,10 +167,10 @@ class evo_learner(aa_learner): ...@@ -167,10 +167,10 @@ class evo_learner(aa_learner):
prob2 += torch.sigmoid(y[idx, section+self.fun_num]).item() prob2 += torch.sigmoid(y[idx, section+self.fun_num]).item()
if mag1 is not None: if mag1 is not None:
# mag1 += min(max(0, (y[idx, self.auto_aug_agent.fun_num+1]).item()), 8) # mag1 += min(max(0, (y[idx, self.auto_aug_agent.fun_num+1]).item()), 8)
mag1 += 10 * torch.sigmoid(y[idx, self.fun_num+1]).item() mag1 += min(9, 10 * torch.sigmoid(y[idx, self.fun_num+1]).item())
if mag2 is not None: if mag2 is not None:
# mag2 += min(max(0, y[idx, section+self.auto_aug_agent.fun_num+1].item()), 8) # mag2 += min(max(0, y[idx, section+self.auto_aug_agent.fun_num+1].item()), 8)
mag2 += 10 * torch.sigmoid(y[idx, self.fun_num+1]).item() mag2 += min(9, 10 * torch.sigmoid(y[idx, self.fun_num+1]).item())
counter += 1 counter += 1
...@@ -240,7 +240,7 @@ class evo_learner(aa_learner): ...@@ -240,7 +240,7 @@ class evo_learner(aa_learner):
self.policy_dict[trans1][trans2].append(new_set) self.policy_dict[trans1][trans2].append(new_set)
return False return False
else: else:
self.policy_dict[trans1][trans2] = [new_set] self.policy_dict[trans1] = {trans2: [new_set]}
if trans2 in self.policy_dict: if trans2 in self.policy_dict:
if trans1 in self.policy_dict[trans2]: if trans1 in self.policy_dict[trans2]:
for test_pol in self.policy_dict[trans2][trans1]: for test_pol in self.policy_dict[trans2][trans1]:
...@@ -249,7 +249,7 @@ class evo_learner(aa_learner): ...@@ -249,7 +249,7 @@ class evo_learner(aa_learner):
self.policy_dict[trans2][trans1].append(new_set) self.policy_dict[trans2][trans1].append(new_set)
return False return False
else: else:
self.policy_dict[trans2][trans1] = [new_set] self.policy_dict[trans2] = {trans1: [new_set]}
def set_up_instance(self, train_dataset, test_dataset, child_network_architecture): def set_up_instance(self, train_dataset, test_dataset, child_network_architecture):
...@@ -298,11 +298,11 @@ class evo_learner(aa_learner): ...@@ -298,11 +298,11 @@ class evo_learner(aa_learner):
if len(self.policy_result) > self.sp_num: if len(self.policy_result) > self.sp_num:
self.policy_result = sorted(self.policy_result, key=lambda x: x[1], reverse=True) self.policy_result = sorted(self.policy_result, key=lambda x: x[1], reverse=True)
self.policy_result = self.policy_result[:self.sp_num] self.policy_result = self.policy_result[:self.sp_num]
print("Appended policy: ", self.policy_result) print("appended policy: ", self.policy_result)
if fit_val > self.history_best[self.gen_count]: if fit_val > self.history_best[self.gen_count]:
print("Best policy: ", full_policy) print("best policy: ", full_policy)
self.history_best[self.gen_count] = fit_val self.history_best[self.gen_count] = fit_val
self.best_model = model_weights_dict self.best_model = model_weights_dict
...@@ -335,4 +335,3 @@ class evo_learner(aa_learner): ...@@ -335,4 +335,3 @@ class evo_learner(aa_learner):
mutation_percent_genes = 0.1, mutation_percent_genes = 0.1,
fitness_func=fitness_func, fitness_func=fitness_func,
on_generation = on_generation) on_generation = on_generation)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment