Skip to content
Snippets Groups Projects
Commit 5f05e0d7 authored by Max Ramsay King's avatar Max Ramsay King
Browse files

Updated the ES learner to be more in line with the random search funcitons

parent 08b20c1f
No related branches found
No related tags found
No related merge requests found
......@@ -14,36 +14,56 @@ import pygad.torchga as torchga
import random
import copy
from MetaAugment.main import *
# import MetaAugment.child_networks as child_networks
# from MetaAugment.main import *
# import MetaAugment.child_networks as child_networks
np.random.seed(0)
random.seed(0)
# augmentation_space = [
# # (function_name, do_we_need_to_specify_magnitude)
# ("ShearX", True),
# ("ShearY", True),
# ("TranslateX", True),
# ("TranslateY", True),
# ("Rotate", True),
# ("Brightness", True),
# ("Color", True),
# ("Contrast", True),
# ("Sharpness", True),
# ("Posterize", True),
# ("Solarize", True),
# ("AutoContrast", False),
# ("Equalize", False),
# ("Invert", False),
# ]
augmentation_space = [
# (function_name, do_we_need_to_specify_magnitude)
("ShearX", True),
("ShearY", True),
("TranslateX", True),
("TranslateY", True),
("Rotate", True),
("Brightness", True),
("Color", True),
("Contrast", True),
("Sharpness", True),
("Posterize", True),
("Solarize", True),
("AutoContrast", False),
("Equalize", False),
("Invert", False),
]
class Learner(nn.Module):
def __init__(self, num_transforms = 3):
def __init__(self, fun_num=14, p_bins=11, m_bins=10):
self.fun_num = fun_num
self.p_bins = p_bins
self.m_bins = m_bins
self.augmentation_space = [
# (function_name, do_we_need_to_specify_magnitude)
("ShearX", True),
("ShearY", True),
("TranslateX", True),
("TranslateY", True),
("Rotate", True),
("Brightness", True),
("Color", True),
("Contrast", True),
("Sharpness", True),
("Posterize", True),
("Solarize", True),
("AutoContrast", False),
("Equalize", False),
("Invert", False),
]
super().__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.relu1 = nn.ReLU()
......@@ -55,11 +75,9 @@ class Learner(nn.Module):
self.relu3 = nn.ReLU()
self.fc2 = nn.Linear(120, 84)
self.relu4 = nn.ReLU()
self.fc3 = nn.Linear(84, num_transforms + 21)
# self.sig = nn.Sigmoid()
# Currently using discrete outputs for the probabilities
self.fc3 = nn.Linear(84, 5 * 2 * (self.fun_num + self.p_bins + self.m_bins))
# Currently using discrete outputs for the probabilities
def forward(self, x):
y = self.conv1(x)
......@@ -78,10 +96,22 @@ class Learner(nn.Module):
return y
def get_idx(self, x):
section = self.fun_num + self.p_bins + self.m_bins
y = self.forward(x)
idx_ret = torch.argmax(y[:, 0:3].mean(dim = 0))
p_ret = 0.1 * torch.argmax(y[:, 3:].mean(dim = 0))
return (idx_ret, p_ret)
full_policy = []
for pol in range(5 * 2):
int_pol = []
idx_ret = torch.argmax(y[:, (pol * section):(pol*section) + self.fun_num].mean(dim = 0))
trans, need_mag = self.augmentation_space[idx_ret]
p_ret = 0.1 * torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0))
mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0)) if need_mag else 0
int_pol.append((trans, p_ret, mag))
if pol % 2 != 0:
full_policy.append(tuple(int_pol))
return full_policy
class LeNet(nn.Module):
......@@ -118,44 +148,27 @@ class LeNet(nn.Module):
# code from https://github.com/ChawDoe/LeNet5-MNIST-PyTorch/blob/master/train.py
def train_model(transform_idx, p, child_network):
def train_model(full_policy, child_network):
"""
Takes in the specific transformation index and probability
"""
if transform_idx == 0:
transform_train = torchvision.transforms.Compose(
[
torchvision.transforms.RandomVerticalFlip(p),
torchvision.transforms.ToTensor(),
]
)
elif transform_idx == 1:
transform_train = torchvision.transforms.Compose(
[
torchvision.transforms.RandomHorizontalFlip(p),
torchvision.transforms.ToTensor(),
]
)
else:
transform_train = torchvision.transforms.Compose(
[
torchvision.transforms.RandomGrayscale(p),
torchvision.transforms.ToTensor(),
]
)
# transformation = generate_policy(5, ps, mags)
train_transform = transforms.Compose([
full_policy,
transforms.ToTensor()
])
batch_size = 32
n_samples = 0.005
train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, transform=transform_train)
train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, transform=train_transform)
test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False, transform=torchvision.transforms.ToTensor())
train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, 0.01)
# child_network = child_networks.lenet()
sgd = optim.SGD(child_network.parameters(), lr=1e-1)
cost = nn.CrossEntropyLoss()
epoch = 20
......@@ -191,20 +204,37 @@ train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=600
class Evolutionary_learner():
def __init__(self, network, num_solutions = 30, num_generations = 10, num_parents_mating = 15, train_loader = None, sec_model = None):
self.meta_rl_agent = network
def __init__(self, network, num_solutions = 30, num_generations = 10, num_parents_mating = 15, train_loader = None, sec_model = None, p_bins = 11, mag_bins = 10, fun_num = 14):
self.meta_rl_agent = Learner(fun_num, p_bins=11, m_bins=10)
self.torch_ga = torchga.TorchGA(model=network, num_solutions=num_solutions)
self.num_generations = num_generations
self.num_parents_mating = num_parents_mating
self.initial_population = self.torch_ga.population_weights
self.train_loader = train_loader
self.sec_model = sec_model
self.p_bins = p_bins
self.mag_bins = mag_bins
self.fun_num = fun_num
assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!'
self.set_up_instance()
def generate_policy(self, sp_num, ps, mags):
policies = []
for subpol in range(sp_num):
sub = []
for idx in range(2):
transformation = augmentation_space[(2*subpol) + idx]
p = ps[(2*subpol) + idx]
mag = mags[(2*subpol) + idx]
sub.append((transformation, p, mag))
policies.append(tuple(sub))
return policies
def run_instance(self, return_weights = False):
self.ga_instance.run()
solution, solution_fitness, solution_idx = self.ga_instance.best_solution()
......@@ -213,12 +243,14 @@ class Evolutionary_learner():
else:
return solution, solution_fitness, solution_idx
def new_model(self):
copy_model = copy.deepcopy(self.sec_model)
return copy_model
def set_up_instance(self):
def fitness_func(solution, sol_idx):
"""
Defines fitness function (accuracy of the model)
......@@ -227,9 +259,9 @@ class Evolutionary_learner():
weights_vector=solution)
self.meta_rl_agent.load_state_dict(model_weights_dict)
for idx, (test_x, label_x) in enumerate(train_loader):
trans_idx, p = self.meta_rl_agent.get_idx(test_x)
full_policy = self.meta_rl_agent.get_idx(test_x)
cop_mod = self.new_model()
fit_val = train_model(trans_idx, p, cop_mod)
fit_val = train_model(full_policy, cop_mod)
cop_mod = 0
return fit_val
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment