Skip to content
Snippets Groups Projects
04_22_evo.py 5.64 KiB
Newer Older
import torchvision.datasets as datasets
import torchvision
import torch

import autoaug.child_networks as cn
import autoaug.autoaugment_learners as aal


controller = cn.EasyNet(img_height=32, img_width=32, num_labels=16*10, img_channels=3)

config = {
        'sp_num' : 5,
        'learning_rate' : 1e-1,
        'batch_size' : 32,
        'max_epochs' : 100,
        'early_stop_num' : 10,
        'controller' : controller,
        'num_solutions' : 10,
        }




import torch

import autoaug.autoaugment_learners as aal

import pprint

"""
testing GruLearner and RsLearner on

  fashionmnist with simple net

 and 

  cifar10 with lenet

"""
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')






def run_benchmark(
    save_file,
    train_dataset,
    test_dataset,
    child_network_architecture,
    agent_arch,
    config,
    total_iter=150,
    ):
    try:
        # try to load agent
        with open(save_file, 'rb') as f:
            agent = torch.load(f, map_location=device)
    except FileNotFoundError:
        # if agent hasn't been saved yet, initialize the agent
        agent = agent_arch(**config)


    # if history is not length total_iter yet(if total_iter
    # different policies haven't been tested yet), keep running
    
    print("agent history: ", agent.history)
    while len(agent.history)<total_iter:
        print(f'{len(agent.history)} / {total_iter}')
        # run 1 iteration (test one new policy and update the GRU)
        agent.learn(
                    train_dataset=train_dataset,
                    test_dataset=test_dataset,
                    child_network_architecture=child_network_architecture,
Max Ramsay King's avatar
Max Ramsay King committed
                    iterations=5
                    )
        # save agent every iteration
        with open(save_file, 'wb+') as f:
            torch.save(agent, f)

    print('run_benchmark closing')


def get_mega_policy(history, n):
        """
        we get the best n policies from an agent's history,
        concatenate them to form our best mega policy

        Args:
            history (list[tuple])
            n (int)
        
        Returns:
            list[float]: validation accuracies
        """
        assert len(history) >= n

        # agent.history is a list of (policy(list), val_accuracy(float)) tuples 
        sorted_history = sorted(history, key=lambda x:x[1], reverse=True) # sort wrt acc

        best_history = sorted_history[:n]

        megapolicy = []
        # we also want to keep track of how good the best policies were
        # maybe if we add them all up, they'll become worse! Hopefully better tho
        orig_accs = []

        for policy,acc in best_history:
            for subpolicy in policy:
                megapolicy.append(subpolicy)
            orig_accs.append(acc)
        
        return megapolicy, orig_accs


def rerun_best_policy(
    agent_pickle,
    accs_txt,
    train_dataset,
    test_dataset,
    child_network_architecture,
    config,
    repeat_num
    ):

    with open(agent_pickle, 'rb') as f:
        agent = torch.load(f)
    
    megapol, orig_accs = get_mega_policy(agent.history,3)
    print('mega policy to be tested:')
    pprint.pprint(megapol)
    print(orig_accs)

    accs=[]
    for _ in range(repeat_num):
        print(f'{_}/{repeat_num}')
        temp_agent = aal.AaLearner(**config)
        accs.append(
                temp_agent._test_autoaugment_policy(megapol,
                                    child_network_architecture,
                                    train_dataset,
                                    test_dataset,
                                    logging=False)
                    )
        with open(accs_txt, 'w') as f:
            f.write(pprint.pformat(megapol))
            f.write(str(accs))
            f.write(f'original small policys accuracies: {orig_accs}')




# # CIFAR10 with LeNet
train_dataset = datasets.CIFAR10(root='./datasets/cifar10/train',
                        train=True, download=True, transform=None)
test_dataset = datasets.CIFAR10(root='./datasets/cifar10/train',
                        train=False, download=True, 
                        transform=torchvision.transforms.ToTensor())
child_network_architecture = cn.LeNet(
                                    img_height=32,
                                    img_width=32,
                                    num_labels=10,
                                    img_channels=3
                                    )

save_dir='./benchmark/pickles/04_22_cf_ln_rssad'
run_benchmark(
    save_file=save_dir+'.pkl',
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    child_network_architecture=child_network_architecture,
    agent_arch=aal.EvoLearner,
    config=config,
    )

# # rerun_best_policy(
# #     agent_pickle=save_dir+'.pkl',
# #     accs_txt=save_dir+'.txt',
# #     train_dataset=train_dataset,
# #     test_dataset=test_dataset,
# #     child_network_architecture=child_network_architecture,
# #     config=config,
# #     repeat_num=5
# #     )



megapol = [(('ShearY', 0.5, 5), ('Posterize', 0.6, 5)), (('Color', 1.0, 9), ('Contrast', 1.0, 9)), (('TranslateX', 0.5, 5), ('Posterize', 0.5, 5)), (('TranslateX', 0.5, 5), ('Posterize', 0.5, 5)), (('Color', 0.5, 5), ('Posterize', 0.5, 5))]


# accs=[]
# for _ in range(10):
#     print(f'{_}/{10}')
#     temp_agent = aal.EvoLearner(**config)
#     accs.append(
#             temp_agent._test_autoaugment_policy(megapol,
#                                 child_network_architecture,
#                                 train_dataset,
#                                 test_dataset,
#                                 logging=False)
#                 )
# print("CIPHAR10 accs: ", accs)