Skip to content
Snippets Groups Projects
Commit 86e0fc89 authored by Mia Wang's avatar Mia Wang
Browse files

connect ds and childnetwork selection to flask

parent c0d18959
No related branches found
No related tags found
No related merge requests found
from flask import Flask
from auto_augmentation import create_app
import os
app = create_app()
port = int(os.environ.get("PORT", 5000))
# if __name__ == '__main__':
# app.run(host='0.0.0.0',port=port)
if __name__ == '__main__':
app.run(debug=True)
\ No newline at end of file
app.run(host='0.0.0.0',port=port)
\ No newline at end of file
from flask import Blueprint, request, render_template, flash, send_file
import subprocess
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data_utils
import torchvision
import torchvision.datasets as datasets
from matplotlib import pyplot as plt
from numpy import save, load
from tqdm import trange
torch.manual_seed(0)
# import agents and its functions
from MetaAugment import UCB1_JC
bp = Blueprint("progress", __name__)
@bp.route("/user_input", methods=["GET", "POST"])
def response():
# hyperparameters to change
batch_size = 32 # size of batch the inner NN is trained with
learning_rate = 1e-1 # fix learning rate
ds = request.args["dataset_selection"] # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)
toy_size = 0.02 # total propeortion of training and test set we use
max_epochs = 100 # max number of epochs that is run if early stopping is not hit
early_stop_num = 10 # max number of worse validation scores before early stopping is triggered
num_policies = 5 # fix number of policies
num_sub_policies = 5 # fix number of sub-policies in a policy
iterations = 100 # total iterations, should be more than the number of policies
IsLeNet = request.args["network_selection"] # using LeNet or EasyNet or SimpleNet ->> default
print(f'@@@@@ dataset is: {ds}, network is :{IsLeNet}')
# generate random policies at start
policies = UCB1_JC.generate_policies(num_policies, num_sub_policies)
q_values, best_q_values = UCB1_JC.run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet)
plt.plot(best_q_values)
best_q_values = np.array(best_q_values)
save('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), best_q_values)
#best_q_values = load('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), allow_pickle=True)
return render_template("progress.html")
\ No newline at end of file
......@@ -12,16 +12,24 @@
<!-- dataset radio button -->
Or you can select a dataset from our database: <br>
<input type="radio" id="dataset1"
name="dataset_selection" value="MINIST">
<label for="dataset1">MINIST dataset</label><br>
name="dataset_selection" value="MNIST">
<label for="dataset1">MNIST dataset</label><br>
<input type="radio" id="dataset2"
name="dataset_selection" value="IMGNET">
<label for="dataset2">IMGNET dataset</label><br>
name="dataset_selection" value="KMNIST">
<label for="dataset2">KMNIST dataset</label><br>
<input type="radio" id="dataset3"
name="dataset_selection" value="dataset3">
<label for="dataset3">dataset3</label><br><br>
name="dataset_selection" value="FashionMNIST">
<label for="dataset3">FashionMNIST dataset</label><br>
<input type="radio" id="dataset4"
name="dataset_selection" value="CIFAR10">
<label for="dataset4">CIFAR10 dataset</label><br>
<input type="radio" id="dataset5"
name="dataset_selection" value="CIFAR100">
<label for="dataset5">CIFAR100 dataset</label><br><br>
<!-- --------------------------------------------------------------- -->
......@@ -35,16 +43,16 @@
<!-- network selection -->
Or you can select a dataset from our database: <br>
<input type="radio" id="network1"
name="network_selection" value="EasyNet">
<label for="dataset1">EasyNet</label><br>
name="network_selection" value="LeNet">
<label for="network1">LeNet</label><br>
<input type="radio" id="network2"
name="network_selection" value="LeNet">
<label for="dataset2">LeNet</label><br>
name="network_selection" value="EasyNet">
<label for="network2">EasyNet</label><br>
<input type="radio" id="network3"
name="network_selection" value="AlexNet">
<label for="dataset3">AlexNet</label><br><br>
name="network_selection" value="SimpleNet">
<label for="network3">SimpleNet</label><br><br>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment