diff --git a/backend/react_app.py b/backend/react_app.py index 62ed36fb972ad695da2ab9dcc94915f5f3789417..cf06ce72a43493760f2e6b0d56d6c507ce9730d6 100644 --- a/backend/react_app.py +++ b/backend/react_app.py @@ -1,7 +1,6 @@ from dataclasses import dataclass -from flask import Flask, request +from flask import Flask, request, current_app # from flask_cors import CORS - import subprocess import os import zipfile @@ -19,83 +18,102 @@ from matplotlib import pyplot as plt from numpy import save, load from tqdm import trange torch.manual_seed(0) -# import agents and its functions +# import agents and its functions from ..library.MetaAugment import UCB1_JC_py as UCB1_JC from ..library.MetaAugment import Evo_learner as Evo +print('@@@ import successful') app = Flask(__name__) +# it is used to collect user input and store them in the app @app.route('/home', methods=["GET", "POST"]) def home(): - print('in flask home') - form_data = request.get_json() - batch_size = 1 # size of batch the inner NN is trained with - learning_rate = 1e-1 # fix learning rate - ds = form_data['select_dataset'] # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100) - toy_size = form_data['toy_size'] # total propeortion of training and test set we use + print('@@@ in Flask Home') + form_data = request.get_json() + # form_data = request.files + # form_data = request.form.get('test') + print('@@@ this is form data', form_data) + + # required input + ds = form_data['select_dataset'] # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100) + IsLeNet = form_data["select_network"] # using LeNet or EasyNet or SimpleNet ->> default + auto_aug_leanrer = form_data["select_learner"] # augmentation methods to be excluded + + print('@@@ required user input:', 'ds', ds, 'IsLeNet:', IsLeNet, 'auto_aug_leanrer:',auto_aug_leanrer) + # advanced input + if 'batch_size' in form_data.keys(): + batch_size = form_data['batch_size'] # size of batch the inner NN is trained with + else: + batch_size = 1 # this is for demonstration purposes + if 'learning_rate' in form_data.keys(): + learning_rate = form_data['learning_rate'] # fix learning rate + else: + learning_rate = 10-1 + if 'toy_size' in form_data.keys(): + toy_size = form_data['toy_size'] # total propeortion of training and test set we use + else: + toy_size = 1 # this is for demonstration purposes + if 'iterations' in form_data.keys(): + iterations = form_data['iterations'] # total iterations, should be more than the number of policies + else: + iterations = 10 + print('@@@ advanced search: batch_size:', batch_size, 'learning_rate:', learning_rate, 'toy_size:', toy_size, 'iterations:', iterations) + + # default values max_epochs = 10 # max number of epochs that is run if early stopping is not hit early_stop_num = 10 # max number of worse validation scores before early stopping is triggered num_policies = 5 # fix number of policies num_sub_policies = 5 # fix number of sub-policies in a policy - iterations = 5 # total iterations, should be more than the number of policies - IsLeNet = form_data["network_selection"] # using LeNet or EasyNet or SimpleNet ->> default + # if user upload datasets and networks, save them in the database if ds == 'Other': - ds_folder = request.files['dataset_upload'] + ds_folder = request.files #['ds_upload'] + print('!!!ds_folder', ds_folder) ds_name_zip = ds_folder.filename ds_name = ds_name_zip.split('.')[0] - ds_folder.save('./MetaAugment/datasets/'+ ds_name_zip) - with zipfile.ZipFile('./MetaAugment/datasets/'+ ds_name_zip, 'r') as zip_ref: - zip_ref.extractall('./MetaAugment/datasets/upload_dataset/') - os.remove(f'./MetaAugment/datasets/{ds_name_zip}') - + ds_folder.save('./datasets/'+ ds_name_zip) + with zipfile.ZipFile('./datasets/'+ ds_name_zip, 'r') as zip_ref: + zip_ref.extractall('./datasets/upload_dataset/') + if not current_app.debug: + os.remove(f'./datasets/{ds_name_zip}') else: ds_name = None - for (dirpath, dirnames, filenames) in os.walk(f'./MetaAugment/datasets/upload_dataset/{ds_name}/'): + # test if uploaded dataset meets the criteria + for (dirpath, dirnames, filenames) in os.walk(f'./datasets/upload_dataset/{ds_name}/'): for dirname in dirnames: if dirname[0:6] != 'class_': - return render_template("fail_dataset.html") - else: - pass - + return None # neet to change render to a 'failed dataset webpage' + # save the user uploaded network if IsLeNet == 'Other': childnetwork = request.files['network_upload'] - childnetwork.save('./MetaAugment/child_networks/'+childnetwork.filename) + childnetwork.save('./child_networks/'+childnetwork.filename) # generate random policies at start - auto_aug_leanrer = request.form.get("auto_aug_selection") - - if auto_aug_leanrer == 'UCB': - policies = UCB1_JC.generate_policies(num_policies, num_sub_policies) - q_values, best_q_values = UCB1_JC.run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet, ds_name) - elif auto_aug_leanrer == 'Evolutionary Learner': - learner = Evo.Evolutionary_learner(fun_num=num_funcs, p_bins=1, mag_bins=1, sub_num_pol=1, ds_name=ds_name, exclude_method=exclude_method) - learner.run_instance() - elif auto_aug_leanrer == 'Random Searcher': - pass - elif auto_aug_leanrer == 'Genetic Learner': - pass - - plt.figure() - plt.plot(q_values) - - - best_q_values = np.array(best_q_values) - - # save('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), best_q_values) - # best_q_values = load('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), allow_pickle=True) + + current_app.config['AAL'] = auto_aug_leanrer + current_app.config['NP'] = num_policies + current_app.config['NSP'] = num_sub_policies + current_app.config['BS'] = batch_size + current_app.config['LR'] = learning_rate + current_app.config['TS'] = toy_size + current_app.config['ME'] = max_epochs + current_app.config['ESN'] = early_stop_num + current_app.config['IT'] = iterations + current_app.config['ISLENET'] = IsLeNet + current_app.config['DSN'] = ds_name + current_app.config['ds'] = ds - print("DONE") + print("@@@ user input has all stored in the app") - return None + return {'try': 'Hello'} @app.route('/api') def index(): - return {'name': 'Hello'} \ No newline at end of file + return {'name': 'Hello'}