Skip to content
Snippets Groups Projects
progress.py 11.1 KiB
Newer Older
  • Learn to ignore specific revisions
  • from flask import Blueprint, request, render_template, flash, send_file, current_app, g, session
    
    Mia Wang's avatar
    Mia Wang committed
    import subprocess
    
    import zipfile
    
    Mia Wang's avatar
    Mia Wang committed
    
    
    import numpy as np
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim
    import torch.utils.data as data_utils
    import torchvision
    import torchvision.datasets as datasets
    
    from matplotlib import pyplot as plt
    from numpy import save, load
    from tqdm import trange
    torch.manual_seed(0)
    # import agents and its functions
    
    from MetaAugment.autoaugment_learners import ucb_learner
    
    from MetaAugment import Evo_learner as Evo
    
    import MetaAugment.autoaugment_learners as aal
    from MetaAugment.main import create_toy
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    import MetaAugment.child_networks as cn
    
    Mia Wang's avatar
    Mia Wang committed
    bp = Blueprint("progress", __name__)
    
    
    Mia Wang's avatar
    Mia Wang committed
    @bp.route("/user_input", methods=["GET", "POST"])
    def response():
    
    
        # hyperparameters to change
    
        if request.method == 'POST':
    
            # generate random policies at start
            auto_aug_learner = request.form.get("auto_aug_selection")
            
            # search space & problem setting
            ds = request.form.get("dataset_selection")      # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)
            ds_up = request.files['dataset_upload']
    
            exclude_method = request.form.getlist("action_space")
    
            num_funcs = 14 - len(exclude_method)
    
            num_policies = 5      # fix number of policies
            num_sub_policies = 5  # fix number of sub-policies in a policy
            toy_size = 1      # total propeortion of training and test set we use
    
            # child network
            IsLeNet = request.form.get("network_selection")   # using LeNet or EasyNet or SimpleNet ->> default 
    
            nw_up = childnetwork = request.files['network_upload']
    
    
            # child network training hyperparameters
            batch_size = 1       # size of batch the inner NN is trained with
    
            early_stop_num = 10   # max number of worse validation scores before early stopping is triggered
    
            iterations = 5      # total iterations, should be more than the number of policies
    
            learning_rate = 1e-1  # fix learning rate
            max_epochs = 10      # max number of epochs that is run if early stopping is not hit
    
            # if user upload datasets and networks, save them in the database
    
    
            if ds == None and ds_up != None:
                ds = 'Other'
    
                ds_folder = request.files['dataset_upload']
                ds_name_zip = ds_folder.filename
    
                ds_folder.save('./MetaAugment/datasets/'+ ds_name_zip)
                with zipfile.ZipFile('./MetaAugment/datasets/'+ ds_name_zip, 'r') as zip_ref:
    
                    zip_ref.extractall('./MetaAugment/datasets/upload_dataset/')
    
                if not current_app.debug:
                    os.remove(f'./MetaAugment/datasets/{ds_name_zip}')
    
    
            for (dirpath, dirnames, filenames) in os.walk(f'./MetaAugment/datasets/upload_dataset/{ds_name}/'):
                for dirname in dirnames:
                    if dirname[0:6] != 'class_':
                        return render_template("fail_dataset.html")
                    else:
                        pass
    
    
    
            if IsLeNet == None and nw_up != None:
    
                childnetwork = request.files['network_upload']
                childnetwork.save('./MetaAugment/child_networks/'+childnetwork.filename)
            
    
    
            if auto_aug_learner == 'UCB':
                policies = ucb_learner.generate_policies(num_policies, num_sub_policies)
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                q_values, best_q_values = ucb_learner.run_UCB1(
                                                            policies, 
                                                            batch_size, 
                                                            learning_rate, 
                                                            ds, 
                                                            toy_size, 
                                                            max_epochs, 
                                                            early_stop_num, 
                                                            iterations, 
                                                            IsLeNet, 
                                                            ds_name
                                                            )    
                best_q_values = np.array(best_q_values)
    
            elif auto_aug_learner == 'Evolutionary Learner':
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                learner = Evo.Evolutionary_learner(
                                                fun_num=num_funcs, 
                                                p_bins=1, 
                                                mag_bins=1, 
                                                sub_num_pol=1, 
                                                ds_name=ds_name, 
                                                exclude_method=exclude_method
                                                )
    
                learner.run_instance()
    
            elif auto_aug_learner == 'Random Searcher':
    
                # As opposed to when ucb==True, `ds` and `IsLenet` are processed outside of the agent
                # This system makes more sense for the user who is not using the webapp and is instead
                # using the library within their code
                download = True
                if ds == "MNIST":
                    train_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/train', train=True, download=download)
                    test_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/test', train=False,
                                                    download=download, transform=torchvision.transforms.ToTensor())
                elif ds == "KMNIST":
                    train_dataset = datasets.KMNIST(root='./MetaAugment/datasets/kmnist/train', train=True, download=download)
                    test_dataset = datasets.KMNIST(root='./MetaAugment/datasets/kmnist/test', train=False,
                                                    download=download, transform=torchvision.transforms.ToTensor())
                elif ds == "FashionMNIST":
                    train_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=download)
                    test_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/test', train=False,
                                                    download=download, transform=torchvision.transforms.ToTensor())
                elif ds == "CIFAR10":
    
                    train_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/train', train=True, download=download)
                    test_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/test', train=False,
    
                                                    download=download, transform=torchvision.transforms.ToTensor())
                elif ds == "CIFAR100":
    
                    train_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/train', train=True, download=download)
                    test_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/test', train=False,
    
                                                    download=download, transform=torchvision.transforms.ToTensor())
                elif ds == 'Other':
                    dataset = datasets.ImageFolder('./MetaAugment/datasets/'+ ds_name)
                    len_train = int(0.8*len(dataset))
                    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len_train, len(dataset)-len_train])
    
                # check sizes of images
                img_height = len(train_dataset[0][0][0])
                img_width = len(train_dataset[0][0][0][0])
                img_channels = len(train_dataset[0][0])
                # check output labels
                if ds == 'Other':
                    num_labels = len(dataset.class_to_idx)
                elif ds == "CIFAR10" or ds == "CIFAR100":
                    num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1)
                else:
                    num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1).item()
                # create toy dataset from above uploaded data
                train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size)
                # create model
                if IsLeNet == "LeNet":
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                    model = cn.LeNet(img_height, img_width, num_labels, img_channels)
    
                elif IsLeNet == "EasyNet":
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                    model = cn.EasyNet(img_height, img_width, num_labels, img_channels)
    
                elif IsLeNet == 'SimpleNet':
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                    model = cn.SimpleNet(img_height, img_width, num_labels, img_channels)
    
                else:
                    model = pickle.load(open(f'datasets/childnetwork', "rb"))
    
                # use an aa_learner. in this case, a rs learner
                agent = aal.randomsearch_learner(batch_size=batch_size,
                                                toy_flag=True,
                                                learning_rate=learning_rate,
                                                toy_size=toy_size,
                                                max_epochs=max_epochs,
                                                early_stop_num=early_stop_num,
                                                )
                agent.learn(train_dataset,
                            test_dataset,
                            child_network_architecture=model,
                            iterations=iterations)
    
            elif auto_aug_learner == 'Genetic Learner':
    
    
            plt.figure()
            plt.plot(q_values)
    
            # if auto_aug_learner == 'UCB':
    
            #     policies = ucb_learner.generate_policies(num_policies, num_sub_policies)
            #     q_values, best_q_values = ucb_learner.run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet, ds_name)     
    
            #     # plt.figure()
            #     # plt.plot(q_values)
            #     best_q_values = np.array(best_q_values)
    
            # elif auto_aug_learner == 'Evolutionary Learner':
            #     network = Evo.Learner(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1)
            #     child_network = Evo.LeNet()
            #     learner = Evo.Evolutionary_learner(network=network, fun_num=num_funcs, p_bins=1, mag_bins=1, sub_num_pol=1, ds = ds, ds_name=ds_name, exclude_method=exclude_method, child_network=child_network)
            #     learner.run_instance()
            # elif auto_aug_learner == 'Random Searcher':
            #     pass 
            # elif auto_aug_learner == 'Genetic Learner':
            #     pass
    
        current_app.config['AAL'] = auto_aug_learner
        current_app.config['NP'] = num_policies
        current_app.config['NSP'] = num_sub_policies
        current_app.config['BS'] = batch_size
        current_app.config['LR'] = learning_rate
        current_app.config['TS'] = toy_size
        current_app.config['ME'] = max_epochs
        current_app.config['ESN'] = early_stop_num
        current_app.config['IT'] = iterations
        current_app.config['ISLENET'] = IsLeNet
        current_app.config['DSN'] = ds_name
        current_app.config['NUMFUN'] = num_funcs
        current_app.config['ds'] = ds
        current_app.config['exc_meth'] = exclude_method
    
        # return render_template("progress.html", exclude_method = exclude_method, auto_aug_learner=auto_aug_learner)
        return render_template("training.html", exclude_method = exclude_method, auto_aug_learner=auto_aug_learner)