Newer
Older
from dataclasses import dataclass
from flask import Flask, request, current_app, send_file
import os
import zipfile
import torch
from numpy import save, load
import temp_util.wapp_util as wapp_util
import time
import os
import sys
sys.path.insert(0, os.path.abspath('..'))
# it is used to collect user input and store them in the app
@app.route('/home', methods=["GET", "POST"])
# form_data = request.files['ds_upload']
# print('@@@ form_data', form_data)
form_data = request.form
print('@@@ this is form data', form_data)
ds = form_data['select_dataset'] # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)
IsLeNet = form_data["select_network"] # using LeNet or EasyNet or SimpleNet ->> default
auto_aug_learner = form_data["select_learner"] # augmentation methods to be excluded
print('@@@ required user input:', 'ds', ds, 'IsLeNet:', IsLeNet, 'auto_aug_leanrer:',auto_aug_learner)
# advanced input
if form_data['batch_size'] != 'undefined':
batch_size = form_data['batch_size'] # size of batch the inner NN is trained with
else:
batch_size = 1 # this is for demonstration purposes
if form_data['learning_rate'] != 'undefined':
learning_rate = form_data['learning_rate'] # fix learning rate
else:
learning_rate = 10-1
if form_data['toy_size'] != 'undefined':
toy_size = form_data['toy_size'] # total propeortion of training and test set we use
else:
toy_size = 1 # this is for demonstration purposes
if form_data['iterations'] != 'undefined':
iterations = form_data['iterations'] # total iterations, should be more than the number of policies
else:
iterations = 10
exclude_method = form_data['select_action']
print('@@@ advanced search: batch_size:', batch_size, 'learning_rate:', learning_rate, 'toy_size:', toy_size, 'iterations:', iterations, 'exclude_method', exclude_method)
# default values
max_epochs = 10 # max number of epochs that is run if early stopping is not hit
early_stop_num = 10 # max number of worse validation scores before early stopping is triggered
num_policies = 5 # fix number of policies
num_sub_policies = 5 # fix number of sub-policies in a policy
# if user upload datasets and networks, save them in the database
if ds == 'Other':
ds_folder = request.files['ds_upload']
print('!!!ds_folder', ds_folder)
ds_name_zip = ds_folder.filename
ds_name = ds_name_zip.split('.')[0]
ds_folder.save('./datasets/'+ ds_name_zip)
with zipfile.ZipFile('./datasets/'+ ds_name_zip, 'r') as zip_ref:
zip_ref.extractall('./datasets/upload_dataset/')
if not current_app.debug:
os.remove(f'./datasets/{ds_name_zip}')
else:
ds_name_zip = None
ds_name = None
# test if uploaded dataset meets the criteria
for (dirpath, dirnames, filenames) in os.walk(f'./datasets/upload_dataset/{ds_name}/'):
for dirname in dirnames:
if dirname[0:6] != 'class_':
return None # neet to change render to a 'failed dataset webpage'
# save the user uploaded network
if IsLeNet == 'Other':
childnetwork = request.files['network_upload']
childnetwork.save('./child_networks/'+childnetwork.filename)
network_name = childnetwork.filename
else:
network_name = None
print("@@@ user input has all stored in the app")
data = {'ds': ds, 'ds_name': ds_name_zip, 'IsLeNet': IsLeNet, 'network_name': network_name,
'auto_aug_learner':auto_aug_learner, 'batch_size': batch_size, 'learning_rate': learning_rate,
'toy_size':toy_size, 'iterations':iterations, 'exclude_method': exclude_method, }
current_app.config['data'] = data
print('@@@ all data sent', current_app.config['data'])
return {'data': 'all stored'}
# ========================================================================
@app.route('/confirm', methods=['POST', 'GET'])
def confirm():
print('inside confirm page')
data = current_app.config['data']
return data
# ========================================================================
@app.route('/training', methods=['POST', 'GET'])
def training():
# default values
max_epochs = 10 # max number of epochs that is run if early stopping is not hit
early_stop_num = 10 # max number of worse validation scores before early stopping is triggered
num_policies = 5 # fix number of policies
num_sub_policies = 5 # fix number of sub-policies in a policy
data = current_app.config.get('data')
# fake training
print('pretend it is training')
time.sleep(3)
print('epoch: 1')
time.sleep(3)
print('epoch: 2')
time.sleep(3)
print('epoch: 3')
print('it has finished training')
# ========================================================================
file_path = "./policy.txt"
f = open(file_path, "r")
return send_file(file_path, as_attachment=True)
@app.route('/api')
def index():
return {'status': 'api test'}
if __name__ == '__main__':