diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py index bed50f1420f385161216c32295b9227137506e6e..716c285731c20f107fcbebabe78b2184bff3fea5 100644 --- a/MetaAugment/autoaugment_learners/aa_learner.py +++ b/MetaAugment/autoaugment_learners/aa_learner.py @@ -177,7 +177,7 @@ class aa_learner: mag = torch.multinomial(mag_t, 1).item() # 0 <= m <= 9 function = augmentation_space[fun_idx][0] - prob = prob_idx/10 + prob = prob_idx/self.p_bins indices = (fun_idx, prob_idx, mag) @@ -207,8 +207,8 @@ class aa_learner: function = augmentation_space[fun_idx][0] - assert 0 <= prob <= 1 - assert 0 <= mag <= self.m_bins-1 + assert 0 <= prob <= 1, prob + assert 0 <= mag <= self.m_bins-1, (mag, self.m_bins) # if the image function does not require a magnitude, we set the magnitude to None if augmentation_space[fun_idx][1] == True: # if the image function has a magnitude @@ -335,6 +335,8 @@ class aa_learner: if isinstance(child_network_architecture, types.FunctionType): child_network = child_network_architecture() + elif isinstance(child_network_architecture, type): + child_network = child_network_architecture() elif isinstance(child_network_architecture, torch.nn.Module): child_network = copy.deepcopy(child_network_architecture) else: