From d2827f861b7721cc5b9a445ab7e174bb81660ae1 Mon Sep 17 00:00:00 2001
From: Sun Jin Kim <sk2521@ic.ac.uk>
Date: Fri, 22 Apr 2022 18:34:23 +0100
Subject: [PATCH] change save directory of cifar from fashinmnist to cifar

---
 MetaAugment/UCB1_JC_py.py     | 8 ++++----
 auto_augmentation/progress.py | 8 ++++----
 2 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/MetaAugment/UCB1_JC_py.py b/MetaAugment/UCB1_JC_py.py
index 27322463..ad43c31e 100644
--- a/MetaAugment/UCB1_JC_py.py
+++ b/MetaAugment/UCB1_JC_py.py
@@ -144,11 +144,11 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl
             train_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=True, transform=transform)
             test_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/test', train=False, download=True, transform=transform)
         elif ds == "CIFAR10":
-            train_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=True, transform=transform)
-            test_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/fashionmnist/test', train=False, download=True, transform=transform)
+            train_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/train', train=True, download=True, transform=transform)
+            test_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/test', train=False, download=True, transform=transform)
         elif ds == "CIFAR100":
-            train_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=True, transform=transform)
-            test_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/fashionmnist/test', train=False, download=True, transform=transform)
+            train_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/train', train=True, download=True, transform=transform)
+            test_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/test', train=False, download=True, transform=transform)
         elif ds == 'Other':
             dataset = datasets.ImageFolder('./MetaAugment/datasets/'+ ds_name, transform=transform)
             len_train = int(0.8*len(dataset))
diff --git a/auto_augmentation/progress.py b/auto_augmentation/progress.py
index cc260001..55760cc9 100644
--- a/auto_augmentation/progress.py
+++ b/auto_augmentation/progress.py
@@ -102,12 +102,12 @@ def response():
                 test_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/test', train=False,
                                                 download=download, transform=torchvision.transforms.ToTensor())
             elif ds == "CIFAR10":
-                train_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=download)
-                test_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/fashionmnist/test', train=False,
+                train_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/train', train=True, download=download)
+                test_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/test', train=False,
                                                 download=download, transform=torchvision.transforms.ToTensor())
             elif ds == "CIFAR100":
-                train_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=download)
-                test_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/fashionmnist/test', train=False,
+                train_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/train', train=True, download=download)
+                test_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/test', train=False,
                                                 download=download, transform=torchvision.transforms.ToTensor())
             elif ds == 'Other':
                 dataset = datasets.ImageFolder('./MetaAugment/datasets/'+ ds_name)
-- 
GitLab