Commit 63b1f905 authored by cc215's avatar cc215 💬
Browse files

clean data

parent 73fa4709
{ "name": "Composite",
"data": {
"dataset_name":"UKBB" ,
"readable_frames": ["ED", "ES"],
"train_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/train",
"validate_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/validation",
"image_format_name": "sa_{frame}.nii.gz",
"label_format_name": "label_sa_{frame}.nii.gz",
"data_aug_policy" :"UKBB_advancedv2",
"if_resample": true,
"new_spacing": [1.25, 1.25, 10],
"keep_z_spacing": true,
"image_size":[224,224,1],
"label_size":[224,224],
"pad_size": [192,192,1],
"crop_size" :[192,192,1],
"num_classes": 4,
"use_cache": false,
"myocardium_only": false,
"ignore_black_slices": true
},
"segmentation_model": {
"network_type": "UNet_64",
"num_classes": 4,
"resume_path":"./checkpoints/Unet_LVSA_best.pkl",
"lr": 0.00001,
"n_epochs": 50,
"max_iteration": 500000,
"batch_size": 20,
"use_gpu": true,
"decoder_dropout": 0.2
},
"adversarial_augmentation":
{
"transformation_type":"composite"
}
,
"output":
{
"save_epoch_every_num_epochs":1,
"save_dir":"./result/Composite"
}
}
\ No newline at end of file
......@@ -90,9 +90,10 @@ class CARDIAC_Predict_DATASET(data.Dataset):
### read image ##
temp_image = sitk.ReadImage(temp_image_path)
original_shape=sitk.GetArrayFromImage(temp_image).shape
##convert to float format
temp_image = sitk.Cast(sitk.RescaleIntensity(temp_image), sitk.sitkFloat32)
temp_image_arr = sitk.GetArrayFromImage(temp_image)
original_shape=temp_image_arr.shape
##convert to float format
origin_spacing = temp_image.GetSpacing()
if self.if_resample:
## image resampling
......@@ -103,7 +104,7 @@ class CARDIAC_Predict_DATASET(data.Dataset):
new_image = temp_image
## new image shape
self.aft_resample_shape = sitk.GetArrayFromImage(new_image).shape
self.aft_resample_shape = temp_image_arr.shape
data = sitk.GetArrayFromImage(new_image).astype(float)
patient_id = self.patient_path_list[self.pid].split('/')[-1]
......
......@@ -42,23 +42,23 @@ def predict(model_path, input_image_path,
model.eval()
print('<------Loading data-------->')
### read image ##
temp_image = sitk.ReadImage(input_image_path)
original_im_arr=sitk.GetArrayFromImage(temp_image)
origin_image = sitk.ReadImage(input_image_path)
origin_image = sitk.Cast(sitk.RescaleIntensity(origin_image), sitk.sitkFloat32)
original_im_arr=sitk.GetArrayFromImage(origin_image)
original_shape=original_im_arr.shape
temp_image = sitk.Cast(sitk.RescaleIntensity(temp_image), sitk.sitkFloat32)
origin_spacing = temp_image.GetSpacing()
print('<------Preprocessing data-------->')
## image resampling ##
if if_resample:
new_image = resample_by_spacing(im=temp_image, new_spacing=[1.25,1.25,10],interpolator=sitk.sitkLinear,
new_image = resample_by_spacing(im=origin_image, new_spacing=[1.25,1.25,10],interpolator=sitk.sitkLinear,
keep_z_spacing=True)
else:
new_image = temp_image
new_image = origin_image
## new image shape
npy_data = sitk.GetArrayFromImage(new_image).astype(float)
aft_resample_shape = npy_data.shape
new_im_arr = sitk.GetArrayFromImage(new_image).astype(float)
aft_resample_shape = new_im_arr.shape
soft_prediction = np.zeros((aft_resample_shape[0], num_classes, aft_resample_shape[1], aft_resample_shape[2])) ##N*4*H*W
print('<------Batchwise prediction-------->')
......@@ -71,7 +71,7 @@ def predict(model_path, input_image_path,
batch_size = num_slices
n_batch = int(np.round(num_slices / batch_size))
for i in range(n_batch):
batch_data = npy_data[i * batch_size:(i + 1) * batch_size, :, :]
batch_data = new_im_arr[i * batch_size:(i + 1) * batch_size, :, :]
if batch_size == 1 and len(batch_data.shape) == 2:
batch_data = batch_data[np.newaxis, :, :]
......@@ -100,7 +100,7 @@ def predict(model_path, input_image_path,
one_class_im = sitk.GetImageFromArray(soft_prediction[:, i_class, :, :]) ##N*H*W
after_resampled_image = new_image
one_class_im.CopyInformation(after_resampled_image)
ref_im =temp_image
ref_im =origin_image
post_one_class_im = resample_by_ref(one_class_im, ref_im, interpolator=sitk.sitkLinear)
stacked_list.append(sitk.GetArrayFromImage(post_one_class_im))
stacked_prob = np.stack(stacked_list) # 4*N*H*W
......@@ -116,7 +116,7 @@ def predict(model_path, input_image_path,
if save_pred_path is not None:
print('Saving segmentation to {}'.format(save_pred_path))
post_im = sitk.GetImageFromArray(predict_result)
ref_im = temp_image
ref_im = origin_image
post_im.CopyInformation(ref_im)
sitk.WriteImage(post_im, save_pred_path, True)
return model,original_im_arr,predict_result
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment