From 86e0fc8941d2683942fa1c5a59360a07bf9aa878 Mon Sep 17 00:00:00 2001 From: Mia Wang <yw21218@ic.ac.uk> Date: Wed, 13 Apr 2022 23:03:21 +0100 Subject: [PATCH] connect ds and childnetwork selection to flask --- app.py | 7 ++--- auto_augmentation/progress.py | 44 ++++++++++++++++++++++++++- auto_augmentation/templates/home.html | 32 +++++++++++-------- 3 files changed, 65 insertions(+), 18 deletions(-) diff --git a/app.py b/app.py index 35d7b504..e0f2a3ca 100644 --- a/app.py +++ b/app.py @@ -1,12 +1,9 @@ from flask import Flask from auto_augmentation import create_app import os + app = create_app() port = int(os.environ.get("PORT", 5000)) - -# if __name__ == '__main__': -# app.run(host='0.0.0.0',port=port) - if __name__ == '__main__': - app.run(debug=True) \ No newline at end of file + app.run(host='0.0.0.0',port=port) \ No newline at end of file diff --git a/auto_augmentation/progress.py b/auto_augmentation/progress.py index b95acdc1..03d33fad 100644 --- a/auto_augmentation/progress.py +++ b/auto_augmentation/progress.py @@ -1,9 +1,51 @@ from flask import Blueprint, request, render_template, flash, send_file import subprocess +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torch.utils.data as data_utils +import torchvision +import torchvision.datasets as datasets + +from matplotlib import pyplot as plt +from numpy import save, load +from tqdm import trange +torch.manual_seed(0) +# import agents and its functions +from MetaAugment import UCB1_JC + bp = Blueprint("progress", __name__) @bp.route("/user_input", methods=["GET", "POST"]) def response(): - + + # hyperparameters to change + batch_size = 32 # size of batch the inner NN is trained with + learning_rate = 1e-1 # fix learning rate + ds = request.args["dataset_selection"] # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100) + toy_size = 0.02 # total propeortion of training and test set we use + max_epochs = 100 # max number of epochs that is run if early stopping is not hit + early_stop_num = 10 # max number of worse validation scores before early stopping is triggered + num_policies = 5 # fix number of policies + num_sub_policies = 5 # fix number of sub-policies in a policy + iterations = 100 # total iterations, should be more than the number of policies + IsLeNet = request.args["network_selection"] # using LeNet or EasyNet or SimpleNet ->> default + + print(f'@@@@@ dataset is: {ds}, network is :{IsLeNet}') + + # generate random policies at start + policies = UCB1_JC.generate_policies(num_policies, num_sub_policies) + + q_values, best_q_values = UCB1_JC.run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet) + + plt.plot(best_q_values) + + best_q_values = np.array(best_q_values) + save('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), best_q_values) + #best_q_values = load('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), allow_pickle=True) + + return render_template("progress.html") \ No newline at end of file diff --git a/auto_augmentation/templates/home.html b/auto_augmentation/templates/home.html index a45d6065..7e4fc067 100644 --- a/auto_augmentation/templates/home.html +++ b/auto_augmentation/templates/home.html @@ -12,16 +12,24 @@ <!-- dataset radio button --> Or you can select a dataset from our database: <br> <input type="radio" id="dataset1" - name="dataset_selection" value="MINIST"> - <label for="dataset1">MINIST dataset</label><br> + name="dataset_selection" value="MNIST"> + <label for="dataset1">MNIST dataset</label><br> <input type="radio" id="dataset2" - name="dataset_selection" value="IMGNET"> - <label for="dataset2">IMGNET dataset</label><br> + name="dataset_selection" value="KMNIST"> + <label for="dataset2">KMNIST dataset</label><br> <input type="radio" id="dataset3" - name="dataset_selection" value="dataset3"> - <label for="dataset3">dataset3</label><br><br> + name="dataset_selection" value="FashionMNIST"> + <label for="dataset3">FashionMNIST dataset</label><br> + + <input type="radio" id="dataset4" + name="dataset_selection" value="CIFAR10"> + <label for="dataset4">CIFAR10 dataset</label><br> + + <input type="radio" id="dataset5" + name="dataset_selection" value="CIFAR100"> + <label for="dataset5">CIFAR100 dataset</label><br><br> <!-- --------------------------------------------------------------- --> @@ -35,16 +43,16 @@ <!-- network selection --> Or you can select a dataset from our database: <br> <input type="radio" id="network1" - name="network_selection" value="EasyNet"> - <label for="dataset1">EasyNet</label><br> + name="network_selection" value="LeNet"> + <label for="network1">LeNet</label><br> <input type="radio" id="network2" - name="network_selection" value="LeNet"> - <label for="dataset2">LeNet</label><br> + name="network_selection" value="EasyNet"> + <label for="network2">EasyNet</label><br> <input type="radio" id="network3" - name="network_selection" value="AlexNet"> - <label for="dataset3">AlexNet</label><br><br> + name="network_selection" value="SimpleNet"> + <label for="network3">SimpleNet</label><br><br> -- GitLab