Skip to content
Snippets Groups Projects
react_app.py 9.33 KiB
Newer Older
  • Learn to ignore specific revisions
  • from dataclasses import dataclass
    
    Mia Wang's avatar
    Mia Wang committed
    from flask import Flask, request, current_app, render_template
    
    Mia Wang's avatar
    Mia Wang committed
    # from flask_cors import CORS
    
    import subprocess
    import os
    import zipfile
    
    import numpy as np
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim
    import torch.utils.data as data_utils
    import torchvision
    import torchvision.datasets as datasets
    
    from matplotlib import pyplot as plt
    from numpy import save, load
    from tqdm import trange
    torch.manual_seed(0)
    
    
    Mia Wang's avatar
    Mia Wang committed
    import os
    import sys
    sys.path.insert(0, os.path.abspath('..'))
    
    
    # # import agents and its functions
    
    Max Ramsay King's avatar
    Max Ramsay King committed
    import MetaAugment.autoaugment_learners as aal
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    import MetaAugment.controller_networks as cont_n
    import MetaAugment.child_networks as cn
    
    Wang, Mia's avatar
    Wang, Mia committed
    print('@@@ import successful')
    
    Mia Wang's avatar
    Mia Wang committed
    # import agents and its functions
    # from ..MetaAugment import UCB1_JC_py as UCB1_JC
    # from ..MetaAugment import Evo_learner as Evo
    # print('@@@ import successful')
    
    
    app = Flask(__name__)
    
    Mia Wang's avatar
    Mia Wang committed
    
    
    Wang, Mia's avatar
    Wang, Mia committed
    # it is used to collect user input and store them in the app
    
    @app.route('/home', methods=["GET", "POST"])
    
    Mia Wang's avatar
    Mia Wang committed
    def get_form_data():
    
    Wang, Mia's avatar
    Wang, Mia committed
        print('@@@ in Flask Home')
    
    Mia Wang's avatar
    Mia Wang committed
        # form_data = request.get_json() 
    
        # form_data = request.files['ds_upload'] 
        # print('@@@ form_data', form_data) 
    
    Mia Wang's avatar
    Mia Wang committed
     
        # form_data = request.form.get('test') 
        # print('@@@ this is form data', request.get_data())
    
    Wang, Mia's avatar
    Wang, Mia committed
    
        # required input
    
    Mia Wang's avatar
    Mia Wang committed
        # ds = form_data['select_dataset'] # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)
        # IsLeNet = form_data["select_network"]   # using LeNet or EasyNet or SimpleNet ->> default 
        # auto_aug_learner = form_data["select_learner"] # augmentation methods to be excluded
    
        # print('@@@ required user input:', 'ds', ds, 'IsLeNet:', IsLeNet, 'auto_aug_leanrer:',auto_aug_learner)
        # # advanced input
        # if 'batch_size' in form_data.keys(): 
        #     batch_size = form_data['batch_size']       # size of batch the inner NN is trained with
        # else: 
        #     batch_size = 1 # this is for demonstration purposes
        # if 'learning_rate' in form_data.keys(): 
        #     learning_rate =  form_data['learning_rate']  # fix learning rate
        # else: 
        #     learning_rate = 10-1
        # if 'toy_size' in form_data.keys(): 
        #     toy_size = form_data['toy_size']      # total propeortion of training and test set we use
        # else: 
        #     toy_size = 1 # this is for demonstration purposes
        # if 'iterations' in form_data.keys(): 
        #     iterations = form_data['iterations']      # total iterations, should be more than the number of policies
        # else: 
        #     iterations = 10
        # exclude_method = form_data['select_action']
        # num_funcs = 14 - len(exclude_method)
        # print('@@@ advanced search: batch_size:', batch_size, 'learning_rate:', learning_rate, 'toy_size:', toy_size, 'iterations:', iterations, 'exclude_method', exclude_method, 'num_funcs', num_funcs)
    
    Mia Wang's avatar
    Mia Wang committed
        
    
    Wang, Mia's avatar
    Wang, Mia committed
    
    
    Mia Wang's avatar
    Mia Wang committed
        # # default values 
        # max_epochs = 10      # max number of epochs that is run if early stopping is not hit
        # early_stop_num = 10   # max number of worse validation scores before early stopping is triggered
        # num_policies = 5      # fix number of policies
        # num_sub_policies = 5  # fix number of sub-policies in a policy
    
    Wang, Mia's avatar
    Wang, Mia committed
        
    
    Mia Wang's avatar
    Mia Wang committed
        # # if user upload datasets and networks, save them in the database
        # if ds == 'Other':
        #     ds_folder = request.files['ds_upload'] 
        #     print('!!!ds_folder', ds_folder)
        #     ds_name_zip = ds_folder.filename
        #     ds_name = ds_name_zip.split('.')[0]
        #     ds_folder.save('./datasets/'+ ds_name_zip)
        #     with zipfile.ZipFile('./datasets/'+ ds_name_zip, 'r') as zip_ref:
        #         zip_ref.extractall('./datasets/upload_dataset/')
        #     if not current_app.debug:
        #         os.remove(f'./datasets/{ds_name_zip}')
        # else: 
        #     ds_name = None
    
        # # test if uploaded dataset meets the criteria 
        # for (dirpath, dirnames, filenames) in os.walk(f'./datasets/upload_dataset/{ds_name}/'):
        #     for dirname in dirnames:
        #         if dirname[0:6] != 'class_':
        #             return None # neet to change render to a 'failed dataset webpage'
    
        # # save the user uploaded network
        # if IsLeNet == 'Other':
        #     childnetwork = request.files['network_upload']
        #     childnetwork.save('./child_networks/'+childnetwork.filename)
        #     network_name = childnetwork.filename
    
    Mia Wang's avatar
    Mia Wang committed
        # # generate random policies at start
        # current_app.config['AAL'] = auto_aug_learner
        # current_app.config['NP'] = num_policies
        # current_app.config['NSP'] = num_sub_policies
        # current_app.config['BS'] = batch_size
        # current_app.config['LR'] = learning_rate
        # current_app.config['TS'] = toy_size
        # current_app.config['ME'] = max_epochs
        # current_app.config['ESN'] = early_stop_num
        # current_app.config['IT'] = iterations
        # current_app.config['ISLENET'] = IsLeNet
        # current_app.config['DSN'] = ds_name
        # current_app.config['ds'] = ds
    
    Mia Wang's avatar
    Mia Wang committed
        # print("@@@ user input has all stored in the app")
    
    Mia Wang's avatar
    Mia Wang committed
        # data = {'ds': ds, 'ds_name': ds_name, 'IsLeNet': IsLeNet, 'ds_folder.filename': ds_name,
        #         'auto_aug_learner':auto_aug_learner, 'batch_size': batch_size, 'learning_rate': learning_rate, 
        #         'toy_size':toy_size, 'iterations':iterations, }
    
    Mia Wang's avatar
    Mia Wang committed
        
    
    Mia Wang's avatar
    Mia Wang committed
        # print('@@@ all data sent', data)
    
    Mia Wang's avatar
    Mia Wang committed
        return {'data': 'show training data'}
    
    
    Mia Wang's avatar
    Mia Wang committed
    @app.route('/confirm', methods=['POST', 'GET'])
    def confirm():
        print('inside confirm')
    
    Mia Wang's avatar
    Mia Wang committed
        auto_aug_learner = current_app.config.get('AAL')
    
    
        # search space & problem setting
        ds = current_app.config.get('ds')
        ds_name = current_app.config.get('DSN')
        exclude_method = current_app.config.get('exc_meth')
    
    Mia Wang's avatar
    Mia Wang committed
        num_policies = current_app.config.get('NP')
        num_sub_policies = current_app.config.get('NSP')
    
        num_funcs = current_app.config.get('NUMFUN')
    
    Mia Wang's avatar
    Mia Wang committed
        toy_size = current_app.config.get('TS')
    
    
        # child network
        IsLeNet = current_app.config.get('ISLENET')
    
        # child network training hyperparameters
        batch_size = current_app.config.get('BS')
    
    Mia Wang's avatar
    Mia Wang committed
        early_stop_num = current_app.config.get('ESN')
        iterations = current_app.config.get('IT')
    
        learning_rate = current_app.config.get('LR')
        max_epochs = current_app.config.get('ME')
    
    Mia Wang's avatar
    Mia Wang committed
    
    
    Mia Wang's avatar
    Mia Wang committed
        data = {'ds': ds, 'ds_name': ds_name, 'IsLeNet': IsLeNet, 'ds_folder.filename': ds_name,
                'auto_aug_learner':auto_aug_learner, 'batch_size': batch_size, 'learning_rate': learning_rate, 
                'toy_size':toy_size, 'iterations':iterations, }
    
        return {'batch_size': '12'}
    
    Mia Wang's avatar
    Mia Wang committed
    
    # ========================================================================
    @app.route('/training', methods=['POST', 'GET'])
    def training():
    
    Mia Wang's avatar
    Mia Wang committed
        auto_aug_learner = current_app.config.get('AAL')
    
    
        # search space & problem setting
        ds = current_app.config.get('ds')
        ds_name = current_app.config.get('DSN')
        exclude_method = current_app.config.get('exc_meth')
        num_funcs = current_app.config.get('NUMFUN')
    
    Mia Wang's avatar
    Mia Wang committed
        num_policies = current_app.config.get('NP')
        num_sub_policies = current_app.config.get('NSP')
        toy_size = current_app.config.get('TS')
    
        
        # child network
        IsLeNet = current_app.config.get('ISLENET')
    
        # child network training hyperparameters
        batch_size = current_app.config.get('BS')
    
    Mia Wang's avatar
    Mia Wang committed
        early_stop_num = current_app.config.get('ESN')
        iterations = current_app.config.get('IT')
    
        learning_rate = current_app.config.get('LR')
        max_epochs = current_app.config.get('ME')
    
    Mia Wang's avatar
    Mia Wang committed
    
    
        if auto_aug_learner == 'UCB':
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
            policies = aal.ucb_learner.generate_policies(num_policies, num_sub_policies)
            q_values, best_q_values = aal.ucb_learner.run_UCB1(
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                                                    policies,
                                                    batch_size, 
                                                    learning_rate, 
                                                    ds, 
                                                    toy_size, 
                                                    max_epochs, 
                                                    early_stop_num, 
                                                    iterations, 
                                                    IsLeNet, 
                                                    ds_name
                                                    )     
    
    Mia Wang's avatar
    Mia Wang committed
            best_q_values = np.array(best_q_values)
    
        elif auto_aug_learner == 'Evolutionary Learner':
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
            network = cont_n.evo_controller(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1)
            child_network = cn.LeNet()
            learner = aal.evo_learner(
                                    network=network, 
                                    fun_num=num_funcs, 
                                    p_bins=1, 
                                    mag_bins=1, 
                                    sub_num_pol=1, 
                                    ds = ds, 
                                    ds_name=ds_name, 
                                    exclude_method=exclude_method, 
                                    child_network=child_network
                                    )
    
    Mia Wang's avatar
    Mia Wang committed
            learner.run_instance()
        elif auto_aug_learner == 'Random Searcher':
            pass 
        elif auto_aug_learner == 'Genetic Learner':
            pass
    
        return {'status': 'training'}
    
    
    
    # ========================================================================
    @app.route('/results')
    def show_result():
        return {'status': 'results'}
    
    Mia Wang's avatar
    Mia Wang committed
    
    
    @app.route('/api')
    def index():
    
    Mia Wang's avatar
    Mia Wang committed
        return {'status': 'api test'}
    
    
    if __name__ == '__main__':
        app.run(debug=True)