Skip to content
Snippets Groups Projects
wapp_util.py 6.83 KiB
Newer Older
  • Learn to ignore specific revisions
  • """
    CONTAINS THE FUNTIONS THAT THE WEBAPP CAN USE TO INTERACT WITH
    THE LIBRARY
    """
    
    import numpy as np
    import torch
    import torchvision
    import torchvision.datasets as datasets
    
    # # import agents and its functions
    import MetaAugment.autoaugment_learners as aal
    import MetaAugment.controller_networks as cont_n
    import MetaAugment.child_networks as cn
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    from MetaAugment.main import create_toy
    
    import pickle
    
    
    def parse_users_learner_spec(
    
                # aalearner type
    
                auto_aug_learner, 
    
                # search space settings
    
                ds, 
                ds_name, 
                exclude_method, 
                num_funcs, 
                num_policies, 
                num_sub_policies, 
    
                # child network settings
    
                toy_size, 
                IsLeNet, 
                batch_size, 
                early_stop_num, 
                iterations, 
                learning_rate, 
                max_epochs
                ):
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
        """
        The website receives user inputs on what they want the aa_learner
        to be. We take those hyperparameters and return an aa_learner
    
        """
    
        if auto_aug_learner == 'UCB':
            policies = aal.ucb_learner.generate_policies(num_policies, num_sub_policies)
            q_values, best_q_values = aal.ucb_learner.run_UCB1(
                                                    policies,
                                                    batch_size, 
                                                    learning_rate, 
                                                    ds, 
                                                    toy_size, 
                                                    max_epochs, 
                                                    early_stop_num, 
                                                    iterations, 
                                                    IsLeNet, 
                                                    ds_name
                                                    )     
            best_q_values = np.array(best_q_values)
        elif auto_aug_learner == 'Evolutionary Learner':
            network = cont_n.evo_controller(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1)
            child_network = cn.LeNet()
            learner = aal.evo_learner(
                                    network=network, 
                                    fun_num=num_funcs, 
                                    p_bins=1, 
                                    mag_bins=1, 
                                    sub_num_pol=1, 
                                    ds = ds, 
                                    ds_name=ds_name, 
                                    exclude_method=exclude_method, 
                                    child_network=child_network
                                    )
            learner.run_instance()
        elif auto_aug_learner == 'Random Searcher':
                # As opposed to when ucb==True, `ds` and `IsLenet` are processed outside of the agent
                # This system makes more sense for the user who is not using the webapp and is instead
                # using the library within their code
            download = True
            if ds == "MNIST":
                train_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/train', train=True, download=download)
                test_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/test', train=False,
                                                    download=download, transform=torchvision.transforms.ToTensor())
            elif ds == "KMNIST":
                train_dataset = datasets.KMNIST(root='./MetaAugment/datasets/kmnist/train', train=True, download=download)
                test_dataset = datasets.KMNIST(root='./MetaAugment/datasets/kmnist/test', train=False,
                                                    download=download, transform=torchvision.transforms.ToTensor())
            elif ds == "FashionMNIST":
                train_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=download)
                test_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/test', train=False,
                                                    download=download, transform=torchvision.transforms.ToTensor())
            elif ds == "CIFAR10":
                train_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/train', train=True, download=download)
                test_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/test', train=False,
                                                    download=download, transform=torchvision.transforms.ToTensor())
            elif ds == "CIFAR100":
                train_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/train', train=True, download=download)
                test_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/test', train=False,
                                                    download=download, transform=torchvision.transforms.ToTensor())
            elif ds == 'Other':
                dataset = datasets.ImageFolder('./MetaAugment/datasets/'+ ds_name)
                len_train = int(0.8*len(dataset))
                train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len_train, len(dataset)-len_train])
    
                # check sizes of images
            img_height = len(train_dataset[0][0][0])
            img_width = len(train_dataset[0][0][0][0])
            img_channels = len(train_dataset[0][0])
                # check output labels
            if ds == 'Other':
                num_labels = len(dataset.class_to_idx)
            elif ds == "CIFAR10" or ds == "CIFAR100":
                num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1)
            else:
                num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1).item()
                # create toy dataset from above uploaded data
            train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size)
                # create model
            if IsLeNet == "LeNet":
                model = cn.LeNet(img_height, img_width, num_labels, img_channels)
            elif IsLeNet == "EasyNet":
                model = cn.EasyNet(img_height, img_width, num_labels, img_channels)
            elif IsLeNet == 'SimpleNet':
                model = cn.SimpleNet(img_height, img_width, num_labels, img_channels)
            else:
                model = pickle.load(open(f'datasets/childnetwork', "rb"))
    
                # use an aa_learner. in this case, a rs learner
            agent = aal.randomsearch_learner(batch_size=batch_size,
                                                toy_flag=True,
                                                learning_rate=learning_rate,
                                                toy_size=toy_size,
                                                max_epochs=max_epochs,
                                                early_stop_num=early_stop_num,
                                                )
            agent.learn(train_dataset,
                            test_dataset,
                            child_network_architecture=model,
                            iterations=iterations)
        elif auto_aug_learner == 'Genetic Learner':
            pass