diff --git a/MetaAugment/autoaugment_learners/autoaugment.py b/MetaAugment/autoaugment_learners/autoaugment.py
index e5ee9f5e93b5799bfe3fd61430d6e6e46c3e2d0f..977946bca32f8273de4c03104fd0f49f2a61f2a9 100644
--- a/MetaAugment/autoaugment_learners/autoaugment.py
+++ b/MetaAugment/autoaugment_learners/autoaugment.py
@@ -409,6 +409,7 @@ if __name__=='__main__':
     from MetaAugment.main import *
     import MetaAugment.child_networks as cn
     import torchvision.transforms as transforms
+    from torchvision.transforms import functional as F, InterpolationMode
 
     batch_size = 32
     n_samples = 0.005
@@ -435,7 +436,7 @@ if __name__=='__main__':
         aa_transform.policies = policies
 
         train_transform = transforms.Compose([
-                                                aa_transform(),
+                                                aa_transform,
                                                 transforms.ToTensor()
                                             ])
 
@@ -455,4 +456,4 @@ if __name__=='__main__':
     
 
     test_autoaugment_policy(policies1)
-    test_autoaugment_policy(policies2)
+    test_autoaugment_policy(policies2)
\ No newline at end of file