from flask import Blueprint, request, render_template, flash, send_file, current_app, g, session import subprocess import os import zipfile import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torch.utils.data as data_utils import torchvision import torchvision.datasets as datasets from matplotlib import pyplot as plt from numpy import save, load from tqdm import trange torch.manual_seed(0) # import agents and its functions from MetaAugment.autoaugment_learners import ucb_learner # hi from MetaAugment import Evo_learner as Evo import MetaAugment.autoaugment_learners as aal from MetaAugment.main import create_toy import MetaAugment.child_networks as cn import pickle bp = Blueprint("progress", __name__) @bp.route("/user_input", methods=["GET", "POST"]) def response(): # hyperparameters to change if request.method == 'POST': # generate random policies at start auto_aug_learner = request.form.get("auto_aug_selection") # search space & problem setting ds = request.form.get("dataset_selection") # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100) ds_up = request.files['dataset_upload'] exclude_method = request.form.getlist("action_space") num_funcs = 14 - len(exclude_method) num_policies = 5 # fix number of policies num_sub_policies = 5 # fix number of sub-policies in a policy toy_size = 1 # total propeortion of training and test set we use # child network IsLeNet = request.form.get("network_selection") # using LeNet or EasyNet or SimpleNet ->> default nw_up = childnetwork = request.files['network_upload'] # child network training hyperparameters batch_size = 1 # size of batch the inner NN is trained with early_stop_num = 10 # max number of worse validation scores before early stopping is triggered iterations = 5 # total iterations, should be more than the number of policies learning_rate = 1e-1 # fix learning rate max_epochs = 10 # max number of epochs that is run if early stopping is not hit # if user upload datasets and networks, save them in the database if ds == None and ds_up != None: ds = 'Other' ds_folder = request.files['dataset_upload'] ds_name_zip = ds_folder.filename ds_name = ds_name_zip.split('.')[0] ds_folder.save('./MetaAugment/datasets/'+ ds_name_zip) with zipfile.ZipFile('./MetaAugment/datasets/'+ ds_name_zip, 'r') as zip_ref: zip_ref.extractall('./MetaAugment/datasets/upload_dataset/') if not current_app.debug: os.remove(f'./MetaAugment/datasets/{ds_name_zip}') else: ds_name = None for (dirpath, dirnames, filenames) in os.walk(f'./MetaAugment/datasets/upload_dataset/{ds_name}/'): for dirname in dirnames: if dirname[0:6] != 'class_': return render_template("fail_dataset.html") else: pass if IsLeNet == None and nw_up != None: childnetwork = request.files['network_upload'] childnetwork.save('./MetaAugment/child_networks/'+childnetwork.filename) if auto_aug_learner == 'UCB': policies = ucb_learner.generate_policies(num_policies, num_sub_policies) q_values, best_q_values = ucb_learner.run_UCB1( policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet, ds_name ) best_q_values = np.array(best_q_values) elif auto_aug_learner == 'Evolutionary Learner': learner = Evo.Evolutionary_learner( fun_num=num_funcs, p_bins=1, mag_bins=1, sub_num_pol=1, ds_name=ds_name, exclude_method=exclude_method ) learner.run_instance() elif auto_aug_learner == 'Random Searcher': # As opposed to when ucb==True, `ds` and `IsLenet` are processed outside of the agent # This system makes more sense for the user who is not using the webapp and is instead # using the library within their code download = True if ds == "MNIST": train_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/train', train=True, download=download) test_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/test', train=False, download=download, transform=torchvision.transforms.ToTensor()) elif ds == "KMNIST": train_dataset = datasets.KMNIST(root='./MetaAugment/datasets/kmnist/train', train=True, download=download) test_dataset = datasets.KMNIST(root='./MetaAugment/datasets/kmnist/test', train=False, download=download, transform=torchvision.transforms.ToTensor()) elif ds == "FashionMNIST": train_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=download) test_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/test', train=False, download=download, transform=torchvision.transforms.ToTensor()) elif ds == "CIFAR10": train_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/train', train=True, download=download) test_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/test', train=False, download=download, transform=torchvision.transforms.ToTensor()) elif ds == "CIFAR100": train_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/train', train=True, download=download) test_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/test', train=False, download=download, transform=torchvision.transforms.ToTensor()) elif ds == 'Other': dataset = datasets.ImageFolder('./MetaAugment/datasets/'+ ds_name) len_train = int(0.8*len(dataset)) train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len_train, len(dataset)-len_train]) # check sizes of images img_height = len(train_dataset[0][0][0]) img_width = len(train_dataset[0][0][0][0]) img_channels = len(train_dataset[0][0]) # check output labels if ds == 'Other': num_labels = len(dataset.class_to_idx) elif ds == "CIFAR10" or ds == "CIFAR100": num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1) else: num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1).item() # create toy dataset from above uploaded data train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size) # create model if IsLeNet == "LeNet": model = cn.LeNet(img_height, img_width, num_labels, img_channels) elif IsLeNet == "EasyNet": model = cn.EasyNet(img_height, img_width, num_labels, img_channels) elif IsLeNet == 'SimpleNet': model = cn.SimpleNet(img_height, img_width, num_labels, img_channels) else: model = pickle.load(open(f'datasets/childnetwork', "rb")) # use an aa_learner. in this case, a rs learner agent = aal.randomsearch_learner(batch_size=batch_size, toy_flag=True, learning_rate=learning_rate, toy_size=toy_size, max_epochs=max_epochs, early_stop_num=early_stop_num, ) agent.learn(train_dataset, test_dataset, child_network_architecture=model, iterations=iterations) elif auto_aug_learner == 'Genetic Learner': pass plt.figure() plt.plot(q_values) # if auto_aug_learner == 'UCB': # policies = ucb_learner.generate_policies(num_policies, num_sub_policies) # q_values, best_q_values = ucb_learner.run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet, ds_name) # # plt.figure() # # plt.plot(q_values) # best_q_values = np.array(best_q_values) # elif auto_aug_learner == 'Evolutionary Learner': # network = Evo.Learner(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1) # child_network = Evo.LeNet() # learner = Evo.Evolutionary_learner(network=network, fun_num=num_funcs, p_bins=1, mag_bins=1, sub_num_pol=1, ds = ds, ds_name=ds_name, exclude_method=exclude_method, child_network=child_network) # learner.run_instance() # elif auto_aug_learner == 'Random Searcher': # pass # elif auto_aug_learner == 'Genetic Learner': # pass current_app.config['AAL'] = auto_aug_learner current_app.config['NP'] = num_policies current_app.config['NSP'] = num_sub_policies current_app.config['BS'] = batch_size current_app.config['LR'] = learning_rate current_app.config['TS'] = toy_size current_app.config['ME'] = max_epochs current_app.config['ESN'] = early_stop_num current_app.config['IT'] = iterations current_app.config['ISLENET'] = IsLeNet current_app.config['DSN'] = ds_name current_app.config['NUMFUN'] = num_funcs current_app.config['ds'] = ds current_app.config['exc_meth'] = exclude_method # return render_template("progress.html", exclude_method = exclude_method, auto_aug_learner=auto_aug_learner) return render_template("training.html", exclude_method = exclude_method, auto_aug_learner=auto_aug_learner)