Skip to content
Snippets Groups Projects
Commit 76afa805 authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

update aa_learners

parent 30e5131c
No related branches found
No related tags found
No related merge requests found
......@@ -177,7 +177,7 @@ class aa_learner:
mag = torch.multinomial(mag_t, 1).item() # 0 <= m <= 9
function = augmentation_space[fun_idx][0]
prob = prob_idx/self.p_bins
prob = prob_idx/(self.p_bins-1)
indices = (fun_idx, prob_idx, mag)
......
......@@ -416,15 +416,6 @@ if __name__=='__main__':
import MetaAugment.child_networks as cn
import torchvision.transforms as transforms
# If you get rid of this nextimport, the whole thing doesn't work... By the way this import also
# exists on the top of this document.
# I think this is because "import torchvision.transforms as transforms" overrides the import at
# the top of this file and does some funny stuff... Anyways we need to call this import again to get
# rid of the bug.
from torchvision.transforms import functional as F, InterpolationMode
subpolicies1 = [
(("Invert", 0.8, None), ("Contrast", 0.2, 6)),
......
......@@ -60,7 +60,8 @@ class gru_learner(aa_learner):
early_stop_num=20,
# GRU-specific attributes that aren't in all other aa_learners's
alpha=0.2,
cont_mb_size=8):
cont_mb_size=4,
cont_lr=0.1):
"""
Args:
alpha (float, optional): Exploration parameter. It is multiplied to
......@@ -92,11 +93,13 @@ class gru_learner(aa_learner):
# GRU-specific attributes that aren't in general aa_learner's
self.alpha = alpha
self.cont_mb_size = cont_mb_size
self.b = 0.5 # b is the running exponential mean of the rewards, used for training stability
# (see section 3.2 of https://arxiv.org/abs/1611.01578)
# CONTROLLER (GRU NETWORK) SETTINGS
self.controller = RNNModel(mode='GRU', output_size=self.op_tensor_length,
num_layers=2, bias=True)
self.cont_optim = torch.optim.SGD(self.controller.parameters(), lr=1e-2)
self.cont_optim = torch.optim.SGD(self.controller.parameters(), lr=cont_lr)
self.softmax = torch.nn.Softmax(dim=0)
......@@ -181,9 +184,6 @@ class gru_learner(aa_learner):
child_network_architecture,
iterations=15,):
b = 0.5 # b is the running exponential mean of the rewards, used for training stability
# (see section 3.2 of https://arxiv.org/abs/1611.01578)
for _ in range(iterations):
self.cont_optim.zero_grad()
......@@ -209,20 +209,14 @@ class gru_learner(aa_learner):
self.history.append((policy, reward))
# gradient accumulation
obj += (reward-b)*log_prob
obj += (reward-self.b)*log_prob
# update running mean of rewards
b = 0.7*b + 0.3*(mb_rewards_sum/self.cont_mb_size)
self.b = 0.7*self.b + 0.3*(mb_rewards_sum/self.cont_mb_size)
(-obj).backward() # We put a minus because we want to maximize the objective, not
# minimize it.
self.cont_optim.step()
# # save the history every 1 epochs as a pickle
# with open('gru_logs.pkl', 'wb') as file:
# pickle.dump(self.history, file)
# with open('gru_learner.pkl', 'wb') as file:
# pickle.dump(self, file)
......@@ -253,8 +247,8 @@ if __name__=='__main__':
sp_num=7,
toy_flag=True,
toy_size=0.01,
batch_size=4,
learning_rate=0.05,
batch_size=32,
learning_rate=0.1,
max_epochs=float('inf'),
early_stop_num=35,
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment