Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
cc215
Cardiac_Multi_view_segmentation
Commits
84efc713
Commit
84efc713
authored
Aug 02, 2021
by
cc215
💬
Browse files
add advchain submodule
parent
a8edd375
Changes
21
Hide whitespace changes
Inline
Side-by-side
.gitmodules
0 → 100644
View file @
84efc713
[submodule "advchain"]
path = advchain
url = https://github.com/cherise215/advchain.git
advchain
@
d65e3620
Subproject commit d65e36207ba1baffd39930b9e77cfe66e4b26059
common_utils/basic_operations.py
View file @
84efc713
...
...
@@ -4,66 +4,69 @@
# Enter steps here
import
torch
import
numpy
as
np
def
switch_kv_in_dict
(
mydict
):
switched_dict
=
{
y
:
x
for
x
,
y
in
mydict
.
items
()}
return
switched_dict
def
unit_normalize
(
d
):
d_abs_max
=
torch
.
max
(
torch
.
abs
(
d
.
view
(
d
.
size
(
0
),
-
1
)),
1
,
keepdim
=
True
)[
0
].
view
(
d
.
size
(
0
),
1
,
1
,
1
)
# print(d_abs_max.size())
d
/=
(
1e-20
+
d_abs_max
)
#
# d' =d/d_max
d
/=
(
1e-20
+
d_abs_max
)
# d' =d/d_max
d
/=
torch
.
sqrt
(
1e-6
+
torch
.
sum
(
torch
.
pow
(
d
,
2.0
),
tuple
(
range
(
1
,
len
(
d
.
size
()))),
keepdim
=
True
))
##
d'/sqrt(d'^2)
torch
.
pow
(
d
,
2.0
),
tuple
(
range
(
1
,
len
(
d
.
size
()))),
keepdim
=
True
))
#
d'/sqrt(d'^2)
# print(torch.norm(d.view(d.size(0), -1), dim=1))
return
d
def
intensity_norm_fn
(
intensity_norm_type
):
if
intensity_norm_type
==
'min_max'
:
if
intensity_norm_type
==
'min_max'
:
return
rescale_intensity
elif
intensity_norm_type
==
'z_score'
:
elif
intensity_norm_type
==
'z_score'
:
return
z_score_intensity
else
:
raise
ValueError
def
rescale_intensity
(
data
,
new_min
=
0
,
new_max
=
1
,
eps
=
1e-20
):
def
rescale_intensity
(
data
,
new_min
=
0
,
new_max
=
1
,
eps
=
1e-20
):
'''
rescale pytorch batch data
:param data: N*1*H*W
:return: data with intensity ranging from 0 to 1
'''
bs
,
c
,
h
,
w
=
data
.
size
(
0
),
data
.
size
(
1
),
data
.
size
(
2
),
data
.
size
(
3
)
bs
,
c
,
h
,
w
=
data
.
size
(
0
),
data
.
size
(
1
),
data
.
size
(
2
),
data
.
size
(
3
)
data
=
data
.
view
(
bs
*
c
,
-
1
)
old_max
=
torch
.
max
(
data
,
dim
=
1
,
keepdim
=
True
).
values
old_min
=
torch
.
min
(
data
,
dim
=
1
,
keepdim
=
True
).
values
new_data
=
(
data
-
old_min
)
/
(
old_max
-
old_min
+
eps
)
*
(
new_max
-
new_min
)
+
new_min
new_data
=
(
data
-
old_min
)
/
(
old_max
-
old_min
+
eps
)
*
\
(
new_max
-
new_min
)
+
new_min
new_data
=
new_data
.
view
(
bs
,
c
,
h
,
w
)
return
new_data
def
z_score_intensity
(
data
):
'''
rescale pytorch batch data
:param data: N*c*H*W
:return: data with intensity with zero mean dnd 1 std.
'''
bs
,
c
,
h
,
w
=
data
.
size
(
0
),
data
.
size
(
1
),
data
.
size
(
2
),
data
.
size
(
3
)
bs
,
c
,
h
,
w
=
data
.
size
(
0
),
data
.
size
(
1
),
data
.
size
(
2
),
data
.
size
(
3
)
data
=
data
.
view
(
bs
*
c
,
-
1
)
mean
=
torch
.
mean
(
data
,
dim
=
1
,
keepdim
=
True
)
data_dmean
=
data
-
mean
.
detach
()
std
=
torch
.
std
(
data_dmean
,
dim
=
1
,
keepdim
=
True
)
std
=
std
.
detach
()
std
[
abs
(
std
)
==
0
]
=
1
std
[
abs
(
std
)
==
0
]
=
1
new_data
=
(
data_dmean
)
/
(
std
)
new_data
=
new_data
.
view
(
bs
,
c
,
h
,
w
)
return
new_data
def
transform2tensor
(
cPader
,
img_slice
,
if_z_score
=
False
):
def
transform2tensor
(
cPader
,
img_slice
,
if_z_score
=
False
):
'''
transform npy data to torch tensor
:param cPader:pad image to be divided by 16
...
...
@@ -74,76 +77,74 @@ def transform2tensor(cPader, img_slice,if_z_score=False):
###
new_img_slice
=
cPader
(
img_slice
)
#
#
normalize data
new_img_slice
=
new_img_slice
*
1.0
#
#
N*H*W
# normalize data
new_img_slice
=
new_img_slice
*
1.0
#
N*H*W
new_input_mean
=
np
.
mean
(
new_img_slice
,
axis
=
(
1
,
2
),
keepdims
=
True
)
if
if_z_score
:
new_img_slice
-=
new_input_mean
new_std
=
np
.
std
(
new_img_slice
,
axis
=
(
1
,
2
),
keepdims
=
True
)
if
abs
(
new_std
-
0
)
<
1e-3
:
new_std
=
1
if
abs
(
new_std
-
0
)
<
1e-3
:
new_std
=
1
new_img_slice
/=
(
new_std
)
else
:
##print ('0-1 rescaling')
min_val
=
np
.
min
(
new_img_slice
,
axis
=
(
1
,
2
),
keepdims
=
True
)
max_val
=
np
.
max
(
new_img_slice
,
axis
=
(
1
,
2
),
keepdims
=
True
)
new_img_slice
=
(
new_img_slice
-
min_val
)
/
(
max_val
-
min_val
+
1e-10
)
min_val
=
np
.
min
(
new_img_slice
,
axis
=
(
1
,
2
),
keepdims
=
True
)
max_val
=
np
.
max
(
new_img_slice
,
axis
=
(
1
,
2
),
keepdims
=
True
)
new_img_slice
=
(
new_img_slice
-
min_val
)
/
(
max_val
-
min_val
+
1e-10
)
new_img_slice
=
new_img_slice
[:,
np
.
newaxis
,
:,
:]
#
#
transform to tensor
#
transform to tensor
new_image_tensor
=
torch
.
from_numpy
(
new_img_slice
).
float
()
return
new_image_tensor
def
construct_input
(
segmentation
,
image
=
None
,
num_classes
=
None
,
temperature
=
1.0
,
apply_softmax
=
True
,
is_labelmap
=
False
,
smooth_label
=
False
,
shuffle
=
False
,
use_gpu
=
True
):
def
construct_input
(
segmentation
,
image
=
None
,
num_classes
=
None
,
temperature
=
1.0
,
apply_softmax
=
True
,
is_labelmap
=
False
,
smooth_label
=
False
,
use_gpu
=
True
):
"""
concat image and segmentation toghether to form an input to an external assessor
Args:
image ([4d float tensor]): a of batch of images N(Ch)HW, Ch is the image channel
segmentation ([4d float tensor] or 3d label map): corresponding segmentation map NCHW or 3 one hotmap NHW
shuffle (bool, optional): if true, it will shuffle the input image and segmentation before concat. Defaults to False.
"""
assert
(
apply_softmax
and
is_labelmap
)
is
False
if
not
is_labelmap
:
batch_size
,
h
,
w
=
segmentation
.
size
(
0
),
segmentation
.
size
(
2
),
segmentation
.
size
(
3
)
else
:
batch_size
,
h
,
w
=
segmentation
.
size
(
0
),
segmentation
.
size
(
1
),
segmentation
.
size
(
2
)
if
not
is_labelmap
:
batch_size
,
h
,
w
=
segmentation
.
size
(
0
),
segmentation
.
size
(
2
),
segmentation
.
size
(
3
)
else
:
batch_size
,
h
,
w
=
segmentation
.
size
(
0
),
segmentation
.
size
(
1
),
segmentation
.
size
(
2
)
device
=
torch
.
device
(
'cuda'
)
if
use_gpu
else
torch
.
device
(
'cpu'
)
if
not
is_labelmap
:
if
apply_softmax
:
assert
len
(
segmentation
.
size
())
==
4
segmentation
=
segmentation
/
temperature
softmax_predict
=
torch
.
softmax
(
segmentation
,
dim
=
1
)
segmentation
=
softmax_predict
assert
len
(
segmentation
.
size
())
==
4
segmentation
=
segmentation
/
temperature
softmax_predict
=
torch
.
softmax
(
segmentation
,
dim
=
1
)
segmentation
=
softmax_predict
else
:
#
#
make onehot maps
assert
num_classes
is
not
None
,
'please specify num_classes'
# make onehot maps
assert
num_classes
is
not
None
,
'please specify num_classes'
flatten_y
=
segmentation
.
view
(
batch_size
*
h
*
w
,
1
)
y_onehot
=
torch
.
zeros
(
batch_size
*
h
*
w
,
num_classes
,
dtype
=
torch
.
float32
,
device
=
device
)
y_onehot
=
torch
.
zeros
(
batch_size
*
h
*
w
,
num_classes
,
dtype
=
torch
.
float32
,
device
=
device
)
y_onehot
.
scatter_
(
1
,
flatten_y
,
1
)
y_onehot
=
y_onehot
.
view
(
batch_size
,
h
,
w
,
num_classes
)
y_onehot
=
y_onehot
.
permute
(
0
,
3
,
1
,
2
)
y_onehot
.
requires_grad
=
False
y_onehot
=
y_onehot
.
view
(
batch_size
,
h
,
w
,
num_classes
)
y_onehot
=
y_onehot
.
permute
(
0
,
3
,
1
,
2
)
y_onehot
.
requires_grad
=
False
if
smooth_label
:
#
#
add noise to labels
smooth_factor
=
torch
.
rand
(
1
,
device
=
device
)
*
0.2
y_onehot
[
y_onehot
==
1
]
=
1
-
smooth_factor
y_onehot
[
y_onehot
==
0
]
=
smooth_factor
/
(
num_classes
-
1
)
# add noise to labels
smooth_factor
=
torch
.
rand
(
1
,
device
=
device
)
*
0.2
y_onehot
[
y_onehot
==
1
]
=
1
-
smooth_factor
y_onehot
[
y_onehot
==
0
]
=
smooth_factor
/
(
num_classes
-
1
)
segmentation
=
y_onehot
if
shuffle
and
image
is
not
None
:
## shuffle images in a batch, such that the segmentations do not match anymore.
image
=
shuffle_tensor
(
image
)
if
image
is
not
None
:
tuple
=
torch
.
cat
([
segmentation
,
image
],
dim
=
1
)
tuple
=
torch
.
cat
([
segmentation
,
image
],
dim
=
1
)
return
tuple
else
:
return
segmentation
\ No newline at end of file
return
segmentation
common_utils/load_args.py
View file @
84efc713
...
...
@@ -91,6 +91,6 @@ def plot_training_results(model_dir, plot_history):
plt
.
clf
()
if
__name__
==
'__main__'
:
params
=
Params
(
'/vol/medic01/users/cc215/Dropbox/projects/DeformADA/configs/gat_loss.json'
)
print
(
params
.
dict
)
\ No newline at end of file
# if __name__ =='__main__':
# params = Params('/vol/medic01/users/cc215/Dropbox/projects/DeformADA/configs/gat_loss.json')
# print (params.dict)
common_utils/loss.py
View file @
84efc713
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
import
torch.nn
as
nn
from
torch.autograd
import
Variable
def
cross_entropy_2D
(
input
,
target
,
weight
=
None
,
size_average
=
True
,
mask
=
None
):
"""[summary]
calc cross entropy loss computed on 2D images
Args:
input ([torch tensor]): [4d logit] in the format of NCHW
target ([torch tensor]): 3D labelmap or 4d logit (before softmax), in the format of NCHW
weight ([type], optional): weights for classes. Defaults to None.
size_average (bool, optional): take the average across the spatial domain. Defaults to True.
Raises:
NotImplementedError: [description]
Returns:
[type]: [description]
"""
n
,
c
,
h
,
w
=
input
.
size
()
log_p
=
F
.
log_softmax
(
input
,
dim
=
1
)
log_p
=
log_p
.
transpose
(
1
,
2
).
transpose
(
2
,
3
).
contiguous
().
view
(
-
1
,
c
)
if
mask
is
None
:
mask
=
torch
.
ones_like
(
log_p
,
device
=
log_p
.
device
)
##
mask
=
mask
.
view
(
-
1
,
c
)
mask_region_size
=
torch
.
sum
(
mask
[:,
0
])
if
len
(
target
.
size
())
==
3
:
target
=
target
.
view
(
target
.
numel
())
if
not
weight
is
None
:
## sum(weight) =C, for numerical stability.
weight
=
torch
.
softmax
(
weight
,
dim
=
0
)
*
c
loss_vector
=
F
.
nll_loss
(
log_p
,
target
,
weight
=
weight
,
reduce
=
False
)
loss_vector
=
loss_vector
*
mask
[:,
0
]
loss
=
torch
.
sum
(
loss_vector
)
if
size_average
:
loss
/=
float
(
mask_region_size
)
## /N*H'*W'
elif
len
(
target
.
size
())
==
4
:
## ce loss=-qlog(p)
reference
=
F
.
softmax
(
target
,
dim
=
1
)
#M,C
reference
=
reference
.
transpose
(
1
,
2
).
transpose
(
2
,
3
).
contiguous
().
view
(
-
1
,
c
)
#M,C
if
weight
is
None
:
plogq
=
torch
.
sum
(
reference
*
log_p
*
mask
,
dim
=
1
)
plogq
=
torch
.
sum
(
plogq
)
if
size_average
:
plogq
/=
float
(
mask_region_size
)
else
:
## sum(weight) =C
weight
=
torch
.
softmax
(
weight
,
dim
=
0
)
*
c
plogq_class_wise
=
reference
*
log_p
*
mask
plogq_sum_class
=
0.
for
i
in
range
(
c
):
plogq_sum_class
+=
torch
.
sum
(
plogq_class_wise
[:,
i
]
*
weight
[
i
])
plogq
=
plogq_sum_class
if
size_average
:
plogq
/=
float
(
mask_region_size
)
# only average loss on the mask entries with value =1
loss
=-
1
*
plogq
else
:
raise
NotImplementedError
return
loss
class
SoftDiceLoss
(
nn
.
Module
):
### Dice loss: code is from https://github.com/ozan-oktay/Attention-Gated-Networks/blob/master/models/layers/loss
# .py
def
__init__
(
self
,
n_classes
,
use_gpu
=
True
,
squared_union
=
False
):
super
(
SoftDiceLoss
,
self
).
__init__
()
self
.
one_hot_encoder
=
One_Hot
(
n_classes
,
use_gpu
).
forward
self
.
n_classes
=
n_classes
self
.
squared_union
=
squared_union
def
forward
(
self
,
input
,
target
,
weight
=
None
):
smooth
=
0.01
batch_size
=
input
.
size
(
0
)
input
=
F
.
softmax
(
input
,
dim
=
1
).
view
(
batch_size
,
self
.
n_classes
,
-
1
)
if
len
(
target
.
size
())
==
3
:
target
=
self
.
one_hot_encoder
(
target
).
contiguous
().
view
(
batch_size
,
self
.
n_classes
,
-
1
)
elif
len
(
target
.
size
())
==
4
and
target
.
size
(
1
)
==
input
.
size
(
1
):
target
=
F
.
softmax
(
target
,
dim
=
1
).
view
(
batch_size
,
self
.
n_classes
,
-
1
)
target
=
target
.
view
(
batch_size
,
self
.
n_classes
,
-
1
)
else
:
print
(
'the shapes for input and target do not match, input:{} target:{}'
.
format
(
str
(
input
.
size
())),
str
(
target
.
size
()))
raise
ValueError
inter
=
torch
.
sum
(
input
*
target
,
2
)
if
self
.
squared_union
:
##2pq/(|p|^2+|q|^2)
union
=
torch
.
sum
(
input
**
2
,
2
)
+
torch
.
sum
(
target
**
2
,
2
)
else
:
##2pq/(|p|+|q|)
union
=
torch
.
sum
(
input
,
2
)
+
torch
.
sum
(
target
,
2
)
score
=
torch
.
sum
((
2.0
*
inter
+
smooth
)
/
(
union
+
smooth
))
score
=
1.0
-
score
/
(
float
(
batch_size
)
*
float
(
self
.
n_classes
))
return
score
def
calc_segmentation_mse_consistency
(
input
,
target
):
loss
=
calc_segmentation_consistency
(
output
=
input
,
reference
=
target
,
divergence_types
=
[
'mse'
],
divergence_weights
=
[
1.0
],
class_weights
=
None
,
mask
=
None
)
return
loss
def
calc_segmentation_kl_consistency
(
input
,
target
):
loss
=
calc_segmentation_consistency
(
output
=
input
,
reference
=
target
,
divergence_types
=
[
'kl'
],
divergence_weights
=
[
1.0
],
class_weights
=
None
,
mask
=
None
)
return
loss
import
numpy
as
np
def
calc_segmentation_consistency
(
output
,
reference
,
divergence_types
=
[
'kl'
,
'contour'
],
divergence_weights
=
[
1.0
,
0.5
],
mask
=
None
):
def
calc_segmentation_consistency
(
output
,
reference
,
divergence_types
=
[
'kl'
,
'contour'
],
divergence_weights
=
[
1.0
,
0.5
],
class_weights
=
None
,
scales
=
[
0
],
mask
=
None
,
is_gt
=
False
):
"""
measuring the difference between two predictions (network logits before softmax)
Args:
output (torch tensor 4d): network predicts: NCHW (after perturbation)
reference (torch tensor 4d): network references: NCHW (before perturbation)
divergence_types (list, string): specify loss types. Defaults to ['kl','contour'].
divergence_weights (list, float): specify coefficients for each loss above. Defaults to [1.0,0.5].
scales (list of int): specify a list of downsampling rates so that losses will be calculated on different scales. Defaults to [0].
mask ([tensor], 0-1 onehotmap): [N*1*H*W]. No losses on the elements with mask=0. Defaults to None.
divergence_types (list, string): specifying loss types. Defaults to ['kl','contour'].
divergence_weights (list, float): specifying coefficients for each loss above. Defaults to [1.0,0.5].
class_weights (list of scalars): specifying class weights for loss computation
scales (list of int): specifying a list of downsampling rates so that losses will be calculated on different scales. Defaults to [0].
mask ([tensor], 0-1 onehotmap): [N*1*H*W]. disable loss computation on corresponding elements with mask=0. Defaults to None.
is_gt: bool, if true, will use one-hot encoded `reference' instead of probabilities maps after appying softmax to compute the consistency loss
Raises:
NotImplementedError: when loss name is not in ['kl','mse','contour']
Returns:
loss (tensor float):
"""
if
class_weights
is
not
None
:
raise
NotImplemented
dist
=
0.
num_classes
=
output
.
size
(
1
)
reference
=
reference
.
detach
()
num_classes
=
reference
.
size
(
1
)
if
mask
is
None
:
#
#
apply masks so that only gradients on
certain
regions will be backpropagated.
# apply masks so that only gradients on
non-zero
regions
will be backpropagated.
mask
=
torch
.
ones_like
(
output
).
float
().
to
(
reference
.
device
)
output_reference
=
reference
output_new
=
output
for
divergence_type
,
d_weight
in
zip
(
divergence_types
,
divergence_weights
):
for
scale
in
scales
:
if
scale
>
0
:
output_reference
=
torch
.
nn
.
AvgPool2d
(
2
**
scale
)(
reference
)
output_new
=
torch
.
nn
.
AvgPool2d
(
2
**
scale
)(
output
)
else
:
output_reference
=
reference
output_new
=
output
for
divergence_type
,
d_weight
in
zip
(
divergence_types
,
divergence_weights
):
loss
=
0.
if
divergence_type
==
'kl'
:
if
divergence_type
==
'kl'
:
'''
standard kl loss
'''
loss
=
kl_divergence
(
pred
=
output_new
,
reference
=
output_reference
.
detach
(),
mask
=
mask
)
elif
divergence_type
==
'ce'
:
loss
=
cross_entropy_2D
(
input
=
output_new
,
target
=
output_reference
.
detach
(),
mask
=
mask
)
elif
divergence_type
==
'mse'
:
target_pred
=
torch
.
softmax
(
output_reference
,
dim
=
1
)
loss
=
kl_divergence
(
pred
=
output_new
,
reference
=
output_reference
,
mask
=
mask
,
is_gt
=
is_gt
)
elif
divergence_type
==
'mse'
:
n
,
h
,
w
=
output_new
.
size
(
0
),
output_new
.
size
(
2
),
output_new
.
size
(
3
)
if
not
is_gt
:
target_pred
=
torch
.
softmax
(
output_reference
,
dim
=
1
)
else
:
target_pred
=
output_reference
input_pred
=
torch
.
softmax
(
output_new
,
dim
=
1
)
loss
=
torch
.
nn
.
MSELoss
(
reduction
=
'sum'
)(
target
=
target_pred
*
mask
,
input
=
input_pred
*
mask
)
loss
=
loss
/
torch
.
sum
(
mask
[:,
0
])
elif
divergence_type
==
'contour'
:
## contour-based loss
target_pred
=
torch
.
softmax
(
output_reference
,
dim
=
1
)
loss
=
torch
.
nn
.
MSELoss
(
reduction
=
'sum'
)(
target
=
target_pred
*
mask
,
input
=
input_pred
*
mask
)
loss
=
loss
/
(
n
*
h
*
w
)
elif
divergence_type
==
'contour'
:
# contour-based loss
if
not
is_gt
:
target_pred
=
torch
.
softmax
(
output_reference
,
dim
=
1
)
else
:
target_pred
=
output_reference
input_pred
=
torch
.
softmax
(
output_new
,
dim
=
1
)
cnt
=
0
for
i
in
range
(
1
,
num_classes
):
cnt
+=
1
loss
+=
contour_loss
(
input
=
input_pred
[:,[
i
],],
target
=
(
target_pred
[:,[
i
]]).
detach
(),
ignore_background
=
False
,
mask
=
mask
,
one_hot_target
=
False
)
# if cnt>0:loss/=cnt
for
i
in
range
(
1
,
num_classes
):
cnt
+=
1
loss
+=
contour_loss
(
input
=
input_pred
[:,
[
i
],
],
target
=
(
target_pred
[:,
[
i
]]),
ignore_background
=
False
,
mask
=
mask
,
one_hot_target
=
False
)
if
cnt
>
0
:
loss
/=
cnt
else
:
raise
NotImplementedError
# print ('{}:{}'.format(divergence_type,loss.item()))
dist
+=
(
d_weight
*
loss
)
return
dist
dist
+=
2
**
scale
*
(
d_weight
*
loss
)
return
dist
/
(
1.0
*
len
(
scales
))
def
calc_segmentation_mse_consistency
(
input
,
target
):
loss
=
calc_segmentation_consistency
(
output
=
input
,
reference
=
target
,
divergence_types
=
[
'mse'
],
divergence_weights
=
[
1.0
],
class_weights
=
None
,
mask
=
None
)
return
loss
def
calc_segmentation_kl_consistency
(
input
,
target
):
loss
=
calc_segmentation_consistency
(
output
=
input
,
reference
=
target
,
divergence_types
=
[
'kl'
],
divergence_weights
=
[
1.0
],
class_weights
=
None
,
mask
=
None
)
return
loss
def
contour_loss
(
input
,
target
,
size_average
=
True
,
use_gpu
=
True
,
ignore_background
=
True
,
one_hot_target
=
True
,
mask
=
None
):
def
contour_loss
(
input
,
target
,
use_gpu
=
True
,
ignore_background
=
True
,
one_hot_target
=
True
,
mask
=
None
):
'''
calc the contour loss across object boundaries (WITHOUT background class)
:param input: NDArray. N*num_classes*H*W : pixelwise probs. for each class e.g. the softmax output from a neural network
:param target: ground truth labels (NHW) or one-hot ground truth maps N*C*H*W
:param size_average: batch mean
:param use_gpu:boolean. default: True, use GPU.
:param ignore_background:boolean, ignore the background class. default: True
:param one_hot_target: boolean. if true, will first convert the target from NHW to NCHW. Default: True.
:return:
'''
n
,
num_classes
,
h
,
w
=
input
.
size
(
0
),
input
.
size
(
1
),
input
.
size
(
2
),
input
.
size
(
3
)
n
,
num_classes
,
h
,
w
=
input
.
size
(
0
),
input
.
size
(
1
),
input
.
size
(
2
),
input
.
size
(
3
)
if
one_hot_target
:
onehot_mapper
=
One_Hot
(
depth
=
num_classes
,
use_gpu
=
use_gpu
)
target
=
target
.
long
()
onehot_target
=
onehot_mapper
(
target
).
contiguous
().
view
(
input
.
size
(
0
),
num_classes
,
input
.
size
(
2
),
input
.
size
(
3
))
onehot_target
=
onehot_mapper
(
target
).
contiguous
().
view
(
input
.
size
(
0
),
num_classes
,
input
.
size
(
2
),
input
.
size
(
3
))
else
:
onehot_target
=
target
assert
onehot_target
.
size
()
==
input
.
size
(),
'pred size: {} must match target size: {}'
.
format
(
str
(
input
.
size
()),
str
(
onehot_target
.
size
()))
onehot_target
=
target
assert
onehot_target
.
size
()
==
input
.
size
(),
'pred size: {} must match target size: {}'
.
format
(
str
(
input
.
size
()),
str
(
onehot_target
.
size
()))
if
mask
is
None
:
#
#
apply masks so that only gradients on certain regions will be backpropagated.
# apply masks so that only gradients on certain regions will be backpropagated.
mask
=
torch
.
ones_like
(
input
).
long
().
to
(
input
.
device
)
mask
.
requires_grad
=
False
else
:
pass
# print ('mask applied')
if
ignore_background
:
object_classes
=
num_classes
-
1
target_object_maps
=
onehot_target
[:,
1
:].
float
()
input
=
input
[:,
1
:]
else
:
target_object_maps
=
onehot_target
object_classes
=
num_classes
target_object_maps
=
onehot_target
object_classes
=
num_classes
x_filter
=
np
.
array
([[
1
,
0
,
-
1
],
[
2
,
0
,
-
2
],
...
...
@@ -224,7 +147,7 @@ def contour_loss(input, target, size_average=True, use_gpu=True,ignore_backgroun
y_filter
=
np
.
repeat
(
y_filter
,
axis
=
1
,
repeats
=
object_classes
)
y_filter
=
np
.
repeat
(
y_filter
,
axis
=
0
,
repeats
=
object_classes
)
conv_y
=
nn
.
Conv2d
(
in_channels
=
object_classes
,
out_channels
=
object_classes
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
)
bias
=
False
)
conv_y
.
weight
=
nn
.
Parameter
(
torch
.
from_numpy
(
y_filter
).
float
())
if
use_gpu
:
...
...
@@ -235,41 +158,44 @@ def contour_loss(input, target, size_average=True, use_gpu=True,ignore_backgroun
for
param
in
conv_x
.
parameters
():
param
.
requires_grad
=
False
g_x_pred
=
conv_x
(
input
)
*
mask
[:,:
object_classes
]
g_y_pred
=
conv_y
(
input
)
*
mask
[:,:
object_classes
]
g_y_truth
=
conv_y
(
target_object_maps
)
*
mask
[:,:
object_classes
]
g_x_truth
=
conv_x
(
target_object_maps
)
*
mask
[:,:
object_classes
]
## mse loss
loss
=
torch
.
nn
.
MSELoss
(
reduction
=
'sum'
)(
input
=
g_x_pred
,
target
=
g_x_truth
)
+
torch
.
nn
.
MSELoss
(
reduction
=
'sum'
)(
input
=
g_y_pred
,
target
=
g_y_truth
)
loss
/=
torch
.
sum
(
mask
[:,
0
,:,:])
g_x_pred
=
conv_x
(
input
)
*
mask
[:,
:
object_classes
]
g_y_pred
=
conv_y
(
input
)
*
mask
[:,
:
object_classes
]
g_y_truth
=
conv_y
(
target_object_maps
)
*
mask
[:,
:
object_classes
]
g_x_truth
=
conv_x
(
target_object_maps
)
*
mask
[:,
:
object_classes
]
# mse loss
loss
=
0.5
*
(
torch
.
nn
.
MSELoss
(
reduction
=
'mean'
)(
input
=
g_x_pred
,
target
=
g_x_truth
)
+
torch
.
nn
.
MSELoss
(
reduction
=
'mean'
)(
input
=
g_y_pred
,
target
=
g_y_truth
))
return
loss
def
kl_divergence
(
reference
,
pred
,
mask
=
None
):
def
kl_divergence
(
reference
,
pred
,
mask
=
None
,
is_gt
=
False
):
'''
calc the kl div distance between two outputs p and q from a network/model: p(y1|x1).p(y2|x2).
:param reference p: directly output from network using origin input without softmax
:param output q: approximate output: directly output from network using perturbed input without softmax
:param is_gt: is onehot maps
:return: kl divergence: DKL(P||Q) = mean(\sum_1
\t
o C (p^c log (p^c|q^c)))
'''
p
=
reference
q
=
pred
p_logit
=
F
.
softmax
(
p
,
dim
=
1
)
if
mask
is
None
:
mask
=
torch
.
ones_like
(
p_logit
,
device
=
p_logit
.
device
)
mask
.
requires_grad
=
False
cls_plogp
=
mask
*
p_logit
*
F
.
log_softmax
(
p
,
dim
=
1
)
cls_plogq
=
mask
*
p_logit
*
F
.
log_softmax
(
q
,
dim
=
1
)
plogp
=
torch
.
sum
(
cls_plogp
,
dim
=
1
,
keepdim
=
True
)
plogq
=
torch
.
sum
(
cls_plogq
,
dim
=
1
,
keepdim
=
True
)
q
=
pred
kl_loss
=
torch
.
sum
(
plogp
-
plogq
)
kl_loss
/=
torch
.
sum
(
mask
[:,
0
,:,:])
if
mask
is
None
:
mask
=
torch
.
ones_like
(
q
,
device
=
q
.
device
)
mask
.
requires_grad
=
False
if
not
is_gt
:
p
=
F
.
softmax
(
reference
,
dim
=
1
)
log_p
=
F
.
log_softmax
(
reference
,
dim
=
1
)
else
:
p
=
torch
.
where
(
reference
==
0
,
1e-8
,
1
-
1e-8
)
log_p
=
torch
.
log
(
p
)
# avoid NAN when log(0)
cls_plogp
=
mask
*
(
p
*
log_p
)
cls_plogq
=
mask
*
(
p
*
F
.
log_softmax
(
q
,
dim
=
1
))
plogp
=
torch
.
sum
(
cls_plogp
,
dim
=
1
,
keepdim
=
True
)
plogq
=
torch
.
sum
(
cls_plogq
,
dim
=
1
,
keepdim
=
True
)
kl_loss
=
torch
.
mean
(
plogp
-
plogq
)
return
kl_loss
class
One_Hot
(
nn
.
Module
):
def
__init__
(
self
,
depth
,
use_gpu
=
True
):
super
(
One_Hot
,
self
).
__init__
()
...
...
@@ -289,3 +215,59 @@ class One_Hot(nn.Module):
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
"({})"
.
format
(
self
.
depth
)
def
cross_entropy_2D
(
input
,
target
,
weight
=
None
,
size_average
=
True
):
"""[summary]
calc cross entropy loss computed on 2D images
Args:
input ([torch tensor]): [4d logit] in the format of NCHW
target ([torch tensor]): 3D labelmap or 4d logit (before softmax), in the format of NCHW
weight ([type], optional): weights for classes. Defaults to None.
size_average (bool, optional): take the average across the spatial domain. Defaults to True.
Raises:
NotImplementedError: [description]
Returns: