Commit ea081abb authored by Huaqi Qiu's avatar Huaqi Qiu
Browse files

automated PEP-8 code reformat

parent fb9f2b8d
......@@ -5,42 +5,41 @@ import time
import SimpleITK as sitk
from torch.autograd import Variable
from dataset.cardiac_dataset import CARDIAC_Predict_DATASET
from dataset.utils import ReverseCropPad,CropPad
from dataset.utils import ReverseCropPad, CropPad
from dataset.utils import resample_by_ref
from model.unet import UNet
def save_predict(img,root_dir,patient_dir,file_name):
def save_predict(img, root_dir, patient_dir, file_name):
if not os.path.exists(root_dir):
os.makedirs(root_dir)
patient_dir=os.path.join(root_dir,patient_dir)
patient_dir = os.path.join(root_dir, patient_dir)
if not os.path.exists(patient_dir):
os.makedirs(patient_dir)
file_path=os.path.join(patient_dir,file_name)
sitk.WriteImage(img,file_path,True)
file_path = os.path.join(patient_dir, file_name)
sitk.WriteImage(img, file_path, True)
def link_image(origin_path,root_dir,patient_dir):
def link_image(origin_path, root_dir, patient_dir):
if not os.path.exists(root_dir):
os.mkdir(root_dir)
patient_dir=os.path.join(root_dir,patient_dir)
patient_dir = os.path.join(root_dir, patient_dir)
if not os.path.exists(patient_dir):
os.mkdir(patient_dir)
image_name = origin_path.split('/')[-1]
linked_name =image_name
linked_name = image_name
linked_path=os.path.join(patient_dir,linked_name)
print ('link path from {} to {}'.format(origin_path, linked_path))
linked_path = os.path.join(patient_dir, linked_name)
print('link path from {} to {}'.format(origin_path, linked_path))
os.system('ln -s {0} {1}'.format(origin_path, linked_path))
def predict(sequence_name,root_dir,image_format_name,
def predict(sequence_name, root_dir, image_format_name,
readable_frames,
model_path,
save_dir='./test_result',
save_format_name = 'seg_sa_{}.nii.gz',batch_size=4,if_resample=True):
save_format_name='seg_sa_{}.nii.gz', batch_size=4, if_resample=True):
'''
:param sequence_name: LVSA/4CH/VLA/LVOT
......@@ -61,7 +60,7 @@ def predict(sequence_name,root_dir,image_format_name,
'LVOT': 2
}[sequence_name]
save_dir=os.path.join(save_dir,sequence_name)
save_dir = os.path.join(save_dir, sequence_name)
## load model params from the path
model = UNet(input_channel=1, num_classes=number_classes)
......@@ -78,18 +77,17 @@ def predict(sequence_name,root_dir,image_format_name,
torch.cuda.set_device(0)
model.eval()
testset=CARDIAC_Predict_DATASET(root_dir,image_format_name=image_format_name,
if_resample=if_resample,
readable_frames=readable_frames,
new_spacing = [1.25, 1.25, 10],
keep_z_spacing=True)
time_records=[]
n_count=0
testset = CARDIAC_Predict_DATASET(root_dir, image_format_name=image_format_name,
if_resample=if_resample,
readable_frames=readable_frames,
new_spacing=[1.25, 1.25, 10],
keep_z_spacing=True)
time_records = []
n_count = 0
###iterate all images
for i in range(testset.get_size()):
print ('<------Loading data-------->')
for i in range(testset.get_size()):
print('<------Loading data-------->')
## for each patient
patient_data = testset.get_next_patient()
if patient_data is None:
......@@ -102,23 +100,23 @@ def predict(sequence_name,root_dir,image_format_name,
try:
frame_image = patient_data[frame]['image'] ##N*H*W
except:
print ('{} {} is missing'.format(patient_data['patient_id'],frame))
print('{} {} is missing'.format(patient_data['patient_id'], frame))
continue
aft_resample_shape=patient_data[frame]['aft_resample_shape']
print ('image shape after resample:',aft_resample_shape)
soft_prediction=np.zeros((aft_resample_shape[0],number_classes,aft_resample_shape[1],aft_resample_shape[2])) ##N*4*H*W
aft_resample_shape = patient_data[frame]['aft_resample_shape']
print('image shape after resample:', aft_resample_shape)
soft_prediction = np.zeros(
(aft_resample_shape[0], number_classes, aft_resample_shape[1], aft_resample_shape[2])) ##N*4*H*W
## DATA Preprocessing
## pad/crop image to a image size of 16
# X2, Y2 = int(math.ceil(aft_resample_shape[1] / 16.0)) * 16, int(math.ceil(aft_resample_shape[2] / 16.0)) * 16
X2, Y2=256,256
cPader=CropPad(X2,Y2,chw=True) ##central crop
reversecroppad=ReverseCropPad(aft_resample_shape[1],aft_resample_shape[2])
# X2, Y2 = int(math.ceil(aft_resample_shape[1] / 16.0)) * 16, int(math.ceil(aft_resample_shape[2] / 16.0)) * 16
X2, Y2 = 256, 256
cPader = CropPad(X2, Y2, chw=True) ##central crop
reversecroppad = ReverseCropPad(aft_resample_shape[1], aft_resample_shape[2])
## START Predicting
print('Segmenting {} frame: {}'.format(patient_data['patient_id'],frame))
print('Segmenting {} frame: {}'.format(patient_data['patient_id'], frame))
num_slices = aft_resample_shape[0]
n_batch = int(np.round(num_slices / batch_size))
for i in range(n_batch):
......@@ -134,13 +132,12 @@ def predict(sequence_name,root_dir,image_format_name,
input = Variable(input_tensor)
### predict every batch
batch_output= model(input)
batch_output = model(input)
batch_output_npy = batch_output.data.cpu().numpy() ##batch_size*n_cls*H'*W'
temp = reversecroppad(batch_output_npy.squeeze()) ##batch_size*n_cls*H*W
soft_prediction[i * batch_size:(i + 1) * batch_size] = temp
## soft recover predictions to original resolution
## soft recover predictions to original resolution
if if_resample:
stacked_list = []
## for each class prob, recover resolution
......@@ -155,45 +152,46 @@ def predict(sequence_name,root_dir,image_format_name,
predict_result = np.argmax(stacked_prob, axis=0).squeeze()
predict_result = np.uint8(predict_result)
else:
predict_result =np.argmax(soft_prediction, axis=1).squeeze()
predict_result = np.argmax(soft_prediction, axis=1).squeeze()
predict_result = np.uint8(predict_result)
end_process_time = time.time()-start_process_time
end_process_time = time.time() - start_process_time
time_records.append(end_process_time)
n_count+=1
n_count += 1
pid = patient_data['patient_id']
print ('Saving segmentation to {}/{}'.format(save_dir,pid))
if len(predict_result.shape)<len(patient_data[frame]['original_shape']):
predict_result=np.reshape(predict_result,patient_data[frame]['original_shape'])
print('Saving segmentation to {}/{}'.format(save_dir, pid))
if len(predict_result.shape) < len(patient_data[frame]['original_shape']):
predict_result = np.reshape(predict_result, patient_data[frame]['original_shape'])
post_im = sitk.GetImageFromArray(predict_result)
ref_im = patient_data[frame]['origin_itk_image']
post_im.CopyInformation(ref_im)
pred_file_name=save_format_name.format(frame)
save_predict(post_im,save_dir,pid,pred_file_name)
pred_file_name = save_format_name.format(frame)
save_predict(post_im, save_dir, pid, pred_file_name)
## create soft link to the image.
image_path = patient_data[frame]['temp_image_path']
link_image(origin_path=image_path,root_dir=save_dir,patient_dir=pid)
print ('Successfully segmented {:d} subjects '.format(n_count))
print ('Cost {:.3f}s per subjects ( exclude data loading time)'.format(np.mean(time_records)))
link_image(origin_path=image_path, root_dir=save_dir, patient_dir=pid)
print('Successfully segmented {:d} subjects '.format(n_count))
print('Cost {:.3f}s per subjects ( exclude data loading time)'.format(np.mean(time_records)))
if __name__=='__main__':
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Cardiac Seg Prediction Function')
parser.add_argument('--sequence',type=str, default='LVSA', choices=['VLA','LVSA', 'LVOT','4CH'],help='clarify cardiac image sequence/which view')
parser.add_argument('--sequence', type=str, default='LVSA', choices=['VLA', 'LVSA', 'LVOT', '4CH'],
help='clarify cardiac image sequence/which view')
parser.add_argument('--root_dir', default='test_data/', help='test data folder')
parser.add_argument('--image_format', default='', help='test image name format under each patient dir, e.g. if LVSA_img_ED.nii.gz then format is LVSA_img_{}.nii.gz')
parser.add_argument('--image_format', default='',
help='test image name format under each patient dir, e.g. if LVSA_img_ED.nii.gz then format is LVSA_img_{}.nii.gz')
parser.add_argument('--save_folder_path', default='./test_results', help='folder to save predicted masks')
parser.add_argument('--save_name_format', default='seg_sa_{}.nii.gz', help='save mask format. {} for ED/ES')
parser.add_argument('--no_resample', default=False, action='store_true',help='not resample image before prediction')
parser.add_argument('--no_resample', default=False, action='store_true',
help='not resample image before prediction')
parser.add_argument('--batch_size', default=1, help='how many image slices to be sent to predictor each iteration ')
parser.add_argument('--gpu', default=0, help='select GPU by masking shell environment variable CUDA_VISIBLE_DEVICES')
parser.add_argument('--gpu', default=0,
help='select GPU by masking shell environment variable CUDA_VISIBLE_DEVICES')
args = parser.parse_args()
......@@ -202,41 +200,41 @@ if __name__=='__main__':
### DATA CONFIG
root_dir = os.path.abspath(args.root_dir)
sequence_name=args.sequence
sequence_name = args.sequence
if sequence_name == 'LVSA':
## use batch size > 1, may accelerate the prediction process, but be careful of memory limit
batch_size=args.batch_size
batch_size = args.batch_size
else:
## 4CH, LVOT, VLA, only contains one slice
batch_size=1
batch_size = 1
if args.image_format=='':
if args.image_format == '':
##use default setting
image_format_name = {
'LVSA':'LVSA/LVSA_img_{}.nii.gz',
'VLA': 'VLA/VLA_img_{}.nii.gz',
'LVOT': 'VLA/VLA_img_{}.nii.gz',
'4CH': '4CH/4CH_img_{}.nii.gz',
}[sequence_name]
'LVSA': 'LVSA/LVSA_img_{}.nii.gz',
'VLA': 'VLA/VLA_img_{}.nii.gz',
'LVOT': 'VLA/VLA_img_{}.nii.gz',
'4CH': '4CH/4CH_img_{}.nii.gz',
}[sequence_name]
else:
image_format_name=args.image_format
readable_frames=['ED','ES']
image_format_name = args.image_format
readable_frames = ['ED', 'ES']
##MODEL CONFIG
model_path = {'LVSA':'checkpoints/Unet_LVSA_best.pkl',
'VLA':'checkpoints/Unet_VLA_best.pkl',
'LVOT':'checkpoints/Unet_LVOT_best.pkl',
'4CH':'checkpoints/Unet_4CH_best.pkl'
model_path = {'LVSA': 'checkpoints/Unet_LVSA_best.pkl',
'VLA': 'checkpoints/Unet_VLA_best.pkl',
'LVOT': 'checkpoints/Unet_LVOT_best.pkl',
'4CH': 'checkpoints/Unet_4CH_best.pkl'
}[sequence_name]
if_resample = not args.no_resample
## SAVE CONFIG
save_dir= args.save_folder_path
save_dir = args.save_folder_path
save_name_format = args.save_name_format
predict(sequence_name=sequence_name,root_dir=root_dir,readable_frames=readable_frames,image_format_name=image_format_name,model_path=model_path,batch_size=batch_size,if_resample=if_resample,save_dir=save_dir,save_format_name=save_name_format)
predict(sequence_name=sequence_name, root_dir=root_dir, readable_frames=readable_frames,
image_format_name=image_format_name, model_path=model_path, batch_size=batch_size, if_resample=if_resample,
save_dir=save_dir, save_format_name=save_name_format)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment