diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py index 3a6b3e4c5d11f1cddaac0e1d18b372ef45f14584..6e7874e94a82c5d77740e219261d14ab8f33b4de 100644 --- a/MetaAugment/autoaugment_learners/aa_learner.py +++ b/MetaAugment/autoaugment_learners/aa_learner.py @@ -9,6 +9,8 @@ from MetaAugment.autoaugment_learners.autoaugment import AutoAugment import torchvision.transforms as transforms from pprint import pprint +import matplotlib.pyplot as plt + # We will use this augmentation_space temporarily. Later on we will need to # make sure we are able to add other image functions if the users want. @@ -59,7 +61,7 @@ class aa_learner: self.history = [] - def translate_operation_tensor(self, operation_tensor, argmax=False): + def translate_operation_tensor(self, operation_tensor, return_log_prob=False, argmax=False): ''' takes in a tensor representing an operation and returns an actual operation which is in the form of: @@ -76,10 +78,28 @@ class aa_learner: - If self.discrete_p_m is False, we expect to take in a tensor with dimension (self.fun_num + 1 + 1) + return_log_prob (boolesn): + When this is on, we return which indices (of fun, prob, mag) were + chosen (either randomly or deterministically, depending on argmax). + This is used, for example, in the gru_learner to calculate the + probability of the actions were chosen, which is then logged, then + differentiated. + argmax (boolean): Whether we are taking the argmax of the softmaxed tensors. If this is False, we treat the softmaxed outputs as multinomial pdf's. + + Returns: + operation (list of tuples): + An operation in the format that can be directly put into an + AutoAugment object. + log_prob + ''' + if (not self.discrete_p_m) and return_log_prob: + raise ValueError("You are not supposed to use return_log_prob=True when the agent's \ + self.discrete_p_m is False!") + # make sure shape is correct assert operation_tensor.shape==(self.op_tensor_length, ), operation_tensor.shape @@ -92,53 +112,66 @@ class aa_learner: assert prob_t.shape==(self.p_bins,), f'{prob_t.shape} != {self.p_bins}' assert mag_t.shape==(self.m_bins,), f'{mag_t.shape} != {self.m_bins}' + if argmax==True: - fun = torch.argmax(fun_t).item() - prob = torch.argmax(prob_t).item() # 0 <= p <= 10 + fun_idx = torch.argmax(fun_t).item() + prob_idx = torch.argmax(prob_t).item() # 0 <= p <= 10 mag = torch.argmax(mag_t).item() # 0 <= m <= 9 elif argmax==False: # we need these to add up to 1 to be valid pdf's of multinomials assert torch.sum(fun_t).isclose(torch.ones(1)), torch.sum(fun_t) assert torch.sum(prob_t).isclose(torch.ones(1)), torch.sum(prob_t) assert torch.sum(mag_t).isclose(torch.ones(1)), torch.sum(mag_t) - fun = torch.multinomial(fun_t, 1).item() # 0 <= fun <= self.fun_num-1 - prob = torch.multinomial(prob_t, 1).item() # 0 <= p <= 10 + + fun_idx = torch.multinomial(fun_t, 1).item() # 0 <= fun <= self.fun_num-1 + prob_idx = torch.multinomial(prob_t, 1).item() # 0 <= p <= 10 mag = torch.multinomial(mag_t, 1).item() # 0 <= m <= 9 - function = augmentation_space[fun][0] - prob = prob/10 + function = augmentation_space[fun_idx][0] + prob = prob_idx/10 + + indices = (fun_idx, prob_idx, mag) + + # log probability is the sum of the log of the softmax values of the indices + # (of fun_t, prob_t, mag_t) that we have chosen + log_prob = torch.log(fun_t[fun_idx]) + torch.log(prob_t[prob_idx]) + torch.log(mag_t[mag]) # if probability and magnitude are represented as continuous variables else: fun_t, prob, mag = operation_tensor.split([self.fun_num, 1, 1]) + prob = prob.item() # 0 =< prob =< 1 + mag = mag.item() # 0 =< mag =< 9 # make sure the shape is correct assert fun_t.shape==(self.fun_num,), f'{fun_t.shape} != {self.fun_num}' if argmax==True: - fun = torch.argmax(fun_t) + fun_idx = torch.argmax(fun_t) elif argmax==False: assert torch.sum(fun_t).isclose(torch.ones(1)) - fun = torch.multinomial(fun_t, 1) + fun_idx = torch.multinomial(fun_t, 1).item() + prob = round(prob, 1) # round to nearest first decimal digit + mag = round(mag) # round to nearest integer - function = augmentation_space[fun][0] - prob = round(prob, 1) # round to nearest first decimal digit - mag = round(mag) # round to nearest integer + function = augmentation_space[fun_idx][0] assert 0 <= prob <= 1 assert 0 <= mag <= self.m_bins-1 # if the image function does not require a magnitude, we set the magnitude to None - if augmentation_space[fun][1] == True: # if the image function has a magnitude - return (function, prob, mag) + if augmentation_space[fun_idx][1] == True: # if the image function has a magnitude + operation = (function, prob, mag) else: - return (function, prob, None) - - - + operation = (function, prob, None) + + if return_log_prob: + return operation, log_prob + else: + return operation + def generate_new_policy(self): ''' @@ -176,7 +209,8 @@ class aa_learner: self.history.append((policy, reward)) - def test_autoaugment_policy(self, policy, child_network, train_dataset, test_dataset, toy_flag): + def test_autoaugment_policy(self, policy, child_network, train_dataset, test_dataset, + toy_flag, logging=False): ''' Given a policy (using AutoAugment paper terminology), we train a child network using the policy and return the accuracy (how good the policy is for the dataset and @@ -198,7 +232,7 @@ class aa_learner: # create Dataloader objects out of the Dataset objects train_loader, test_loader = create_toy(train_dataset, test_dataset, - batch_size=32, + batch_size=64, n_samples=0.01, seed=100) @@ -206,9 +240,12 @@ class aa_learner: accuracy = train_child_network(child_network, train_loader, test_loader, - sgd = optim.SGD(child_network.parameters(), lr=1e-1), + sgd = optim.SGD(child_network.parameters(), lr=3e-1), + # sgd = optim.Adadelta(child_network.parameters(), lr=1e-2), cost = nn.CrossEntropyLoss(), - max_epochs = 100, - early_stop_num = 15, - logging = False) + max_epochs = 3000000, + early_stop_num = 120, + logging = logging) + + # if logging is true, 'accuracy' is actually a tuple: (accuracy, accuracy_log) return accuracy \ No newline at end of file diff --git a/MetaAugment/autoaugment_learners/gru_learner.py b/MetaAugment/autoaugment_learners/gru_learner.py index f003f1f13b06de1c8f7b26de98f52bfffc17b635..377064a2c107c573bf8ae9a89630b22d8ee51d6c 100644 --- a/MetaAugment/autoaugment_learners/gru_learner.py +++ b/MetaAugment/autoaugment_learners/gru_learner.py @@ -5,6 +5,7 @@ from MetaAugment.autoaugment_learners.aa_learner import aa_learner from MetaAugment.controller_networks.rnn_controller import RNNModel from pprint import pprint +import pickle @@ -36,15 +37,18 @@ class gru_learner(aa_learner): # and # http://arxiv.org/abs/1611.01578 - def __init__(self, sp_num=5, fun_num=14, p_bins=11, m_bins=10, discrete_p_m=True): + def __init__(self, sp_num=5, fun_num=14, p_bins=11, m_bins=10, discrete_p_m=True, alpha=0.2): ''' Args: spdim: number of subpolicies per policy fun_num: number of image functions in our search space p_bins: number of bins we divide the interval [0,1] for probabilities m_bins: number of bins we divide the magnitude space + + alpha: Exploration parameter. The lower this value, the more exploration. ''' super().__init__(sp_num, fun_num, p_bins, m_bins, discrete_p_m=True) + self.alpha = alpha self.rnn_output_size = fun_num+p_bins+m_bins self.controller = RNNModel(mode='GRU', output_size=self.rnn_output_size, @@ -66,8 +70,10 @@ class gru_learner(aa_learner): (("ShearY", 0.5, 8), ("Invert", 0.7, None)), ] ''' + log_prob = 0 + # we need a random input to put in - random_input = torch.rand(self.rnn_output_size, requires_grad=False) + random_input = torch.zeros(self.rnn_output_size, requires_grad=False) # 2*self.sp_num because we need 2 operations for every subpolicy vectors = self.controller(input=random_input, time_steps=2*self.sp_num) @@ -76,15 +82,13 @@ class gru_learner(aa_learner): # of each timestep softmaxed_vectors = [] for vector in vectors: - print(vector) fun_t, prob_t, mag_t = vector.split([self.fun_num, self.p_bins, self.m_bins]) - fun_t = self.softmax(fun_t) - prob_t = self.softmax(prob_t) - mag_t = self.softmax(mag_t) + fun_t = self.softmax(fun_t * self.alpha) + prob_t = self.softmax(prob_t * self.alpha) + mag_t = self.softmax(mag_t * self.alpha) softmaxed_vector = torch.cat((fun_t, prob_t, mag_t)) softmaxed_vectors.append(softmaxed_vector) - print(softmaxed_vectors) new_policy = [] for subpolicy_idx in range(self.sp_num): @@ -94,16 +98,16 @@ class gru_learner(aa_learner): op2 = softmaxed_vectors[2*subpolicy_idx+1] # translate both vectors - op1 = self.translate_operation_tensor(op1) - op2 = self.translate_operation_tensor(op2) + op1, log_prob1 = self.translate_operation_tensor(op1, return_log_prob=True) + op2, log_prob2 = self.translate_operation_tensor(op2, return_log_prob=True) - print('new subpol:', (op1, op2)) new_policy.append((op1,op2)) + log_prob += (log_prob1+log_prob2) - return new_policy + return new_policy, log_prob - def learn(self, train_dataset, test_dataset, child_network_architecture, toy_flag): + def learn(self, train_dataset, test_dataset, child_network_architecture, toy_flag, m=8): ''' Does the loop which is seen in Figure 1 in the AutoAugment paper. In other words, repeat: @@ -111,16 +115,52 @@ class gru_learner(aa_learner): 2. <see how good that policy is> 3. <save how good the policy is in a list/dictionary> ''' - # test out 15 random policies - for _ in range(15): - policy = self.generate_new_policy() + # optimizer for training the GRU controller + cont_optim = torch.optim.SGD(self.controller.parameters(), lr=1e-2) + + m = 8 # minibatch size + b = 0.88 # b is the running exponential mean of the rewards, used for training stability + # (see section 3.2 of https://arxiv.org/abs/1611.01578) + + for _ in range(1000): + cont_optim.zero_grad() + + # obj(objective) is $ \sum_{k=1}^m (reward_k-b) \sum_{t=1}^T log(P(a_t|a_{(t-1):1};\theta_c))$, + # which is used in PPO + obj = 0 + + # sum up the rewards within a minibatch in order to update the running mean, 'b' + mb_rewards_sum = 0 + + for k in range(m): + # log_prob is $\sum_{t=1}^T log(P(a_t|a_{(t-1):1};\theta_c))$, used in PPO + policy, log_prob = self.generate_new_policy() - pprint(policy) - child_network = child_network_architecture() - reward = self.test_autoaugment_policy(policy, child_network, train_dataset, - test_dataset, toy_flag) + pprint(policy) + child_network = child_network_architecture() + reward = self.test_autoaugment_policy(policy, child_network, train_dataset, + test_dataset, toy_flag) + mb_rewards_sum += reward + + # log + self.history.append((policy, reward)) + + # gradient accumulation + obj += (reward-b)*log_prob + + # update running mean of rewards + b = 0.7*b + 0.3*(mb_rewards_sum/m) + + (-obj).backward() # We put a minus because we want to maximize the objective, not + # minimize it. + cont_optim.step() + + # save the history every 1 epochs as a pickle + if _%1==1: + with open('gru_logs.pkl', 'wb') as file: + pickle.dump(self.history, file) + - self.history.append((policy, reward)) if __name__=='__main__': diff --git a/MetaAugment/autoaugment_learners/randomsearch_learner.py b/MetaAugment/autoaugment_learners/randomsearch_learner.py index e82f6aba18b14941f3066632f0343dd1df49f285..a5e971c13ef4ef490ed7c5b413949ff00b6e7c00 100644 --- a/MetaAugment/autoaugment_learners/randomsearch_learner.py +++ b/MetaAugment/autoaugment_learners/randomsearch_learner.py @@ -5,6 +5,8 @@ import MetaAugment.child_networks as cn from MetaAugment.autoaugment_learners.aa_learner import aa_learner from pprint import pprint +import matplotlib.pyplot as plt +import pickle @@ -84,7 +86,7 @@ class randomsearch_learner(aa_learner): fun_p_m[random_fun] = 1 fun_p_m[-2] = np.random.uniform() # 0<prob<1 - fun_p_m[-1] = np.random.uniform() * (self.m_bins-1) # 0<mag<9 + fun_p_m[-1] = np.random.uniform() * (self.m_bins-0.0000001) - 0.4999999 # -0.5<mag<9.5 return fun_p_m @@ -129,7 +131,7 @@ class randomsearch_learner(aa_learner): 3. <save how good the policy is in a list/dictionary> ''' # test out 15 random policies - for _ in range(15): + for _ in range(1500): policy = self.generate_new_policy() pprint(policy) @@ -139,9 +141,41 @@ class randomsearch_learner(aa_learner): self.history.append((policy, reward)) + # save the history every 10 epochs as a pickle + if _%10==1: + with open('randomsearch_logs.pkl', 'wb') as file: + pickle.dump(self.history, file) + -if __name__=='__main__': + def demo_plot(self, train_dataset, test_dataset, child_network_architecture, toy_flag, n=50): + ''' + I made this to plot a couple of accuracy graphs to help manually tune my gradient + optimizer hyperparameters. + ''' + acc_lists = [] + + # This is dummy code + # test out 15 random policies + for _ in range(n): + policy = self.generate_new_policy() + + pprint(policy) + child_network = child_network_architecture() + reward, acc_list = self.test_autoaugment_policy(policy, child_network, train_dataset, + test_dataset, toy_flag, logging=True) + + self.history.append((policy, reward)) + acc_lists.append(acc_list) + for acc_list in acc_lists: + plt.plot(acc_list) + plt.title('I ran 50 random policies to see if there is any sign of \ + catastrophic failure during training') + plt.show() + plt.savefig('random_policies') + + +if __name__=='__main__': # We can initialize the train_dataset with its transform as None. # Later on, we will change this object's transform attribute to the policy # that we want to test @@ -154,7 +188,7 @@ if __name__=='__main__': train=False, download=True, transform=torchvision.transforms.ToTensor()) child_network = cn.lenet - - rs_learner = randomsearch_learner(discrete_p_m=False) + rs_learner = randomsearch_learner(discrete_p_m=True) rs_learner.learn(train_dataset, test_dataset, child_network, toy_flag=True) + # rs_learner.demo_plot(train_dataset, test_dataset, child_network, toy_flag=True) pprint(rs_learner.history) \ No newline at end of file diff --git a/MetaAugment/main.py b/MetaAugment/main.py index 5b0e04e47202272e25df81dabf84a39ed7050a1a..b39b4a21658c63d03e4dd6b1d251d4587546e633 100644 --- a/MetaAugment/main.py +++ b/MetaAugment/main.py @@ -1,6 +1,5 @@ import numpy as np import torch -torch.manual_seed(0) import torch.nn as nn import torch.optim as optim import torchvision diff --git a/randomsearch_logs.pkl b/randomsearch_logs.pkl new file mode 100644 index 0000000000000000000000000000000000000000..b475be1198d4b25e8aa6f715e1f31d6945c210ff Binary files /dev/null and b/randomsearch_logs.pkl differ