Skip to content
Snippets Groups Projects
test_aa_learner.py 5.8 KiB
Newer Older
  • Learn to ignore specific revisions
  • Sun Jin Kim's avatar
    Sun Jin Kim committed
    import MetaAugment.autoaugment_learners as aal
    import MetaAugment.child_networks as cn
    import torch
    import torchvision
    import torchvision.datasets as datasets
    
    import random
    
    
    
    def test__translate_operation_tensor():
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
        """
    
        See if aa_learner class's _translate_operation_tensor works
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
        by feeding many (valid) inputs in it.
    
        We make a lot of (fun_num+p_bins_m_bins,) size tensors, softmax 
    
        them, and feed them through the _translate_operation_tensor method
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
        to see if it doesn't break
        """
    
     
        
        # discrete_p_m=True
    
    
        for i in range(2000):
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    
            softmax = torch.nn.Softmax(dim=0)
    
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
            p_bins = random.randint(2, 15)
            m_bins = random.randint(2, 15)
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
            agent = aal.aa_learner(
                    sp_num=5,
                    p_bins=p_bins,
                    m_bins=m_bins,
                    discrete_p_m=True
                    )
    
            alpha = i/1000
            vector = torch.rand(fun_num+p_bins+m_bins)
            fun_t, prob_t, mag_t = vector.split([fun_num, p_bins, m_bins])
            fun_t = softmax(fun_t * alpha)
            prob_t = softmax(prob_t * alpha)
            mag_t = softmax(mag_t * alpha)
            softmaxed_vector = torch.cat((fun_t, prob_t, mag_t))
    
    
            agent._translate_operation_tensor(softmaxed_vector)
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
        
    
        # discrete_p_m=False
        softmax = torch.nn.Softmax(dim=0)
        sigmoid = torch.nn.Sigmoid()
    
        for i in range(2000):
    
            fun_num = 14
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
            p_bins = random.randint(1, 15)
            m_bins = random.randint(1, 15)
    
            agent = aal.aa_learner(
                    sp_num=5,
                    p_bins=p_bins,
                    m_bins=m_bins,
                    discrete_p_m=False
                    )
    
            alpha = i/1000
            vector = torch.rand(fun_num+2)
            fun_t, prob_t, mag_t = vector.split([fun_num, 1, 1])
            fun_t = softmax(fun_t * alpha)
            prob_t = sigmoid(prob_t)
            mag_t = sigmoid(mag_t) * (m_bins-1)
    
            softmaxed_vector = torch.cat((fun_t, prob_t, mag_t))
    
    
            agent._translate_operation_tensor(softmaxed_vector)
    
    def test__test_autoaugment_policy():
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
        agent = aal.aa_learner(
                    sp_num=5,
                    p_bins=11,
                    m_bins=10,
                    discrete_p_m=True,
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                    toy_size=0.002,
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                    max_epochs=20,
                    early_stop_num=10
                    )
        
    
        policy = [
                (("Invert", 0.8, None), ("Contrast", 0.2, 6)),
                (("Rotate", 0.7, 2), ("Invert", 0.8, None)),
                (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
                (("ShearY", 0.5, 8), ("Invert", 0.7, None)),
                (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
                (("ShearY", 0.8, 4), ("Rotate", 0.5, 6)),
                (("TranslateY", 0.7, 4), ("TranslateX", 0.8, 6)),
                (("Rotate", 0.5, 3), ("ShearY", 0.8, 5)),
                (("ShearX", 0.5, 6), ("TranslateY", 0.7, 3)),
                (("Rotate", 0.5, 3), ("TranslateX", 0.5, 5))
                ]
        child_network_architecture = cn.SimpleNet
        train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train',
                                train=True, download=True, transform=None)
        test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test', 
                                train=False, download=True,
                                transform=torchvision.transforms.ToTensor())
    
    
        acc = agent._test_autoaugment_policy(
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                                            policy,
                                            child_network_architecture,
                                            train_dataset,
                                            test_dataset,
                                            logging=False
                                            )
        
        assert isinstance(acc, float)
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    
    
    def test_exclude_method():
        """
        we want to see if the exclude_methods
        parameter is working properly in aa_learners 
        """
        
        exclude_method = [
                        'ShearX', 
                        'Color', 
                        'Brightness', 
                        'Contrast'
                        ]
        agent = aal.gru_learner(
            exclude_method=exclude_method
        )
        for _ in range(200):
    
            new_pol, _ = agent._generate_new_policy()
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
            print(new_pol)
            for (op1, op2) in new_pol:
                image_function_1 = op1[0]
                image_function_2 = op2[0]
                assert image_function_1 not in exclude_method
                assert image_function_2 not in exclude_method
        
        agent = aal.randomsearch_learner(
            exclude_method=exclude_method
        )
        for _ in range(200):
    
            new_pol= agent._generate_new_policy()
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
            print(new_pol)
            for (op1, op2) in new_pol:
                image_function_1 = op1[0]
                image_function_2 = op2[0]
                assert image_function_1 not in exclude_method
                assert image_function_2 not in exclude_method
    
    Max Ramsay King's avatar
    Max Ramsay King committed
        
    
    def test_get_mega_policy():
    
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
        agent = aal.randomsearch_learner(
                    sp_num=5,
                    p_bins=11,
                    m_bins=10,
                    discrete_p_m=True,
                    toy_size=0.002,
                    max_epochs=20,
                    early_stop_num=10
                    )
    
    Max Ramsay King's avatar
    Max Ramsay King committed
    
        child_network_architecture = cn.SimpleNet
        train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train',
                                train=True, download=True, transform=None)
        test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test', 
                                train=False, download=True,
                                transform=torchvision.transforms.ToTensor())
    
        agent.learn(train_dataset, test_dataset, child_network_architecture, 10)
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
        mega_pol = agent.get_mega_policy(number_policies=30)
        mega_pol = agent.get_mega_policy(number_policies=3)
        mega_pol = agent.get_mega_policy(number_policies=1)
    
    Max Ramsay King's avatar
    Max Ramsay King committed
        print("megapol: ", mega_pol)
    
    
    if __name__=='__main__':
        test_get_mega_policy()