Skip to content
Snippets Groups Projects
ttest_gen.py 1.91 KiB
Newer Older
  • Learn to ignore specific revisions
  • Max Ramsay King's avatar
    Max Ramsay King committed
    # We can initialize the train_dataset with its transform as None.
    # Later on, we will change this object's transform attribute to the policy
    # that we want to test
    import torchvision.datasets as datasets
    import torchvision
    
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    import autoaug.child_networks as cn
    from autoaug.autoaugment_learners.AaLearner import AaLearner
    
    from autoaug.autoaugment_learners.GenLearner import Genetic_learner
    
    Max Ramsay King's avatar
    Max Ramsay King committed
    
    import random
    
    Max Ramsay King's avatar
    Max Ramsay King committed
        
    # train_dataset = datasets.MNIST(root='./datasets/mnist/train',
    #                                 train=True, download=True, transform=None)
    # test_dataset = datasets.MNIST(root='./datasets/mnist/test', 
    #                         train=False, download=True, transform=torchvision.transforms.ToTensor())
    train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train',
                            train=True, download=True, transform=None)
    test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test', 
                            train=False, download=True,
                            transform=torchvision.transforms.ToTensor())
    child_network_architecture = cn.lenet
    # child_network_architecture = cn.lenet()
    
    agent = Genetic_learner(
                                sp_num=2,
                                toy_size=0.01,
                                batch_size=4,
                                learning_rate=0.05,
                                max_epochs=float('inf'),
                                early_stop_num=10,
    
    Max Ramsay King's avatar
    Max Ramsay King committed
                                num_offspring=10
    
    Max Ramsay King's avatar
    Max Ramsay King committed
                                )
    
    
    agent.learn(train_dataset,
                test_dataset,
                child_network_architecture=child_network_architecture,
    
    Max Ramsay King's avatar
    Max Ramsay King committed
                iterations=100)
    
    Max Ramsay King's avatar
    Max Ramsay King committed
    
    
    with open('genetic_logs.pkl', 'wb') as file:
                    pickle.dump(agent.history, file)
    print(sorted(agent.history, key = lambda x: x[1], reverse = True))
    
    print("ACCURACIES IN TIME: ")
    
    for iter, (pol, acc) in enumerate(agent.history):
        print("pol: ", pol)
        print("acc: ", acc)