From 2420ae8079e3ed4acf7324eb38b3828fe69c9538 Mon Sep 17 00:00:00 2001 From: Sun Jin Kim <sk2521@ic.ac.uk> Date: Fri, 8 Apr 2022 11:53:34 +0900 Subject: [PATCH] Round up gru_learner's prob and mag values of operations --- .../autoaugment_learners/aa_learner.py | 26 +++++++++---------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py index 8d1e6430..3a6b3e4c 100644 --- a/MetaAugment/autoaugment_learners/aa_learner.py +++ b/MetaAugment/autoaugment_learners/aa_learner.py @@ -93,17 +93,17 @@ class aa_learner: assert mag_t.shape==(self.m_bins,), f'{mag_t.shape} != {self.m_bins}' if argmax==True: - fun = torch.argmax(fun_t) - prob = torch.argmax(prob_t) # 0 <= p <= 10 - mag = torch.argmax(mag_t) # 0 <= m <= 9 + fun = torch.argmax(fun_t).item() + prob = torch.argmax(prob_t).item() # 0 <= p <= 10 + mag = torch.argmax(mag_t).item() # 0 <= m <= 9 elif argmax==False: # we need these to add up to 1 to be valid pdf's of multinomials assert torch.sum(fun_t).isclose(torch.ones(1)), torch.sum(fun_t) assert torch.sum(prob_t).isclose(torch.ones(1)), torch.sum(prob_t) assert torch.sum(mag_t).isclose(torch.ones(1)), torch.sum(mag_t) - fun = torch.multinomial(fun_t, 1) # 0 <= fun <= self.fun_num-1 - prob = torch.multinomial(prob_t, 1) # 0 <= p <= 10 - mag = torch.multinomial(mag_t, 1) # 0 <= m <= 9 + fun = torch.multinomial(fun_t, 1).item() # 0 <= fun <= self.fun_num-1 + prob = torch.multinomial(prob_t, 1).item() # 0 <= p <= 10 + mag = torch.multinomial(mag_t, 1).item() # 0 <= m <= 9 function = augmentation_space[fun][0] prob = prob/10 @@ -111,9 +111,9 @@ class aa_learner: # if probability and magnitude are represented as continuous variables else: - fun_t, p, m = operation_tensor.split([self.fun_num, 1, 1]) - p = operation_tensor[-2].item() # 0 < p < 1 - m = operation_tensor[-1].item() # 0 < m < 9 + fun_t, prob, mag = operation_tensor.split([self.fun_num, 1, 1]) + # 0 =< prob =< 1 + # 0 =< mag =< 9 # make sure the shape is correct assert fun_t.shape==(self.fun_num,), f'{fun_t.shape} != {self.fun_num}' @@ -124,11 +124,9 @@ class aa_learner: assert torch.sum(fun_t).isclose(torch.ones(1)) fun = torch.multinomial(fun_t, 1) - function = augmentation_space[fun][0] - prob = round(p, 1) # round to nearest first decimal digit - mag = round(m) # round to nearest integer - # If argmax is False, we treat operation_tensor as a concatenation of three - # multinomial pdf's. + function = augmentation_space[fun][0] + prob = round(prob, 1) # round to nearest first decimal digit + mag = round(mag) # round to nearest integer assert 0 <= prob <= 1 assert 0 <= mag <= self.m_bins-1 -- GitLab