Skip to content
Snippets Groups Projects
Commit 4b1447e7 authored by Wang, Mia's avatar Wang, Mia
Browse files

Update react_app.py

parent db7c4d34
No related branches found
No related tags found
No related merge requests found
from dataclasses import dataclass from dataclasses import dataclass
from flask import Flask, request from flask import Flask, request, current_app
# from flask_cors import CORS # from flask_cors import CORS
import subprocess import subprocess
import os import os
import zipfile import zipfile
...@@ -19,83 +18,102 @@ from matplotlib import pyplot as plt ...@@ -19,83 +18,102 @@ from matplotlib import pyplot as plt
from numpy import save, load from numpy import save, load
from tqdm import trange from tqdm import trange
torch.manual_seed(0) torch.manual_seed(0)
# import agents and its functions
# import agents and its functions
from ..library.MetaAugment import UCB1_JC_py as UCB1_JC from ..library.MetaAugment import UCB1_JC_py as UCB1_JC
from ..library.MetaAugment import Evo_learner as Evo from ..library.MetaAugment import Evo_learner as Evo
print('@@@ import successful')
app = Flask(__name__) app = Flask(__name__)
# it is used to collect user input and store them in the app
@app.route('/home', methods=["GET", "POST"]) @app.route('/home', methods=["GET", "POST"])
def home(): def home():
print('in flask home') print('@@@ in Flask Home')
form_data = request.get_json() form_data = request.get_json()
batch_size = 1 # size of batch the inner NN is trained with # form_data = request.files
learning_rate = 1e-1 # fix learning rate # form_data = request.form.get('test')
ds = form_data['select_dataset'] # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100) print('@@@ this is form data', form_data)
toy_size = form_data['toy_size'] # total propeortion of training and test set we use
# required input
ds = form_data['select_dataset'] # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)
IsLeNet = form_data["select_network"] # using LeNet or EasyNet or SimpleNet ->> default
auto_aug_leanrer = form_data["select_learner"] # augmentation methods to be excluded
print('@@@ required user input:', 'ds', ds, 'IsLeNet:', IsLeNet, 'auto_aug_leanrer:',auto_aug_leanrer)
# advanced input
if 'batch_size' in form_data.keys():
batch_size = form_data['batch_size'] # size of batch the inner NN is trained with
else:
batch_size = 1 # this is for demonstration purposes
if 'learning_rate' in form_data.keys():
learning_rate = form_data['learning_rate'] # fix learning rate
else:
learning_rate = 10-1
if 'toy_size' in form_data.keys():
toy_size = form_data['toy_size'] # total propeortion of training and test set we use
else:
toy_size = 1 # this is for demonstration purposes
if 'iterations' in form_data.keys():
iterations = form_data['iterations'] # total iterations, should be more than the number of policies
else:
iterations = 10
print('@@@ advanced search: batch_size:', batch_size, 'learning_rate:', learning_rate, 'toy_size:', toy_size, 'iterations:', iterations)
# default values
max_epochs = 10 # max number of epochs that is run if early stopping is not hit max_epochs = 10 # max number of epochs that is run if early stopping is not hit
early_stop_num = 10 # max number of worse validation scores before early stopping is triggered early_stop_num = 10 # max number of worse validation scores before early stopping is triggered
num_policies = 5 # fix number of policies num_policies = 5 # fix number of policies
num_sub_policies = 5 # fix number of sub-policies in a policy num_sub_policies = 5 # fix number of sub-policies in a policy
iterations = 5 # total iterations, should be more than the number of policies
IsLeNet = form_data["network_selection"] # using LeNet or EasyNet or SimpleNet ->> default
# if user upload datasets and networks, save them in the database # if user upload datasets and networks, save them in the database
if ds == 'Other': if ds == 'Other':
ds_folder = request.files['dataset_upload'] ds_folder = request.files #['ds_upload']
print('!!!ds_folder', ds_folder)
ds_name_zip = ds_folder.filename ds_name_zip = ds_folder.filename
ds_name = ds_name_zip.split('.')[0] ds_name = ds_name_zip.split('.')[0]
ds_folder.save('./MetaAugment/datasets/'+ ds_name_zip) ds_folder.save('./datasets/'+ ds_name_zip)
with zipfile.ZipFile('./MetaAugment/datasets/'+ ds_name_zip, 'r') as zip_ref: with zipfile.ZipFile('./datasets/'+ ds_name_zip, 'r') as zip_ref:
zip_ref.extractall('./MetaAugment/datasets/upload_dataset/') zip_ref.extractall('./datasets/upload_dataset/')
os.remove(f'./MetaAugment/datasets/{ds_name_zip}') if not current_app.debug:
os.remove(f'./datasets/{ds_name_zip}')
else: else:
ds_name = None ds_name = None
for (dirpath, dirnames, filenames) in os.walk(f'./MetaAugment/datasets/upload_dataset/{ds_name}/'): # test if uploaded dataset meets the criteria
for (dirpath, dirnames, filenames) in os.walk(f'./datasets/upload_dataset/{ds_name}/'):
for dirname in dirnames: for dirname in dirnames:
if dirname[0:6] != 'class_': if dirname[0:6] != 'class_':
return render_template("fail_dataset.html") return None # neet to change render to a 'failed dataset webpage'
else:
pass
# save the user uploaded network
if IsLeNet == 'Other': if IsLeNet == 'Other':
childnetwork = request.files['network_upload'] childnetwork = request.files['network_upload']
childnetwork.save('./MetaAugment/child_networks/'+childnetwork.filename) childnetwork.save('./child_networks/'+childnetwork.filename)
# generate random policies at start # generate random policies at start
auto_aug_leanrer = request.form.get("auto_aug_selection")
current_app.config['AAL'] = auto_aug_leanrer
if auto_aug_leanrer == 'UCB': current_app.config['NP'] = num_policies
policies = UCB1_JC.generate_policies(num_policies, num_sub_policies) current_app.config['NSP'] = num_sub_policies
q_values, best_q_values = UCB1_JC.run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet, ds_name) current_app.config['BS'] = batch_size
elif auto_aug_leanrer == 'Evolutionary Learner': current_app.config['LR'] = learning_rate
learner = Evo.Evolutionary_learner(fun_num=num_funcs, p_bins=1, mag_bins=1, sub_num_pol=1, ds_name=ds_name, exclude_method=exclude_method) current_app.config['TS'] = toy_size
learner.run_instance() current_app.config['ME'] = max_epochs
elif auto_aug_leanrer == 'Random Searcher': current_app.config['ESN'] = early_stop_num
pass current_app.config['IT'] = iterations
elif auto_aug_leanrer == 'Genetic Learner': current_app.config['ISLENET'] = IsLeNet
pass current_app.config['DSN'] = ds_name
current_app.config['ds'] = ds
plt.figure()
plt.plot(q_values)
best_q_values = np.array(best_q_values)
# save('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), best_q_values)
# best_q_values = load('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), allow_pickle=True)
print("DONE")
print("@@@ user input has all stored in the app")
return None return {'try': 'Hello'}
@app.route('/api') @app.route('/api')
def index(): def index():
return {'name': 'Hello'} return {'name': 'Hello'}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment