-
Sun Jin Kim authoredSun Jin Kim authored
react_app.py 9.33 KiB
from dataclasses import dataclass
from flask import Flask, request, current_app, render_template
# from flask_cors import CORS
import subprocess
import os
import zipfile
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data_utils
import torchvision
import torchvision.datasets as datasets
from matplotlib import pyplot as plt
from numpy import save, load
from tqdm import trange
torch.manual_seed(0)
import os
import sys
sys.path.insert(0, os.path.abspath('..'))
# # import agents and its functions
import MetaAugment.autoaugment_learners as aal
import MetaAugment.controller_networks as cont_n
import MetaAugment.child_networks as cn
print('@@@ import successful')
# import agents and its functions
# from ..MetaAugment import UCB1_JC_py as UCB1_JC
# from ..MetaAugment import Evo_learner as Evo
# print('@@@ import successful')
app = Flask(__name__)
# it is used to collect user input and store them in the app
@app.route('/home', methods=["GET", "POST"])
def get_form_data():
print('@@@ in Flask Home')
# form_data = request.get_json()
# form_data = request.files['ds_upload']
# print('@@@ form_data', form_data)
# form_data = request.form.get('test')
# print('@@@ this is form data', request.get_data())
# required input
# ds = form_data['select_dataset'] # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)
# IsLeNet = form_data["select_network"] # using LeNet or EasyNet or SimpleNet ->> default
# auto_aug_learner = form_data["select_learner"] # augmentation methods to be excluded
# print('@@@ required user input:', 'ds', ds, 'IsLeNet:', IsLeNet, 'auto_aug_leanrer:',auto_aug_learner)
# # advanced input
# if 'batch_size' in form_data.keys():
# batch_size = form_data['batch_size'] # size of batch the inner NN is trained with
# else:
# batch_size = 1 # this is for demonstration purposes
# if 'learning_rate' in form_data.keys():
# learning_rate = form_data['learning_rate'] # fix learning rate
# else:
# learning_rate = 10-1
# if 'toy_size' in form_data.keys():
# toy_size = form_data['toy_size'] # total propeortion of training and test set we use
# else:
# toy_size = 1 # this is for demonstration purposes
# if 'iterations' in form_data.keys():
# iterations = form_data['iterations'] # total iterations, should be more than the number of policies
# else:
# iterations = 10
# exclude_method = form_data['select_action']
# num_funcs = 14 - len(exclude_method)
# print('@@@ advanced search: batch_size:', batch_size, 'learning_rate:', learning_rate, 'toy_size:', toy_size, 'iterations:', iterations, 'exclude_method', exclude_method, 'num_funcs', num_funcs)
# # default values
# max_epochs = 10 # max number of epochs that is run if early stopping is not hit
# early_stop_num = 10 # max number of worse validation scores before early stopping is triggered
# num_policies = 5 # fix number of policies
# num_sub_policies = 5 # fix number of sub-policies in a policy
# # if user upload datasets and networks, save them in the database
# if ds == 'Other':
# ds_folder = request.files['ds_upload']
# print('!!!ds_folder', ds_folder)
# ds_name_zip = ds_folder.filename
# ds_name = ds_name_zip.split('.')[0]
# ds_folder.save('./datasets/'+ ds_name_zip)
# with zipfile.ZipFile('./datasets/'+ ds_name_zip, 'r') as zip_ref:
# zip_ref.extractall('./datasets/upload_dataset/')
# if not current_app.debug:
# os.remove(f'./datasets/{ds_name_zip}')
# else:
# ds_name = None
# # test if uploaded dataset meets the criteria
# for (dirpath, dirnames, filenames) in os.walk(f'./datasets/upload_dataset/{ds_name}/'):
# for dirname in dirnames:
# if dirname[0:6] != 'class_':
# return None # neet to change render to a 'failed dataset webpage'
# # save the user uploaded network
# if IsLeNet == 'Other':
# childnetwork = request.files['network_upload']
# childnetwork.save('./child_networks/'+childnetwork.filename)
# network_name = childnetwork.filename
# # generate random policies at start
# current_app.config['AAL'] = auto_aug_learner
# current_app.config['NP'] = num_policies
# current_app.config['NSP'] = num_sub_policies
# current_app.config['BS'] = batch_size
# current_app.config['LR'] = learning_rate
# current_app.config['TS'] = toy_size
# current_app.config['ME'] = max_epochs
# current_app.config['ESN'] = early_stop_num
# current_app.config['IT'] = iterations
# current_app.config['ISLENET'] = IsLeNet
# current_app.config['DSN'] = ds_name
# current_app.config['ds'] = ds
# print("@@@ user input has all stored in the app")
# data = {'ds': ds, 'ds_name': ds_name, 'IsLeNet': IsLeNet, 'ds_folder.filename': ds_name,
# 'auto_aug_learner':auto_aug_learner, 'batch_size': batch_size, 'learning_rate': learning_rate,
# 'toy_size':toy_size, 'iterations':iterations, }
# print('@@@ all data sent', data)
return {'data': 'show training data'}
@app.route('/confirm', methods=['POST', 'GET'])
def confirm():
print('inside confirm')
# aa learner
auto_aug_learner = current_app.config.get('AAL')
# search space & problem setting
ds = current_app.config.get('ds')
ds_name = current_app.config.get('DSN')
exclude_method = current_app.config.get('exc_meth')
num_policies = current_app.config.get('NP')
num_sub_policies = current_app.config.get('NSP')
num_funcs = current_app.config.get('NUMFUN')
toy_size = current_app.config.get('TS')
# child network
IsLeNet = current_app.config.get('ISLENET')
# child network training hyperparameters
batch_size = current_app.config.get('BS')
early_stop_num = current_app.config.get('ESN')
iterations = current_app.config.get('IT')
learning_rate = current_app.config.get('LR')
max_epochs = current_app.config.get('ME')
data = {'ds': ds, 'ds_name': ds_name, 'IsLeNet': IsLeNet, 'ds_folder.filename': ds_name,
'auto_aug_learner':auto_aug_learner, 'batch_size': batch_size, 'learning_rate': learning_rate,
'toy_size':toy_size, 'iterations':iterations, }
return {'batch_size': '12'}
# ========================================================================
@app.route('/training', methods=['POST', 'GET'])
def training():
# aa learner
auto_aug_learner = current_app.config.get('AAL')
# search space & problem setting
ds = current_app.config.get('ds')
ds_name = current_app.config.get('DSN')
exclude_method = current_app.config.get('exc_meth')
num_funcs = current_app.config.get('NUMFUN')
num_policies = current_app.config.get('NP')
num_sub_policies = current_app.config.get('NSP')
toy_size = current_app.config.get('TS')
# child network
IsLeNet = current_app.config.get('ISLENET')
# child network training hyperparameters
batch_size = current_app.config.get('BS')
early_stop_num = current_app.config.get('ESN')
iterations = current_app.config.get('IT')
learning_rate = current_app.config.get('LR')
max_epochs = current_app.config.get('ME')
if auto_aug_learner == 'UCB':
policies = aal.ucb_learner.generate_policies(num_policies, num_sub_policies)
q_values, best_q_values = aal.ucb_learner.run_UCB1(
policies,
batch_size,
learning_rate,
ds,
toy_size,
max_epochs,
early_stop_num,
iterations,
IsLeNet,
ds_name
)
best_q_values = np.array(best_q_values)
elif auto_aug_learner == 'Evolutionary Learner':
network = cont_n.evo_controller(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1)
child_network = cn.LeNet()
learner = aal.evo_learner(
network=network,
fun_num=num_funcs,
p_bins=1,
mag_bins=1,
sub_num_pol=1,
ds = ds,
ds_name=ds_name,
exclude_method=exclude_method,
child_network=child_network
)
learner.run_instance()
elif auto_aug_learner == 'Random Searcher':
pass
elif auto_aug_learner == 'Genetic Learner':
pass
return {'status': 'training'}
# ========================================================================
@app.route('/results')
def show_result():
return {'status': 'results'}
@app.route('/api')
def index():
return {'status': 'api test'}
if __name__ == '__main__':
app.run(debug=True)