Skip to content
Snippets Groups Projects
test_GruLearner.py 1.73 KiB
Newer Older
  • Learn to ignore specific revisions
  • import autoaug.autoaugment_learners as aal
    import autoaug.child_networks as cn
    
    import torch
    import torchvision
    import torchvision.datasets as datasets
    
    import random
    
    
    def test__generate_new_policy():
    
        make sure GruLearner._generate_new_policy() is robust
    
        with respect to different values of sp_num, fun_num, 
        p_bins, and m_bins
        """
        for _ in range(40):
            sp_num = random.randint(1,20)
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
            p_bins = random.randint(2, 15)
            m_bins = random.randint(2, 15)
    
            agent = aal.GruLearner(
    
                sp_num=sp_num,
                p_bins=p_bins,
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                m_bins=m_bins,
                cont_mb_size=2
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
            for _ in range(4):
    
                new_policy = agent._generate_new_policy()
    
                assert isinstance(new_policy[0], list), new_policy
    
    
    def test_learn():
        """
    
        tests the GruLearner.learn() method
    
        """
        train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train',
                                train=True, download=True, transform=None)
        test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test', 
                                train=False, download=True,
                                transform=torchvision.transforms.ToTensor())
        child_network_architecture = cn.lenet
        # child_network_architecture = cn.lenet()
    
    
        agent = aal.GruLearner(
    
                            sp_num=7,
                            toy_size=0.001,
                            batch_size=32,
                            learning_rate=0.05,
                            max_epochs=100,
                            early_stop_num=10,
                            )
        agent.learn(train_dataset,
                    test_dataset,
                    child_network_architecture=child_network_architecture,
                    iterations=2)