Skip to content
Snippets Groups Projects
04_22_evo.py 5.63 KiB
Newer Older
  • Learn to ignore specific revisions
  • import torchvision.datasets as datasets
    import torchvision
    import torch
    
    import autoaug.child_networks as cn
    import autoaug.autoaugment_learners as aal
    
    
    controller = cn.EasyNet(img_height=32, img_width=32, num_labels=16*10, img_channels=3)
    
    config = {
            'sp_num' : 5,
            'learning_rate' : 1e-1,
            'batch_size' : 32,
            'max_epochs' : 100,
            'early_stop_num' : 10,
            'controller' : controller,
            'num_solutions' : 10,
            }
    
    
    
    
    import torch
    
    import autoaug.autoaugment_learners as aal
    
    import pprint
    
    """
    testing GruLearner and RsLearner on
    
      fashionmnist with simple net
    
     and 
    
      cifar10 with lenet
    
    """
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    
    
    
    
    
    
    def run_benchmark(
        save_file,
        train_dataset,
        test_dataset,
        child_network_architecture,
        agent_arch,
        config,
        total_iter=150,
        ):
        try:
            # try to load agent
            with open(save_file, 'rb') as f:
                agent = torch.load(f, map_location=device)
        except FileNotFoundError:
            # if agent hasn't been saved yet, initialize the agent
            agent = agent_arch(**config)
    
    
        # if history is not length total_iter yet(if total_iter
        # different policies haven't been tested yet), keep running
        
        print("agent history: ", agent.history)
        while len(agent.history)<total_iter:
            print(f'{len(agent.history)} / {total_iter}')
            # run 1 iteration (test one new policy and update the GRU)
            agent.learn(
                        train_dataset=train_dataset,
                        test_dataset=test_dataset,
                        child_network_architecture=child_network_architecture,
    
    Max Ramsay King's avatar
    Max Ramsay King committed
                        iterations=5
    
                        )
            # save agent every iteration
            with open(save_file, 'wb+') as f:
                torch.save(agent, f)
    
        print('run_benchmark closing')
    
    
    def get_mega_policy(history, n):
            """
            we get the best n policies from an agent's history,
            concatenate them to form our best mega policy
    
            Args:
                history (list[tuple])
                n (int)
            
            Returns:
                list[float]: validation accuracies
            """
            assert len(history) >= n
    
            # agent.history is a list of (policy(list), val_accuracy(float)) tuples 
            sorted_history = sorted(history, key=lambda x:x[1], reverse=True) # sort wrt acc
    
            best_history = sorted_history[:n]
    
            megapolicy = []
            # we also want to keep track of how good the best policies were
            # maybe if we add them all up, they'll become worse! Hopefully better tho
            orig_accs = []
    
            for policy,acc in best_history:
                for subpolicy in policy:
                    megapolicy.append(subpolicy)
                orig_accs.append(acc)
            
            return megapolicy, orig_accs
    
    
    def rerun_best_policy(
        agent_pickle,
        accs_txt,
        train_dataset,
        test_dataset,
        child_network_architecture,
        config,
        repeat_num
        ):
    
        with open(agent_pickle, 'rb') as f:
            agent = torch.load(f)
        
        megapol, orig_accs = get_mega_policy(agent.history,3)
        print('mega policy to be tested:')
        pprint.pprint(megapol)
        print(orig_accs)
    
        accs=[]
        for _ in range(repeat_num):
            print(f'{_}/{repeat_num}')
            temp_agent = aal.AaLearner(**config)
            accs.append(
                    temp_agent._test_autoaugment_policy(megapol,
                                        child_network_architecture,
                                        train_dataset,
                                        test_dataset,
                                        logging=False)
                        )
            with open(accs_txt, 'w') as f:
                f.write(pprint.pformat(megapol))
                f.write(str(accs))
                f.write(f'original small policys accuracies: {orig_accs}')
    
    
    
    
    # # CIFAR10 with LeNet
    train_dataset = datasets.CIFAR10(root='./datasets/cifar10/train',
                            train=True, download=True, transform=None)
    test_dataset = datasets.CIFAR10(root='./datasets/cifar10/train',
                            train=False, download=True, 
                            transform=torchvision.transforms.ToTensor())
    child_network_architecture = cn.LeNet(
                                        img_height=32,
                                        img_width=32,
                                        num_labels=10,
                                        img_channels=3
                                        )
    
    # save_dir='./benchmark/pickles/04_22_cf_ln_rssad'
    
    # # evo
    # run_benchmark(
    #     save_file=save_dir+'.pkl',
    #     train_dataset=train_dataset,
    #     test_dataset=test_dataset,
    #     child_network_architecture=child_network_architecture,
    #     agent_arch=aal.EvoLearner,
    #     config=config,
    #     )
    
    # # rerun_best_policy(
    # #     agent_pickle=save_dir+'.pkl',
    # #     accs_txt=save_dir+'.txt',
    # #     train_dataset=train_dataset,
    # #     test_dataset=test_dataset,
    # #     child_network_architecture=child_network_architecture,
    # #     config=config,
    # #     repeat_num=5
    # #     )
    
    
    
    megapol = [(('ShearY', 0.5, 5), ('Posterize', 0.6, 5)), (('Color', 1.0, 9), ('Contrast', 1.0, 9)), (('TranslateX', 0.5, 5), ('Posterize', 0.5, 5)), (('TranslateX', 0.5, 5), ('Posterize', 0.5, 5)), (('Color', 0.5, 5), ('Posterize', 0.5, 5))]
    
    
    accs=[]
    for _ in range(10):
        print(f'{_}/{10}')
        temp_agent = aal.evo_learner(**config)
        accs.append(
                temp_agent.test_autoaugment_policy(megapol,
                                    child_network_architecture,
                                    train_dataset,
                                    test_dataset,
                                    logging=False)
                    )
    
    print("CIPHAR10 accs: ", accs)