From d2827f861b7721cc5b9a445ab7e174bb81660ae1 Mon Sep 17 00:00:00 2001 From: Sun Jin Kim <sk2521@ic.ac.uk> Date: Fri, 22 Apr 2022 18:34:23 +0100 Subject: [PATCH] change save directory of cifar from fashinmnist to cifar --- MetaAugment/UCB1_JC_py.py | 8 ++++---- auto_augmentation/progress.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/MetaAugment/UCB1_JC_py.py b/MetaAugment/UCB1_JC_py.py index 27322463..ad43c31e 100644 --- a/MetaAugment/UCB1_JC_py.py +++ b/MetaAugment/UCB1_JC_py.py @@ -144,11 +144,11 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl train_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=True, transform=transform) test_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/test', train=False, download=True, transform=transform) elif ds == "CIFAR10": - train_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=True, transform=transform) - test_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/fashionmnist/test', train=False, download=True, transform=transform) + train_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/train', train=True, download=True, transform=transform) + test_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/test', train=False, download=True, transform=transform) elif ds == "CIFAR100": - train_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=True, transform=transform) - test_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/fashionmnist/test', train=False, download=True, transform=transform) + train_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/train', train=True, download=True, transform=transform) + test_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/test', train=False, download=True, transform=transform) elif ds == 'Other': dataset = datasets.ImageFolder('./MetaAugment/datasets/'+ ds_name, transform=transform) len_train = int(0.8*len(dataset)) diff --git a/auto_augmentation/progress.py b/auto_augmentation/progress.py index cc260001..55760cc9 100644 --- a/auto_augmentation/progress.py +++ b/auto_augmentation/progress.py @@ -102,12 +102,12 @@ def response(): test_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/test', train=False, download=download, transform=torchvision.transforms.ToTensor()) elif ds == "CIFAR10": - train_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=download) - test_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/fashionmnist/test', train=False, + train_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/train', train=True, download=download) + test_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/test', train=False, download=download, transform=torchvision.transforms.ToTensor()) elif ds == "CIFAR100": - train_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=download) - test_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/fashionmnist/test', train=False, + train_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/train', train=True, download=download) + test_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/test', train=False, download=download, transform=torchvision.transforms.ToTensor()) elif ds == 'Other': dataset = datasets.ImageFolder('./MetaAugment/datasets/'+ ds_name) -- GitLab