diff --git a/MetaAugment/UCB1_JC_py.py b/MetaAugment/UCB1_JC_py.py index 27322463dc6b679751aac4b2483f0e12f4352cdb..ad43c31e58f288de1b2280391e86c7e506cec2e2 100644 --- a/MetaAugment/UCB1_JC_py.py +++ b/MetaAugment/UCB1_JC_py.py @@ -144,11 +144,11 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl train_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=True, transform=transform) test_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/test', train=False, download=True, transform=transform) elif ds == "CIFAR10": - train_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=True, transform=transform) - test_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/fashionmnist/test', train=False, download=True, transform=transform) + train_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/train', train=True, download=True, transform=transform) + test_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/test', train=False, download=True, transform=transform) elif ds == "CIFAR100": - train_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=True, transform=transform) - test_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/fashionmnist/test', train=False, download=True, transform=transform) + train_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/train', train=True, download=True, transform=transform) + test_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/test', train=False, download=True, transform=transform) elif ds == 'Other': dataset = datasets.ImageFolder('./MetaAugment/datasets/'+ ds_name, transform=transform) len_train = int(0.8*len(dataset)) diff --git a/auto_augmentation/progress.py b/auto_augmentation/progress.py index cc260001d96319a309ba719c540f6c075ec934fd..55760cc941d8f525c934066864fd92fc09d00274 100644 --- a/auto_augmentation/progress.py +++ b/auto_augmentation/progress.py @@ -102,12 +102,12 @@ def response(): test_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/test', train=False, download=download, transform=torchvision.transforms.ToTensor()) elif ds == "CIFAR10": - train_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=download) - test_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/fashionmnist/test', train=False, + train_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/train', train=True, download=download) + test_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/test', train=False, download=download, transform=torchvision.transforms.ToTensor()) elif ds == "CIFAR100": - train_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=download) - test_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/fashionmnist/test', train=False, + train_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/train', train=True, download=download) + test_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/test', train=False, download=download, transform=torchvision.transforms.ToTensor()) elif ds == 'Other': dataset = datasets.ImageFolder('./MetaAugment/datasets/'+ ds_name)