From 2420ae8079e3ed4acf7324eb38b3828fe69c9538 Mon Sep 17 00:00:00 2001
From: Sun Jin Kim <sk2521@ic.ac.uk>
Date: Fri, 8 Apr 2022 11:53:34 +0900
Subject: [PATCH] Round up gru_learner's prob and mag values of operations

---
 .../autoaugment_learners/aa_learner.py        | 26 +++++++++----------
 1 file changed, 12 insertions(+), 14 deletions(-)

diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py
index 8d1e6430..3a6b3e4c 100644
--- a/MetaAugment/autoaugment_learners/aa_learner.py
+++ b/MetaAugment/autoaugment_learners/aa_learner.py
@@ -93,17 +93,17 @@ class aa_learner:
             assert mag_t.shape==(self.m_bins,), f'{mag_t.shape} != {self.m_bins}'
 
             if argmax==True:
-                fun = torch.argmax(fun_t)
-                prob = torch.argmax(prob_t) # 0 <= p <= 10
-                mag = torch.argmax(mag_t) # 0 <= m <= 9
+                fun = torch.argmax(fun_t).item()
+                prob = torch.argmax(prob_t).item() # 0 <= p <= 10
+                mag = torch.argmax(mag_t).item() # 0 <= m <= 9
             elif argmax==False:
                 # we need these to add up to 1 to be valid pdf's of multinomials
                 assert torch.sum(fun_t).isclose(torch.ones(1)), torch.sum(fun_t)
                 assert torch.sum(prob_t).isclose(torch.ones(1)), torch.sum(prob_t)
                 assert torch.sum(mag_t).isclose(torch.ones(1)), torch.sum(mag_t)
-                fun = torch.multinomial(fun_t, 1) # 0 <= fun <= self.fun_num-1
-                prob = torch.multinomial(prob_t, 1) # 0 <= p <= 10
-                mag = torch.multinomial(mag_t, 1) # 0 <= m <= 9
+                fun = torch.multinomial(fun_t, 1).item() # 0 <= fun <= self.fun_num-1
+                prob = torch.multinomial(prob_t, 1).item() # 0 <= p <= 10
+                mag = torch.multinomial(mag_t, 1).item() # 0 <= m <= 9
 
             function = augmentation_space[fun][0]
             prob = prob/10
@@ -111,9 +111,9 @@ class aa_learner:
 
         # if probability and magnitude are represented as continuous variables
         else:
-            fun_t, p, m = operation_tensor.split([self.fun_num, 1, 1])
-            p = operation_tensor[-2].item() # 0 < p < 1
-            m = operation_tensor[-1].item() # 0 < m < 9
+            fun_t, prob, mag = operation_tensor.split([self.fun_num, 1, 1])
+            # 0 =< prob =< 1
+            # 0 =< mag =< 9
 
             # make sure the shape is correct
             assert fun_t.shape==(self.fun_num,), f'{fun_t.shape} != {self.fun_num}'
@@ -124,11 +124,9 @@ class aa_learner:
                 assert torch.sum(fun_t).isclose(torch.ones(1))
                 fun = torch.multinomial(fun_t, 1)
             
-            function = augmentation_space[fun][0]
-            prob = round(p, 1) # round to nearest first decimal digit
-            mag = round(m) # round to nearest integer
-            # If argmax is False, we treat operation_tensor as a concatenation of three
-            # multinomial pdf's.
+        function = augmentation_space[fun][0]
+        prob = round(prob, 1) # round to nearest first decimal digit
+        mag = round(mag) # round to nearest integer
 
         assert 0 <= prob <= 1
         assert 0 <= mag <= self.m_bins-1
-- 
GitLab