Skip to content
Snippets Groups Projects
Commit 2420ae80 authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

Round up gru_learner's prob and mag values of operations

parent 1d782627
No related branches found
No related tags found
No related merge requests found
......@@ -93,17 +93,17 @@ class aa_learner:
assert mag_t.shape==(self.m_bins,), f'{mag_t.shape} != {self.m_bins}'
if argmax==True:
fun = torch.argmax(fun_t)
prob = torch.argmax(prob_t) # 0 <= p <= 10
mag = torch.argmax(mag_t) # 0 <= m <= 9
fun = torch.argmax(fun_t).item()
prob = torch.argmax(prob_t).item() # 0 <= p <= 10
mag = torch.argmax(mag_t).item() # 0 <= m <= 9
elif argmax==False:
# we need these to add up to 1 to be valid pdf's of multinomials
assert torch.sum(fun_t).isclose(torch.ones(1)), torch.sum(fun_t)
assert torch.sum(prob_t).isclose(torch.ones(1)), torch.sum(prob_t)
assert torch.sum(mag_t).isclose(torch.ones(1)), torch.sum(mag_t)
fun = torch.multinomial(fun_t, 1) # 0 <= fun <= self.fun_num-1
prob = torch.multinomial(prob_t, 1) # 0 <= p <= 10
mag = torch.multinomial(mag_t, 1) # 0 <= m <= 9
fun = torch.multinomial(fun_t, 1).item() # 0 <= fun <= self.fun_num-1
prob = torch.multinomial(prob_t, 1).item() # 0 <= p <= 10
mag = torch.multinomial(mag_t, 1).item() # 0 <= m <= 9
function = augmentation_space[fun][0]
prob = prob/10
......@@ -111,9 +111,9 @@ class aa_learner:
# if probability and magnitude are represented as continuous variables
else:
fun_t, p, m = operation_tensor.split([self.fun_num, 1, 1])
p = operation_tensor[-2].item() # 0 < p < 1
m = operation_tensor[-1].item() # 0 < m < 9
fun_t, prob, mag = operation_tensor.split([self.fun_num, 1, 1])
# 0 =< prob =< 1
# 0 =< mag =< 9
# make sure the shape is correct
assert fun_t.shape==(self.fun_num,), f'{fun_t.shape} != {self.fun_num}'
......@@ -124,11 +124,9 @@ class aa_learner:
assert torch.sum(fun_t).isclose(torch.ones(1))
fun = torch.multinomial(fun_t, 1)
function = augmentation_space[fun][0]
prob = round(p, 1) # round to nearest first decimal digit
mag = round(m) # round to nearest integer
# If argmax is False, we treat operation_tensor as a concatenation of three
# multinomial pdf's.
function = augmentation_space[fun][0]
prob = round(prob, 1) # round to nearest first decimal digit
mag = round(mag) # round to nearest integer
assert 0 <= prob <= 1
assert 0 <= mag <= self.m_bins-1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment