import autoaug.autoaugment_learners as aal
import autoaug.child_networks as cn
import torchvision
import torchvision.datasets as datasets
from pprint import pprint

def test_ucb_learner():
    child_network_architecture = cn.SimpleNet
    train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train',
                            train=True, download=True, transform=None)
    test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test', 
                            train=False, download=True,
                            transform=torchvision.transforms.ToTensor())


    learner = aal.UcbLearner(
        # parameters that define the search space
                sp_num=5,
                p_bins=11,
                m_bins=10,
                discrete_p_m=True,
                # hyperparameters for when training the child_network
                batch_size=8,
                toy_size=0.001,
                learning_rate=1e-1,
                max_epochs=float('inf'),
                early_stop_num=30,
                # UcbLearner specific hyperparameter
                num_policies=3
    )
    pprint(learner.policies)
    assert len(learner.policies)==len(learner.avg_accs), \
                (len(learner.policies), (len(learner.avg_accs)))

    # learn on the 3 policies we generated
    learner.learn(
        train_dataset=train_dataset,
        test_dataset=test_dataset,
        child_network_architecture=child_network_architecture,
        iterations=5
        )
    
    # let's say we want to explore more policies:
    # we generate more new policies
    learner.make_more_policies(n=4)

    # and let's explore how good those are as well
    learner.learn(
        train_dataset=train_dataset,
        test_dataset=test_dataset,
        child_network_architecture=child_network_architecture,
        iterations=7
        )

    print(learner.get_mega_policy(number_policies=50))
    print(learner.get_mega_policy(number_policies=3))

if __name__=="__main__":
    test_ucb_learner()