Commit e3289d1b authored by cc215's avatar cc215 💬
Browse files

Merge branch 'hq615/Cardiac_Multi_view_segmentation-patch-1'

parents 014bac2c 7d8de9ef
......@@ -47,7 +47,7 @@ Results will be saved under `test_results` by default
- VLA
- VLA_img_ED.nii.gz
- e.g.
- predict LVSA: `python predict.py --sequence LVSA --root_dir 'test_data/' --image_format 'LVSA/LVSA_img_{}.nii.gz' --save_folder_path 'test_results/' --save_format_name 'seg_sa_{}.nii.gz' `
- predict LVSA: `python predict.py --sequence LVSA --root_dir 'test_data/' --image_format 'LVSA/LVSA_img_{}.nii.gz' --save_folder_path 'test_results/' --save_name_format 'seg_sa_{}.nii.gz' `
- please read predict.py for avanced settings (batch size, image resampling).
......
......@@ -5,7 +5,7 @@ import torch as th
import numpy as np
from skimage import transform as sktform
import random
from torchsample.utils import th_affine2d
def resample_by_spacing(im, new_spacing, interpolator=sitk.sitkLinear, keep_z_spacing=False):
'''
......@@ -205,92 +205,6 @@ class MySpecialRandomRotate(object):
return outputs
class MyRotate(object):
def __init__(self,
value,
output_size,
interp='bilinear',
lazy=False,crop=False):
"""
Randomly rotate an image between (-degrees, degrees). If the image
has multiple channels, the same rotation will be applied to each channel.
Arguments
---------
rotation_range : integer or float
image will be rotated between (-degrees, degrees) degrees
interp : string in {'bilinear', 'nearest'} or list of strings
type of interpolation to use. You can provide a different
type of interpolation for each input, e.g. if you have two
inputs then you can say `interp=['bilinear','nearest']
lazy : boolean
if true, only create the affine transform matrix and return that
if false, perform the transform on the tensor and return the tensor
"""
self.value = value
self.interp = interp
self.lazy = lazy
self.crop=crop ## remove black artifacts
self.output_size=output_size
def __call__(self, *inputs):
if not isinstance(self.interp, (tuple,list)):
interp = [self.interp]*len(inputs)
else:
interp = self.interp
theta =math.radians(self.value)
rotation_matrix = th.FloatTensor([[math.cos(theta), -math.sin(theta), 0],
[math.sin(theta), math.cos(theta), 0],
[0, 0, 1]])
self.theta=theta
if self.lazy:
return rotation_matrix
else:
outputs = []
new_w=0
new_h=0
for idx, _input in enumerate(inputs):
# lrr_width, lrr_height = _largest_rotated_rect(output_height, output_width,
# math.radians(rotation_degree))
# resized_image = tf.image.central_crop(image, float(lrr_height) / output_height)
# image = tf.image.resize_images(resized_image, [output_height, output_width],
# method=tf.image.ResizeMethod.BILINEAR, align_corners=False)
image_height=_input.size(1)
image_width=_input.size(2)
if not self.theta ==0.:
input_tf = th_affine2d(_input,
rotation_matrix,
mode=interp[idx],
center=True)
print ('size:',input_tf.size())
if self.crop:
if idx == 0:
##find largest rec to crop## adapted from the origin: https://stackoverflow.com/questions/16702966/rotate-image-and-crop-out-black-borders
new_w, new_h = largest_rotated_rect(
image_height,
image_width,
theta)
edge=min(new_w,new_h)
# out_edge=max(self.output_size[2],self.output_size[1])
cropper = MySpecialCrop(size=(edge, edge, 1), crop_type=0)
print('here')
output = cropper(input_tf) ## 1*H*W
Resizer = MyResize(size=(self.output_size[1], self.output_size[2]), interp=interp[idx])
output = Resizer(output)
else:
input_tf=_input #
padder=MyPad(size=(1,self.output_size[1],self.output_size[2]))
output = padder(input_tf)
# print (output.size())
outputs.append(output)
return outputs if idx >= 1 else outputs[0]
class MyResize(object):
"""
resize a 2D numpy array using skimage , support float type
......@@ -629,4 +543,4 @@ class RandomGamma(object):
def __call__(self, image, origin_space=None):
assert len(image.shape) == 3
\ No newline at end of file
assert len(image.shape) == 3
......@@ -5,42 +5,41 @@ import time
import SimpleITK as sitk
from torch.autograd import Variable
from dataset.cardiac_dataset import CARDIAC_Predict_DATASET
from dataset.utils import ReverseCropPad,CropPad
from dataset.utils import ReverseCropPad, CropPad
from dataset.utils import resample_by_ref
from model.unet import UNet
def save_predict(img,root_dir,patient_dir,file_name):
def save_predict(img, root_dir, patient_dir, file_name):
if not os.path.exists(root_dir):
os.makedirs(root_dir)
patient_dir=os.path.join(root_dir,patient_dir)
patient_dir = os.path.join(root_dir, patient_dir)
if not os.path.exists(patient_dir):
os.makedirs(patient_dir)
file_path=os.path.join(patient_dir,file_name)
sitk.WriteImage(img,file_path,True)
file_path = os.path.join(patient_dir, file_name)
sitk.WriteImage(img, file_path, True)
def link_image(origin_path,root_dir,patient_dir):
def link_image(origin_path, root_dir, patient_dir):
if not os.path.exists(root_dir):
os.mkdir(root_dir)
patient_dir=os.path.join(root_dir,patient_dir)
patient_dir = os.path.join(root_dir, patient_dir)
if not os.path.exists(patient_dir):
os.mkdir(patient_dir)
image_name = origin_path.split('/')[-1]
linked_name =image_name
linked_name = image_name
linked_path=os.path.join(patient_dir,linked_name)
print ('link path from {} to {}'.format(origin_path, linked_path))
linked_path = os.path.join(patient_dir, linked_name)
print('link path from {} to {}'.format(origin_path, linked_path))
os.system('ln -s {0} {1}'.format(origin_path, linked_path))
def predict(sequence_name,root_dir,image_format_name,
def predict(sequence_name, root_dir, image_format_name,
readable_frames,
model_path,
save_dir='./test_result',
save_format_name = 'seg_sa_{}.nii.gz',batch_size=4,if_resample=True):
save_format_name='seg_sa_{}.nii.gz', batch_size=4, if_resample=True):
'''
:param sequence_name: LVSA/4CH/VLA/LVOT
......@@ -61,7 +60,7 @@ def predict(sequence_name,root_dir,image_format_name,
'LVOT': 2
}[sequence_name]
save_dir=os.path.join(save_dir,sequence_name)
save_dir = os.path.join(save_dir, sequence_name)
## load model params from the path
model = UNet(input_channel=1, num_classes=number_classes)
......@@ -78,18 +77,17 @@ def predict(sequence_name,root_dir,image_format_name,
torch.cuda.set_device(0)
model.eval()
testset=CARDIAC_Predict_DATASET(root_dir,image_format_name=image_format_name,
if_resample=if_resample,
readable_frames=readable_frames,
new_spacing = [1.25, 1.25, 10],
keep_z_spacing=True)
time_records=[]
n_count=0
testset = CARDIAC_Predict_DATASET(root_dir, image_format_name=image_format_name,
if_resample=if_resample,
readable_frames=readable_frames,
new_spacing=[1.25, 1.25, 10],
keep_z_spacing=True)
time_records = []
n_count = 0
###iterate all images
for i in range(testset.get_size()):
print ('<------Loading data-------->')
for i in range(testset.get_size()):
print('<------Loading data-------->')
## for each patient
patient_data = testset.get_next_patient()
if patient_data is None:
......@@ -102,23 +100,23 @@ def predict(sequence_name,root_dir,image_format_name,
try:
frame_image = patient_data[frame]['image'] ##N*H*W
except:
print ('{} {} is missing'.format(patient_data['patient_id'],frame))
print('{} {} is missing'.format(patient_data['patient_id'], frame))
continue
aft_resample_shape=patient_data[frame]['aft_resample_shape']
print ('image shape after resample:',aft_resample_shape)
soft_prediction=np.zeros((aft_resample_shape[0],number_classes,aft_resample_shape[1],aft_resample_shape[2])) ##N*4*H*W
aft_resample_shape = patient_data[frame]['aft_resample_shape']
print('image shape after resample:', aft_resample_shape)
soft_prediction = np.zeros(
(aft_resample_shape[0], number_classes, aft_resample_shape[1], aft_resample_shape[2])) ##N*4*H*W
## DATA Preprocessing
## pad/crop image to a image size of 16
# X2, Y2 = int(math.ceil(aft_resample_shape[1] / 16.0)) * 16, int(math.ceil(aft_resample_shape[2] / 16.0)) * 16
X2, Y2=256,256
cPader=CropPad(X2,Y2,chw=True) ##central crop
reversecroppad=ReverseCropPad(aft_resample_shape[1],aft_resample_shape[2])
# X2, Y2 = int(math.ceil(aft_resample_shape[1] / 16.0)) * 16, int(math.ceil(aft_resample_shape[2] / 16.0)) * 16
X2, Y2 = 256, 256
cPader = CropPad(X2, Y2, chw=True) ##central crop
reversecroppad = ReverseCropPad(aft_resample_shape[1], aft_resample_shape[2])
## START Predicting
print('Segmenting {} frame: {}'.format(patient_data['patient_id'],frame))
print('Segmenting {} frame: {}'.format(patient_data['patient_id'], frame))
num_slices = aft_resample_shape[0]
n_batch = int(np.round(num_slices / batch_size))
for i in range(n_batch):
......@@ -134,13 +132,12 @@ def predict(sequence_name,root_dir,image_format_name,
input = Variable(input_tensor)
### predict every batch
batch_output= model(input)
batch_output = model(input)
batch_output_npy = batch_output.data.cpu().numpy() ##batch_size*n_cls*H'*W'
temp = reversecroppad(batch_output_npy.squeeze()) ##batch_size*n_cls*H*W
soft_prediction[i * batch_size:(i + 1) * batch_size] = temp
## soft recover predictions to original resolution
## soft recover predictions to original resolution
if if_resample:
stacked_list = []
## for each class prob, recover resolution
......@@ -155,84 +152,89 @@ def predict(sequence_name,root_dir,image_format_name,
predict_result = np.argmax(stacked_prob, axis=0).squeeze()
predict_result = np.uint8(predict_result)
else:
predict_result =np.argmax(soft_prediction, axis=1).squeeze()
predict_result = np.argmax(soft_prediction, axis=1).squeeze()
predict_result = np.uint8(predict_result)
end_process_time = time.time()-start_process_time
end_process_time = time.time() - start_process_time
time_records.append(end_process_time)
n_count+=1
n_count += 1
pid = patient_data['patient_id']
print ('Saving segmentation to {}/{}'.format(save_dir,pid))
if len(predict_result.shape)<len(patient_data[frame]['original_shape']):
predict_result=np.reshape(predict_result,patient_data[frame]['original_shape'])
print('Saving segmentation to {}/{}'.format(save_dir, pid))
if len(predict_result.shape) < len(patient_data[frame]['original_shape']):
predict_result = np.reshape(predict_result, patient_data[frame]['original_shape'])
post_im = sitk.GetImageFromArray(predict_result)
ref_im = patient_data[frame]['origin_itk_image']
post_im.CopyInformation(ref_im)
pred_file_name=save_format_name.format(frame)
save_predict(post_im,save_dir,pid,pred_file_name)
pred_file_name = save_format_name.format(frame)
save_predict(post_im, save_dir, pid, pred_file_name)
## create soft link to the image.
image_path = patient_data[frame]['temp_image_path']
link_image(origin_path=image_path,root_dir=save_dir,patient_dir=pid)
print ('Successfully segmented {:d} subjects '.format(n_count))
print ('Cost {:.3f}s per subjects ( exclude data loading time)'.format(np.mean(time_records)))
link_image(origin_path=image_path, root_dir=save_dir, patient_dir=pid)
print('Successfully segmented {:d} subjects '.format(n_count))
print('Cost {:.3f}s per subjects (exclude data loading time)'.format(np.mean(time_records)))
if __name__=='__main__':
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Cardiac Seg Prediction Function')
parser.add_argument('--sequence',type=str, default='LVSA', choices=['VLA','LVSA', 'LVOT','4CH'],help='clarify cardiac image sequence/which view')
parser.add_argument('--sequence', type=str, default='LVSA', choices=['VLA', 'LVSA', 'LVOT', '4CH'],
help='clarify cardiac image sequence/which view')
parser.add_argument('--root_dir', default='test_data/', help='test data folder')
parser.add_argument('--image_format', default='', help='test image name format under each patient dir, e.g. if LVSA_img_ED.nii.gz then format is LVSA_img_{}.nii.gz')
parser.add_argument('--image_format', default='',
help='test image name format under each patient dir, e.g. if LVSA_img_ED.nii.gz then format is LVSA_img_{}.nii.gz')
parser.add_argument('--save_folder_path', default='./test_results', help='folder to save predicted masks')
parser.add_argument('--save_name_format', default='seg_sa_{}.nii.gz', help='save mask format. {} for ED/ES')
parser.add_argument('--no_resample', default=False, action='store_true',help='not resample image before prediction')
parser.add_argument('--no_resample', default=False, action='store_true',
help='not resample image before prediction')
parser.add_argument('--batch_size', default=1, help='how many image slices to be sent to predictor each iteration ')
parser.add_argument('--gpu', default=0,
help='select GPU by masking shell environment variable CUDA_VISIBLE_DEVICES')
args = parser.parse_args()
### GPU CONFIG
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
### DATA CONFIG
root_dir = os.path.abspath(args.root_dir)
sequence_name=args.sequence
sequence_name = args.sequence
if sequence_name == 'LVSA':
## use batch size > 1, may accelerate the prediction process, but be careful of memory limit
batch_size=args.batch_size
batch_size = args.batch_size
else:
## 4CH, LVOT, VLA, only contains one slice
batch_size=1
batch_size = 1
if args.image_format=='':
if args.image_format == '':
##use default setting
image_format_name = {
'LVSA':'LVSA/LVSA_img_{}.nii.gz',
'VLA': 'VLA/VLA_img_{}.nii.gz',
'LVOT': 'VLA/VLA_img_{}.nii.gz',
'4CH': '4CH/4CH_img_{}.nii.gz',
}[sequence_name]
'LVSA': 'LVSA/LVSA_img_{}.nii.gz',
'VLA': 'VLA/VLA_img_{}.nii.gz',
'LVOT': 'VLA/VLA_img_{}.nii.gz',
'4CH': '4CH/4CH_img_{}.nii.gz',
}[sequence_name]
else:
image_format_name=args.image_format
readable_frames=['ED','ES']
image_format_name = args.image_format
readable_frames = ['ED', 'ES']
##MODEL CONFIG
model_path = {'LVSA':'checkpoints/Unet_LVSA_best.pkl',
'VLA':'checkpoints/Unet_VLA_best.pkl',
'LVOT':'checkpoints/Unet_LVOT_best.pkl',
'4CH':'checkpoints/Unet_4CH_best.pkl'
model_path = {'LVSA': 'checkpoints/Unet_LVSA_best.pkl',
'VLA': 'checkpoints/Unet_VLA_best.pkl',
'LVOT': 'checkpoints/Unet_LVOT_best.pkl',
'4CH': 'checkpoints/Unet_4CH_best.pkl'
}[sequence_name]
if_resample = not args.no_resample
## SAVE CONFIG
save_dir= args.save_folder_path
save_dir = args.save_folder_path
save_name_format = args.save_name_format
predict(sequence_name=sequence_name,root_dir=root_dir,readable_frames=readable_frames,image_format_name=image_format_name,model_path=model_path,batch_size=batch_size,if_resample=if_resample,save_dir=save_dir,save_format_name=save_name_format)
predict(sequence_name=sequence_name, root_dir=root_dir, readable_frames=readable_frames,
image_format_name=image_format_name, model_path=model_path, batch_size=batch_size,
if_resample=if_resample,save_dir=save_dir, save_format_name=save_name_format)
/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/test_data/10DG02038/LVSA/LVSA_img_ED.nii.gz
\ No newline at end of file
/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/test_data/10DH04454/LVSA/LVSA_img_ED.nii.gz
\ No newline at end of file
/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/test_data/10DH04454/LVSA/LVSA_img_ES.nii.gz
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment