Newer
Older
from flask import Flask, request, current_app, send_file, send_from_directory, redirect, url_for, session
from flask_cors import CORS, cross_origin
import os
import zipfile
import torch
from numpy import int0, save, load
from react_backend.wapp_util import parse_users_learner_spec
import pprint
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import os
import sys
sys.path.insert(0, os.path.abspath('..'))
torch.manual_seed(0)
print('@@@ import successful')
# app = Flask(__name__, static_folder='react_frontend/build', static_url_path='/')
app = Flask(__name__)
CORS(app)
# it is used to collect user input and store them in the app
@app.route('/home', methods=["GET", "POST"])
if request.method == 'POST':
print('@@@ in Flask Home')
form_data = request.form
print('@@@ this is form data', form_data)
# required input
ds = form_data['select_dataset'] # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)
IsLeNet = form_data["select_network"] # using LeNet or EasyNet or SimpleNet ->> default
auto_aug_learner = form_data["select_learner"] # augmentation methods to be excluded
print('@@@ required user input:', 'ds', ds, 'IsLeNet:', IsLeNet, 'auto_aug_leanrer:',auto_aug_learner)
# advanced input
if form_data['batch_size'] not in ['undefined', ""]:
batch_size = int(form_data['batch_size'] ) # size of batch the inner NN is trained with
batch_size = 16 # this is for demonstration purposes
if form_data['learning_rate'] not in ['undefined', ""]:
learning_rate = float(form_data['learning_rate']) # fix learning rate
learning_rate = 1e-2
if form_data['toy_size'] not in ['undefined', ""]:
toy_size = float(form_data['toy_size']) # total propeortion of training and test set we use
toy_size = 0.01 # this is for demonstration purposes
if form_data['iterations'] not in ['undefined', ""]:
iterations = int(form_data['iterations']) # total iterations, should be more than the number of policies
exclude_method = form_data['select_action']
print('@@@ advanced search: batch_size:', batch_size, 'learning_rate:', learning_rate, 'toy_size:', toy_size, 'iterations:', iterations, 'exclude_method', exclude_method)
# default values
max_epochs = 10 # max number of epochs that is run if early stopping is not hit
early_stop_num = 10 # max number of worse validation scores before early stopping is triggered
num_policies = 5 # fix number of policies
num_sub_policies = 5 # fix number of sub-policies in a policy
# if user upload datasets and networks, save them in the database
if ds == 'Other':
ds_folder = request.files['ds_upload']
ds_name_zip = ds_folder.filename
# check dataset zip file format
if ds_name_zip.split('.')[1] != 'zip':
data = {'error_type': 'not a zip file', 'error': "We found that your uplaoded dataset is not a zip file..."}
current_app.config['data'] = data
return data
ds_folder.save('./react_backend/datasets/'+ ds_name_zip)
with zipfile.ZipFile('./react_backend/datasets/'+ ds_name_zip, 'r') as zip_ref:
zip_ref.extractall('./react_backend/datasets/upload_dataset/')
os.remove(f'./react_backend/datasets/{ds_name_zip}')
else:
ds_name_zip = None
ds_name = None
# test if uploaded dataset meets the criteria
i = -1
folder = 0
for (dirpath, dirnames, filenames) in os.walk(f'./react_backend/datasets/upload_dataset/{ds_name}/'):
i += 1
if i==0:
folders = dirnames
has_child_folder = dirnames!=[] # check if there are child folders
if not has_child_folder and i==0:
data = {'error_type': 'incorret dataset',
'error': "We found that your uplaoded dataset doesn't have the correct format that we are looking for."}
current_app.config['data'] = data
return data
if folder!=0 and len(folders)!=i:
data = {'error_type': 'incorret dataset',
'error': "We found that your uplaoded dataset doesn't have the correct format that we are looking for."}
current_app.config['data'] = data
return data
print('@@@ correct dataset folder!')
# save the user uploaded network
if IsLeNet == 'Other':
childnetwork = request.files['network_upload']
network_name = childnetwork.filename
if network_name.split('.')[1] != 'pkl':
data = {'error_type': 'incorrect network',
'error': "We found that your uploaded network is not a pickle file"}
current_app.config['data'] = data
else:
childnetwork.save('./child_networks/'+childnetwork.filename)
else:
network_name = None
print("@@@ user input has all stored in the app")
data = {'ds': ds, 'ds_name': ds_name_zip, 'IsLeNet': IsLeNet, 'network_name': network_name,
'auto_aug_learner':auto_aug_learner, 'batch_size': batch_size, 'learning_rate': learning_rate,
'toy_size':toy_size, 'iterations':iterations, 'exclude_method': exclude_method, }
current_app.config['data'] = data
print('@@@ all data sent', current_app.config['data'])
elif request.method == 'GET':
print('it is GET method')
if 'data' in current_app.config.keys():
data = current_app.config['data']
else:
data = {'error': "We didn't received any data from you submission form. Please go back to the home page",
'error_type': 'no data'}
# return redirect(url_for('confirm', data=data))
# ========================================================================
@app.route('/training', methods=['POST', 'GET'])
def training():
# default values
max_epochs = 10 # max number of epochs that is run if early stopping is not hit
early_stop_num = 10 # max number of worse validation scores before early stopping is triggered
num_policies = 5 # fix number of policies
num_sub_policies = 5 # fix number of sub-policies in a policy
data = current_app.config.get('data')
# parse the settings given by the user to obtain tools we need
train_dataset, test_dataset, child_archi, agent = parse_users_learner_spec(
max_epochs=max_epochs,
early_stop_num=early_stop_num,
num_policies=num_policies,
num_sub_policies=num_sub_policies,
**data
)
# train the autoaugment learner for number of `iterations`
agent.learn(
train_dataset=train_dataset,
test_dataset=test_dataset,
child_network_architecture=child_archi,
iterations=data['iterations']
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
print('the history of all the policies the agent has tested:')
pprint.pprint(agent.history)
# get acc graph and best acc graph
acc_list = [acc for (policy,acc) in agent.history]
best_acc_list = []
best_til_now = 0
for acc in acc_list:
if acc>best_til_now:
best_til_now=acc
best_acc_list.append(best_til_now)
# plot both here
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.plot(acc_list)
ax.plot(best_acc_list)
ax.set_xlabel('Number of Iterations')
ax.set_ylabel('Accuracy')
ax.set_title('Auto-augmentation Learner Performance Curve')
with open("./react_frontend/src/pages/output.png", 'wb') as f:
fig.savefig(f)
print("best policies:")
best_policy = agent.get_mega_policy(number_policies=4)
print(best_policy)
with open("./react_backend/policy.txt", 'w') as f:
# save the best_policy in pretty_print string format
f.write(pprint.pformat(best_policy, indent=4))
print('')
return {'status': 'Training is done!'}
# ========================================================================
@app.route('/result')
return send_file(file_path, as_attachment=True, cache_timeout=0)
# @app.route('/')
# def serve():
# return send_from_directory(app.static_folder, 'index.html')