From 9db06748640e7bec3509b510358328248c9d7fb6 Mon Sep 17 00:00:00 2001
From: Max Ramsay King <maxramsayking@gmail.com>
Date: Mon, 18 Apr 2022 01:58:13 -0700
Subject: [PATCH] added fail_dataset.html that triggers if one of the files in
 the upload dataset doesnt follow the form class_{x}

---
 backend/auto_augmentation/progress.py            | 16 ++++++++++------
 .../templates/fail_dataset.html                  | 10 ++++++++++
 backend/progress.html                            |  8 --------
 3 files changed, 20 insertions(+), 14 deletions(-)
 create mode 100644 backend/auto_augmentation/templates/fail_dataset.html
 delete mode 100644 backend/progress.html

diff --git a/backend/auto_augmentation/progress.py b/backend/auto_augmentation/progress.py
index b1ea73e7..d63688d6 100644
--- a/backend/auto_augmentation/progress.py
+++ b/backend/auto_augmentation/progress.py
@@ -46,24 +46,28 @@ def response():
         num_sub_policies = 5  # fix number of sub-policies in a policy
         iterations = 5      # total iterations, should be more than the number of policies
         IsLeNet = request.form.get("network_selection")   # using LeNet or EasyNet or SimpleNet ->> default 
-        print("HERHERHERHEHR")
-        
+
         # if user upload datasets and networks, save them in the database
         if ds == 'Other':
             ds_folder = request.files['dataset_upload']
             ds_name_zip = ds_folder.filename
+            ds_name = ds_name_zip.split('.')[0]
             ds_folder.save('./MetaAugment/datasets/'+ ds_name_zip)
             with zipfile.ZipFile('./MetaAugment/datasets/'+ ds_name_zip, 'r') as zip_ref:
                 zip_ref.extractall('./MetaAugment/datasets/upload_dataset/')
-                print("zip_ref name: ", zip_ref.namelist())
-            ds_name = ds_name_zip.split('.')[0]
-            print("DATASET NAMe: ", ds_name_zip)
             os.remove(f'./MetaAugment/datasets/{ds_name_zip}')
 
         else: 
             ds_name = None
 
-        
+        for (dirpath, dirnames, filenames) in os.walk(f'./MetaAugment/datasets/upload_dataset/{ds_name}/'):
+            for dirname in dirnames:
+                if dirname[0:6] != 'class_':
+                    return render_template("fail_dataset.html")
+                else:
+                    pass
+
+
         if IsLeNet == 'Other':
             childnetwork = request.files['network_upload']
             childnetwork.save('./MetaAugment/child_networks/'+childnetwork.filename)
diff --git a/backend/auto_augmentation/templates/fail_dataset.html b/backend/auto_augmentation/templates/fail_dataset.html
new file mode 100644
index 00000000..53c67ea3
--- /dev/null
+++ b/backend/auto_augmentation/templates/fail_dataset.html
@@ -0,0 +1,10 @@
+<!doctype html>
+<html>
+
+<h1>Dataset failure!</h1>
+
+<body>
+    Your dataset format is not in accordance with PyTorch requirements. Please see  <a href="https://pytorch.org/vision/main/generated/torchvision.datasets.DatasetFolder.html#torchvision.datasets.DatasetFolder">here</a> for guidance. 
+
+</body>
+</html>
\ No newline at end of file
diff --git a/backend/progress.html b/backend/progress.html
deleted file mode 100644
index 9bad71aa..00000000
--- a/backend/progress.html
+++ /dev/null
@@ -1,8 +0,0 @@
-{% extends "structure.html" %}
-{% block title%}Home{% endblock %}
-{% block body %}
-<h1>Loading</h1>
-      <progress value = "65" max = "100"/>
-
-{% endblock %}
-
-- 
GitLab