diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py
index bed50f1420f385161216c32295b9227137506e6e..716c285731c20f107fcbebabe78b2184bff3fea5 100644
--- a/MetaAugment/autoaugment_learners/aa_learner.py
+++ b/MetaAugment/autoaugment_learners/aa_learner.py
@@ -177,7 +177,7 @@ class aa_learner:
                 mag = torch.multinomial(mag_t, 1).item() # 0 <= m <= 9
 
             function = augmentation_space[fun_idx][0]
-            prob = prob_idx/10
+            prob = prob_idx/self.p_bins
 
             indices = (fun_idx, prob_idx, mag)
 
@@ -207,8 +207,8 @@ class aa_learner:
             
         function = augmentation_space[fun_idx][0]
 
-        assert 0 <= prob <= 1
-        assert 0 <= mag <= self.m_bins-1
+        assert 0 <= prob <= 1, prob
+        assert 0 <= mag <= self.m_bins-1, (mag, self.m_bins)
         
         # if the image function does not require a magnitude, we set the magnitude to None
         if augmentation_space[fun_idx][1] == True: # if the image function has a magnitude
@@ -335,6 +335,8 @@ class aa_learner:
         
         if isinstance(child_network_architecture, types.FunctionType):
             child_network = child_network_architecture()
+        elif isinstance(child_network_architecture, type):
+            child_network = child_network_architecture()
         elif isinstance(child_network_architecture, torch.nn.Module):
             child_network = copy.deepcopy(child_network_architecture)
         else: