Newer
Older
from dataclasses import dataclass
from flask import Flask, request, current_app, render_template
import subprocess
import os
import zipfile
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data_utils
import torchvision
import torchvision.datasets as datasets
from matplotlib import pyplot as plt
from numpy import save, load
from tqdm import trange
torch.manual_seed(0)
import os
import sys
sys.path.insert(0, os.path.abspath('..'))
# # import agents and its functions
from MetaAugment.autoaugment_learners import ucb_learner as UCB1_JC
from MetaAugment.autoaugment_learners import evo_learner
import MetaAugment.controller_networks as cn
import MetaAugment.autoaugment_learners as aal
# import agents and its functions
# from ..MetaAugment import UCB1_JC_py as UCB1_JC
# from ..MetaAugment import Evo_learner as Evo
# print('@@@ import successful')
# it is used to collect user input and store them in the app
@app.route('/home', methods=["GET", "POST"])
# form_data = request.files['ds_upload']
# print('@@@ form_data', form_data)
form_data = request.form
print('@@@ this is form data', form_data)
ds = form_data['select_dataset'] # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)
IsLeNet = form_data["select_network"] # using LeNet or EasyNet or SimpleNet ->> default
auto_aug_learner = form_data["select_learner"] # augmentation methods to be excluded
print('@@@ required user input:', 'ds', ds, 'IsLeNet:', IsLeNet, 'auto_aug_leanrer:',auto_aug_learner)
# advanced input
if form_data['batch_size'] != 'undefined':
batch_size = form_data['batch_size'] # size of batch the inner NN is trained with
else:
batch_size = 1 # this is for demonstration purposes
if form_data['learning_rate'] != 'undefined':
learning_rate = form_data['learning_rate'] # fix learning rate
else:
learning_rate = 10-1
if form_data['toy_size'] != 'undefined':
toy_size = form_data['toy_size'] # total propeortion of training and test set we use
else:
toy_size = 1 # this is for demonstration purposes
if form_data['iterations'] != 'undefined':
iterations = form_data['iterations'] # total iterations, should be more than the number of policies
else:
iterations = 10
exclude_method = form_data['select_action']
print('@@@ advanced search: batch_size:', batch_size, 'learning_rate:', learning_rate, 'toy_size:', toy_size, 'iterations:', iterations, 'exclude_method', exclude_method)
# default values
max_epochs = 10 # max number of epochs that is run if early stopping is not hit
early_stop_num = 10 # max number of worse validation scores before early stopping is triggered
num_policies = 5 # fix number of policies
num_sub_policies = 5 # fix number of sub-policies in a policy
# if user upload datasets and networks, save them in the database
if ds == 'Other':
ds_folder = request.files['ds_upload']
print('!!!ds_folder', ds_folder)
ds_name_zip = ds_folder.filename
ds_name = ds_name_zip.split('.')[0]
ds_folder.save('./datasets/'+ ds_name_zip)
with zipfile.ZipFile('./datasets/'+ ds_name_zip, 'r') as zip_ref:
zip_ref.extractall('./datasets/upload_dataset/')
if not current_app.debug:
os.remove(f'./datasets/{ds_name_zip}')
else:
ds_name_zip = None
ds_name = None
# test if uploaded dataset meets the criteria
for (dirpath, dirnames, filenames) in os.walk(f'./datasets/upload_dataset/{ds_name}/'):
for dirname in dirnames:
if dirname[0:6] != 'class_':
return None # neet to change render to a 'failed dataset webpage'
# save the user uploaded network
if IsLeNet == 'Other':
childnetwork = request.files['network_upload']
childnetwork.save('./child_networks/'+childnetwork.filename)
network_name = childnetwork.filename
else:
network_name = None
print("@@@ user input has all stored in the app")
data = {'ds': ds, 'ds_name': ds_name_zip, 'IsLeNet': IsLeNet, 'network_name': network_name,
'auto_aug_learner':auto_aug_learner, 'batch_size': batch_size, 'learning_rate': learning_rate,
'toy_size':toy_size, 'iterations':iterations, 'exclude_method': exclude_method, }
current_app.config['data'] = data
print('@@@ all data sent', current_app.config['data'])
return {'data': 'all stored'}
# ========================================================================
@app.route('/confirm', methods=['POST', 'GET'])
def confirm():
print('inside confirm page')
data = current_app.config['data']
return data
# ========================================================================
@app.route('/training', methods=['POST', 'GET'])
def training():
# search space & problem setting
ds = current_app.config.get('ds')
ds_name = current_app.config.get('DSN')
exclude_method = current_app.config.get('exc_meth')
num_funcs = current_app.config.get('NUMFUN')
num_policies = current_app.config.get('NP')
num_sub_policies = current_app.config.get('NSP')
toy_size = current_app.config.get('TS')
# child network
IsLeNet = current_app.config.get('ISLENET')
# child network training hyperparameters
batch_size = current_app.config.get('BS')
early_stop_num = current_app.config.get('ESN')
iterations = current_app.config.get('IT')
learning_rate = current_app.config.get('LR')
max_epochs = current_app.config.get('ME')
# default values
max_epochs = 10 # max number of epochs that is run if early stopping is not hit
early_stop_num = 10 # max number of worse validation scores before early stopping is triggered
num_policies = 5 # fix number of policies
num_sub_policies = 5 # fix number of sub-policies in a policy
data = current_app.config.get('data')
if data.auto_aug_learner == 'UCB':
policies = UCB1_JC.generate_policies(num_policies, num_sub_policies)
data.batch_size,
data.learning_rate,
data.ds,
data.toy_size,
data.iterations,
data.IsLeNet,
data.ds_name
elif data.auto_aug_learner == 'Evolutionary Learner':
network = cn.evo_controller.evo_controller(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1)
child_network = aal.evo.LeNet()
learner = aal.evo.evo_learner(
network=network,
fun_num=num_funcs,
p_bins=1,
mag_bins=1,
sub_num_pol=1,
ds = ds,
ds_name=ds_name,
exclude_method=exclude_method,
child_network=child_network
elif data.auto_aug_learner == 'Random Searcher':
elif data.auto_aug_learner == 'Genetic Learner':
return {'status': 'training done!'}
# ========================================================================
@app.route('/results')
def show_result():
return {'status': 'results'}
@app.route('/api')
def index():