"""
This file demonstrates how a trained convolution neural net (CNN) model which has been trained using tensorflow
can be used to predict the label in iFind1 videos (and images)

Author: Christian Baumgartner (18. Jan 2016)
"""

import tensorflow as tf
import cv2
import convnet_model as convnet
import os
import time
import numpy as np

# Constants specific for testing
EVALUATE_ON_VIDEOS = 1 # 1 - for videos, 2 - for images
MODEL_DIRECTORY = '.' # where to go looking for the model
PROCESSED_DATA_PATH = '.'
VIDEO_PATH = '/vol/medic01/users/cbaumgar/data/iFind1/iFind1_600/simple/vids/iFIND01472.avi'
SAVE_VIDEO = False
DISPLAY_VIDEO = True
# If evaluating on images set this:
IMAGE_FOLDER = '/vol/medic01/users/cbaumgar/data/iFind1/iFind1_300/iFind1_simple/stillframes_testtrain/test'

saver = tf.train.Saver()

def restore_model(session):
    ckpt = tf.train.get_checkpoint_state(MODEL_DIRECTORY)
    if ckpt and ckpt.model_checkpoint_path:
        print("Check point details:")
        print(ckpt)
        saver.restore(sess, ckpt.model_checkpoint_path)
    else:
        print("No Checkpoints found!")

# Make a tensorflow data placeholder
data = tf.placeholder(
                tf.float32,
                shape=(1, convnet.IMAGE_SIZE, convnet.IMAGE_SIZE, convnet.NUM_CHANNELS))

# Using this place holder define a tensorflow function which predicts the labels of some data
# Note that the output will be a probability vector over all the classes
prediction = tf.nn.softmax(convnet.model(data))

# load labels and label_names dictionary
expected_label_path = os.path.join(PROCESSED_DATA_PATH, 'ifind1_scanplane_labels.npy')
y_train, y_test, label_names, label_numbers = np.load(expected_label_path)

if EVALUATE_ON_VIDEOS == 1:

    with tf.Session() as sess:

        restore_model(sess)

        if SAVE_VIDEO:
            video_filebase = (VIDEO_PATH.split('/')[-1]).split('.')[0]
            out_file = video_filebase + '_classified.avi'

        cap = cv2.VideoCapture(VIDEO_PATH)

        if SAVE_VIDEO:

            height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
            width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
            out = cv2.VideoWriter( out_file,
                                   cv2.VideoWriter_fourcc('D','I', 'V', 'X'),
                                   25.0, # Frame Rate
                                   (width,height) )

        nT = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

        for tt in xrange(nT):

            start_time = time.time()
            ret, frame = cap.read()

            # Do the same transformations as for the training data
            im_cropped = frame.copy()[80:530,150:600,:]
            im_cropped = cv2.resize(im_cropped, (convnet.IMAGE_SIZE, convnet.IMAGE_SIZE))
            im_gray = cv2.cvtColor(im_cropped, cv2.COLOR_BGR2GRAY)
            im_data = im_gray.astype(np.float32)
            im_data = np.reshape(im_data, (1, convnet.IMAGE_SIZE, convnet.IMAGE_SIZE, 1))
            X_white = convnet._whiten_images(im_data)

            # Run the tensorflow prediction function we defined earlier
            prediction_raw = sess.run(
                                prediction,
                                feed_dict={data: X_white})

            # Remove unneccessary dimensions
            predictions = prediction_raw.flatten()

            # The argmax of the prediction is the label number. The label_names dictionary helps us
            # to translate this to a label we can understand.
            label_predicted = label_names[np.argmax(prediction)]
            confidence = max(prediction)
            elapsed_time = time.time() - start_time

            # Print prediction
            prediction_string = "Prediction: %s (%.2f)" % (label_predicted, confidence)
            fps = 1 / ((elapsed_time * 1000)/60)

            # set font and colour
            font = cv2.FONT_HERSHEY_SIMPLEX
            if confidence < 0.9:
                colour = (255,255,255) # white
            else:
                colour = (0, 255, 0) # green (BGR)

            # Print the prediction and framerate on the video frames
            cv2.putText(frame, prediction_string, (10,500), font, 1, colour, 2)
            cv2.putText(frame, "framerate: %2.0f fps (on CPUx8)" % fps, (500,20), font, 0.5, (255, 255, 255), 2)

            # Output the prediction also in the console
            print(prediction_string)

            if DISPLAY_VIDEO:
                cv2.imshow('annotated frame', frame)
                cv2.waitKey(25)

            # WRITE FRAME
            if SAVE_VIDEO:
                out.write(frame)

        if SAVE_VIDEO:
            out.release()

        cap.release()

elif EVALUATE_ON_VIDEOS == 2:  # This means evaluate on images

    with tf.Session() as sess:

        ckpt = tf.train.get_checkpoint_state(MODEL_DIRECTORY)
        if ckpt and ckpt.model_checkpoint_path:
            print("Check point details:")
            print(ckpt)
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            print("No Checkpoints found!")

        for root, directories, files in os.walk(IMAGE_FOLDER):
            for fn in files:
                path = os.path.join(root, fn)
                gt_label_name = root.split('/')[-1]

                im = cv2.imread(path)

                start_time = time.time()

                im_cropped = im.copy()[106:713,176:783,:]
                im_cropped = cv2.resize(im_cropped, (convnet.IMAGE_SIZE, convnet.IMAGE_SIZE))
                im_gray = cv2.cvtColor(im_cropped, cv2.COLOR_BGR2GRAY)
                im_data = im_gray.astype(np.float32)
                im_data = np.reshape(im_data, (1, convnet.IMAGE_SIZE, convnet.IMAGE_SIZE, 1))

                X_white = convnet._whiten_images(im_data)

                prediction = sess.run(
                                prediction,
                                feed_dict={data: X_white})

                prediction = prediction.flatten()

                elapsed_time = time.time() - start_time

                print(np.argmax(prediction))
                print(label_names[np.argmax(prediction)])
                print("Conficence: %.2f" % max(prediction))
                print("Elapsed Time: %.1f ms" % (1000*elapsed_time))

                cv2.imshow('current image', im )
                cv2.waitKey(0)