Skip to content
Snippets Groups Projects
Commit 4b1447e7 authored by Wang, Mia's avatar Wang, Mia
Browse files

Update react_app.py

parent db7c4d34
No related branches found
No related tags found
No related merge requests found
from dataclasses import dataclass
from flask import Flask, request
from flask import Flask, request, current_app
# from flask_cors import CORS
import subprocess
import os
import zipfile
......@@ -19,83 +18,102 @@ from matplotlib import pyplot as plt
from numpy import save, load
from tqdm import trange
torch.manual_seed(0)
# import agents and its functions
# import agents and its functions
from ..library.MetaAugment import UCB1_JC_py as UCB1_JC
from ..library.MetaAugment import Evo_learner as Evo
print('@@@ import successful')
app = Flask(__name__)
# it is used to collect user input and store them in the app
@app.route('/home', methods=["GET", "POST"])
def home():
print('in flask home')
form_data = request.get_json()
batch_size = 1 # size of batch the inner NN is trained with
learning_rate = 1e-1 # fix learning rate
ds = form_data['select_dataset'] # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)
toy_size = form_data['toy_size'] # total propeortion of training and test set we use
print('@@@ in Flask Home')
form_data = request.get_json()
# form_data = request.files
# form_data = request.form.get('test')
print('@@@ this is form data', form_data)
# required input
ds = form_data['select_dataset'] # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)
IsLeNet = form_data["select_network"] # using LeNet or EasyNet or SimpleNet ->> default
auto_aug_leanrer = form_data["select_learner"] # augmentation methods to be excluded
print('@@@ required user input:', 'ds', ds, 'IsLeNet:', IsLeNet, 'auto_aug_leanrer:',auto_aug_leanrer)
# advanced input
if 'batch_size' in form_data.keys():
batch_size = form_data['batch_size'] # size of batch the inner NN is trained with
else:
batch_size = 1 # this is for demonstration purposes
if 'learning_rate' in form_data.keys():
learning_rate = form_data['learning_rate'] # fix learning rate
else:
learning_rate = 10-1
if 'toy_size' in form_data.keys():
toy_size = form_data['toy_size'] # total propeortion of training and test set we use
else:
toy_size = 1 # this is for demonstration purposes
if 'iterations' in form_data.keys():
iterations = form_data['iterations'] # total iterations, should be more than the number of policies
else:
iterations = 10
print('@@@ advanced search: batch_size:', batch_size, 'learning_rate:', learning_rate, 'toy_size:', toy_size, 'iterations:', iterations)
# default values
max_epochs = 10 # max number of epochs that is run if early stopping is not hit
early_stop_num = 10 # max number of worse validation scores before early stopping is triggered
num_policies = 5 # fix number of policies
num_sub_policies = 5 # fix number of sub-policies in a policy
iterations = 5 # total iterations, should be more than the number of policies
IsLeNet = form_data["network_selection"] # using LeNet or EasyNet or SimpleNet ->> default
# if user upload datasets and networks, save them in the database
if ds == 'Other':
ds_folder = request.files['dataset_upload']
ds_folder = request.files #['ds_upload']
print('!!!ds_folder', ds_folder)
ds_name_zip = ds_folder.filename
ds_name = ds_name_zip.split('.')[0]
ds_folder.save('./MetaAugment/datasets/'+ ds_name_zip)
with zipfile.ZipFile('./MetaAugment/datasets/'+ ds_name_zip, 'r') as zip_ref:
zip_ref.extractall('./MetaAugment/datasets/upload_dataset/')
os.remove(f'./MetaAugment/datasets/{ds_name_zip}')
ds_folder.save('./datasets/'+ ds_name_zip)
with zipfile.ZipFile('./datasets/'+ ds_name_zip, 'r') as zip_ref:
zip_ref.extractall('./datasets/upload_dataset/')
if not current_app.debug:
os.remove(f'./datasets/{ds_name_zip}')
else:
ds_name = None
for (dirpath, dirnames, filenames) in os.walk(f'./MetaAugment/datasets/upload_dataset/{ds_name}/'):
# test if uploaded dataset meets the criteria
for (dirpath, dirnames, filenames) in os.walk(f'./datasets/upload_dataset/{ds_name}/'):
for dirname in dirnames:
if dirname[0:6] != 'class_':
return render_template("fail_dataset.html")
else:
pass
return None # neet to change render to a 'failed dataset webpage'
# save the user uploaded network
if IsLeNet == 'Other':
childnetwork = request.files['network_upload']
childnetwork.save('./MetaAugment/child_networks/'+childnetwork.filename)
childnetwork.save('./child_networks/'+childnetwork.filename)
# generate random policies at start
auto_aug_leanrer = request.form.get("auto_aug_selection")
if auto_aug_leanrer == 'UCB':
policies = UCB1_JC.generate_policies(num_policies, num_sub_policies)
q_values, best_q_values = UCB1_JC.run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet, ds_name)
elif auto_aug_leanrer == 'Evolutionary Learner':
learner = Evo.Evolutionary_learner(fun_num=num_funcs, p_bins=1, mag_bins=1, sub_num_pol=1, ds_name=ds_name, exclude_method=exclude_method)
learner.run_instance()
elif auto_aug_leanrer == 'Random Searcher':
pass
elif auto_aug_leanrer == 'Genetic Learner':
pass
plt.figure()
plt.plot(q_values)
best_q_values = np.array(best_q_values)
# save('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), best_q_values)
# best_q_values = load('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), allow_pickle=True)
current_app.config['AAL'] = auto_aug_leanrer
current_app.config['NP'] = num_policies
current_app.config['NSP'] = num_sub_policies
current_app.config['BS'] = batch_size
current_app.config['LR'] = learning_rate
current_app.config['TS'] = toy_size
current_app.config['ME'] = max_epochs
current_app.config['ESN'] = early_stop_num
current_app.config['IT'] = iterations
current_app.config['ISLENET'] = IsLeNet
current_app.config['DSN'] = ds_name
current_app.config['ds'] = ds
print("DONE")
print("@@@ user input has all stored in the app")
return None
return {'try': 'Hello'}
@app.route('/api')
def index():
return {'name': 'Hello'}
\ No newline at end of file
return {'name': 'Hello'}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment