Skip to content
Snippets Groups Projects
Commit 555ef1b4 authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

Add new files where GRU_Learner will grow

parent 92d97847
No related branches found
No related tags found
No related merge requests found
import torch
import numpy as np
import torchvision.transforms as transforms
import torchvision.transforms.autoaugment as torchaa
from torchvision.transforms import functional as F, InterpolationMode
from MetaAugment.main import *
import MetaAugment.child_networks as cn
from MetaAugment.autoaugment_learners.autoaugment import *
from MetaAugment.autoaugment_learners.aa_learner import *
from pprint import pprint
# We will use this augmentation_space temporarily. Later on we will need to
# make sure we are able to add other image functions if the users want.
augmentation_space = [
# (function_name, do_we_need_to_specify_magnitude)
("ShearX", True),
("ShearY", True),
("TranslateX", True),
("TranslateY", True),
("Rotate", True),
("Brightness", True),
("Color", True),
("Contrast", True),
("Sharpness", True),
("Posterize", True),
("Solarize", True),
("AutoContrast", False),
("Equalize", False),
("Invert", False),
]
class gru_learner(aa_learner):
# Uses a GRU controller which is updated via Proximal Polixy Optimization
# It is the same model use in
# http://arxiv.org/abs/1805.09501
# and
# http://arxiv.org/abs/1611.01578
def __init__(self, sp_num=5, fun_num=14, p_bins=11, m_bins=10, discrete_p_m=False):
'''
Args:
spdim: number of subpolicies per policy
fun_num: number of image functions in our search space
p_bins: number of bins we divide the interval [0,1] for probabilities
m_bins: number of bins we divide the magnitude space
'''
super().__init__(sp_num, fun_num, p_bins, m_bins, discrete_p_m)
# TODO: We should probably use a different way to store results than self.history
self.history = []
def generate_new_policy(self):
'''
We run the GRU for 10 timesteps to obtain 10 operations.
At each time step, it outputs a (fun_num + p_bins + m_bins) dimensional vector
And then for each operation, we put it through self.
Generate a new policy in the form of
[
(("Invert", 0.8, None), ("Contrast", 0.2, 6)),
(("Rotate", 0.7, 2), ("Invert", 0.8, None)),
(("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
(("ShearY", 0.5, 8), ("Invert", 0.7, None)),
]
'''
new_policy = []
for _ in range(self.sp_num): # generate sp_num subpolicies for each policy
ops = []
# generate 2 operations for each subpolicy
for i in range(2):
# if our agent uses discrete representations of probability and magnitude
if self.discrete_p_m:
new_op = self.generate_new_discrete_operation()
else:
new_op = self.generate_new_continuous_operation()
new_op = self.translate_operation_tensor(new_op)
ops.append(new_op)
new_subpolicy = tuple(ops)
new_policy.append(new_subpolicy)
return new_policy
def learn(self, train_dataset, test_dataset, child_network_architecture, toy_flag):
'''
Does the loop which is seen in Figure 1 in the AutoAugment paper.
In other words, repeat:
1. <generate a random policy>
2. <see how good that policy is>
3. <save how good the policy is in a list/dictionary>
'''
# test out 15 random policies
for _ in range(15):
policy = self.generate_new_policy()
pprint(policy)
child_network = child_network_architecture()
reward = self.test_autoaugment_policy(policy, child_network, train_dataset,
test_dataset, toy_flag)
self.history.append((policy, reward))
if __name__=='__main__':
# We can initialize the train_dataset with its transform as None.
# Later on, we will change this object's transform attribute to the policy
# that we want to test
train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False,
transform=None)
test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False,
transform=torchvision.transforms.ToTensor())
child_network = cn.lenet
rs_learner = randomsearch_learner(discrete_p_m=False)
rs_learner.learn(train_dataset, test_dataset, child_network, toy_flag=True)
pprint(rs_learner.history)
\ No newline at end of file
import torch
import torch.nn as nn
import math
class LSTMCell(nn.Module):
def __init__(self, input_size, hidden_size, bias=True):
super(LSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
self.x2h = nn.Linear(input_size, 4 * hidden_size, bias=bias)
self.h2h = nn.Linear(hidden_size, 4 * hidden_size, bias=bias)
self.reset_parameters()
def reset_parameters(self):
std = 1.0 / math.sqrt(self.hidden_size)
for w in self.parameters():
w.data.uniform_(-std, std)
def forward(self, input, hx=None):
if hx is None:
hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
hx = (hx, hx)
# We used hx to pack both the hidden and cell states
hx, cx = hx
hi = self.x2h(input) + self.h2h(hx)
i, f, o, g = torch.chunk(hi, 4, dim=-1)
i = torch.sigmoid(i)
f = torch.sigmoid(f)
o = torch.sigmoid(o)
g = torch.tanh(g)
cy = f * cx + i * g
hy = o * torch.tanh(cy)
return (hy, cy)
class GRUCell(nn.Module):
def __init__(self, input_size, hidden_size, bias=True):
super(GRUCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
self.x2h = nn.Linear(input_size, 2 * hidden_size, bias=bias)
self.h2h = nn.Linear(hidden_size, 2 * hidden_size, bias=bias)
self.x2r = nn.Linear(input_size, hidden_size, bias=bias)
self.h2r = nn.Linear(hidden_size, hidden_size, bias=bias)
self.reset_parameters()
def reset_parameters(self):
std = 1.0 / math.sqrt(self.hidden_size)
for w in self.parameters():
w.data.uniform_(-std, std)
def forward(self, input, hx=None):
if hx is None:
hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
z, r = torch.chunk(self.x2h(input) + self.h2h(hx), 2, -1)
z = torch.sigmoid(z)
r = torch.sigmoid(r)
g = torch.tanh(self.h2r(hx)*r + self.x2r(input))
hy = z * hx + (1 - z) * g
return hy
class RNNModel(nn.Module):
def __init__(self, mode, input_size, hidden_size, num_layers, bias, output_size):
super(RNNModel, self).__init__()
self.mode = mode
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bias = bias
self.output_size = output_size
self.rnn_cell_list = nn.ModuleList()
if mode == 'LSTM':
self.rnn_cell_list.append(LSTMCell(self.input_size,
self.hidden_size,
self.bias))
for l in range(1, self.num_layers):
self.rnn_cell_list.append(LSTMCell(self.hidden_size,
self.hidden_size,
self.bias))
elif mode == 'GRU':
self.rnn_cell_list.append(GRUCell(self.input_size,
self.hidden_size,
self.bias))
for l in range(1, self.num_layers):
self.rnn_cell_list.append(GRUCell(self.hidden_size,
self.hidden_size,
self.bias))
else:
raise ValueError("Invalid RNN mode selected.")
self.att_fc = nn.Linear(self.hidden_size, 1)
self.fc = nn.Linear(self.hidden_size, self.output_size)
def forward(self, input, hx=None):
outs = []
h0 = [None] * self.num_layers if hx is None else list(hx)
X = list(input.permute(1, 0, 2))
for j, l in enumerate(self.rnn_cell_list):
hx = h0[j]
for i in range(input.shape[1]):
hx = l(X[i], hx)
X[i] = hx if self.mode != 'LSTM' else hx[0]
outs = X
# out = outs[-1].squeeze()
# out = self.fc(out)
# return out
return outs
class BidirRecurrentModel(nn.Module):
def __init__(self, mode, input_size, hidden_size, num_layers, bias, output_size):
super(BidirRecurrentModel, self).__init__()
self.mode = mode
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bias = bias
self.output_size = output_size
self.rnn_cell_list = nn.ModuleList()
self.rnn_cell_list_rev = nn.ModuleList()
if mode == 'LSTM':
self.rnn_cell_list.append(LSTMCell(self.input_size,
self.hidden_size,
self.bias))
for l in range(1, self.num_layers):
self.rnn_cell_list.append(LSTMCell(self.hidden_size,
self.hidden_size,
self.bias))
self.rnn_cell_list_rev.append(LSTMCell(self.input_size,
self.hidden_size,
self.bias))
for l in range(1, self.num_layers):
self.rnn_cell_list_rev.append(LSTMCell(self.hidden_size,
self.hidden_size,
self.bias))
elif mode == 'GRU':
self.rnn_cell_list.append(GRUCell(self.input_size,
self.hidden_size,
self.bias))
for l in range(1, self.num_layers):
self.rnn_cell_list.append(GRUCell(self.hidden_size,
self.hidden_size,
self.bias))
self.rnn_cell_list_rev.append(GRUCell(self.input_size,
self.hidden_size,
self.bias))
for l in range(1, self.num_layers):
self.rnn_cell_list_rev.append(GRUCell(self.hidden_size,
self.hidden_size,
self.bias))
else:
raise ValueError("Invalid RNN mode selected.")
self.fc = nn.Linear(2 * self.hidden_size, self.output_size)
def forward(self, input, hx=None):
outs = []
outs_rev = []
X = list(input.permute(1, 0, 2))
X_rev = list(input.permute(1, 0, 2))
X_rev.reverse()
hi = [None] * self.num_layers if hx is None else list(hx)
hi_rev = [None] * self.num_layers if hx is None else list(hx)
for j in range(self.num_layers):
hx = hi[j]
hx_rev = hi_rev[j]
for i in range(input.shape[1]):
hx = self.rnn_cell_list[j](X[i], hx)
X[i] = hx if self.mode != 'LSTM' else hx[0]
hx_rev = self.rnn_cell_list_rev[j](X_rev[i], hx_rev)
X_rev[i] = hx_rev if self.mode != 'LSTM' else hx_rev[0]
outs = X
outs_rev = X_rev
out = outs[-1].squeeze()
out_rev = outs_rev[0].squeeze()
out = torch.cat((out, out_rev), 1)
out = self.fc(out)
return out
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment