Skip to content
Snippets Groups Projects
Commit 5c322005 authored by Max Ramsay King's avatar Max Ramsay King
Browse files

fixed evo learner

parent f117547f
No related branches found
No related tags found
No related merge requests found
Pipeline #272061 failed
from cgi import test
import torch
torch.manual_seed(0)
import torch.nn as nn
......@@ -5,43 +6,39 @@ import pygad
import pygad.torchga as torchga
import copy
import torch
from meta_augment.controller_networks.evo_controller import evo_controller
from MetaAugment.autoaugment_learners.aa_learner import aa_learner
from meta_augment.autoaugment_learners.aa_learner import aa_learner
import meta_augment.child_networks as cn
class evo_learner(aa_learner):
def __init__(self,
# search space settings
sp_num=5,
p_bins=10,
m_bins=10,
discrete_p_m=False,
exclude_method=[],
# child network settings
learning_rate=1e-1,
sp_num=1,
num_solutions = 5,
num_parents_mating = 3,
learning_rate = 1e-1,
max_epochs=float('inf'),
early_stop_num=20,
p_bins = 1,
m_bins = 1,
batch_size=8,
toy_size=1,
# evolutionary learner specific settings
num_solutions=5,
num_parents_mating=3,
controller=None
toy_size=0.1,
fun_num = 14,
exclude_method=[],
controller = None
):
super().__init__(
sp_num=sp_num,
p_bins=p_bins,
m_bins=m_bins,
discrete_p_m=discrete_p_m,
batch_size=batch_size,
toy_size=toy_size,
learning_rate=learning_rate,
max_epochs=max_epochs,
early_stop_num=early_stop_num,
exclude_method=exclude_method
)
super().__init__(sp_num,
fun_num,
p_bins,
m_bins,
batch_size=batch_size,
toy_size=toy_size,
learning_rate=learning_rate,
max_epochs=max_epochs,
early_stop_num=early_stop_num,)
self.num_solutions = num_solutions
self.controller = controller
......@@ -51,6 +48,8 @@ class evo_learner(aa_learner):
self.p_bins = p_bins
self.sub_num_pol = sp_num
self.m_bins = m_bins
self.fun_num = fun_num
self.augmentation_space = [x for x in self.augmentation_space if x[0] not in exclude_method]
self.policy_dict = {}
self.policy_result = []
......@@ -58,6 +57,7 @@ class evo_learner(aa_learner):
assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!'
def get_full_policy(self, x):
"""
Generates the full policy (self.num_sub_pol subpolicies). Network architecture requires
......@@ -167,10 +167,10 @@ class evo_learner(aa_learner):
prob2 += torch.sigmoid(y[idx, section+self.fun_num]).item()
if mag1 is not None:
# mag1 += min(max(0, (y[idx, self.auto_aug_agent.fun_num+1]).item()), 8)
mag1 += 10 * torch.sigmoid(y[idx, self.fun_num+1]).item()
mag1 += min(9, 10 * torch.sigmoid(y[idx, self.fun_num+1]).item())
if mag2 is not None:
# mag2 += min(max(0, y[idx, section+self.auto_aug_agent.fun_num+1].item()), 8)
mag2 += 10 * torch.sigmoid(y[idx, self.fun_num+1]).item()
mag2 += min(9, 10 * torch.sigmoid(y[idx, self.fun_num+1]).item())
counter += 1
......@@ -240,7 +240,7 @@ class evo_learner(aa_learner):
self.policy_dict[trans1][trans2].append(new_set)
return False
else:
self.policy_dict[trans1][trans2] = [new_set]
self.policy_dict[trans1] = {trans2: [new_set]}
if trans2 in self.policy_dict:
if trans1 in self.policy_dict[trans2]:
for test_pol in self.policy_dict[trans2][trans1]:
......@@ -249,7 +249,7 @@ class evo_learner(aa_learner):
self.policy_dict[trans2][trans1].append(new_set)
return False
else:
self.policy_dict[trans2][trans1] = [new_set]
self.policy_dict[trans2] = {trans1: [new_set]}
def set_up_instance(self, train_dataset, test_dataset, child_network_architecture):
......@@ -298,11 +298,11 @@ class evo_learner(aa_learner):
if len(self.policy_result) > self.sp_num:
self.policy_result = sorted(self.policy_result, key=lambda x: x[1], reverse=True)
self.policy_result = self.policy_result[:self.sp_num]
print("Appended policy: ", self.policy_result)
print("appended policy: ", self.policy_result)
if fit_val > self.history_best[self.gen_count]:
print("Best policy: ", full_policy)
print("best policy: ", full_policy)
self.history_best[self.gen_count] = fit_val
self.best_model = model_weights_dict
......@@ -335,4 +335,3 @@ class evo_learner(aa_learner):
mutation_percent_genes = 0.1,
fitness_func=fitness_func,
on_generation = on_generation)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment