diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py index 75f01f8c03202b0842af52165dff16666bed6e58..29b5d29a0ccb54d5525312e157cd320af735b96a 100644 --- a/MetaAugment/autoaugment_learners/aa_learner.py +++ b/MetaAugment/autoaugment_learners/aa_learner.py @@ -101,7 +101,7 @@ class aa_learner: self.history = [] self.augmentation_space = [x for x in augmentation_space if x not in exclude_method] self.fun_num = len(augmentation_space) - self.op_tensor_length = self.fun_num +p_bins+m_bins if discrete_p_m else self.fun_num +2 + self.op_tensor_length = self.fun_num + p_bins + m_bins if discrete_p_m else self.fun_num +2 def translate_operation_tensor(self, operation_tensor, return_log_prob=False, argmax=False):