import torch
import torch.nn as nn
import torch.optim as optim
from MetaAugment.main import train_child_network, create_toy
from MetaAugment.autoaugment_learners.autoaugment import AutoAugment

import torchvision.transforms as transforms

import copy
import types




class aa_learner:
    """
    The parent class for all aa_learner's
    
    Attributes:
        op_tensor_length (int): what is the dimension of the tensor that represents
                            each 'operation' (which is made up of fun_name, prob,
                            and mag).
    """
    def __init__(self, 
                # parameters that define the search space
                sp_num=5,
                p_bins=11,
                m_bins=10,
                discrete_p_m=False,
                # hyperparameters for when training the child_network
                batch_size=32,
                toy_size=1,
                learning_rate=1e-1,
                max_epochs=float('inf'),
                early_stop_num=20,
                exclude_method = [],
                ):
        """
        Args:
            sp_num (int, optional): number of subpolicies per policy. Defaults to 5.
            fun_num (int, optional): number of image functions in our search space.
                            Defaults to 14.
            p_bins (int, optional): number of bins we divide the interval [0,1] for 
                            probabilities. Defaults to 11.
            m_bins (int, optional): number of bins we divide the magnitude space.
                            Defaults to 10.
            discrete_p_m (bool, optional):
                            Whether or not the agent should represent probability and 
                            magnitude as discrete variables as the out put of the 
                            controller (A controller can be a neural network, genetic
                            algorithm, etc.). Defaults to False
            
            batch_size (int, optional): child_network training parameter. Defaults to 32.
            toy_size (int, optional): child_network training parameter. ratio of original
                                dataset used in toy dataset. Defaults to 0.1.
            learning_rate (float, optional): child_network training parameter. Defaults to 1e-2.
            max_epochs (Union[int, float], optional): child_network training parameter. 
                                Defaults to float('inf').
            early_stop_num (int, optional): child_network training parameter. Defaults to 20.
        """
        # related to defining the search space
        self.sp_num = sp_num
        self.p_bins = p_bins
        self.m_bins = m_bins
        self.discrete_p_m = discrete_p_m

        # related to training of the child_network
        self.batch_size = batch_size
        self.toy_size = toy_size
        self.learning_rate = learning_rate

        self.max_epochs = max_epochs
        self.early_stop_num = early_stop_num

        # TODO: We should probably use a different way to store results than self.history
        self.history = []

        # this is the full augmentation space. We take out some image functions
        # if the user specifies so in the exclude_method parameter
        augmentation_space = [
            # (function_name, do_we_need_to_specify_magnitude)
            ("ShearX", True),
            ("ShearY", True),
            ("TranslateX", True),
            ("TranslateY", True),
            ("Rotate", True),
            ("Brightness", True),
            ("Color", True),
            ("Contrast", True),
            ("Sharpness", True),
            ("Posterize", True),
            ("Solarize", True),
            ("AutoContrast", False),
            ("Equalize", False),
            ("Invert", False),
        ]
        self.exclude_method = exclude_method
        self.augmentation_space = [x for x in augmentation_space if x[0] not in exclude_method]

        self.fun_num = len(self.augmentation_space)
        self.op_tensor_length = self.fun_num + p_bins + m_bins if discrete_p_m else self.fun_num +2
        self.num_pols_tested = 0
        self.policy_record = {}




    def _translate_operation_tensor(self, operation_tensor, return_log_prob=False, argmax=False):
        """
        takes in a tensor representing an operation and returns an actual operation which
        is in the form of:
            ("Invert", 0.8, None)
            or
            ("Contrast", 0.2, 6)

        Args:
            operation_tensor (tensor): 
                                We expect this tensor to already have been softmaxed.
                                Furthermore,
                                - If self.discrete_p_m is True, we expect to take in a tensor with
                                dimension (self.fun_num + self.p_bins + self.m_bins)
                                - If self.discrete_p_m is False, we expect to take in a tensor with
                                dimension (self.fun_num + 1 + 1)

            return_log_prob (boolesn): 
                                When this is on, we return which indices (of fun, prob, mag) were
                                chosen (either randomly or deterministically, depending on argmax).
                                This is used, for example, in the gru_learner to calculate the
                                probability of the actions were chosen, which is then logged, then
                                differentiated.

            argmax (boolean): 
                            Whether we are taking the argmax of the softmaxed tensors. 
                            If this is False, we treat the softmaxed outputs as multinomial pdf's.

        Returns:
            operation (list of tuples):
                                An operation in the format that can be directly put into an
                                AutoAugment object.
            log_prob (float):
                            Used in reinforcement learning updates, such as proximal policy update
                            in the gru_learner.
                            Can only be used when self.discrete_p_m.
                            We add the logged values of the indices of the image_function,
                            probability, and magnitude chosen.
                            This corresponds to multiplying the non-logged values, then logging
                            it.                  
        """

        if (not self.discrete_p_m) and return_log_prob:
            raise ValueError("You are not supposed to use return_log_prob=True when the agent's \
                            self.discrete_p_m is False!")

        # make sure shape is correct
        assert operation_tensor.shape==(self.op_tensor_length, ), operation_tensor.shape

        # if probability and magnitude are represented as discrete variables
        if self.discrete_p_m:
            fun_t, prob_t, mag_t = operation_tensor.split([self.fun_num, self.p_bins, self.m_bins])

            # make sure they are of right size
            assert fun_t.shape==(self.fun_num,), f'{fun_t.shape} != {self.fun_num}'
            assert prob_t.shape==(self.p_bins,), f'{prob_t.shape} != {self.p_bins}'
            assert mag_t.shape==(self.m_bins,), f'{mag_t.shape} != {self.m_bins}'


            if argmax==True:
                fun_idx = torch.argmax(fun_t).item()
                prob_idx = torch.argmax(prob_t).item() # 0 <= p <= 10
                mag = torch.argmax(mag_t).item() # 0 <= m <= 9
            elif argmax==False:
                # we need these to add up to 1 to be valid pdf's of multinomials
                assert torch.sum(fun_t).isclose(torch.ones(1)), torch.sum(fun_t)
                assert torch.sum(prob_t).isclose(torch.ones(1)), torch.sum(prob_t)
                assert torch.sum(mag_t).isclose(torch.ones(1)), torch.sum(mag_t)

                fun_idx = torch.multinomial(fun_t, 1).item() # 0 <= fun <= self.fun_num-1
                prob_idx = torch.multinomial(prob_t, 1).item() # 0 <= p <= 10
                mag = torch.multinomial(mag_t, 1).item() # 0 <= m <= 9

            function = self.augmentation_space[fun_idx][0]
            prob = prob_idx/(self.p_bins-1)

            indices = (fun_idx, prob_idx, mag)

            # log probability is the sum of the log of the softmax values of the indices 
            # (of fun_t, prob_t, mag_t) that we have chosen
            log_prob = torch.log(fun_t[fun_idx]) + torch.log(prob_t[prob_idx]) + torch.log(mag_t[mag])


        # if probability and magnitude are represented as continuous variables
        else:
            fun_t, prob, mag = operation_tensor.split([self.fun_num, 1, 1])
            prob = prob.item()
            # 0 =< prob =< 1
            mag = mag.item()
            # 0 =< mag =< 9

            # make sure the shape is correct
            assert fun_t.shape==(self.fun_num,), f'{fun_t.shape} != {self.fun_num}'
            
            if argmax==True:
                fun_idx = torch.argmax(fun_t)
            elif argmax==False:
                assert torch.sum(fun_t).isclose(torch.ones(1))
                fun_idx = torch.multinomial(fun_t, 1).item()
            prob = round(prob, 1) # round to nearest first decimal digit
            mag = round(mag) # round to nearest integer
            
        function = self.augmentation_space[fun_idx][0]

        assert 0 <= prob <= 1, prob
        assert 0 <= mag <= self.m_bins-1, (mag, self.m_bins)
        
        # if the image function does not require a magnitude, we set the magnitude to None
        if self.augmentation_space[fun_idx][1] == True: # if the image function has a magnitude
            operation = (function, prob, mag)
        else:
            operation =  (function, prob, None)
        
        if return_log_prob:
            return operation, log_prob
        else:
            return operation
        

    def _generate_new_policy(self):
        """
        Generate a new policy which can be fed into an AutoAugment object 
        by calling:
            AutoAugment.subpolicies = policy
        
        Args:
            none
        
        Returns:
            new_policy (list[tuple]):
                        A new policy generated by the controller. It
                        has the form of:
                            [
                            (("Invert", 0.8, None), ("Contrast", 0.2, 6)),
                            (("Rotate", 0.7, 2), ("Invert", 0.8, None)),
                            (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
                            (("ShearY", 0.5, 8), ("Invert", 0.7, None)),
                            ]
                        This object can be fed into an AutoAUgment object
                        by calling: AutoAugment.subpolicies = policy
        """

        raise NotImplementedError('_generate_new_policy not implemented in aa_learner')


    def learn(self, train_dataset, test_dataset, child_network_architecture, iterations=15):
        """
        Runs the main loop (of finding a good policy for the given child network,
        training dataset, and test(validation) dataset)

        Does the loop which is seen in Figure 1 in the AutoAugment paper
        which is:
            1. <generate a random policy>
            2. <see how good that policy is>
            3. <save how good the policy is in a list/dictionary and 
                (if applicable,) update the controller (e.g. RL agent)>
        
        Args:
            train_dataset (torchvision.dataset.vision.VisionDataset)
            test_dataset (torchvision.dataset.vision.VisionDataset)
            child_network_architecture (Union[function, nn.Module]):
                                NOTE This can be both, for example,
                                    MyNetworkArchitecture
                                    and
                                    MyNetworkArchitecture()
            iterations (int): how many different policies do you want to test
        Returns:
            none
        
        
        If child_network_architecture is a <function>, then we make an 
        instance of it. If this is a <nn.Module>, we make a copy.deepcopy
        of it. We make a copy of it because we we want to keep an untrained 
        (initialized but not trained) version of the child network
        architecture, because we need to train it multiple times
        for each policy. Keeping child_network_architecture as a `function` is
        potentially better than keeping it as a nn.Module because every
        time we make a new instance, the weights are differently initialized
        which means that our results will be less biased
        (https://en.wikipedia.org/wiki/Bias_(statistics)).
        

        Example code:

        .. code-block::
            :caption: This is an example dummy code which tests out 15 
                      different policies
            
            for _ in range(15):
                policy = self._generate_new_policy()

                pprint(policy)
                reward = self._test_autoaugment_policy(policy,
                                        child_network_architecture,
                                        train_dataset,
                                        test_dataset)

                self.history.append((policy, reward))
        """

    

    def _test_autoaugment_policy(self,
                                policy,
                                child_network_architecture,
                                train_dataset,
                                test_dataset,
                                logging=False,
                                print_every_epoch=True):
        """
        Given a policy (using AutoAugment paper terminology), we train a child network
        using the policy and return the accuracy (how good the policy is for the dataset and 
        child network).

        Args: 
            policy (list[tuple]): A list of tuples representing a policy.
            child_network_architecture (Union[function, nn.Module]):
                                If this is a :code:`function`, then we make
                                an instance of it. If this is a 
                                :code:`nn.Module`, we make a :code:`copy.deepcopy`
                                of it.
            train_dataset (torchvision.dataset.vision.VisionDataset)
            test_dataset (torchvision.dataset.vision.VisionDataset)
            logging (boolean): Whether we want to save logs
        
        Returns:
            accuracy (float): best accuracy reached in any
        """



        # we create an instance of the child network that we're going
        # to train. The method of creation depends on the type of 
        # input we got for child_network_architecture
        if isinstance(child_network_architecture, types.FunctionType):
            child_network = child_network_architecture()
        elif isinstance(child_network_architecture, type):
            child_network = child_network_architecture()
        elif isinstance(child_network_architecture, torch.nn.Module):
            child_network = copy.deepcopy(child_network_architecture)
        else:
            raise ValueError('child_network_architecture must either be \
                            a <function> or a <torch.nn.Module>. Type of : ',
                            child_network_architecture, ': ' ,
                            type(child_network_architecture))

        # We need to define an object aa_transform which takes in the image and 
        # transforms it with the policy (specified in its .policies attribute)
        # in its forward pass
        aa_transform = AutoAugment()
        aa_transform.subpolicies = policy
        train_transform = transforms.Compose([
                                                aa_transform,
                                                transforms.ToTensor()
                                            ])
        
        # We feed the transformation into the Dataset object
        train_dataset.transform = train_transform

        # create Dataloader objects out of the Dataset objects
        train_loader, test_loader = create_toy(train_dataset,
                                            test_dataset,
                                            batch_size=self.batch_size,
                                            n_samples=self.toy_size,
                                            seed=100)
        
        # train the child network with the dataloaders equipped with our specific policy
        accuracy = train_child_network(child_network, 
                                    train_loader, 
                                    test_loader, 
                                    sgd = optim.SGD(child_network.parameters(),
                                                    lr=self.learning_rate),
                                    # sgd = optim.Adadelta(
                                    #               child_network.parameters(),
                                    #               lr=self.learning_rate),
                                    cost = nn.CrossEntropyLoss(),
                                    max_epochs = self.max_epochs, 
                                    early_stop_num = self.early_stop_num, 
                                    logging = logging,
                                    print_every_epoch=print_every_epoch)


        # turn policy into dictionary format and add it into self.policy_record
        curr_pol = f'pol{self.num_pols_tested}'
        pol_dict = {}
        for subpol in policy:
            first_trans, first_prob, first_mag = subpol[0]
            second_trans, second_prob, second_mag = subpol[1]
            components = (first_prob, first_mag, second_prob, second_mag)
            if first_trans in pol_dict:
                if second_trans in pol_dict[first_trans]:
                    pol_dict[first_trans][second_trans].append(components)
                else:
                    pol_dict[first_trans]= {second_trans: [components]}
            else:
                pol_dict[first_trans]= {second_trans: [components]}
        self.policy_record[curr_pol] = (pol_dict, accuracy)

        self.num_pols_tested += 1
        return accuracy
    

    # def demo_plot(self, train_dataset, test_dataset, child_network_architecture, n=5):
    #     """
    #     I made this to plot a couple of accuracy graphs to help manually tune my gradient 
    #     optimizer hyperparameters.

    #     Saves a plot of `n` training accuracy graphs overlapped.
    #     """
        
    #     acc_lists = []

    #     # This is dummy code
    #     # test out `n` random policies
    #     for _ in range(n):
    #         policy = self._generate_new_policy()

    #         pprint(policy)
    #         reward, acc_list = self._test_autoaugment_policy(policy,
    #                                             child_network_architecture,
    #                                             train_dataset,
    #                                             test_dataset,
    #                                             logging=True)

    #         self.history.append((policy, reward))
    #         acc_lists.append(acc_list)

    #     for acc_list in acc_lists:
    #         plt.plot(acc_list)
    #     plt.title('I ran 5 random policies to see if there is any sign of \
    #                 catastrophic failure during training. If there are \
    #                 any lines which reach significantly lower (>10%) \
    #                 accuracies, you might want to tune the hyperparameters')
    #     plt.xlabel('epoch')
    #     plt.ylabel('accuracy')
    #     plt.show()
    #     plt.savefig('training_graphs_without_policies')