Skip to content
Snippets Groups Projects
Commit 43019f73 authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

Merge branch 'master' of gitlab.doc.ic.ac.uk:yw21218/metarl

parents c928ebaa 0ff5e107
No related branches found
No related tags found
No related merge requests found
Pipeline #272574 failed
No preview for this file type
...@@ -94,6 +94,25 @@ class AaLearner: ...@@ -94,6 +94,25 @@ class AaLearner:
("Equalize", False), ("Equalize", False),
("Invert", False), ("Invert", False),
] ]
self.aug_space_dict = {
'ShearX': True,
'ShearY' : True,
'TranslateX' : True,
'TranslateY' : True,
'Rotate' : True,
'Brightness' : True,
'Color' : True,
'Contrast': True,
'Sharpness' : True,
'Posterize' : True,
'Solarize' : True,
'AutoContrast' : False,
'Equalize' : False,
'Invert' : False
}
self.exclude_method = exclude_method self.exclude_method = exclude_method
self.augmentation_space = [x for x in augmentation_space if x[0] not in exclude_method] self.augmentation_space = [x for x in augmentation_space if x[0] not in exclude_method]
...@@ -311,7 +330,7 @@ class AaLearner: ...@@ -311,7 +330,7 @@ class AaLearner:
train_dataset, train_dataset,
test_dataset, test_dataset,
logging=False, logging=False,
print_every_epoch=True): print_every_epoch=False):
""" """
Given a policy (using AutoAugment paper terminology), we train a child network Given a policy (using AutoAugment paper terminology), we train a child network
using the policy and return the accuracy (how good the policy is for the dataset and using the policy and return the accuracy (how good the policy is for the dataset and
......
import torch
import MetaAugment.child_networks as cn
from MetaAugment.autoaugment_learners.AaLearner import AaLearner
import random
class Genetic_learner(AaLearner):
def __init__(self,
# search space settings
sp_num=5,
p_bins=11,
m_bins=10,
discrete_p_m=False,
exclude_method=[],
# child network settings
learning_rate=1e-1,
max_epochs=float('inf'),
early_stop_num=20,
batch_size=8,
toy_size=1,
num_offspring=1,
):
super().__init__(
sp_num=sp_num,
p_bins=p_bins,
m_bins=m_bins,
discrete_p_m=discrete_p_m,
batch_size=batch_size,
toy_size=toy_size,
learning_rate=learning_rate,
max_epochs=max_epochs,
early_stop_num=early_stop_num,
exclude_method=exclude_method
)
self.bin_to_aug = {}
for idx, augmentation in enumerate(self.augmentation_space):
bin_rep = '{0:b}'.format(idx)
while len(bin_rep) < len('{0:b}'.format(len(self.augmentation_space))):
bin_rep = '0' + bin_rep
self.bin_to_aug[bin_rep] = augmentation[0]
self.just_augs = [x[0] for x in self.augmentation_space]
self.mag_to_bin = {
'0': "0000",
'1': '0001',
'2': '0010',
'3': '0011',
'4': '0100',
'5': '0101',
'6': '0110',
'7': '0111',
'8' : '1000',
'9': '1001',
'10': '1010',
}
self.prob_to_bin = {
'0': "0000",
'0.0' : '0000',
'0.1': '0001',
'0.2': '0010',
'0.3': '0011',
'0.4': '0100',
'0.5': '0101',
'0.6': '0110',
'0.7': '0111',
'0.8' : '1000',
'0.9': '1001',
'1.0': '1010',
'1' : '1010',
}
self.bin_to_prob = dict((value, key) for key, value in self.prob_to_bin.items())
self.bin_to_mag = dict((value, key) for key, value in self.mag_to_bin.items())
self.aug_to_bin = dict((value, key) for key, value in self.bin_to_aug.items())
self.num_offspring = num_offspring
def gen_random_subpol(self):
choose_items = [x[0] for x in self.augmentation_space]
trans1 = str(random.choice(choose_items))
trans2 = str(random.choice(choose_items))
prob1 = float(random.randrange(0, 11, 1) / 10)
prob2 = float(random.randrange(0, 11, 1) / 10)
if self.aug_space_dict[trans1]:
mag1 = int(random.randrange(0, 10, 1))
else:
mag1 = None
if self.aug_space_dict[trans2]:
mag2 = int(random.randrange(0, 10, 1))
else:
mag2 = None
subpol = ((trans1, prob1, mag1), (trans2, prob2, mag2))
return subpol
def gen_random_policy(self):
pol = []
for _ in range(self.sp_num):
pol.append(self.gen_random_subpol())
return pol
def bin_to_subpol(self, subpol):
pol = []
for idx in range(2):
if subpol[idx*12:(idx*12)+4] in self.bin_to_aug:
trans = self.bin_to_aug[subpol[idx*12:(idx*12)+4]]
else:
trans = random.choice(self.just_augs)
mag_is_none = not self.aug_space_dict[trans]
if subpol[(idx*12)+4: (idx*12)+8] in self.bin_to_prob:
prob = float(self.bin_to_prob[subpol[(idx*12)+4: (idx*12)+8]])
else:
prob = float(random.randrange(0, 11, 1) / 10)
if subpol[(idx*12)+8:(idx*12)+12] in self.bin_to_mag:
mag = int(self.bin_to_mag[subpol[(idx*12)+8:(idx*12)+12]])
else:
mag = int(random.randrange(0, 10, 1))
if mag_is_none:
mag = None
pol.append((trans, prob, mag))
pol = [tuple(pol)]
return pol
def subpol_to_bin(self, subpol):
pol = ''
trans1, prob1, mag1 = subpol[0]
trans2, prob2, mag2 = subpol[1]
pol += self.aug_to_bin[trans1] + self.prob_to_bin[str(prob1)]
if mag1 == None:
pol += '1111'
else:
pol += self.mag_to_bin[str(mag1)]
pol += self.aug_to_bin[trans2] + self.prob_to_bin[str(prob2)]
if mag2 == None:
pol += '1111'
else:
pol += self.mag_to_bin[str(mag2)]
return pol
def choose_parents(self, parents, parents_weights):
parent1 = random.choices(parents, parents_weights, k=1)[0][0]
parent2 = random.choices(parents, parents_weights, k=1)[0][0]
while parent2 == parent1:
parent2 = random.choices(parents, parents_weights, k=1)[0][0]
parent1 = self.subpol_to_bin(parent1)
parent2 = self.subpol_to_bin(parent2)
return (parent1, parent2)
def generate_children(self):
parent_acc = sorted(self.history, key = lambda x: x[1], reverse=True)[:self.sp_num]
parents = [x[0] for x in parent_acc]
parents_weights = [x[1] for x in parent_acc]
new_pols = []
for _ in range(self.num_offspring):
parent1, parent2 = self.choose_parents(parents, parents_weights)
cross_over = random.randrange(1, len(parent2), 1)
child = parent1[:cross_over]
child += parent2[cross_over:]
new_pols.append(child)
return new_pols
def learn(self, train_dataset, test_dataset, child_network_architecture, iterations = 10):
for idx in range(iterations):
print("iteration: ", idx)
if len(self.history) < self.sp_num:
policy = [self.gen_random_subpol()]
else:
policy = self.bin_to_subpol(random.choice(self.generate_children()))
print("policyu: ", policy)
reward = self._test_autoaugment_policy(policy,
child_network_architecture,
train_dataset,
test_dataset)
print("reward: ", reward)
print("new len hsitory: ", len(self.history))
print("hsitory: ", self.history)
# We can initialize the train_dataset with its transform as None.
# Later on, we will change this object's transform attribute to the policy
# that we want to test
import torchvision.datasets as datasets
import torchvision
import MetaAugment.child_networks as cn
from MetaAugment.autoaugment_learners.AaLearner import AaLearner
from MetaAugment.autoaugment_learners.gen_learner import Genetic_learner
import random
# train_dataset = datasets.MNIST(root='./datasets/mnist/train',
# train=True, download=True, transform=None)
# test_dataset = datasets.MNIST(root='./datasets/mnist/test',
# train=False, download=True, transform=torchvision.transforms.ToTensor())
train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train',
train=True, download=True, transform=None)
test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test',
train=False, download=True,
transform=torchvision.transforms.ToTensor())
child_network_architecture = cn.lenet
# child_network_architecture = cn.lenet()
agent = Genetic_learner(
sp_num=2,
toy_size=0.01,
batch_size=4,
learning_rate=0.05,
max_epochs=float('inf'),
early_stop_num=10,
)
agent.learn(train_dataset,
test_dataset,
child_network_architecture=child_network_architecture,
iterations=10)
# with open('randomsearch_logs.pkl', 'wb') as file:
# pickle.dump(self.history, file)
print(agent.history)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment