from dataclasses import dataclass
from flask import Flask, request, current_app, send_file, send_from_directory, redirect, url_for, session
from flask_cors import CORS, cross_origin
import os
import zipfile
import torch
from numpy import int0, save, load
from react_backend.wapp_util import parse_users_learner_spec
import pprint
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import os
import sys
sys.path.insert(0, os.path.abspath('..'))
torch.manual_seed(0)

print('@@@ import successful')

# app = Flask(__name__, static_folder='react_frontend/build', static_url_path='/')
app = Flask(__name__)
CORS(app)

# it is used to collect user input and store them in the app
@app.route('/home', methods=["GET", "POST"])
# @cross_origin()
def get_form_data():
    
    if request.method == 'POST':
        print('@@@ in Flask Home')
        
        form_data = request.form
        print('@@@ this is form data', form_data)

        # required input
        ds = form_data['select_dataset'] # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)
        IsLeNet = form_data["select_network"]   # using LeNet or EasyNet or SimpleNet ->> default 
        auto_aug_learner = form_data["select_learner"] # augmentation methods to be excluded

        print('@@@ required user input:', 'ds', ds, 'IsLeNet:', IsLeNet, 'auto_aug_leanrer:',auto_aug_learner)
        # advanced input
        if form_data['batch_size'] not in ['undefined', ""]: 
            batch_size = int(form_data['batch_size']    )   # size of batch the inner NN is trained with
        else: 
            batch_size = 16 # this is for demonstration purposes
        if form_data['learning_rate'] not in ['undefined', ""]: 
            learning_rate =  float(form_data['learning_rate'])  # fix learning rate
        else: 
            learning_rate = 1e-2
        if form_data['toy_size'] not in ['undefined', ""]: 
            toy_size = float(form_data['toy_size'])      # total propeortion of training and test set we use
        else: 
            toy_size = 0.01 # this is for demonstration purposes
        if form_data['iterations'] not in ['undefined', ""]: 
            iterations = int(form_data['iterations'])      # total iterations, should be more than the number of policies
        else: 
            iterations = 2
        exclude_method = form_data['select_action']
        print('@@@ advanced search: batch_size:', batch_size, 'learning_rate:', learning_rate, 'toy_size:', toy_size, 'iterations:', iterations, 'exclude_method', exclude_method)
        

        # default values 
        max_epochs = 10      # max number of epochs that is run if early stopping is not hit
        early_stop_num = 10   # max number of worse validation scores before early stopping is triggered
        num_policies = 5      # fix number of policies
        num_sub_policies = 5  # fix number of sub-policies in a policy
        
        
        # if user upload datasets and networks, save them in the database
        if ds == 'Other':
            ds_folder = request.files['ds_upload'] 
            ds_name_zip = ds_folder.filename
            # check dataset zip file format
            if ds_name_zip.split('.')[1] != 'zip':
                data = {'error_type': 'not a zip file', 'error': "We found that your uplaoded dataset is not a zip file..."}
                current_app.config['data'] = data
                return data
            ds_name = ds_name_zip.split('.')[0]
            ds_folder.save('./react_backend/datasets/'+ ds_name_zip)
            with zipfile.ZipFile('./react_backend/datasets/'+ ds_name_zip, 'r') as zip_ref:
                zip_ref.extractall('./react_backend/datasets/upload_dataset/')
            if not current_app.debug:
                os.remove(f'./react_backend/datasets/{ds_name_zip}')
        else: 
            ds_name_zip = None
            ds_name = None

        # test if uploaded dataset meets the criteria 
        i = -1
        folder = 0
        for (dirpath, dirnames, filenames) in os.walk(f'./react_backend/datasets/upload_dataset/{ds_name}/'):
            i += 1
            if i==0:
                folders = dirnames
            has_child_folder = dirnames!=[] # check if there are child folders
            if not has_child_folder and i==0: 
                data = {'error_type': 'incorret dataset', 
                        'error': "We found that your uplaoded dataset doesn't have the correct format that we are looking for."}
                current_app.config['data'] = data
                return data
        if  folder!=0 and len(folders)!=i:
            data = {'error_type': 'incorret dataset', 
                    'error': "We found that your uplaoded dataset doesn't have the correct format that we are looking for."}
            current_app.config['data'] = data
            return data
        print('@@@ correct dataset folder!')
        
        # save the user uploaded network
        if IsLeNet == 'Other':
            childnetwork = request.files['network_upload']
            network_name = childnetwork.filename
            if network_name.split('.')[1] != 'pkl':
                data = {'error_type': 'incorrect network', 
                        'error': "We found that your uploaded network is not a pickle file"}
                current_app.config['data'] = data
                return data 
            else: 
                childnetwork.save('./child_networks/'+childnetwork.filename)
        else: 
            network_name = None

        print("@@@ user input has all stored in the app")

        data = {'ds': ds, 'ds_name': ds_name_zip, 'IsLeNet': IsLeNet, 'network_name': network_name,
                'auto_aug_learner':auto_aug_learner, 'batch_size': batch_size, 'learning_rate': learning_rate, 
                'toy_size':toy_size, 'iterations':iterations, 'exclude_method': exclude_method, }

        current_app.config['data'] = data
        
        print('@@@ all data sent', current_app.config['data'])


    elif request.method == 'GET':
        print('it is GET method')
    
        if 'data' in current_app.config.keys():
            data = current_app.config['data']
        else: 
            data = {'error': "We didn't received any data from you submission form. Please go back to the home page", 
            'error_type': 'no data'}

    return data
    # return redirect(url_for('confirm', data=data))



# ========================================================================
@app.route('/training', methods=['POST', 'GET'])
@cross_origin()
def training():

    # default values 
    max_epochs = 10      # max number of epochs that is run if early stopping is not hit
    early_stop_num = 10   # max number of worse validation scores before early stopping is triggered
    num_policies = 5      # fix number of policies
    num_sub_policies = 5  # fix number of sub-policies in a policy
    data = current_app.config.get('data')

    # parse the settings given by the user to obtain tools we need
    train_dataset, test_dataset, child_archi, agent = parse_users_learner_spec(
                                            max_epochs=max_epochs,
                                            early_stop_num=early_stop_num,
                                            num_policies=num_policies,
                                            num_sub_policies=num_sub_policies,
                                            **data
                                        )

    # train the autoaugment learner for number of `iterations`
    agent.learn(
        train_dataset=train_dataset, 
        test_dataset=test_dataset, 
        child_network_architecture=child_archi,
        iterations=data['iterations']
        ) 
    
    print('the history of all the policies the agent has tested:')
    pprint.pprint(agent.history)

    # get acc graph and best acc graph
    acc_list = [acc for (policy,acc) in agent.history]
    best_acc_list = []
    best_til_now = 0
    for acc in acc_list:
        if acc>best_til_now:
            best_til_now=acc
        best_acc_list.append(best_til_now)
    
    # plot both here
    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    ax.plot(acc_list)
    ax.plot(best_acc_list)
    ax.set_xlabel('Number of Iterations')
    ax.set_ylabel('Accuracy')
    ax.set_title('Auto-augmentation Learner Performance Curve')
    with open("./react_frontend/src/pages/output.png", 'wb') as f:
        fig.savefig(f)

    print("best policies:")
    best_policy = agent.get_mega_policy(number_policies=4)
    print(best_policy)
    with open("./react_backend/policy.txt", 'w') as f:
        # save the best_policy in pretty_print string format
        f.write(pprint.pformat(best_policy, indent=4))

    print('')

    return {'status': 'Training is done!'}


# ========================================================================
@app.route('/result')
@cross_origin()
def show_result():
    file_path = "./react_backend/policy.txt"
    f = open(file_path, "r")
    return send_file(file_path, as_attachment=True, cache_timeout=0)



# @app.route('/')
# def serve():
#     return send_from_directory(app.static_folder, 'index.html')



if __name__ == '__main__':
    app.run(debug=False, use_reloader=False)