Skip to content
Snippets Groups Projects
wapp_util.py 6.83 KiB
"""
CONTAINS THE FUNTIONS THAT THE WEBAPP CAN USE TO INTERACT WITH
THE LIBRARY
"""

import numpy as np
import torch
import torchvision
import torchvision.datasets as datasets

# # import agents and its functions
import MetaAugment.autoaugment_learners as aal
import MetaAugment.controller_networks as cont_n
import MetaAugment.child_networks as cn
from MetaAugment.main import create_toy

import pickle

def parse_users_learner_spec(
            # aalearner type
            auto_aug_learner, 
            # search space settings
            ds, 
            ds_name, 
            exclude_method, 
            num_funcs, 
            num_policies, 
            num_sub_policies, 
            # child network settings
            toy_size, 
            IsLeNet, 
            batch_size, 
            early_stop_num, 
            iterations, 
            learning_rate, 
            max_epochs
            ):
    """
    The website receives user inputs on what they want the aa_learner
    to be. We take those hyperparameters and return an aa_learner

    """
    if auto_aug_learner == 'UCB':
        policies = aal.ucb_learner.generate_policies(num_policies, num_sub_policies)
        q_values, best_q_values = aal.ucb_learner.run_UCB1(
                                                policies,
                                                batch_size, 
                                                learning_rate, 
                                                ds, 
                                                toy_size, 
                                                max_epochs, 
                                                early_stop_num, 
                                                iterations, 
                                                IsLeNet, 
                                                ds_name
                                                )     
        best_q_values = np.array(best_q_values)
    elif auto_aug_learner == 'Evolutionary Learner':
        network = cont_n.evo_controller(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1)
        child_network = cn.LeNet()
        learner = aal.evo_learner(
                                network=network, 
                                fun_num=num_funcs, 
                                p_bins=1, 
                                mag_bins=1, 
                                sub_num_pol=1, 
                                ds = ds, 
                                ds_name=ds_name, 
                                exclude_method=exclude_method, 
                                child_network=child_network
                                )
        learner.run_instance()
    elif auto_aug_learner == 'Random Searcher':
            # As opposed to when ucb==True, `ds` and `IsLenet` are processed outside of the agent
            # This system makes more sense for the user who is not using the webapp and is instead
            # using the library within their code
        download = True
        if ds == "MNIST":
            train_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/train', train=True, download=download)
            test_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/test', train=False,
                                                download=download, transform=torchvision.transforms.ToTensor())
        elif ds == "KMNIST":
            train_dataset = datasets.KMNIST(root='./MetaAugment/datasets/kmnist/train', train=True, download=download)
            test_dataset = datasets.KMNIST(root='./MetaAugment/datasets/kmnist/test', train=False,
                                                download=download, transform=torchvision.transforms.ToTensor())
        elif ds == "FashionMNIST":
            train_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=download)
            test_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/test', train=False,
                                                download=download, transform=torchvision.transforms.ToTensor())
        elif ds == "CIFAR10":
            train_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/train', train=True, download=download)
            test_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/test', train=False,
                                                download=download, transform=torchvision.transforms.ToTensor())
        elif ds == "CIFAR100":
            train_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/train', train=True, download=download)
            test_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/test', train=False,
                                                download=download, transform=torchvision.transforms.ToTensor())
        elif ds == 'Other':
            dataset = datasets.ImageFolder('./MetaAugment/datasets/'+ ds_name)
            len_train = int(0.8*len(dataset))
            train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len_train, len(dataset)-len_train])

            # check sizes of images
        img_height = len(train_dataset[0][0][0])
        img_width = len(train_dataset[0][0][0][0])
        img_channels = len(train_dataset[0][0])
            # check output labels
        if ds == 'Other':
            num_labels = len(dataset.class_to_idx)
        elif ds == "CIFAR10" or ds == "CIFAR100":
            num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1)
        else:
            num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1).item()
            # create toy dataset from above uploaded data
        train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size)
            # create model
        if IsLeNet == "LeNet":
            model = cn.LeNet(img_height, img_width, num_labels, img_channels)
        elif IsLeNet == "EasyNet":
            model = cn.EasyNet(img_height, img_width, num_labels, img_channels)
        elif IsLeNet == 'SimpleNet':
            model = cn.SimpleNet(img_height, img_width, num_labels, img_channels)
        else:
            model = pickle.load(open(f'datasets/childnetwork', "rb"))

            # use an aa_learner. in this case, a rs learner
        agent = aal.randomsearch_learner(batch_size=batch_size,
                                            toy_flag=True,
                                            learning_rate=learning_rate,
                                            toy_size=toy_size,
                                            max_epochs=max_epochs,
                                            early_stop_num=early_stop_num,
                                            )
        agent.learn(train_dataset,
                        test_dataset,
                        child_network_architecture=model,
                        iterations=iterations)
    elif auto_aug_learner == 'Genetic Learner':
        pass