Skip to content
Snippets Groups Projects
Commit 51613552 authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

Merge branch 'master' of gitlab.doc.ic.ac.uk:yw21218/metarl

parents 7935e672 e06e3cc0
Branches
Tags
No related merge requests found
......@@ -13,6 +13,11 @@ import pygad
import pygad.torchga as torchga
import random
import copy
from torchvision.transforms import functional as F, InterpolationMode
from typing import List, Tuple, Optional, Dict
import heapq
# from MetaAugment.main import *
# import MetaAugment.child_networks as child_networks
......@@ -41,10 +46,11 @@ augmentation_space = [
]
class Learner(nn.Module):
def __init__(self, fun_num=14, p_bins=11, m_bins=10):
def __init__(self, fun_num=14, p_bins=11, m_bins=10, sub_num_pol=5):
self.fun_num = fun_num
self.p_bins = p_bins
self.m_bins = m_bins
self.sub_num_pol = sub_num_pol
super().__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
......@@ -57,7 +63,7 @@ class Learner(nn.Module):
self.relu3 = nn.ReLU()
self.fc2 = nn.Linear(120, 84)
self.relu4 = nn.ReLU()
self.fc3 = nn.Linear(84, 5 * 2 * (self.fun_num + self.p_bins + self.m_bins))
self.fc3 = nn.Linear(84, self.sub_num_pol * 2 * (self.fun_num + self.p_bins + self.m_bins))
# Currently using discrete outputs for the probabilities
......@@ -112,36 +118,36 @@ class LeNet(nn.Module):
# code from https://github.com/ChawDoe/LeNet5-MNIST-PyTorch/blob/master/train.py
def train_model(full_policy, child_network):
"""
Takes in the specific transformation index and probability
"""
# def train_model(full_policy, child_network):
# """
# Takes in the specific transformation index and probability
# """
# transformation = generate_policy(5, ps, mags)
# # transformation = generate_policy(5, ps, mags)
train_transform = transforms.Compose([
full_policy,
transforms.ToTensor()
])
# train_transform = transforms.Compose([
# full_policy,
# transforms.ToTensor()
# ])
batch_size = 32
n_samples = 0.005
# batch_size = 32
# n_samples = 0.005
train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, transform=train_transform)
test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False, transform=torchvision.transforms.ToTensor())
# train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, transform=train_transform)
# test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False, transform=torchvision.transforms.ToTensor())
train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, 0.01)
# train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, 0.01)
sgd = optim.SGD(child_network.parameters(), lr=1e-1)
cost = nn.CrossEntropyLoss()
epoch = 20
# sgd = optim.SGD(child_network.parameters(), lr=1e-1)
# cost = nn.CrossEntropyLoss()
# epoch = 20
best_acc = train_child_network(child_network, train_loader, test_loader,
sgd, cost, max_epochs=100, print_every_epoch=False)
# best_acc = train_child_network(child_network, train_loader, test_loader,
# sgd, cost, max_epochs=100, print_every_epoch=False)
return best_acc
# return best_acc
......@@ -168,18 +174,21 @@ train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=600
class Evolutionary_learner():
def __init__(self, network, num_solutions = 30, num_generations = 10, num_parents_mating = 15, train_loader = None, sec_model = None, p_bins = 11, mag_bins = 10, fun_num = 14, augmentation_space = None):
self.meta_rl_agent = Learner(fun_num, p_bins=11, m_bins=10)
def __init__(self, network, num_solutions = 30, num_generations = 10, num_parents_mating = 15, train_loader = None, child_network = None, p_bins = 11, mag_bins = 10, sub_num_pol=5, fun_num = 14, augmentation_space = None, train_dataset = None, test_dataset = None):
self.auto_aug_agent = Learner(fun_num=fun_num, p_bins=p_bins, m_bins=mag_bins, sub_num_pol=sub_num_pol)
self.torch_ga = torchga.TorchGA(model=network, num_solutions=num_solutions)
self.num_generations = num_generations
self.num_parents_mating = num_parents_mating
self.initial_population = self.torch_ga.population_weights
self.train_loader = train_loader
self.sec_model = sec_model
self.child_network = child_network
self.p_bins = p_bins
self.sub_num_pol = sub_num_pol
self.mag_bins = mag_bins
self.fun_num = fun_num
self.augmentation_space = augmentation_space
self.train_dataset = train_dataset
self.test_dataset = test_dataset
assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!'
......@@ -202,30 +211,110 @@ class Evolutionary_learner():
return policies
# Every image has specific operation. Policy for every image (2 (trans., prob., mag) output)
# RNN -> change the end -/- leave for now, ask Javier
# Use mini-batch with current output, get mode transformation -> mean probability and magnitude
# Pass through each image in mini-batch to get one/two (transformation, prob., mag.) tuples
# Average softmax probability (get softmax of the outputs, then average them to get the probability)
# For every batch, store all outputs. Pick top operations
# Every image -> output 2 operation tuples e.g. 14 trans + 1 prob + 1 mag. 32 output total.
# 14 neuron output is then prob. of transformations (softmax + average across dim = 0)
# 1000x28
# Problem 1: have 28, if we pick argmax top 2
# For each image have 28 dim output. Calculate covariance of 1000x28 using np.cov(28_dim_vector.T)
# Give 28x28 covariance matrix. Pick top k pairs (corresponds to largest covariance pairs)
# Once we have pairs, go back to 1000x32 output. Find cases where the largest cov. pairs are used and use those probs and mags
# Covariance matrix -> prob. of occurance (might be bad pairs)
# Pair criteria -> highest softmax prob and probaility of occurence
def get_full_policy(self, x):
"""
Generates the full policy (5 x 2 subpolicies)
"""
section = self.meta_rl_agent.fun_num + self.meta_rl_agent.p_bins + self.meta_rl_agent.m_bins
y = self.meta_rl_agent.forward(x)
section = self.auto_aug_agent.fun_num + self.auto_aug_agent.p_bins + self.auto_aug_agent.m_bins
y = self.auto_aug_agent.forward(x)
full_policy = []
for pol in range(5):
for pol in range(self.sub_num_pol):
int_pol = []
for _ in range(2):
idx_ret = torch.argmax(y[:, (pol * section):(pol*section) + self.fun_num].mean(dim = 0))
trans, need_mag = self.augmentation_space[idx_ret]
p_ret = 0.1 * torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0))
mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0)) if need_mag else 0
p_ret = (1/(self.p_bins-1)) * torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0))
mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0)) if need_mag else None
int_pol.append((trans, p_ret, mag))
full_policy.append(tuple(int_pol))
return full_policy
def get_policy_cov(self, x):
"""
Need p_bins = 1, num_sub_pol = 1, mag_bins = 1
"""
section = self.auto_aug_agent.fun_num + self.auto_aug_agent.p_bins + self.auto_aug_agent.m_bins
y = self.auto_aug_agent.forward(x) # 1000 x 32
y_1 = torch.softmax(y[:,:self.auto_aug_agent.fun_num], dim = 1) # 1000 x 14
y_2 = torch.softmax(y[:,section:section+self.auto_aug_agent.fun_num], dim = 1)
concat = torch.cat((y_1, y_2), dim = 1)
cov_mat = torch.cov(concat.T)#[:self.auto_aug_agent.fun_num, self.auto_aug_agent.fun_num:]
cov_mat = cov_mat[:self.auto_aug_agent.fun_num, self.auto_aug_agent.fun_num:]
shape_store = cov_mat.shape
cov_mat = torch.reshape(cov_mat, (1, -1)).squeeze()
max_idx = torch.argmax(cov_mat)
val = (max_idx//shape_store[0])
max_idx = (val, max_idx - (val * shape_store[0]))
counter, prob1, prob2, mag1, mag2 = (0, 0, 0, 0, 0)
if self.augmentation_space[max_idx[0]]:
mag1 = None
if self.augmentation_space[max_idx[1]]:
mag2 = None
for idx in range(y.shape[0]):
# print("torch.argmax(y_1[idx]): ", torch.argmax(y_1[idx]))
# print("torch.argmax(y_2[idx]): ", torch.argmax(y_2[idx]))
# print("max idx0: ", max_idx[0])
# print("max idx1: ", max_idx[1])
if (torch.argmax(y_1[idx]) == max_idx[0]) and (torch.argmax(y_2[idx]) == max_idx[1]):
prob1 += y[idx, self.auto_aug_agent.fun_num+1]
prob2 += y[idx, section+self.auto_aug_agent.fun_num+1]
if mag1 is not None:
mag1 += y[idx, self.auto_aug_agent.fun_num+2]
if mag2 is not None:
mag2 += y[idx, section+self.auto_aug_agent.fun_num+2]
counter += 1
prob1 = prob1/counter if counter != 0 else 0
prob2 = prob2/counter if counter != 0 else 0
if mag1 is not None:
mag1 = mag1/counter
if mag2 is not None:
mag2 = mag2/counter
return [(self.augmentation_space[max_idx[0]], prob1, mag1), (self.augmentation_space[max_idx[1]], prob2, mag2)]
def run_instance(self, return_weights = False):
"""
......@@ -234,7 +323,7 @@ class Evolutionary_learner():
self.ga_instance.run()
solution, solution_fitness, solution_idx = self.ga_instance.best_solution()
if return_weights:
return torchga.model_weights_as_dict(model=self.meta_rl_agent, weights_vector=solution)
return torchga.model_weights_as_dict(model=self.auto_aug_agent, weights_vector=solution)
else:
return solution, solution_fitness, solution_idx
......@@ -243,7 +332,7 @@ class Evolutionary_learner():
"""
Simple function to create a copy of the secondary model (used for classification)
"""
copy_model = copy.deepcopy(self.sec_model)
copy_model = copy.deepcopy(self.child_network)
return copy_model
......@@ -253,22 +342,30 @@ class Evolutionary_learner():
"""
Defines fitness function (accuracy of the model)
"""
model_weights_dict = torchga.model_weights_as_dict(model=self.meta_rl_agent,
model_weights_dict = torchga.model_weights_as_dict(model=self.auto_aug_agent,
weights_vector=solution)
self.meta_rl_agent.load_state_dict(model_weights_dict)
self.auto_aug_agent.load_state_dict(model_weights_dict)
for idx, (test_x, label_x) in enumerate(train_loader):
full_policy = self.get_full_policy(test_x)
# full_policy = self.get_full_policy(test_x)
full_policy = self.get_policy_cov(test_x)
print("full_policy: ", full_policy)
cop_mod = self.new_model()
fit_val = train_model(full_policy, cop_mod)
fit_val = test_autoaugment_policy(full_policy, self.train_dataset, self.test_dataset)[0]
cop_mod = 0
return fit_val
def on_generation(ga_instance):
"""
Just prints stuff while running
"""
print("Generation = {generation}".format(generation=self.ga_instance.generations_completed))
print("Fitness = {fitness}".format(fitness=self.ga_instance.best_solution()[1]))
print("Generation = {generation}".format(generation=ga_instance.generations_completed))
print("Fitness = {fitness}".format(fitness=ga_instance.best_solution()[1]))
return
......@@ -279,8 +376,13 @@ class Evolutionary_learner():
on_generation = on_generation)
meta_rl_agent = Learner()
ev_learner = Evolutionary_learner(meta_rl_agent, train_loader=train_loader, sec_model=LeNet(), augmentation_space=augmentation_space)
auto_aug_agent = Learner()
ev_learner = Evolutionary_learner(auto_aug_agent, train_loader=train_loader, child_network=LeNet(), augmentation_space=augmentation_space, p_bins=1, mag_bins=1, sub_num_pol=1)
ev_learner.run_instance()
......@@ -288,5 +390,5 @@ solution, solution_fitness, solution_idx = ev_learner.ga_instance.best_solution(
print("Fitness value of the best solution = {solution_fitness}".format(solution_fitness=solution_fitness))
print("Index of the best solution : {solution_idx}".format(solution_idx=solution_idx))
# Fetch the parameters of the best solution.
best_solution_weights = torchga.model_weights_as_dict(model=ev_learner.meta_rl_agent,
best_solution_weights = torchga.model_weights_as_dict(model=ev_learner.auto_aug_agent,
weights_vector=solution)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment