Skip to content
Snippets Groups Projects
Commit d6f15d90 authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

cleanup

parent df39cc2d
No related branches found
No related tags found
No related merge requests found
......@@ -217,7 +217,8 @@ class aa_learner:
which is:
1. <generate a random policy>
2. <see how good that policy is>
3. <save how good the policy is in a list/dictionary>
3. <save how good the policy is in a list/dictionary and
(if applicable,) update the controller (e.g. RL agent)>
Args:
train_dataset (torchvision.dataset.vision.VisionDataset)
......@@ -234,16 +235,17 @@ class aa_learner:
"""
# This is dummy code
# test out 15 random policies
for _ in range(15):
policy = self.generate_new_policy()
# for _ in range(15):
# policy = self.generate_new_policy()
pprint(policy)
child_network = child_network_architecture()
reward = self.test_autoaugment_policy(policy, child_network, train_dataset,
test_dataset, toy_flag)
# pprint(policy)
# child_network = child_network_architecture()
# reward = self.test_autoaugment_policy(policy, child_network, train_dataset,
# test_dataset, toy_flag)
self.history.append((policy, reward))
# self.history.append((policy, reward))
def test_autoaugment_policy(self, policy, child_network, train_dataset, test_dataset,
......
......@@ -71,7 +71,8 @@ class gru_learner(aa_learner):
contains information regarding which 'image function' to use,
which value of 'probability(prob)' and 'magnitude(mag)' to use.
We run the GRU for 10 timesteps to obtain 10 of such tensors.
We run the GRU for 2*self.sp_num timesteps to obtain 2*self.sp_num
of such tensors.
We then softmax the parts of the tensor which represents the
choice of function, prob, and mag seperately, so that the
......@@ -135,11 +136,11 @@ class gru_learner(aa_learner):
return new_policy, log_prob
def learn(self, train_dataset, test_dataset, child_network_architecture, toy_flag, m=8):
def learn(self, train_dataset, test_dataset, child_network_architecture, toy_flag, mb_size=8):
# optimizer for training the GRU controller
cont_optim = torch.optim.SGD(self.controller.parameters(), lr=1e-2)
m = 8 # minibatch size
mb_size = 8 # minibatch size
b = 0.88 # b is the running exponential mean of the rewards, used for training stability
# (see section 3.2 of https://arxiv.org/abs/1611.01578)
......@@ -153,7 +154,7 @@ class gru_learner(aa_learner):
# sum up the rewards within a minibatch in order to update the running mean, 'b'
mb_rewards_sum = 0
for k in range(m):
for k in range(mb_size):
# log_prob is $\sum_{t=1}^T log(P(a_t|a_{(t-1):1};\theta_c))$, used in PPO
policy, log_prob = self.generate_new_policy()
......@@ -170,7 +171,7 @@ class gru_learner(aa_learner):
obj += (reward-b)*log_prob
# update running mean of rewards
b = 0.7*b + 0.3*(mb_rewards_sum/m)
b = 0.7*b + 0.3*(mb_rewards_sum/mb_size)
(-obj).backward() # We put a minus because we want to maximize the objective, not
# minimize it.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment