From 86e0fc8941d2683942fa1c5a59360a07bf9aa878 Mon Sep 17 00:00:00 2001
From: Mia Wang <yw21218@ic.ac.uk>
Date: Wed, 13 Apr 2022 23:03:21 +0100
Subject: [PATCH] connect ds and childnetwork selection to flask

---
 app.py                                |  7 ++---
 auto_augmentation/progress.py         | 44 ++++++++++++++++++++++++++-
 auto_augmentation/templates/home.html | 32 +++++++++++--------
 3 files changed, 65 insertions(+), 18 deletions(-)

diff --git a/app.py b/app.py
index 35d7b504..e0f2a3ca 100644
--- a/app.py
+++ b/app.py
@@ -1,12 +1,9 @@
 from flask import Flask
 from auto_augmentation import create_app
 import os
+
 app = create_app()
 port = int(os.environ.get("PORT", 5000))
 
-
-# if __name__ == '__main__':
-#     app.run(host='0.0.0.0',port=port)
-
 if __name__ == '__main__':
-    app.run(debug=True)
\ No newline at end of file
+    app.run(host='0.0.0.0',port=port)
\ No newline at end of file
diff --git a/auto_augmentation/progress.py b/auto_augmentation/progress.py
index b95acdc1..03d33fad 100644
--- a/auto_augmentation/progress.py
+++ b/auto_augmentation/progress.py
@@ -1,9 +1,51 @@
 from flask import Blueprint, request, render_template, flash, send_file
 import subprocess
 
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+import torch.utils.data as data_utils
+import torchvision
+import torchvision.datasets as datasets
+
+from matplotlib import pyplot as plt
+from numpy import save, load
+from tqdm import trange
+torch.manual_seed(0)
+# import agents and its functions
+from MetaAugment import UCB1_JC  
+
 bp = Blueprint("progress", __name__)
 
 @bp.route("/user_input", methods=["GET", "POST"])
 def response():
-    
+
+    # hyperparameters to change
+    batch_size = 32       # size of batch the inner NN is trained with
+    learning_rate = 1e-1  # fix learning rate
+    ds = request.args["dataset_selection"]        # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)
+    toy_size = 0.02       # total propeortion of training and test set we use
+    max_epochs = 100      # max number of epochs that is run if early stopping is not hit
+    early_stop_num = 10   # max number of worse validation scores before early stopping is triggered
+    num_policies = 5      # fix number of policies
+    num_sub_policies = 5  # fix number of sub-policies in a policy
+    iterations = 100      # total iterations, should be more than the number of policies
+    IsLeNet = request.args["network_selection"]   # using LeNet or EasyNet or SimpleNet ->> default 
+
+    print(f'@@@@@ dataset is: {ds}, network is :{IsLeNet}')
+
+    # generate random policies at start
+    policies = UCB1_JC.generate_policies(num_policies, num_sub_policies)
+
+    q_values, best_q_values = UCB1_JC.run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet)
+
+    plt.plot(best_q_values)
+
+    best_q_values = np.array(best_q_values)
+    save('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), best_q_values)
+    #best_q_values = load('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), allow_pickle=True)
+
+
     return render_template("progress.html")
\ No newline at end of file
diff --git a/auto_augmentation/templates/home.html b/auto_augmentation/templates/home.html
index a45d6065..7e4fc067 100644
--- a/auto_augmentation/templates/home.html
+++ b/auto_augmentation/templates/home.html
@@ -12,16 +12,24 @@
   <!-- dataset radio button -->
   Or you can select a dataset from our database: <br>
   <input type="radio" id="dataset1"
-    name="dataset_selection" value="MINIST">
-  <label for="dataset1">MINIST dataset</label><br>
+    name="dataset_selection" value="MNIST">
+  <label for="dataset1">MNIST dataset</label><br>
 
   <input type="radio" id="dataset2"
-    name="dataset_selection" value="IMGNET">
-  <label for="dataset2">IMGNET dataset</label><br>
+    name="dataset_selection" value="KMNIST">
+  <label for="dataset2">KMNIST dataset</label><br>
 
   <input type="radio" id="dataset3"
-  name="dataset_selection" value="dataset3">
-  <label for="dataset3">dataset3</label><br><br> 
+    name="dataset_selection" value="FashionMNIST">
+  <label for="dataset3">FashionMNIST dataset</label><br>
+
+  <input type="radio" id="dataset4"
+  name="dataset_selection" value="CIFAR10">
+  <label for="dataset4">CIFAR10 dataset</label><br>
+
+  <input type="radio" id="dataset5"
+  name="dataset_selection" value="CIFAR100">
+  <label for="dataset5">CIFAR100 dataset</label><br><br> 
 
 <!-- --------------------------------------------------------------- -->
 
@@ -35,16 +43,16 @@
   <!-- network selection -->
   Or you can select a dataset from our database: <br>
   <input type="radio" id="network1"
-    name="network_selection" value="EasyNet">
-  <label for="dataset1">EasyNet</label><br>
+    name="network_selection" value="LeNet">
+  <label for="network1">LeNet</label><br>
 
   <input type="radio" id="network2"
-    name="network_selection" value="LeNet">
-  <label for="dataset2">LeNet</label><br>
+  name="network_selection" value="EasyNet">
+  <label for="network2">EasyNet</label><br>
 
   <input type="radio" id="network3"
-  name="network_selection" value="AlexNet">
-  <label for="dataset3">AlexNet</label><br><br> 
+  name="network_selection" value="SimpleNet">
+  <label for="network3">SimpleNet</label><br><br> 
 
 
 
-- 
GitLab