From 083be6414a4d4b4dfb7aaa97858cb701bd9d931c Mon Sep 17 00:00:00 2001
From: Sun Jin Kim <sk2521@ic.ac.uk>
Date: Wed, 16 Feb 2022 10:43:26 +0000
Subject: [PATCH] move all datasets to /datasets/mnist folder. MAKE SURE TO USE
 PARAMETER DOWNLOAD=TRUE TO DOWNLOAD MNIST DATA AGAIN

---
 MetaAugment/CP2_Max.py                      |   8 ++++----
 MetaAugment/UCB1_JC.ipynb                   |   4 ++--
 MetaAugment/__pycache__/main.cpython-38.pyc | Bin 1949 -> 2843 bytes
 MetaAugment/main.py                         |   7 ++++---
 4 files changed, 10 insertions(+), 9 deletions(-)

diff --git a/MetaAugment/CP2_Max.py b/MetaAugment/CP2_Max.py
index df325504..1ed7b470 100644
--- a/MetaAugment/CP2_Max.py
+++ b/MetaAugment/CP2_Max.py
@@ -90,8 +90,8 @@ def train_model(transform_idx, p):
     batch_size = 32
     n_samples = 0.005
 
-    train_dataset = datasets.MNIST(root='./MetaAugment/train', train=True, download=False, transform=transform_train)
-    test_dataset = datasets.MNIST(root='./MetaAugment/test', train=False, download=False, transform=torchvision.transforms.ToTensor())
+    train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, transform=transform_train)
+    test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False, transform=torchvision.transforms.ToTensor())
 
     # create toy dataset from above uploaded data
     train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, 0.01)
@@ -142,8 +142,8 @@ def callback_generation(ga_instance):
 # ORGANISING DATA
 
 # transforms = ['RandomResizedCrop', 'RandomHorizontalFlip', 'RandomVerticalCrop', 'RandomRotation']
-train_dataset = datasets.MNIST(root='./MetaAugment/train', train=True, download=True, transform=torchvision.transforms.ToTensor())
-test_dataset = datasets.MNIST(root='./MetaAugment/test', train=False, download=True, transform=torchvision.transforms.ToTensor())
+train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=True, transform=torchvision.transforms.ToTensor())
+test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=True, transform=torchvision.transforms.ToTensor())
 n_samples = 0.02
 # shuffle and take first n_samples  %age of training dataset
 shuffled_train_dataset = torch.utils.data.Subset(train_dataset, torch.randperm(len(train_dataset)).tolist())
diff --git a/MetaAugment/UCB1_JC.ipynb b/MetaAugment/UCB1_JC.ipynb
index ac88adec..d3bbda39 100644
--- a/MetaAugment/UCB1_JC.ipynb
+++ b/MetaAugment/UCB1_JC.ipynb
@@ -144,8 +144,8 @@
         "            torchvision.transforms.ToTensor()])\n",
         "\n",
         "        # open data and apply these transformations\n",
-        "        train_dataset = datasets.MNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)\n",
-        "        test_dataset = datasets.MNIST(root='./MetaAugment/test', train=False, download=True, transform=torchvision.transforms.ToTensor())\n",
+        "        train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=True, transform=transform)\n",
+        "        test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=True, transform=torchvision.transforms.ToTensor())\n",
         "\n",
         "        # create toy dataset from above uploaded data\n",
         "        train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size)\n",
diff --git a/MetaAugment/__pycache__/main.cpython-38.pyc b/MetaAugment/__pycache__/main.cpython-38.pyc
index 213ba2380b9f4137ea09f7c7fbea31a8aca33345..73b244741637ddc13240976d7760041aa3ac49ea 100644
GIT binary patch
delta 1156
zcmZuw-EY)J5clrcKA-d11k#YE(n_d8r=paPR|q)~L6K0Z7V%;sRvYj7&dBka-8E4n
zANP<vgQ9z^kl-IcAl~{1@X8~S{}Z0tSzo?db);F(?9BXTcK7EubAK+)KAxEg9gwfu
z+m>@jv#aRU(a$e_47tzShbUUO@9=<k4jtY(z|kT{kDTt*#}5g*R4?g4nIxIu)F_tb
zlrzJ$F!lHs{L^xM@hmE|ISXdEhnC0hF-I$(7F*8{S{=W1zxYcqg8F{^$op$**_zFA
zX}s(|9o+^$2K|>{@HRm8DlEc@5O*TL<oM8uTD-+w_`2MKY5h9v_2@&vR4!C(>rTYl
zR_}spdkGb)9O`ffs?b&!)nP1iGg30kgzDnj)gnvO9Y~yvH<=OKI-oSoQ$y(=7OLri
z_7~6EuARhV7ZOgS14bLP>HjMj91Y#d)ti;`%z2C6IYZCjYJ@7UY8|wm;GsWkn^uJ@
zziJ!Tc*Z~Qhe4Cvu7b*YhmrH#H35u|_rHbYST=>!WOu*}Nj1?z<Rs0>y|iB{L9Scd
z3<T+wM)u17j>ruuSRvF}cnmNi<cbpd3ekq;T$4m90=~>jlggaPq^akrj+HQ}X6JV`
z+<H79x?r(bBM%ITWtnlZEl!7H#FRxQu9Dr<3`mkHZ3xVu_Mk_fDW2w|J}H#kX4`3&
zn!T%V0(cdE8t#)F)`z<<`5EZFZl`&sC@f?_3z?<yUNl!IQ7|P=sW>&E?T+LkQ#auH
z)lFDxPZTm9)P7k2p+z(`>PyLwK>&Bd=CqHwuC9X@1xD`CBxC*PyK}iTk!8+nf6`t}
z2yTmgC3fLt|K*UES&BY}3_S~r({{0oLOegh3wR#y&!1ycleyLPEbnr*BPgvyO5q@7
zX6<Q8pO-9~L{ty3lx5eshNsL`>k#-IR9j*OC}=f1U95hv753-qbT6ni>s%hc_CNo&
oVom!?Hi`d-SqHZ!qpTxH)3yMBQHZ9|0%%~r?O*}6{|>MH0}W>;L;wH)

delta 232
zcmbO&HkY3-l$V!_0SMZQSd;#8PvnzfY?!EB#?O($mcrh`7$ua#7|fu_xp7|_qhyp=
zaYkucT25+8d`VGaW?p<sVo73gYDv^&TP9VBC{cuLYH<ldvWm%(QDpKSrW8i8&5F!f
zjEoYKdsxo12>~rF5}usMx{F%`!~_$flhxSbBrSn_O{OAokgx=ZEnK7pV#`lv<2t4*
i24aGgK(ruf28!L{u*uC&Da}c>V+8VwB_{vlk_G@ppE!R2

diff --git a/MetaAugment/main.py b/MetaAugment/main.py
index c68f4b1a..1f2de939 100644
--- a/MetaAugment/main.py
+++ b/MetaAugment/main.py
@@ -82,7 +82,7 @@ class AA_Learner:
     def __init__(self, controller):
         self.controller = controller
 
-    def learn(self, dataset, child_network, toy_flag):
+    def learn(self, train_dataset, test_dataset, child_network, toy_flag):
         '''
         Deos what is seen in Figure 1 in the AutoAugment paper.
 
@@ -94,9 +94,10 @@ class AA_Learner:
         while not good_policy_found:
             policy = self.controller.pop_policy()
 
-            train_loader, test_loader = prepare_dataset(dataset, policy, toy_flag)
+            train_loader, test_loader = create_toy(train_dataset, test_dataset,
+                                                    batch_size=32, n_samples=0.005)
 
-            reward = train_model(child_network, train_loader, test_loader, sgd, cost, epoch)
+            reward = train_child_network(child_network, train_loader, test_loader, sgd, cost, epoch)
 
             self.controller.update(reward, policy)
         
-- 
GitLab