""" This file demonstrates how a trained convolution neural net (CNN) model which has been trained using tensorflow can be used to predict the label in iFind1 videos (and images) Author: Christian Baumgartner (18. Jan 2016) """ import tensorflow as tf import cv2 import convnet_model as convnet import os import time import numpy as np # Constants specific for testing EVALUATE_ON_VIDEOS = 1 # 1 - for videos, 2 - for images MODEL_DIRECTORY = '.' # where to go looking for the model PROCESSED_DATA_PATH = '.' VIDEO_PATH = '/vol/medic01/users/cbaumgar/data/iFind1/iFind1_600/simple/vids/iFIND01472.avi' SAVE_VIDEO = False DISPLAY_VIDEO = True # If evaluating on images set this: IMAGE_FOLDER = '/vol/medic01/users/cbaumgar/data/iFind1/iFind1_300/iFind1_simple/stillframes_testtrain/test' saver = tf.train.Saver() def restore_model(session): ckpt = tf.train.get_checkpoint_state(MODEL_DIRECTORY) if ckpt and ckpt.model_checkpoint_path: print("Check point details:") print(ckpt) saver.restore(sess, ckpt.model_checkpoint_path) else: print("No Checkpoints found!") # Make a tensorflow data placeholder data = tf.placeholder( tf.float32, shape=(1, convnet.IMAGE_SIZE, convnet.IMAGE_SIZE, convnet.NUM_CHANNELS)) # Using this place holder define a tensorflow function which predicts the labels of some data # Note that the output will be a probability vector over all the classes prediction = tf.nn.softmax(convnet.model(data)) # load labels and label_names dictionary expected_label_path = os.path.join(PROCESSED_DATA_PATH, 'ifind1_scanplane_labels.npy') y_train, y_test, label_names, label_numbers = np.load(expected_label_path) if EVALUATE_ON_VIDEOS == 1: with tf.Session() as sess: restore_model(sess) if SAVE_VIDEO: video_filebase = (VIDEO_PATH.split('/')[-1]).split('.')[0] out_file = video_filebase + '_classified.avi' cap = cv2.VideoCapture(VIDEO_PATH) if SAVE_VIDEO: height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) out = cv2.VideoWriter( out_file, cv2.VideoWriter_fourcc('D','I', 'V', 'X'), 25.0, # Frame Rate (width,height) ) nT = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) for tt in xrange(nT): start_time = time.time() ret, frame = cap.read() # Do the same transformations as for the training data im_cropped = frame.copy()[80:530,150:600,:] im_cropped = cv2.resize(im_cropped, (convnet.IMAGE_SIZE, convnet.IMAGE_SIZE)) im_gray = cv2.cvtColor(im_cropped, cv2.COLOR_BGR2GRAY) im_data = im_gray.astype(np.float32) im_data = np.reshape(im_data, (1, convnet.IMAGE_SIZE, convnet.IMAGE_SIZE, 1)) X_white = convnet._whiten_images(im_data) # Run the tensorflow prediction function we defined earlier prediction_raw = sess.run( prediction, feed_dict={data: X_white}) # Remove unneccessary dimensions predictions = prediction_raw.flatten() # The argmax of the prediction is the label number. The label_names dictionary helps us # to translate this to a label we can understand. label_predicted = label_names[np.argmax(prediction)] confidence = max(prediction) elapsed_time = time.time() - start_time # Print prediction prediction_string = "Prediction: %s (%.2f)" % (label_predicted, confidence) fps = 1 / ((elapsed_time * 1000)/60) # set font and colour font = cv2.FONT_HERSHEY_SIMPLEX if confidence < 0.9: colour = (255,255,255) # white else: colour = (0, 255, 0) # green (BGR) # Print the prediction and framerate on the video frames cv2.putText(frame, prediction_string, (10,500), font, 1, colour, 2) cv2.putText(frame, "framerate: %2.0f fps (on CPUx8)" % fps, (500,20), font, 0.5, (255, 255, 255), 2) # Output the prediction also in the console print(prediction_string) if DISPLAY_VIDEO: cv2.imshow('annotated frame', frame) cv2.waitKey(25) # WRITE FRAME if SAVE_VIDEO: out.write(frame) if SAVE_VIDEO: out.release() cap.release() elif EVALUATE_ON_VIDEOS == 2: # This means evaluate on images with tf.Session() as sess: ckpt = tf.train.get_checkpoint_state(MODEL_DIRECTORY) if ckpt and ckpt.model_checkpoint_path: print("Check point details:") print(ckpt) saver.restore(sess, ckpt.model_checkpoint_path) else: print("No Checkpoints found!") for root, directories, files in os.walk(IMAGE_FOLDER): for fn in files: path = os.path.join(root, fn) gt_label_name = root.split('/')[-1] im = cv2.imread(path) start_time = time.time() im_cropped = im.copy()[106:713,176:783,:] im_cropped = cv2.resize(im_cropped, (convnet.IMAGE_SIZE, convnet.IMAGE_SIZE)) im_gray = cv2.cvtColor(im_cropped, cv2.COLOR_BGR2GRAY) im_data = im_gray.astype(np.float32) im_data = np.reshape(im_data, (1, convnet.IMAGE_SIZE, convnet.IMAGE_SIZE, 1)) X_white = convnet._whiten_images(im_data) prediction = sess.run( prediction, feed_dict={data: X_white}) prediction = prediction.flatten() elapsed_time = time.time() - start_time print(np.argmax(prediction)) print(label_names[np.argmax(prediction)]) print("Conficence: %.2f" % max(prediction)) print("Elapsed Time: %.1f ms" % (1000*elapsed_time)) cv2.imshow('current image', im ) cv2.waitKey(0)