Skip to content
Snippets Groups Projects
test_ucb_learner.py 1.92 KiB
Newer Older
  • Learn to ignore specific revisions
  • import MetaAugment.autoaugment_learners as aal
    
    import MetaAugment.child_networks as cn
    import torchvision
    import torchvision.datasets as datasets
    from pprint import pprint
    
        child_network_architecture = cn.SimpleNet
        train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train',
                                train=True, download=True, transform=None)
        test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test', 
                                train=False, download=True,
                                transform=torchvision.transforms.ToTensor())
    
    
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
        learner = aal.ucb_learner(
            # parameters that define the search space
                    sp_num=5,
                    p_bins=11,
                    m_bins=10,
                    discrete_p_m=True,
                    # hyperparameters for when training the child_network
                    batch_size=8,
    
                    toy_flag=True,
                    toy_size=0.001,
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                    learning_rate=1e-1,
                    max_epochs=float('inf'),
                    early_stop_num=30,
                    # ucb_learner specific hyperparameter
    
                    num_policies=3
    
        pprint(learner.policies)
        assert len(learner.policies)==len(learner.avg_accs), \
                    (len(learner.policies), (len(learner.avg_accs)))
    
        # learn on the 3 policies we generated
        learner.learn(
            train_dataset=train_dataset,
            test_dataset=test_dataset,
            child_network_architecture=child_network_architecture,
            iterations=5
            )
        
        # let's say we want to explore more policies:
        # we generate more new policies
        learner.make_more_policies(n=4)
    
        # and let's explore how good those are as well
        learner.learn(
            train_dataset=train_dataset,
            test_dataset=test_dataset,
            child_network_architecture=child_network_architecture,
            iterations=7
            )
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    
    if __name__=="__main__":
    
        test_ucb_learner()