-
Sun Jin Kim authoredSun Jin Kim authored
gru_learner.py 7.81 KiB
import torch
import MetaAugment.child_networks as cn
from MetaAugment.autoaugment_learners.aa_learner import aa_learner
from MetaAugment.controller_networks.rnn_controller import RNNModel
from pprint import pprint
import pickle
# We will use this augmentation_space temporarily. Later on we will need to
# make sure we are able to add other image functions if the users want.
augmentation_space = [
# (function_name, do_we_need_to_specify_magnitude)
("ShearX", True),
("ShearY", True),
("TranslateX", True),
("TranslateY", True),
("Rotate", True),
("Brightness", True),
("Color", True),
("Contrast", True),
("Sharpness", True),
("Posterize", True),
("Solarize", True),
("AutoContrast", False),
("Equalize", False),
("Invert", False),
]
class gru_learner(aa_learner):
"""
An AutoAugment learner with a GRU controller
The original AutoAugment paper(http://arxiv.org/abs/1805.09501)
uses a LSTM controller updated via Proximal Policy Optimization.
(See Section 3 of AutoAugment paper)
The GRU has been shown to be as powerful of a sequential neural
network as the LSTM whilst training and testing much faster
(https://arxiv.org/abs/1412.3555), which is why we substituted
the LSTM for the GRU.
"""
def __init__(self, sp_num=5, fun_num=14, p_bins=11, m_bins=10, discrete_p_m=True, alpha=0.2):
"""
Args:
alpha (float): Exploration parameter. It is multiplied to
operation tensors before they're softmaxed.
The lower this value, the more smoothed the output
of the softmaxed will be, hence more exploration.
"""
super().__init__(sp_num, fun_num, p_bins, m_bins, discrete_p_m=True)
self.alpha = alpha
self.rnn_output_size = fun_num+p_bins+m_bins
self.controller = RNNModel(mode='GRU', output_size=self.rnn_output_size,
num_layers=2, bias=True)
self.softmax = torch.nn.Softmax(dim=0)
def generate_new_policy(self):
"""
The GRU controller pops out a new policy.
At each time step, the GRU outputs a
(fun_num + p_bins + m_bins, ) dimensional tensor which
contains information regarding which 'image function' to use,
which value of 'probability(prob)' and 'magnitude(mag)' to use.
We run the GRU for 10 timesteps to obtain 10 of such tensors.
We then softmax the parts of the tensor which represents the
choice of function, prob, and mag seperately, so that the
resulting tensor's values sums up to 3.
Then we input each tensor into self.translate_operation_tensor
with parameter (return_log_prob=True), which outputs a tuple
in the form of ('img_function_name', prob, mag) and a float
representing the log probability that we chose the chosen
func, prob and mag.
We add up the log probabilities of each operation.
We turn the operations into a list of 5 tuples such as:
[
(("Invert", 0.8, None), ("Contrast", 0.2, 6)),
(("Rotate", 0.7, 2), ("Invert", 0.8, None)),
(("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
(("ShearY", 0.5, 8), ("Invert", 0.7, None)),
]
This list can then be input into an AutoAugment object
as is done in self.learn()
We return the list and the sum of the log probs
"""
log_prob = 0
# we need a random input to put in
random_input = torch.zeros(self.rnn_output_size, requires_grad=False)
# 2*self.sp_num because we need 2 operations for every subpolicy
vectors = self.controller(input=random_input, time_steps=2*self.sp_num)
# softmax the funcion vector, probability vector, and magnitude vector
# of each timestep
softmaxed_vectors = []
for vector in vectors:
fun_t, prob_t, mag_t = vector.split([self.fun_num, self.p_bins, self.m_bins])
fun_t = self.softmax(fun_t * self.alpha)
prob_t = self.softmax(prob_t * self.alpha)
mag_t = self.softmax(mag_t * self.alpha)
softmaxed_vector = torch.cat((fun_t, prob_t, mag_t))
softmaxed_vectors.append(softmaxed_vector)
new_policy = []
for subpolicy_idx in range(self.sp_num):
# the vector corresponding to the first operation of this subpolicy
op1 = softmaxed_vectors[2*subpolicy_idx]
# the vector corresponding to the second operation of this subpolicy
op2 = softmaxed_vectors[2*subpolicy_idx+1]
# translate both vectors
op1, log_prob1 = self.translate_operation_tensor(op1, return_log_prob=True)
op2, log_prob2 = self.translate_operation_tensor(op2, return_log_prob=True)
new_policy.append((op1,op2))
log_prob += (log_prob1+log_prob2)
return new_policy, log_prob
def learn(self, train_dataset, test_dataset, child_network_architecture, toy_flag, m=8):
# optimizer for training the GRU controller
cont_optim = torch.optim.SGD(self.controller.parameters(), lr=1e-2)
m = 8 # minibatch size
b = 0.88 # b is the running exponential mean of the rewards, used for training stability
# (see section 3.2 of https://arxiv.org/abs/1611.01578)
for _ in range(1000):
cont_optim.zero_grad()
# obj(objective) is $ \sum_{k=1}^m (reward_k-b) \sum_{t=1}^T log(P(a_t|a_{(t-1):1};\theta_c))$,
# which is used in PPO
obj = 0
# sum up the rewards within a minibatch in order to update the running mean, 'b'
mb_rewards_sum = 0
for k in range(m):
# log_prob is $\sum_{t=1}^T log(P(a_t|a_{(t-1):1};\theta_c))$, used in PPO
policy, log_prob = self.generate_new_policy()
pprint(policy)
child_network = child_network_architecture()
reward = self.test_autoaugment_policy(policy, child_network, train_dataset,
test_dataset, toy_flag)
mb_rewards_sum += reward
# log
self.history.append((policy, reward))
# gradient accumulation
obj += (reward-b)*log_prob
# update running mean of rewards
b = 0.7*b + 0.3*(mb_rewards_sum/m)
(-obj).backward() # We put a minus because we want to maximize the objective, not
# minimize it.
cont_optim.step()
# save the history every 1 epochs as a pickle
with open('gru_logs.pkl', 'wb') as file:
pickle.dump(self.history, file)
with open('gru_learner.pkl', 'wb') as file:
pickle.dump(self, file)
if __name__=='__main__':
# We can initialize the train_dataset with its transform as None.
# Later on, we will change this object's transform attribute to the policy
# that we want to test
import torchvision.datasets as datasets
import torchvision
torch.manual_seed(0)
train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=True,
transform=None)
test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=True,
transform=torchvision.transforms.ToTensor())
child_network = cn.lenet
learner = gru_learner(discrete_p_m=False)
learner.learn(train_dataset, test_dataset, child_network, toy_flag=True)
pprint(learner.history)