Skip to content
Snippets Groups Projects
Commit 2a450edb authored by Max Ramsay King's avatar Max Ramsay King
Browse files

Using test_autoaugment_policy to calcualte the accuracy of subpolicies. ES...

Using test_autoaugment_policy to calcualte the accuracy of subpolicies. ES learner now outputs similar list to randomsearcher
parent f5a10e7e
No related branches found
No related tags found
No related merge requests found
...@@ -13,6 +13,10 @@ import pygad ...@@ -13,6 +13,10 @@ import pygad
import pygad.torchga as torchga import pygad.torchga as torchga
import random import random
import copy import copy
from torchvision.transforms import functional as F, InterpolationMode
from typing import List, Tuple, Optional, Dict
# from MetaAugment.main import * # from MetaAugment.main import *
# import MetaAugment.child_networks as child_networks # import MetaAugment.child_networks as child_networks
...@@ -112,36 +116,36 @@ class LeNet(nn.Module): ...@@ -112,36 +116,36 @@ class LeNet(nn.Module):
# code from https://github.com/ChawDoe/LeNet5-MNIST-PyTorch/blob/master/train.py # code from https://github.com/ChawDoe/LeNet5-MNIST-PyTorch/blob/master/train.py
def train_model(full_policy, child_network): # def train_model(full_policy, child_network):
""" # """
Takes in the specific transformation index and probability # Takes in the specific transformation index and probability
""" # """
# transformation = generate_policy(5, ps, mags) # # transformation = generate_policy(5, ps, mags)
train_transform = transforms.Compose([ # train_transform = transforms.Compose([
full_policy, # full_policy,
transforms.ToTensor() # transforms.ToTensor()
]) # ])
batch_size = 32 # batch_size = 32
n_samples = 0.005 # n_samples = 0.005
train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, transform=train_transform) # train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, transform=train_transform)
test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False, transform=torchvision.transforms.ToTensor()) # test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False, transform=torchvision.transforms.ToTensor())
train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, 0.01) # train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, 0.01)
sgd = optim.SGD(child_network.parameters(), lr=1e-1) # sgd = optim.SGD(child_network.parameters(), lr=1e-1)
cost = nn.CrossEntropyLoss() # cost = nn.CrossEntropyLoss()
epoch = 20 # epoch = 20
best_acc = train_child_network(child_network, train_loader, test_loader, # best_acc = train_child_network(child_network, train_loader, test_loader,
sgd, cost, max_epochs=100, print_every_epoch=False) # sgd, cost, max_epochs=100, print_every_epoch=False)
return best_acc # return best_acc
...@@ -168,7 +172,7 @@ train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=600 ...@@ -168,7 +172,7 @@ train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=600
class Evolutionary_learner(): class Evolutionary_learner():
def __init__(self, network, num_solutions = 30, num_generations = 10, num_parents_mating = 15, train_loader = None, sec_model = None, p_bins = 11, mag_bins = 10, fun_num = 14, augmentation_space = None): def __init__(self, network, num_solutions = 30, num_generations = 10, num_parents_mating = 15, train_loader = None, sec_model = None, p_bins = 11, mag_bins = 10, fun_num = 14, augmentation_space = None, train_dataset = None, test_dataset = None):
self.meta_rl_agent = Learner(fun_num, p_bins=11, m_bins=10) self.meta_rl_agent = Learner(fun_num, p_bins=11, m_bins=10)
self.torch_ga = torchga.TorchGA(model=network, num_solutions=num_solutions) self.torch_ga = torchga.TorchGA(model=network, num_solutions=num_solutions)
self.num_generations = num_generations self.num_generations = num_generations
...@@ -180,6 +184,8 @@ class Evolutionary_learner(): ...@@ -180,6 +184,8 @@ class Evolutionary_learner():
self.mag_bins = mag_bins self.mag_bins = mag_bins
self.fun_num = fun_num self.fun_num = fun_num
self.augmentation_space = augmentation_space self.augmentation_space = augmentation_space
self.train_dataset = train_dataset
self.test_dataset = test_dataset
assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!' assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!'
...@@ -219,7 +225,7 @@ class Evolutionary_learner(): ...@@ -219,7 +225,7 @@ class Evolutionary_learner():
trans, need_mag = self.augmentation_space[idx_ret] trans, need_mag = self.augmentation_space[idx_ret]
p_ret = 0.1 * torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0)) p_ret = 0.1 * torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0))
mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0)) if need_mag else 0 mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0)) if need_mag else None
int_pol.append((trans, p_ret, mag)) int_pol.append((trans, p_ret, mag))
full_policy.append(tuple(int_pol)) full_policy.append(tuple(int_pol))
...@@ -253,22 +259,29 @@ class Evolutionary_learner(): ...@@ -253,22 +259,29 @@ class Evolutionary_learner():
""" """
Defines fitness function (accuracy of the model) Defines fitness function (accuracy of the model)
""" """
model_weights_dict = torchga.model_weights_as_dict(model=self.meta_rl_agent, model_weights_dict = torchga.model_weights_as_dict(model=self.meta_rl_agent,
weights_vector=solution) weights_vector=solution)
self.meta_rl_agent.load_state_dict(model_weights_dict) self.meta_rl_agent.load_state_dict(model_weights_dict)
for idx, (test_x, label_x) in enumerate(train_loader): for idx, (test_x, label_x) in enumerate(train_loader):
full_policy = self.get_full_policy(test_x) full_policy = self.get_full_policy(test_x)
cop_mod = self.new_model() cop_mod = self.new_model()
fit_val = train_model(full_policy, cop_mod)
fit_val = test_autoaugment_policy(full_policy, self.train_dataset, self.test_dataset)[0]
cop_mod = 0 cop_mod = 0
return fit_val return fit_val
def on_generation(ga_instance): def on_generation(ga_instance):
""" """
Just prints stuff while running Just prints stuff while running
""" """
print("Generation = {generation}".format(generation=self.ga_instance.generations_completed)) print("Generation = {generation}".format(generation=ga_instance.generations_completed))
print("Fitness = {fitness}".format(fitness=self.ga_instance.best_solution()[1])) print("Fitness = {fitness}".format(fitness=ga_instance.best_solution()[1]))
return return
...@@ -279,6 +292,11 @@ class Evolutionary_learner(): ...@@ -279,6 +292,11 @@ class Evolutionary_learner():
on_generation = on_generation) on_generation = on_generation)
meta_rl_agent = Learner() meta_rl_agent = Learner()
ev_learner = Evolutionary_learner(meta_rl_agent, train_loader=train_loader, sec_model=LeNet(), augmentation_space=augmentation_space) ev_learner = Evolutionary_learner(meta_rl_agent, train_loader=train_loader, sec_model=LeNet(), augmentation_space=augmentation_space)
ev_learner.run_instance() ev_learner.run_instance()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment