Skip to content
Snippets Groups Projects
Commit e08b708c authored by Max Ramsay King's avatar Max Ramsay King
Browse files

Made the get_full_policy function a method of the ES class rather than a feature of the network

parent 799152f6
No related branches found
No related tags found
No related merge requests found
......@@ -46,24 +46,6 @@ class Learner(nn.Module):
self.p_bins = p_bins
self.m_bins = m_bins
self.augmentation_space = [
# (function_name, do_we_need_to_specify_magnitude)
("ShearX", True),
("ShearY", True),
("TranslateX", True),
("TranslateY", True),
("Rotate", True),
("Brightness", True),
("Color", True),
("Contrast", True),
("Sharpness", True),
("Posterize", True),
("Solarize", True),
("AutoContrast", False),
("Equalize", False),
("Invert", False),
]
super().__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.relu1 = nn.ReLU()
......@@ -95,24 +77,6 @@ class Learner(nn.Module):
return y
def get_idx(self, x):
section = self.fun_num + self.p_bins + self.m_bins
y = self.forward(x)
full_policy = []
for pol in range(5 * 2):
int_pol = []
idx_ret = torch.argmax(y[:, (pol * section):(pol*section) + self.fun_num].mean(dim = 0))
trans, need_mag = self.augmentation_space[idx_ret]
p_ret = 0.1 * torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0))
mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0)) if need_mag else 0
int_pol.append((trans, p_ret, mag))
if pol % 2 != 0:
full_policy.append(tuple(int_pol))
return full_policy
class LeNet(nn.Module):
def __init__(self):
......@@ -204,7 +168,7 @@ train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=600
class Evolutionary_learner():
def __init__(self, network, num_solutions = 30, num_generations = 10, num_parents_mating = 15, train_loader = None, sec_model = None, p_bins = 11, mag_bins = 10, fun_num = 14):
def __init__(self, network, num_solutions = 30, num_generations = 10, num_parents_mating = 15, train_loader = None, sec_model = None, p_bins = 11, mag_bins = 10, fun_num = 14, augmentation_space = None):
self.meta_rl_agent = Learner(fun_num, p_bins=11, m_bins=10)
self.torch_ga = torchga.TorchGA(model=network, num_solutions=num_solutions)
self.num_generations = num_generations
......@@ -215,6 +179,7 @@ class Evolutionary_learner():
self.p_bins = p_bins
self.mag_bins = mag_bins
self.fun_num = fun_num
self.augmentation_space = augmentation_space
assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!'
......@@ -222,6 +187,9 @@ class Evolutionary_learner():
def generate_policy(self, sp_num, ps, mags):
"""
"""
policies = []
for subpol in range(sp_num):
sub = []
......@@ -235,7 +203,33 @@ class Evolutionary_learner():
return policies
def get_full_policy(self, x):
"""
Generates the full policy (5 x 2 subpolicies)
"""
section = self.meta_rl_agent.fun_num + self.meta_rl_agent.p_bins + self.meta_rl_agent.m_bins
y = self.meta_rl_agent.forward(x)
full_policy = []
for pol in range(5):
int_pol = []
for _ in range(2):
idx_ret = torch.argmax(y[:, (pol * section):(pol*section) + self.fun_num].mean(dim = 0))
trans, need_mag = self.augmentation_space[idx_ret]
p_ret = 0.1 * torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0))
mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0)) if need_mag else 0
int_pol.append((trans, p_ret, mag))
full_policy.append(tuple(int_pol))
return full_policy
def run_instance(self, return_weights = False):
"""
Runs the GA instance and returns the model weights as a dictionary
"""
self.ga_instance.run()
solution, solution_fitness, solution_idx = self.ga_instance.best_solution()
if return_weights:
......@@ -245,6 +239,9 @@ class Evolutionary_learner():
def new_model(self):
"""
Simple function to create a copy of the secondary model (used for classification)
"""
copy_model = copy.deepcopy(self.sec_model)
return copy_model
......@@ -259,7 +256,7 @@ class Evolutionary_learner():
weights_vector=solution)
self.meta_rl_agent.load_state_dict(model_weights_dict)
for idx, (test_x, label_x) in enumerate(train_loader):
full_policy = self.meta_rl_agent.get_idx(test_x)
full_policy = self.meta_rl_agent.get_full_policy(test_x)
cop_mod = self.new_model()
fit_val = train_model(full_policy, cop_mod)
cop_mod = 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment