Forked from
Christian Baumgartner / ifind1_scanplanes_tensorflow
3 commits behind the upstream repository.
-
Christian Baumgartner authoredChristian Baumgartner authored
convnet_testonvideos.py 6.17 KiB
"""
This file demonstrates how a trained convolution neural net (CNN) model which has been trained using tensorflow
can be used to predict the label in iFind1 videos (and images)
Author: Christian Baumgartner (18. Jan 2016)
"""
import tensorflow as tf
import cv2
import convnet_model as convnet
import os
import time
import numpy as np
# Constants specific for testing
EVALUATE_ON_VIDEOS = 1 # 1 - for videos, 2 - for images
MODEL_DIRECTORY = '.' # where to go looking for the model
PROCESSED_DATA_PATH = '.'
VIDEO_PATH = '/vol/medic01/users/cbaumgar/data/iFind1/iFind1_600/simple/vids/iFIND01472.avi'
SAVE_VIDEO = False
DISPLAY_VIDEO = True
# If evaluating on images set this:
IMAGE_FOLDER = '/vol/medic01/users/cbaumgar/data/iFind1/iFind1_300/iFind1_simple/stillframes_testtrain/test'
saver = tf.train.Saver()
def restore_model(session):
ckpt = tf.train.get_checkpoint_state(MODEL_DIRECTORY)
if ckpt and ckpt.model_checkpoint_path:
print("Check point details:")
print(ckpt)
saver.restore(sess, ckpt.model_checkpoint_path)
else:
print("No Checkpoints found!")
# Make a tensorflow data placeholder
data = tf.placeholder(
tf.float32,
shape=(1, convnet.IMAGE_SIZE, convnet.IMAGE_SIZE, convnet.NUM_CHANNELS))
# Using this place holder define a tensorflow function which predicts the labels of some data
# Note that the output will be a probability vector over all the classes
prediction = tf.nn.softmax(convnet.model(data))
# load labels and label_names dictionary
expected_label_path = os.path.join(PROCESSED_DATA_PATH, 'ifind1_scanplane_labels.npy')
y_train, y_test, label_names, label_numbers = np.load(expected_label_path)
if EVALUATE_ON_VIDEOS == 1:
with tf.Session() as sess:
restore_model(sess)
if SAVE_VIDEO:
video_filebase = (VIDEO_PATH.split('/')[-1]).split('.')[0]
out_file = video_filebase + '_classified.avi'
cap = cv2.VideoCapture(VIDEO_PATH)
if SAVE_VIDEO:
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
out = cv2.VideoWriter( out_file,
cv2.VideoWriter_fourcc('D','I', 'V', 'X'),
25.0, # Frame Rate
(width,height) )
nT = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
for tt in xrange(nT):
start_time = time.time()
ret, frame = cap.read()
# Do the same transformations as for the training data
im_cropped = frame.copy()[80:530,150:600,:]
im_cropped = cv2.resize(im_cropped, (convnet.IMAGE_SIZE, convnet.IMAGE_SIZE))
im_gray = cv2.cvtColor(im_cropped, cv2.COLOR_BGR2GRAY)
im_data = im_gray.astype(np.float32)
im_data = np.reshape(im_data, (1, convnet.IMAGE_SIZE, convnet.IMAGE_SIZE, 1))
X_white = convnet._whiten_images(im_data)
# Run the tensorflow prediction function we defined earlier
prediction_raw = sess.run(
prediction,
feed_dict={data: X_white})
# Remove unneccessary dimensions
predictions = prediction_raw.flatten()
# The argmax of the prediction is the label number. The label_names dictionary helps us
# to translate this to a label we can understand.
label_predicted = label_names[np.argmax(prediction)]
confidence = max(prediction)
elapsed_time = time.time() - start_time
# Print prediction
prediction_string = "Prediction: %s (%.2f)" % (label_predicted, confidence)
fps = 1 / ((elapsed_time * 1000)/60)
# set font and colour
font = cv2.FONT_HERSHEY_SIMPLEX
if confidence < 0.9:
colour = (255,255,255) # white
else:
colour = (0, 255, 0) # green (BGR)
# Print the prediction and framerate on the video frames
cv2.putText(frame, prediction_string, (10,500), font, 1, colour, 2)
cv2.putText(frame, "framerate: %2.0f fps (on CPUx8)" % fps, (500,20), font, 0.5, (255, 255, 255), 2)
# Output the prediction also in the console
print(prediction_string)
if DISPLAY_VIDEO:
cv2.imshow('annotated frame', frame)
cv2.waitKey(25)
# WRITE FRAME
if SAVE_VIDEO:
out.write(frame)
if SAVE_VIDEO:
out.release()
cap.release()
elif EVALUATE_ON_VIDEOS == 2: # This means evaluate on images
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(MODEL_DIRECTORY)
if ckpt and ckpt.model_checkpoint_path:
print("Check point details:")
print(ckpt)
saver.restore(sess, ckpt.model_checkpoint_path)
else:
print("No Checkpoints found!")
for root, directories, files in os.walk(IMAGE_FOLDER):
for fn in files:
path = os.path.join(root, fn)
gt_label_name = root.split('/')[-1]
im = cv2.imread(path)
start_time = time.time()
im_cropped = im.copy()[106:713,176:783,:]
im_cropped = cv2.resize(im_cropped, (convnet.IMAGE_SIZE, convnet.IMAGE_SIZE))
im_gray = cv2.cvtColor(im_cropped, cv2.COLOR_BGR2GRAY)
im_data = im_gray.astype(np.float32)
im_data = np.reshape(im_data, (1, convnet.IMAGE_SIZE, convnet.IMAGE_SIZE, 1))
X_white = convnet._whiten_images(im_data)
prediction = sess.run(
prediction,
feed_dict={data: X_white})
prediction = prediction.flatten()
elapsed_time = time.time() - start_time
print(np.argmax(prediction))
print(label_names[np.argmax(prediction)])
print("Conficence: %.2f" % max(prediction))
print("Elapsed Time: %.1f ms" % (1000*elapsed_time))
cv2.imshow('current image', im )
cv2.waitKey(0)