Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
cc215
Cardiac_Multi_view_segmentation
Commits
73fa4709
Commit
73fa4709
authored
Mar 10, 2021
by
cc215
💬
Browse files
fix bug and add mc dropout prediction
parent
d8808541
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
model/base_segmentation_model.py
View file @
73fa4709
...
...
@@ -10,13 +10,13 @@ import gc
from
model.init_weight
import
init_weights
from
model.unet
import
UNet
from
model.model_utils
import
makeVariable
from
common_utils.loss
import
cross_entropy_2D
from
common_utils.metrics
import
runningScore
from
common_utils.save
import
save_list_results_as_png
class
SegmentationModel
(
nn
.
Module
):
def
__init__
(
self
,
network_type
,
in_channels
=
1
,
num_classes
=
2
,
encoder_dropout
=
None
,
decoder_dropout
=
None
,
use_gpu
=
True
,
lr
=
0.001
,
resume_path
=
None
,
):
...
...
@@ -39,8 +39,7 @@ class SegmentationModel(nn.Module):
self
.
num_classes
=
num_classes
self
.
lr
=
lr
self
.
in_channels
=
in_channels
self
.
encoder_dropout
=
encoder_dropout
if
isinstance
(
encoder_dropout
,
float
)
else
None
self
.
decoder_dropout
=
decoder_dropout
if
isinstance
(
encoder_dropout
,
float
)
else
None
self
.
decoder_dropout
=
decoder_dropout
if
isinstance
(
decoder_dropout
,
float
)
else
None
self
.
model
=
self
.
get_network_from_model_library
(
self
.
network_type
)
assert
not
self
.
model
is
None
,
'cannot find the model given the specified name'
...
...
@@ -137,8 +136,43 @@ class SegmentationModel(nn.Module):
with
torch
.
no_grad
():
output
=
self
.
model
.
forward
(
input
)
probs
=
torch
.
softmax
(
output
,
dim
=
1
)
torch
.
cuda
.
empty_cache
()
return
probs
def
MC_predict
(
self
,
input
,
n_times
=
5
,
decoder_dropout
=
0.1
,
disable_bn
=
False
):
assert
n_times
>=
1
## use MC dropout to get ensembled prediction
## enable dropout
if
self
.
decoder_dropout
is
None
or
self
.
decoder_dropout
!=
decoder_dropout
:
self
.
decoder_dropout
=
decoder_dropout
self
.
model
=
self
.
get_network_from_model_library
(
self
.
network_type
)
self
.
init_model
(
self
.
network_type
)
if
self
.
use_gpu
:
self
.
model
.
cuda
()
self
.
model
.
train
()
## fix batch norm
if
not
disable_bn
:
for
module
in
self
.
model
.
modules
():
# print(module)
if
isinstance
(
module
,
nn
.
BatchNorm2d
):
if
hasattr
(
module
,
'weight'
):
module
.
weight
.
requires_grad_
(
False
)
if
hasattr
(
module
,
'bias'
):
module
.
bias
.
requires_grad_
(
False
)
module
.
eval
()
# mc sampling
probs_list
=
[]
for
i
in
range
(
n_times
):
image
=
input
.
detach
()
output
=
self
.
model
.
forward
(
image
)
probs_i
=
torch
.
softmax
(
output
,
dim
=
1
)
probs_list
.
append
(
probs_i
)
torch
.
cuda
.
empty_cache
()
mean_probs
=
sum
(
probs_list
)
/
len
(
probs_list
)
return
mean_probs
,
probs_list
def
evaluate
(
self
,
input
,
targets_npy
):
'''
evaluate the model performance
...
...
predict.py
View file @
73fa4709
...
...
@@ -75,11 +75,11 @@ def predict(sequence_name, root_dir, image_format_name,
else
:
model
.
load_state_dict
(
torch
.
load
(
model_path
))
# Decide which device we want to run on
device
=
torch
.
device
(
"cuda:
0"
if
(
torch
.
cuda
.
is_available
())
else
"cpu"
)
device
=
torch
.
device
(
"cuda:
{}"
.
format
(
args
.
gpu
)
if
(
torch
.
cuda
.
is_available
())
else
"cpu"
)
if
device
.
type
==
'cuda'
:
model
.
cuda
()
torch
.
cuda
.
set_device
(
0
)
torch
.
cuda
.
set_device
(
args
.
gpu
)
model
.
eval
()
testset
=
CARDIAC_Predict_DATASET
(
root_dir
,
image_format_name
=
image_format_name
,
...
...
predict_single_LVSA.py
View file @
73fa4709
...
...
@@ -16,7 +16,7 @@ from common_utils.basic_operations import transform2tensor
def
predict
(
model_path
,
input_image_path
,
save_pred_path
,
batch_size
=
4
,
crop_size
=
256
,
if_resample
=
True
,
if_z_score
=
False
):
save_pred_path
=
None
,
batch_size
=
4
,
crop_size
=
256
,
if_resample
=
True
,
if_z_score
=
False
,
gpu_id
=
0
,
mc_dropout
=
0
,
decoder_dropout_rate
=
0.1
):
'''
:param model_path: path to the saved model parameters
...
...
@@ -26,21 +26,25 @@ def predict(model_path, input_image_path,
:param crop_size: the size for image ROI cropping, need to be divided by 16.
:param if_resample:if resamping image to a uniform pixel-spacing before predition
:param if_z_score: if rescale images to have zero mean and std deviation, by default, we rescale it to 0-1.
:return:
:return:
prediction: numpy array in 3D format NHW
'''
print
(
'<------Loading model-------->'
)
device
=
torch
.
device
(
"cuda:
0"
if
(
torch
.
cuda
.
is_available
())
else
"cpu"
)
device
=
torch
.
device
(
"cuda:
{}"
.
format
(
gpu_id
)
if
(
torch
.
cuda
.
is_available
())
else
"cpu"
)
if
device
.
type
==
'cuda'
:
use_gpu
=
True
print
(
'use gpu'
)
else
:
use_gpu
=
False
num_classes
=
4
model
=
SegmentationModel
(
network_type
=
'UNet_64'
,
in_channels
=
1
,
num_classes
=
num_classes
,
use_gpu
=
use_gpu
,
resume_path
=
model_path
)
decoder_dropout
=
None
if
mc_dropout
==
0
else
decoder_dropout_rate
model
=
SegmentationModel
(
network_type
=
'UNet_64'
,
in_channels
=
1
,
num_classes
=
num_classes
,
use_gpu
=
use_gpu
,
decoder_dropout
=
decoder_dropout
,
resume_path
=
model_path
)
model
.
eval
()
print
(
'<------Loading data-------->'
)
### read image ##
temp_image
=
sitk
.
ReadImage
(
input_image_path
)
original_shape
=
sitk
.
GetArrayFromImage
(
temp_image
).
shape
original_im_arr
=
sitk
.
GetArrayFromImage
(
temp_image
)
original_shape
=
original_im_arr
.
shape
temp_image
=
sitk
.
Cast
(
sitk
.
RescaleIntensity
(
temp_image
),
sitk
.
sitkFloat32
)
origin_spacing
=
temp_image
.
GetSpacing
()
...
...
@@ -53,8 +57,8 @@ def predict(model_path, input_image_path,
new_image
=
temp_image
## new image shape
aft_resample_shape
=
sitk
.
GetArrayFromImage
(
new_image
).
shape
npy_data
=
sitk
.
GetArrayFromImage
(
new_image
).
astype
(
float
)
aft_resample_shape
=
npy_data
.
shape
soft_prediction
=
np
.
zeros
((
aft_resample_shape
[
0
],
num_classes
,
aft_resample_shape
[
1
],
aft_resample_shape
[
2
]))
##N*4*H*W
print
(
'<------Batchwise prediction-------->'
)
...
...
@@ -79,7 +83,9 @@ def predict(model_path, input_image_path,
input
=
Variable
(
input_tensor
)
### predict every batch
batch_output
=
model
(
input
)
if
mc_dropout
>
0
:
batch_output
,
batch_output_list
=
model
.
MC_predict
(
input
,
n_times
=
mc_dropout
,
decoder_dropout
=
decoder_dropout_rate
)
else
:
batch_output
=
model
.
predict
(
input
)
batch_output_npy
=
batch_output
.
data
.
cpu
().
numpy
()
##batch_size*n_cls*H'*W'
temp
=
reversecroppad
(
batch_output_npy
.
squeeze
())
##batch_size*n_cls*H*W
soft_prediction
[
i
*
batch_size
:(
i
+
1
)
*
batch_size
]
=
temp
...
...
@@ -107,11 +113,13 @@ def predict(model_path, input_image_path,
if
len
(
predict_result
.
shape
)
<
len
(
original_shape
):
predict_result
=
np
.
reshape
(
predict_result
,
original_shape
)
print
(
'Saving segmentation to {}'
.
format
(
save_pred_path
))
post_im
=
sitk
.
GetImageFromArray
(
predict_result
)
ref_im
=
temp_image
post_im
.
CopyInformation
(
ref_im
)
sitk
.
WriteImage
(
post_im
,
save_pred_path
,
True
)
if
save_pred_path
is
not
None
:
print
(
'Saving segmentation to {}'
.
format
(
save_pred_path
))
post_im
=
sitk
.
GetImageFromArray
(
predict_result
)
ref_im
=
temp_image
post_im
.
CopyInformation
(
ref_im
)
sitk
.
WriteImage
(
post_im
,
save_pred_path
,
True
)
return
model
,
original_im_arr
,
predict_result
...
...
@@ -126,10 +134,15 @@ if __name__ == '__main__':
parser
.
add_argument
(
'-c'
,
'--crop_size'
,
type
=
int
,
default
=
192
,
help
=
"crop images to save memory"
)
parser
.
add_argument
(
'-z'
,
'--z_score'
,
action
=
"store_true"
,
default
=
False
,
help
=
"normalize the images to have zero mean and std deviation."
)
parser
.
add_argument
(
'-g'
,
'--gpu'
,
default
=
0
,
help
=
'select GPU by masking shell environment variable CUDA_VISIBLE_DEVICES'
)
parser
.
add_argument
(
'-d'
,
'--mc_dropout'
,
default
=
0
,
help
=
'if >0, it will apply MC dropout for d times, by default the dropout rate=0.1'
)
args
=
parser
.
parse_args
()
### GPU CONFIG
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
str
(
args
.
gpu
)
gpu_id
=
args
.
gpu
predict
(
args
.
model_path
,
args
.
input_image_path
,
args
.
output_segmentation_path
,
batch_size
=
args
.
batch_size
,
crop_size
=
args
.
crop_size
,
if_resample
=
True
,
if_z_score
=
args
.
z_score
)
\ No newline at end of file
args
.
output_segmentation_path
,
batch_size
=
args
.
batch_size
,
crop_size
=
args
.
crop_size
,
if_resample
=
True
,
if_z_score
=
args
.
z_score
,
gpu_id
=
gpu_id
,
mc_dropout
=
int
(
args
.
mc_dropout
))
\ No newline at end of file
test_prediction.ipynb
0 → 100644
View file @
73fa4709
This diff is collapsed.
Click to expand it.
This diff is collapsed.
Click to expand it.
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment