From 555ef1b4f16e883f2b688df68e42e9bb1ed6ab5e Mon Sep 17 00:00:00 2001 From: Sun Jin Kim <sk2521@ic.ac.uk> Date: Thu, 7 Apr 2022 22:49:54 +0900 Subject: [PATCH] Add new files where GRU_Learner will grow --- .../autoaugment_learners/gru_learner.py | 126 ++++++++++ .../controller_networks/rnn_controller.py | 219 ++++++++++++++++++ 2 files changed, 345 insertions(+) create mode 100644 MetaAugment/controller_networks/rnn_controller.py diff --git a/MetaAugment/autoaugment_learners/gru_learner.py b/MetaAugment/autoaugment_learners/gru_learner.py index e69de29b..58970e68 100644 --- a/MetaAugment/autoaugment_learners/gru_learner.py +++ b/MetaAugment/autoaugment_learners/gru_learner.py @@ -0,0 +1,126 @@ +import torch +import numpy as np +import torchvision.transforms as transforms +import torchvision.transforms.autoaugment as torchaa +from torchvision.transforms import functional as F, InterpolationMode + +from MetaAugment.main import * +import MetaAugment.child_networks as cn +from MetaAugment.autoaugment_learners.autoaugment import * +from MetaAugment.autoaugment_learners.aa_learner import * + +from pprint import pprint + + + +# We will use this augmentation_space temporarily. Later on we will need to +# make sure we are able to add other image functions if the users want. +augmentation_space = [ + # (function_name, do_we_need_to_specify_magnitude) + ("ShearX", True), + ("ShearY", True), + ("TranslateX", True), + ("TranslateY", True), + ("Rotate", True), + ("Brightness", True), + ("Color", True), + ("Contrast", True), + ("Sharpness", True), + ("Posterize", True), + ("Solarize", True), + ("AutoContrast", False), + ("Equalize", False), + ("Invert", False), + ] + +class gru_learner(aa_learner): + # Uses a GRU controller which is updated via Proximal Polixy Optimization + # It is the same model use in + # http://arxiv.org/abs/1805.09501 + # and + # http://arxiv.org/abs/1611.01578 + + def __init__(self, sp_num=5, fun_num=14, p_bins=11, m_bins=10, discrete_p_m=False): + ''' + Args: + spdim: number of subpolicies per policy + fun_num: number of image functions in our search space + p_bins: number of bins we divide the interval [0,1] for probabilities + m_bins: number of bins we divide the magnitude space + ''' + super().__init__(sp_num, fun_num, p_bins, m_bins, discrete_p_m) + + # TODO: We should probably use a different way to store results than self.history + self.history = [] + + + def generate_new_policy(self): + ''' + We run the GRU for 10 timesteps to obtain 10 operations. + At each time step, it outputs a (fun_num + p_bins + m_bins) dimensional vector + + And then for each operation, we put it through self. + Generate a new policy in the form of + [ + (("Invert", 0.8, None), ("Contrast", 0.2, 6)), + (("Rotate", 0.7, 2), ("Invert", 0.8, None)), + (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)), + (("ShearY", 0.5, 8), ("Invert", 0.7, None)), + ] + ''' + new_policy = [] + + for _ in range(self.sp_num): # generate sp_num subpolicies for each policy + ops = [] + # generate 2 operations for each subpolicy + for i in range(2): + # if our agent uses discrete representations of probability and magnitude + if self.discrete_p_m: + new_op = self.generate_new_discrete_operation() + else: + new_op = self.generate_new_continuous_operation() + new_op = self.translate_operation_tensor(new_op) + ops.append(new_op) + + new_subpolicy = tuple(ops) + + new_policy.append(new_subpolicy) + + return new_policy + + + def learn(self, train_dataset, test_dataset, child_network_architecture, toy_flag): + ''' + Does the loop which is seen in Figure 1 in the AutoAugment paper. + In other words, repeat: + 1. <generate a random policy> + 2. <see how good that policy is> + 3. <save how good the policy is in a list/dictionary> + ''' + # test out 15 random policies + for _ in range(15): + policy = self.generate_new_policy() + + pprint(policy) + child_network = child_network_architecture() + reward = self.test_autoaugment_policy(policy, child_network, train_dataset, + test_dataset, toy_flag) + + self.history.append((policy, reward)) + + +if __name__=='__main__': + + # We can initialize the train_dataset with its transform as None. + # Later on, we will change this object's transform attribute to the policy + # that we want to test + train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, + transform=None) + test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False, + transform=torchvision.transforms.ToTensor()) + child_network = cn.lenet + + + rs_learner = randomsearch_learner(discrete_p_m=False) + rs_learner.learn(train_dataset, test_dataset, child_network, toy_flag=True) + pprint(rs_learner.history) \ No newline at end of file diff --git a/MetaAugment/controller_networks/rnn_controller.py b/MetaAugment/controller_networks/rnn_controller.py new file mode 100644 index 00000000..1e228fc1 --- /dev/null +++ b/MetaAugment/controller_networks/rnn_controller.py @@ -0,0 +1,219 @@ +import torch +import torch.nn as nn +import math + +class LSTMCell(nn.Module): + def __init__(self, input_size, hidden_size, bias=True): + super(LSTMCell, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.bias = bias + + self.x2h = nn.Linear(input_size, 4 * hidden_size, bias=bias) + self.h2h = nn.Linear(hidden_size, 4 * hidden_size, bias=bias) + + self.reset_parameters() + + def reset_parameters(self): + std = 1.0 / math.sqrt(self.hidden_size) + for w in self.parameters(): + w.data.uniform_(-std, std) + + def forward(self, input, hx=None): + if hx is None: + hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False) + hx = (hx, hx) + + # We used hx to pack both the hidden and cell states + hx, cx = hx + + hi = self.x2h(input) + self.h2h(hx) + i, f, o, g = torch.chunk(hi, 4, dim=-1) + i = torch.sigmoid(i) + f = torch.sigmoid(f) + o = torch.sigmoid(o) + g = torch.tanh(g) + cy = f * cx + i * g + hy = o * torch.tanh(cy) + + return (hy, cy) + + +class GRUCell(nn.Module): + def __init__(self, input_size, hidden_size, bias=True): + super(GRUCell, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.bias = bias + + self.x2h = nn.Linear(input_size, 2 * hidden_size, bias=bias) + self.h2h = nn.Linear(hidden_size, 2 * hidden_size, bias=bias) + + self.x2r = nn.Linear(input_size, hidden_size, bias=bias) + self.h2r = nn.Linear(hidden_size, hidden_size, bias=bias) + self.reset_parameters() + + + def reset_parameters(self): + std = 1.0 / math.sqrt(self.hidden_size) + for w in self.parameters(): + w.data.uniform_(-std, std) + + def forward(self, input, hx=None): + if hx is None: + hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False) + + z, r = torch.chunk(self.x2h(input) + self.h2h(hx), 2, -1) + z = torch.sigmoid(z) + r = torch.sigmoid(r) + g = torch.tanh(self.h2r(hx)*r + self.x2r(input)) + hy = z * hx + (1 - z) * g + + return hy + + +class RNNModel(nn.Module): + def __init__(self, mode, input_size, hidden_size, num_layers, bias, output_size): + super(RNNModel, self).__init__() + self.mode = mode + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.bias = bias + self.output_size = output_size + + self.rnn_cell_list = nn.ModuleList() + + if mode == 'LSTM': + + self.rnn_cell_list.append(LSTMCell(self.input_size, + self.hidden_size, + self.bias)) + for l in range(1, self.num_layers): + self.rnn_cell_list.append(LSTMCell(self.hidden_size, + self.hidden_size, + self.bias)) + + + elif mode == 'GRU': + + self.rnn_cell_list.append(GRUCell(self.input_size, + self.hidden_size, + self.bias)) + for l in range(1, self.num_layers): + self.rnn_cell_list.append(GRUCell(self.hidden_size, + self.hidden_size, + self.bias)) + + else: + raise ValueError("Invalid RNN mode selected.") + + + self.att_fc = nn.Linear(self.hidden_size, 1) + self.fc = nn.Linear(self.hidden_size, self.output_size) + + + def forward(self, input, hx=None): + + outs = [] + h0 = [None] * self.num_layers if hx is None else list(hx) + + X = list(input.permute(1, 0, 2)) + for j, l in enumerate(self.rnn_cell_list): + hx = h0[j] + for i in range(input.shape[1]): + hx = l(X[i], hx) + X[i] = hx if self.mode != 'LSTM' else hx[0] + outs = X + + + # out = outs[-1].squeeze() + + # out = self.fc(out) + + # return out + + return outs + + +class BidirRecurrentModel(nn.Module): + def __init__(self, mode, input_size, hidden_size, num_layers, bias, output_size): + super(BidirRecurrentModel, self).__init__() + self.mode = mode + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.bias = bias + self.output_size = output_size + + self.rnn_cell_list = nn.ModuleList() + self.rnn_cell_list_rev = nn.ModuleList() + + if mode == 'LSTM': + self.rnn_cell_list.append(LSTMCell(self.input_size, + self.hidden_size, + self.bias)) + for l in range(1, self.num_layers): + self.rnn_cell_list.append(LSTMCell(self.hidden_size, + self.hidden_size, + self.bias)) + + self.rnn_cell_list_rev.append(LSTMCell(self.input_size, + self.hidden_size, + self.bias)) + for l in range(1, self.num_layers): + self.rnn_cell_list_rev.append(LSTMCell(self.hidden_size, + self.hidden_size, + self.bias)) + + elif mode == 'GRU': + self.rnn_cell_list.append(GRUCell(self.input_size, + self.hidden_size, + self.bias)) + for l in range(1, self.num_layers): + self.rnn_cell_list.append(GRUCell(self.hidden_size, + self.hidden_size, + self.bias)) + + self.rnn_cell_list_rev.append(GRUCell(self.input_size, + self.hidden_size, + self.bias)) + for l in range(1, self.num_layers): + self.rnn_cell_list_rev.append(GRUCell(self.hidden_size, + self.hidden_size, + self.bias)) + + else: + raise ValueError("Invalid RNN mode selected.") + + + self.fc = nn.Linear(2 * self.hidden_size, self.output_size) + + + def forward(self, input, hx=None): + + outs = [] + outs_rev = [] + + X = list(input.permute(1, 0, 2)) + X_rev = list(input.permute(1, 0, 2)) + X_rev.reverse() + hi = [None] * self.num_layers if hx is None else list(hx) + hi_rev = [None] * self.num_layers if hx is None else list(hx) + for j in range(self.num_layers): + hx = hi[j] + hx_rev = hi_rev[j] + for i in range(input.shape[1]): + hx = self.rnn_cell_list[j](X[i], hx) + X[i] = hx if self.mode != 'LSTM' else hx[0] + hx_rev = self.rnn_cell_list_rev[j](X_rev[i], hx_rev) + X_rev[i] = hx_rev if self.mode != 'LSTM' else hx_rev[0] + outs = X + outs_rev = X_rev + + out = outs[-1].squeeze() + out_rev = outs_rev[0].squeeze() + out = torch.cat((out, out_rev), 1) + + out = self.fc(out) + return out \ No newline at end of file -- GitLab