Newer
Older
import torch.nn as nn
import torch.optim as optim
from MetaAugment.main import train_child_network, create_toy
from MetaAugment.autoaugment_learners.autoaugment import AutoAugment
import torchvision.transforms as transforms
# We will use this augmentation_space temporarily. Later on we will need to
# make sure we are able to add other image functions if the users want.
augmentation_space = [
# (function_name, do_we_need_to_specify_magnitude)
("ShearX", True),
("ShearY", True),
("TranslateX", True),
("TranslateY", True),
("Rotate", True),
("Brightness", True),
("Color", True),
("Contrast", True),
("Sharpness", True),
("Posterize", True),
("Solarize", True),
("AutoContrast", False),
("Equalize", False),
("Invert", False),
]
"""
The parent class for all aa_learner's
"""
def __init__(self, sp_num=5, fun_num=14, p_bins=11, m_bins=10, discrete_p_m=False):
spdim (int): number of subpolicies per policy
fun_num (int): number of image functions in our search space
p_bins (int): number of bins we divide the interval [0,1] for probabilities
m_bins (int): number of bins we divide the magnitude space
discrete_p_m (boolean): Whether or not the agent should represent probability and
magnitude as discrete variables as the out put of the
controller (A controller can be a neural network, genetic
algorithm, etc.)
self.fun_num = fun_num
self.p_bins = p_bins
self.m_bins = m_bins
self.op_tensor_length = fun_num+p_bins+m_bins if discrete_p_m else fun_num+2
# should we repre
self.discrete_p_m = discrete_p_m
# TODO: We should probably use a different way to store results than self.history
self.history = []
def translate_operation_tensor(self, operation_tensor, return_log_prob=False, argmax=False):
takes in a tensor representing an operation and returns an actual operation which
is in the form of:
("Invert", 0.8, None)
or
("Contrast", 0.2, 6)
Args:
operation_tensor (tensor):
Sun Jin Kim
committed
We expect this tensor to already have been softmaxed.
Furthermore,
- If self.discrete_p_m is True, we expect to take in a tensor with
dimension (self.fun_num + self.p_bins + self.m_bins)
Sun Jin Kim
committed
- If self.discrete_p_m is False, we expect to take in a tensor with
dimension (self.fun_num + 1 + 1)
Sun Jin Kim
committed
return_log_prob (boolesn):
When this is on, we return which indices (of fun, prob, mag) were
chosen (either randomly or deterministically, depending on argmax).
This is used, for example, in the gru_learner to calculate the
probability of the actions were chosen, which is then logged, then
differentiated.
Sun Jin Kim
committed
argmax (boolean):
Whether we are taking the argmax of the softmaxed tensors.
If this is False, we treat the softmaxed outputs as multinomial pdf's.
Returns:
operation (list of tuples):
An operation in the format that can be directly put into an
AutoAugment object.
log_prob (float):
Used in reinforcement learning updates, such as proximal policy update
in the gru_learner.
Can only be used when self.discrete_p_m.
We add the logged values of the indices of the image_function,
probability, and magnitude chosen.
This corresponds to multiplying the non-logged values, then logging
it.
"""
if (not self.discrete_p_m) and return_log_prob:
raise ValueError("You are not supposed to use return_log_prob=True when the agent's \
self.discrete_p_m is False!")
# make sure shape is correct
assert operation_tensor.shape==(self.op_tensor_length, ), operation_tensor.shape
# if probability and magnitude are represented as discrete variables
if self.discrete_p_m:
fun_t, prob_t, mag_t = operation_tensor.split([self.fun_num, self.p_bins, self.m_bins])
Sun Jin Kim
committed
# make sure they are of right size
assert fun_t.shape==(self.fun_num,), f'{fun_t.shape} != {self.fun_num}'
assert prob_t.shape==(self.p_bins,), f'{prob_t.shape} != {self.p_bins}'
assert mag_t.shape==(self.m_bins,), f'{mag_t.shape} != {self.m_bins}'
Sun Jin Kim
committed
if argmax==True:
fun_idx = torch.argmax(fun_t).item()
prob_idx = torch.argmax(prob_t).item() # 0 <= p <= 10
mag = torch.argmax(mag_t).item() # 0 <= m <= 9
Sun Jin Kim
committed
elif argmax==False:
# we need these to add up to 1 to be valid pdf's of multinomials
assert torch.sum(fun_t).isclose(torch.ones(1)), torch.sum(fun_t)
assert torch.sum(prob_t).isclose(torch.ones(1)), torch.sum(prob_t)
assert torch.sum(mag_t).isclose(torch.ones(1)), torch.sum(mag_t)
fun_idx = torch.multinomial(fun_t, 1).item() # 0 <= fun <= self.fun_num-1
prob_idx = torch.multinomial(prob_t, 1).item() # 0 <= p <= 10
mag = torch.multinomial(mag_t, 1).item() # 0 <= m <= 9
function = augmentation_space[fun_idx][0]
prob = prob_idx/10
indices = (fun_idx, prob_idx, mag)
# log probability is the sum of the log of the softmax values of the indices
# (of fun_t, prob_t, mag_t) that we have chosen
log_prob = torch.log(fun_t[fun_idx]) + torch.log(prob_t[prob_idx]) + torch.log(mag_t[mag])
# if probability and magnitude are represented as continuous variables
else:
fun_t, prob, mag = operation_tensor.split([self.fun_num, 1, 1])
Sun Jin Kim
committed
# make sure the shape is correct
assert fun_t.shape==(self.fun_num,), f'{fun_t.shape} != {self.fun_num}'
if argmax==True:
Sun Jin Kim
committed
elif argmax==False:
assert torch.sum(fun_t).isclose(torch.ones(1))
fun_idx = torch.multinomial(fun_t, 1).item()
prob = round(prob, 1) # round to nearest first decimal digit
mag = round(mag) # round to nearest integer
Sun Jin Kim
committed
Sun Jin Kim
committed
assert 0 <= prob <= 1
assert 0 <= mag <= self.m_bins-1
# if the image function does not require a magnitude, we set the magnitude to None
if augmentation_space[fun_idx][1] == True: # if the image function has a magnitude
operation = (function, prob, mag)
operation = (function, prob, None)
if return_log_prob:
return operation, log_prob
else:
return operation
"""
Generate a new policy which can be fed into an AutoAugment object
by calling:
AutoAugment.subpolicies = policy
Args:
none
Returns:
new_policy (list[tuple]):
A new policy generated by the controller. It
has the form of:
[
(("Invert", 0.8, None), ("Contrast", 0.2, 6)),
(("Rotate", 0.7, 2), ("Invert", 0.8, None)),
(("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
(("ShearY", 0.5, 8), ("Invert", 0.7, None)),
]
This object can be fed into an AutoAUgment object
by calling: AutoAugment.subpolicies = policy
"""
raise NotImplementedError('generate_new_policy not implemented in aa_learner')
def learn(self, train_dataset, test_dataset, child_network_architecture, toy_flag):
"""
Runs the main loop (of finding a good policy for the given child network,
training dataset, and test(validation) dataset)
Does the loop which is seen in Figure 1 in the AutoAugment paper
which is:
1. <generate a random policy>
2. <see how good that policy is>
3. <save how good the policy is in a list/dictionary>
Args:
train_dataset (torchvision.dataset.vision.VisionDataset)
test_dataset (torchvision.dataset.vision.VisionDataset)
child_network_architecture (type): NOTE THAT THIS VARIABLE IS NOT
A nn.module object. Therefore, this needs
to be, say, `models.LeNet` instead of
`models.LeNet()`.
toy_flag (boolean): whether we want to obtain a toy version of
train_dataset and test_dataset and use those.
Returns:
none
"""
# test out 15 random policies
for _ in range(15):
policy = self.generate_new_policy()
pprint(policy)
child_network = child_network_architecture()
reward = self.test_autoaugment_policy(policy, child_network, train_dataset,
test_dataset, toy_flag)
self.history.append((policy, reward))
def test_autoaugment_policy(self, policy, child_network, train_dataset, test_dataset,
toy_flag, logging=False):
Given a policy (using AutoAugment paper terminology), we train a child network
using the policy and return the accuracy (how good the policy is for the dataset and
child network).
Args:
policy (list[tuple]): A list of tuples representing a policy.
child_network (nn.module)
train_dataset (torchvision.dataset.vision.VisionDataset)
test_dataset (torchvision.dataset.vision.VisionDataset)
toy_flag (boolean): Whether we want to obtain a toy version of
train_dataset and test_dataset and use those.
logging (boolean): Whether we want to save logs
Returns:
accuracy (float): best accuracy reached in any
"""
# We need to define an object aa_transform which takes in the image and
# transforms it with the policy (specified in its .policies attribute)
# in its forward pass
aa_transform = AutoAugment()
aa_transform.subpolicies = policy
train_transform = transforms.Compose([
aa_transform,
transforms.ToTensor()
])
# We feed the transformation into the Dataset object
train_dataset.transform = train_transform
# create Dataloader objects out of the Dataset objects
train_loader, test_loader = create_toy(train_dataset,
test_dataset,
batch_size=32,
n_samples=0.5,
# train the child network with the dataloaders equipped with our specific policy
accuracy = train_child_network(child_network,
train_loader,
test_loader,
# sgd = optim.Adadelta(child_network.parameters(), lr=1e-2),
cost = nn.CrossEntropyLoss(),
logging = logging,
print_every_epoch=True)
# if logging is true, 'accuracy' is actually a tuple: (accuracy, accuracy_log)
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
return accuracy
def demo_plot(self, train_dataset, test_dataset, child_network_architecture, toy_flag, n=5):
"""
I made this to plot a couple of accuracy graphs to help manually tune my gradient
optimizer hyperparameters.
Saves a plot of `n` training accuracy graphs overlapped.
"""
acc_lists = []
# This is dummy code
# test out `n` random policies
for _ in range(n):
policy = self.generate_new_policy()
pprint(policy)
child_network = child_network_architecture()
reward, acc_list = self.test_autoaugment_policy(policy, child_network, train_dataset,
test_dataset, toy_flag, logging=True)
self.history.append((policy, reward))
acc_lists.append(acc_list)
for acc_list in acc_lists:
plt.plot(acc_list)
plt.title('I ran 5 random policies to see if there is any sign of \
catastrophic failure during training. If there are \
any lines which reach significantly lower (>10%) \
accuracies, you might want to tune the hyperparameters')
plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.show()
plt.savefig('training_graphs_without_policies')