diff --git a/MetaAugment/autoaugment_learners/autoaugment.py b/MetaAugment/autoaugment_learners/autoaugment.py index e5ee9f5e93b5799bfe3fd61430d6e6e46c3e2d0f..977946bca32f8273de4c03104fd0f49f2a61f2a9 100644 --- a/MetaAugment/autoaugment_learners/autoaugment.py +++ b/MetaAugment/autoaugment_learners/autoaugment.py @@ -409,6 +409,7 @@ if __name__=='__main__': from MetaAugment.main import * import MetaAugment.child_networks as cn import torchvision.transforms as transforms + from torchvision.transforms import functional as F, InterpolationMode batch_size = 32 n_samples = 0.005 @@ -435,7 +436,7 @@ if __name__=='__main__': aa_transform.policies = policies train_transform = transforms.Compose([ - aa_transform(), + aa_transform, transforms.ToTensor() ]) @@ -455,4 +456,4 @@ if __name__=='__main__': test_autoaugment_policy(policies1) - test_autoaugment_policy(policies2) + test_autoaugment_policy(policies2) \ No newline at end of file