Skip to content
Snippets Groups Projects
Commit e75c71a4 authored by Max Ramsay King's avatar Max Ramsay King
Browse files

Alternative genetic algorithm

parent bbccc6d8
No related branches found
No related tags found
No related merge requests found
...@@ -22,6 +22,24 @@ np.random.seed(0) ...@@ -22,6 +22,24 @@ np.random.seed(0)
random.seed(0) random.seed(0)
# augmentation_space = [
# # (function_name, do_we_need_to_specify_magnitude)
# ("ShearX", True),
# ("ShearY", True),
# ("TranslateX", True),
# ("TranslateY", True),
# ("Rotate", True),
# ("Brightness", True),
# ("Color", True),
# ("Contrast", True),
# ("Sharpness", True),
# ("Posterize", True),
# ("Solarize", True),
# ("AutoContrast", False),
# ("Equalize", False),
# ("Invert", False),
# ]
class Learner(nn.Module): class Learner(nn.Module):
def __init__(self, num_transforms = 3): def __init__(self, num_transforms = 3):
super().__init__() super().__init__()
...@@ -38,6 +56,7 @@ class Learner(nn.Module): ...@@ -38,6 +56,7 @@ class Learner(nn.Module):
self.fc3 = nn.Linear(84, 13) self.fc3 = nn.Linear(84, 13)
# self.sig = nn.Sigmoid() # self.sig = nn.Sigmoid()
def forward(self, x): def forward(self, x):
y = self.conv1(x) y = self.conv1(x)
y = self.relu1(y) y = self.relu1(y)
...@@ -60,7 +79,6 @@ class Learner(nn.Module): ...@@ -60,7 +79,6 @@ class Learner(nn.Module):
p_ret = 0.1 * torch.argmax(y[:, 3:].mean(dim = 0)) p_ret = 0.1 * torch.argmax(y[:, 3:].mean(dim = 0))
return (idx_ret, p_ret) return (idx_ret, p_ret)
# return (torch.argmax(y[0:3]), y[torch.argmax(y[3:])])
class LeNet(nn.Module): class LeNet(nn.Module):
def __init__(self): def __init__(self):
...@@ -253,7 +271,7 @@ class Evolutionary_learner(): ...@@ -253,7 +271,7 @@ class Evolutionary_learner():
self.num_parents_mating = num_parents_mating self.num_parents_mating = num_parents_mating
self.initial_population = self.torch_ga.population_weights self.initial_population = self.torch_ga.population_weights
self.train_loader = train_loader self.train_loader = train_loader
self.backup_model = sec_model self.sec_model = sec_model
assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!' assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!'
...@@ -269,7 +287,7 @@ class Evolutionary_learner(): ...@@ -269,7 +287,7 @@ class Evolutionary_learner():
return solution, solution_fitness, solution_idx return solution, solution_fitness, solution_idx
def new_model(self): def new_model(self):
copy_model = copy.deepcopy(self.backup_model) copy_model = copy.deepcopy(self.sec_model)
return copy_model return copy_model
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment