# %%
import numpy as np
import matplotlib.pyplot as plt 
from itertools import count

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.distributions import Categorical
from torch.utils.data import TensorDataset, DataLoader


from collections import namedtuple, deque
import math
import random

from MetaAugment.main import *


batch_size = 128

test_dataset = torchvision.datasets.MNIST(root='test_dataset/', train=False, download=True, transform=torchvision.transforms.ToTensor())
train_dataset = torchvision.datasets.MNIST(root='test_dataset/', train=True, download=True, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
print('test_loader', len(test_loader))
print('train_loader',len(train_loader))

def create_toy(train_dataset, test_dataset, batch_size, n_samples):
    # shuffle and take first n_samples %age of training dataset
    shuffled_train_dataset = torch.utils.data.Subset(train_dataset, torch.randperm(len(train_dataset)).tolist())
    indices_train = torch.arange(int(n_samples*len(train_dataset)))
    reduced_train_dataset = torch.utils.data.Subset(shuffled_train_dataset, indices_train)
    # shuffle and take first n_samples %age of test dataset
    shuffled_test_dataset = torch.utils.data.Subset(test_dataset, torch.randperm(len(test_dataset)).tolist())
    indices_test = torch.arange(int(n_samples*len(test_dataset)))
    reduced_test_dataset = torch.utils.data.Subset(shuffled_test_dataset, indices_test)

    # push into DataLoader
    train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size)
    test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size)

    return train_loader, test_loader

# train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, 10)


class LeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(256, 120)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(120, 84)
        self.relu4 = nn.ReLU()
        self.fc3 = nn.Linear(84, 10)
        self.relu5 = nn.ReLU()

    def forward(self, x):
        y = self.conv1(x)
        y = self.relu1(y)
        y = self.pool1(y)
        y = self.conv2(y)
        y = self.relu2(y)
        y = self.pool2(y)
        y = y.view(y.shape[0], -1)
        y = self.fc1(y)
        y = self.relu3(y)
        y = self.fc2(y)
        y = self.relu4(y)
        y = self.fc3(y)
        y = self.relu5(y)
        return y

# %% [markdown]
# ## collect reward

# %%

def collect_reward(train_loader, test_loader, max_epochs=100, early_stop_num=10):
    child_network = LeNet() 
    sgd = optim.SGD(child_network.parameters(), lr=1e-1)
    cost = nn.CrossEntropyLoss()
    best_acc=0
    early_stop_cnt = 0
    
    # train child_network and check validation accuracy each epoch
    print('max_epochs', max_epochs)
    for _epoch in range(max_epochs):
        print('_epoch', _epoch)
        # train child_network
        child_network.train()
        for t, (train_x, train_label) in enumerate(train_loader):
            label_np = np.zeros((train_label.shape[0], 10))
            sgd.zero_grad()
            predict_y = child_network(train_x.float())
            loss = cost(predict_y, train_label.long())
            loss.backward()
            sgd.step()

        # check validation accuracy on validation set
        correct = 0
        _sum = 0
        child_network.eval()
        for idx, (test_x, test_label) in enumerate(test_loader):
            predict_y = child_network(test_x.float()).detach()
            predict_ys = np.argmax(predict_y, axis=-1)
            label_np = test_label.numpy()
            _ = predict_ys == test_label
            correct += np.sum(_.numpy(), axis=-1)
            _sum += _.shape[0]
        
        # update best validation accuracy if it was higher, otherwise increase early stop count
        acc = correct / _sum

        if acc > best_acc :
            best_acc = acc
            early_stop_cnt = 0
        else:
            early_stop_cnt += 1

        # exit if validation gets worse over 10 runs
        if early_stop_cnt >= early_stop_num:
            break

        # if _epoch%30 == 0:
        #     print('child_network accuracy: ', best_acc)
        
    return best_acc


# %%
for t, (train_x, train_label) in enumerate(test_loader):
    print(train_x.shape)
    print(train_label)
    break
len(test_loader)

# %%
collect_reward(train_loader, test_loader)


# %% [markdown]
# ## Policy network

# %%
class Policy(nn.Module):
    """
    implements both actor and critic in one model
    """
    def __init__(self):
        super(Policy, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5 , stride=2)
        self.conv2 = nn.Conv2d(6, 12, 5, stride=2)
        self.maxpool = nn.MaxPool2d(4)

        # actor's layer
        self.action_head = nn.Linear(12, 2)

        # critic's layer
        self.value_head = nn.Linear(12, 1)

        # action & reward buffer
        self.saved_actions = []
        self.rewards = []

    def forward(self, x):
        """
        forward of both actor and critic
        """
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.maxpool(x)
        x = x.view(x.size(0), -1)
        # print('x', x.shape)

        # actor: choses action to take from state s_t 
        # by returning probability of each action
        # print('self.action_head(x)', self.action_head(x).shape)
        action_prob = F.softmax(self.action_head(x), dim=-1)
        # print('action_prob', action_prob.shape)

        # critic: evaluates being in the state s_t
        state_values = self.value_head(x)

        # return values for both actor and critic as a tuple of 2 values:
        # 1. a list with the probability of each action over the action space
        # 2. the value from state s_t 
        return action_prob, state_values


# %%
test_dataset = torchvision.datasets.MNIST(root='test_dataset/', train=False, download=True, transform=torchvision.transforms.ToTensor())
train_dataset = torchvision.datasets.MNIST(root='test_dataset/', train=True, download=True, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

policy_model = Policy()
# for t, (x, y) in enumerate(train_loader):
#     # print(x.shape)
#     policy_model(x)

# %% [markdown]
# ## select action

# %%
SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])

def select_action(train_loader, policy_model):
    probs_list = []
    value_list = []
    for t, (x, y) in enumerate(train_loader):
        probs_i, state_value_i = policy_model(x)
        probs_list += [probs_i]
        value_list += [state_value_i]

    probs = torch.mean(torch.cat(probs_list), dim=0)
    state_value = torch.mean(torch.cat(value_list))
    # print('probs_i', probs_i)
    # print('probs', probs)
    # create a categorical distribution over the list of probabilities of actions
    m = Categorical(probs)
    # print('m', m)
    # and sample an action using the distribution
    action = m.sample()
    # print('action', action)

    # save to action buffer
    policy_model.saved_actions.append(SavedAction(m.log_prob(action), state_value))

    # the action to take (left or right)
    return action.item()


# %%
torch.tensor([1, 2, 3])

# %% [markdown]
# ## take action

# %%
def take_action(action_idx):
    # Define actions (data augmentation policy) --- can be improved
    action_list = [
    torchvision.transforms.Compose([torchvision.transforms.RandomVerticalFlip(),
        torchvision.transforms.ToTensor()]),
    torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor()]),
    torchvision.transforms.Compose([torchvision.transforms.RandomGrayscale(),
        torchvision.transforms.ToTensor()]),
    torchvision.transforms.Compose([torchvision.transforms.RandomAffine(30),
        torchvision.transforms.ToTensor()])]

    # transform   
    transform = action_list[action_idx]
    test_dataset = torchvision.datasets.MNIST(root='test_dataset/', train=False, download=True, transform=transform)
    train_dataset = torchvision.datasets.MNIST(root='test_dataset/', train=True, download=True, transform=transform)
    train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, n_samples=0.0002)
    return train_loader, test_loader


# %% [markdown]
# ## finish episode

# %%
policy_model = Policy()
optimizer = optim.Adam(policy_model.parameters(), lr=3e-2)
eps = np.finfo(np.float32).eps.item()
gamma = 0.9
def finish_episode():
    """
    Training code. Calculates actor and critic loss and performs backprop.
    """
    R = 0
    saved_actions = policy_model.saved_actions
    policy_losses = [] # list to save actor (policy) loss
    value_losses = [] # list to save critic (value) loss
    returns = [] # list to save the true values

    # calculate the true value using rewards returned from the environment
    for r in policy_model.rewards[::-1]:
        # calculate the discounted value
        R = r + gamma * R
        returns.insert(0, R)

    returns = torch.tensor(returns)
    returns = (returns - returns.mean()) / (returns.std() + eps)

    for (log_prob, value), R in zip(saved_actions, returns):
        advantage = R - value.item()

        # calculate actor (policy) loss 
        policy_losses.append(-log_prob * advantage)

        # calculate critic (value) loss using L1 smooth loss
        value_losses.append(F.smooth_l1_loss(value, torch.tensor([R])))

    # reset gradients
    optimizer.zero_grad()

    # sum up all the values of policy_losses and value_losses
    loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum()

    # perform backprop
    loss.backward()
    optimizer.step()

    # reset rewards and action buffer
    del policy_model.rewards[:]
    del policy_model.saved_actions[:]

# %% [markdown]
# ## run

# %%

running_reward = 10
episodes_num = 100
policy_model = Policy()
for i_episode in range(episodes_num) :
    # initiate a new state
    train_dataset = torchvision.datasets.MNIST(root='test_dataset/', train=True, download=True, transform=torchvision.transforms.ToTensor())
    # train_dataset = torchvision.datasets.MNIST(root='test_dataset/', train=True, download=True, transform=torchvision.transforms.ToTensor())
    train_loader_state = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    # select action from policy
    action_idx = select_action(train_loader, policy_model)
    print('>>> action_idx', action_idx)

    # take the action -> apply data augmentation
    train_loader, test_loader = take_action(action_idx)
    reward = collect_reward(train_loader, test_loader)
    print('>>> reward', reward)

    # if args.render:
    #     env.render()

    policy_model.rewards.append(reward)

    # perform backprop
    finish_episode()

    # # log result
    if i_episode % 10 == 0:
        print('Episode {}\tLast reward (val accuracy): {:.2f}'.format(i_episode, reward))

# %%