Commit 84efc713 authored by cc215's avatar cc215 💬
Browse files

add advchain submodule

parent a8edd375
[submodule "advchain"]
path = advchain
url = https://github.com/cherise215/advchain.git
Subproject commit d65e36207ba1baffd39930b9e77cfe66e4b26059
......@@ -4,66 +4,69 @@
# Enter steps here
import torch
import numpy as np
def switch_kv_in_dict(mydict):
switched_dict = {y: x for x, y in mydict.items()}
return switched_dict
def unit_normalize(d):
d_abs_max = torch.max(
torch.abs(d.view(d.size(0), -1)), 1, keepdim=True)[0].view(
d.size(0), 1, 1, 1)
# print(d_abs_max.size())
d /= (1e-20 + d_abs_max) ## d' =d/d_max
d /= (1e-20 + d_abs_max) # d' =d/d_max
d /= torch.sqrt(1e-6 + torch.sum(
torch.pow(d, 2.0), tuple(range(1, len(d.size()))), keepdim=True)) ##d'/sqrt(d'^2)
torch.pow(d, 2.0), tuple(range(1, len(d.size()))), keepdim=True)) # d'/sqrt(d'^2)
# print(torch.norm(d.view(d.size(0), -1), dim=1))
return d
def intensity_norm_fn(intensity_norm_type):
if intensity_norm_type =='min_max':
if intensity_norm_type == 'min_max':
return rescale_intensity
elif intensity_norm_type =='z_score':
elif intensity_norm_type == 'z_score':
return z_score_intensity
else:
raise ValueError
def rescale_intensity(data,new_min=0,new_max=1,eps=1e-20):
def rescale_intensity(data, new_min=0, new_max=1, eps=1e-20):
'''
rescale pytorch batch data
:param data: N*1*H*W
:return: data with intensity ranging from 0 to 1
'''
bs, c , h, w = data.size(0),data.size(1),data.size(2), data.size(3)
bs, c, h, w = data.size(0), data.size(1), data.size(2), data.size(3)
data = data.view(bs*c, -1)
old_max = torch.max(data, dim=1, keepdim=True).values
old_min = torch.min(data, dim=1, keepdim=True).values
new_data = (data - old_min) / (old_max - old_min + eps)*(new_max-new_min)+new_min
new_data = (data - old_min) / (old_max - old_min + eps) * \
(new_max-new_min)+new_min
new_data = new_data.view(bs, c, h, w)
return new_data
def z_score_intensity(data):
'''
rescale pytorch batch data
:param data: N*c*H*W
:return: data with intensity with zero mean dnd 1 std.
'''
bs, c , h, w = data.size(0),data.size(1),data.size(2), data.size(3)
bs, c, h, w = data.size(0), data.size(1), data.size(2), data.size(3)
data = data.view(bs*c, -1)
mean = torch.mean(data, dim=1, keepdim=True)
data_dmean = data-mean.detach()
std = torch.std(data_dmean, dim=1, keepdim=True)
std = std.detach()
std[abs(std)==0]=1
std[abs(std) == 0] = 1
new_data = (data_dmean)/(std)
new_data = new_data.view(bs, c, h, w)
return new_data
def transform2tensor(cPader, img_slice,if_z_score=False):
def transform2tensor(cPader, img_slice, if_z_score=False):
'''
transform npy data to torch tensor
:param cPader:pad image to be divided by 16
......@@ -74,75 +77,73 @@ def transform2tensor(cPader, img_slice,if_z_score=False):
###
new_img_slice = cPader(img_slice)
## normalize data
new_img_slice = new_img_slice * 1.0 ##N*H*W
# normalize data
new_img_slice = new_img_slice * 1.0 # N*H*W
new_input_mean = np.mean(new_img_slice, axis=(1, 2), keepdims=True)
if if_z_score:
new_img_slice -= new_input_mean
new_std = np.std(new_img_slice, axis=(1, 2), keepdims=True)
if abs(new_std-0)<1e-3: new_std=1
if abs(new_std-0) < 1e-3:
new_std = 1
new_img_slice /= (new_std)
else:
##print ('0-1 rescaling')
min_val = np.min(new_img_slice,axis=(1, 2), keepdims=True)
max_val = np.max(new_img_slice,axis=(1, 2), keepdims=True)
new_img_slice =(new_img_slice-min_val)/(max_val-min_val+1e-10)
min_val = np.min(new_img_slice, axis=(1, 2), keepdims=True)
max_val = np.max(new_img_slice, axis=(1, 2), keepdims=True)
new_img_slice = (new_img_slice-min_val)/(max_val-min_val+1e-10)
new_img_slice = new_img_slice[:, np.newaxis, :, :]
##transform to tensor
# transform to tensor
new_image_tensor = torch.from_numpy(new_img_slice).float()
return new_image_tensor
def construct_input(segmentation,image=None,num_classes=None,temperature =1.0,apply_softmax=True, is_labelmap=False, smooth_label=False,shuffle=False,use_gpu=True):
def construct_input(segmentation, image=None, num_classes=None, temperature=1.0, apply_softmax=True, is_labelmap=False, smooth_label=False, use_gpu=True):
"""
concat image and segmentation toghether to form an input to an external assessor
Args:
image ([4d float tensor]): a of batch of images N(Ch)HW, Ch is the image channel
segmentation ([4d float tensor] or 3d label map): corresponding segmentation map NCHW or 3 one hotmap NHW
shuffle (bool, optional): if true, it will shuffle the input image and segmentation before concat. Defaults to False.
"""
assert (apply_softmax and is_labelmap) is False
if not is_labelmap:
batch_size, h,w = segmentation.size(0),segmentation.size(2),segmentation.size(3)
batch_size, h, w = segmentation.size(
0), segmentation.size(2), segmentation.size(3)
else:
batch_size, h,w = segmentation.size(0),segmentation.size(1),segmentation.size(2)
batch_size, h, w = segmentation.size(
0), segmentation.size(1), segmentation.size(2)
device = torch.device('cuda') if use_gpu else torch.device('cpu')
if not is_labelmap:
if apply_softmax:
assert len(segmentation.size())==4
assert len(segmentation.size()) == 4
segmentation = segmentation/temperature
softmax_predict = torch.softmax(segmentation,dim=1)
segmentation =softmax_predict
softmax_predict = torch.softmax(segmentation, dim=1)
segmentation = softmax_predict
else:
## make onehot maps
# make onehot maps
assert num_classes is not None, 'please specify num_classes'
flatten_y = segmentation.view(batch_size*h*w, 1)
y_onehot = torch.zeros(batch_size*h*w, num_classes,dtype = torch.float32,device=device)
y_onehot = torch.zeros(batch_size*h*w, num_classes,
dtype=torch.float32, device=device)
y_onehot.scatter_(1, flatten_y, 1)
y_onehot =y_onehot.view(batch_size,h,w, num_classes)
y_onehot = y_onehot.permute(0,3,1,2)
y_onehot.requires_grad=False
y_onehot = y_onehot.view(batch_size, h, w, num_classes)
y_onehot = y_onehot.permute(0, 3, 1, 2)
y_onehot.requires_grad = False
if smooth_label:
## add noise to labels
smooth_factor =torch.rand(1, device=device)*0.2
y_onehot[y_onehot==1] = 1-smooth_factor
y_onehot[y_onehot==0] = smooth_factor/(num_classes-1)
# add noise to labels
smooth_factor = torch.rand(1, device=device)*0.2
y_onehot[y_onehot == 1] = 1-smooth_factor
y_onehot[y_onehot == 0] = smooth_factor/(num_classes-1)
segmentation = y_onehot
if shuffle and image is not None:
## shuffle images in a batch, such that the segmentations do not match anymore.
image = shuffle_tensor(image)
if image is not None:
tuple = torch.cat([segmentation,image],dim=1)
tuple = torch.cat([segmentation, image], dim=1)
return tuple
else:
......
......@@ -91,6 +91,6 @@ def plot_training_results(model_dir, plot_history):
plt.clf()
if __name__ =='__main__':
params = Params('/vol/medic01/users/cc215/Dropbox/projects/DeformADA/configs/gat_loss.json')
print (params.dict)
\ No newline at end of file
# if __name__ =='__main__':
# params = Params('/vol/medic01/users/cc215/Dropbox/projects/DeformADA/configs/gat_loss.json')
# print (params.dict)
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Variable
def cross_entropy_2D(input, target, weight=None, size_average=True,mask=None):
"""[summary]
calc cross entropy loss computed on 2D images
Args:
input ([torch tensor]): [4d logit] in the format of NCHW
target ([torch tensor]): 3D labelmap or 4d logit (before softmax), in the format of NCHW
weight ([type], optional): weights for classes. Defaults to None.
size_average (bool, optional): take the average across the spatial domain. Defaults to True.
Raises:
NotImplementedError: [description]
Returns:
[type]: [description]
"""
n, c, h, w = input.size()
log_p = F.log_softmax(input, dim=1)
log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
if mask is None:
mask = torch.ones_like(log_p,device = log_p.device) ##
mask =mask.view(-1,c)
mask_region_size = torch.sum(mask[:,0])
if len(target.size())==3:
target = target.view(target.numel())
if not weight is None:
## sum(weight) =C, for numerical stability.
weight = torch.softmax(weight,dim=0)*c
loss_vector = F.nll_loss(log_p, target, weight=weight, reduce=False)
loss_vector = loss_vector*mask[:,0]
loss = torch.sum(loss_vector)
if size_average:
loss /= float(mask_region_size) ## /N*H'*W'
elif len(target.size())==4:
## ce loss=-qlog(p)
reference= F.softmax(target, dim=1) #M,C
reference = reference.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) #M,C
if weight is None:
plogq = torch.sum(reference *log_p*mask, dim=1)
plogq = torch.sum(plogq)
if size_average:
plogq/= float(mask_region_size)
else:
## sum(weight) =C
weight = torch.softmax(weight,dim=0)*c
plogq_class_wise =reference *log_p*mask
plogq_sum_class=0.
for i in range(c):
plogq_sum_class+=torch.sum(plogq_class_wise[:,i]*weight[i])
plogq = plogq_sum_class
if size_average:
plogq/= float(mask_region_size) # only average loss on the mask entries with value =1
loss=-1*plogq
else:
raise NotImplementedError
return loss
class SoftDiceLoss(nn.Module):
### Dice loss: code is from https://github.com/ozan-oktay/Attention-Gated-Networks/blob/master/models/layers/loss
# .py
def __init__(self, n_classes, use_gpu=True,squared_union=False):
super(SoftDiceLoss, self).__init__()
self.one_hot_encoder = One_Hot(n_classes, use_gpu).forward
self.n_classes = n_classes
self.squared_union =squared_union
def forward(self, input, target, weight=None):
smooth =0.01
batch_size = input.size(0)
input = F.softmax(input, dim=1).view(batch_size, self.n_classes, -1)
if len(target.size())==3:
target = self.one_hot_encoder(target).contiguous().view(batch_size, self.n_classes, -1)
elif len(target.size())==4 and target.size(1) ==input.size(1):
target = F.softmax(target, dim=1).view(batch_size, self.n_classes, -1)
target = target.view(batch_size, self.n_classes, -1)
else:
print ( 'the shapes for input and target do not match, input:{} target:{}'.format(str(input.size())),str(target.size()))
raise ValueError
inter = torch.sum(input * target, 2)
if self.squared_union:
##2pq/(|p|^2+|q|^2)
union = torch.sum(input**2, 2) + torch.sum(target**2, 2)
else:
##2pq/(|p|+|q|)
union = torch.sum(input, 2) + torch.sum(target, 2)
score = torch.sum((2.0 * inter+smooth) / (union+smooth))
score = 1.0 - score / (float(batch_size) * float(self.n_classes))
return score
def calc_segmentation_mse_consistency(input, target):
loss = calc_segmentation_consistency(output=input,reference=target,divergence_types=['mse'],divergence_weights=[1.0],class_weights=None,mask=None)
return loss
def calc_segmentation_kl_consistency(input, target):
loss = calc_segmentation_consistency(output=input,reference=target,divergence_types=['kl'],divergence_weights=[1.0],class_weights=None,mask=None)
return loss
import numpy as np
def calc_segmentation_consistency(output, reference,divergence_types=['kl','contour'],
divergence_weights=[1.0,0.5],
mask=None):
def calc_segmentation_consistency(output, reference, divergence_types=['kl', 'contour'],
divergence_weights=[1.0, 0.5], class_weights=None, scales=[0],
mask=None, is_gt=False):
"""
measuring the difference between two predictions (network logits before softmax)
Args:
output (torch tensor 4d): network predicts: NCHW (after perturbation)
reference (torch tensor 4d): network references: NCHW (before perturbation)
divergence_types (list, string): specify loss types. Defaults to ['kl','contour'].
divergence_weights (list, float): specify coefficients for each loss above. Defaults to [1.0,0.5].
scales (list of int): specify a list of downsampling rates so that losses will be calculated on different scales. Defaults to [0].
mask ([tensor], 0-1 onehotmap): [N*1*H*W]. No losses on the elements with mask=0. Defaults to None.
divergence_types (list, string): specifying loss types. Defaults to ['kl','contour'].
divergence_weights (list, float): specifying coefficients for each loss above. Defaults to [1.0,0.5].
class_weights (list of scalars): specifying class weights for loss computation
scales (list of int): specifying a list of downsampling rates so that losses will be calculated on different scales. Defaults to [0].
mask ([tensor], 0-1 onehotmap): [N*1*H*W]. disable loss computation on corresponding elements with mask=0. Defaults to None.
is_gt: bool, if true, will use one-hot encoded `reference' instead of probabilities maps after appying softmax to compute the consistency loss
Raises:
NotImplementedError: when loss name is not in ['kl','mse','contour']
Returns:
loss (tensor float):
"""
if class_weights is not None:
raise NotImplemented
dist = 0.
num_classes = output.size(1)
reference = reference.detach()
num_classes = reference.size(1)
if mask is None:
## apply masks so that only gradients on certain regions will be backpropagated.
# apply masks so that only gradients on non-zero regions will be backpropagated.
mask = torch.ones_like(output).float().to(reference.device)
for scale in scales:
if scale > 0:
output_reference = torch.nn.AvgPool2d(2 ** scale)(reference)
output_new = torch.nn.AvgPool2d(2 ** scale)(output)
else:
output_reference = reference
output_new = output
for divergence_type, d_weight in zip(divergence_types, divergence_weights):
loss = 0.
if divergence_type=='kl':
if divergence_type == 'kl':
'''
standard kl loss
'''
loss = kl_divergence(pred=output_new,reference=output_reference.detach(),mask=mask)
elif divergence_type =='ce':
loss = cross_entropy_2D(input=output_new,target=output_reference.detach(),mask=mask)
elif divergence_type =='mse':
loss = kl_divergence(
pred=output_new, reference=output_reference, mask=mask, is_gt=is_gt)
elif divergence_type == 'mse':
n, h, w = output_new.size(
0), output_new.size(2), output_new.size(3)
if not is_gt:
target_pred = torch.softmax(output_reference, dim=1)
else:
target_pred = output_reference
input_pred = torch.softmax(output_new, dim=1)
loss = torch.nn.MSELoss(reduction='sum')(target = target_pred*mask, input = input_pred*mask)
loss = loss/torch.sum(mask[:,0])
elif divergence_type == 'contour': ## contour-based loss
loss = torch.nn.MSELoss(reduction='sum')(
target=target_pred*mask, input=input_pred*mask)
loss = loss/(n*h*w)
elif divergence_type == 'contour': # contour-based loss
if not is_gt:
target_pred = torch.softmax(output_reference, dim=1)
else:
target_pred = output_reference
input_pred = torch.softmax(output_new, dim=1)
cnt = 0
for i in range(1,num_classes):
cnt +=1
loss += contour_loss(input=input_pred[:,[i],], target=(target_pred[:,[i]]).detach(), ignore_background=False,mask=mask,
for i in range(1, num_classes):
cnt += 1
loss += contour_loss(input=input_pred[:, [i], ], target=(target_pred[:, [i]]), ignore_background=False, mask=mask,
one_hot_target=False)
# if cnt>0:loss/=cnt
if cnt > 0:
loss /= cnt
else:
raise NotImplementedError
# print ('{}:{}'.format(divergence_type,loss.item()))
dist += (d_weight * loss)
return dist
dist += 2 ** scale*(d_weight * loss)
return dist / (1.0 * len(scales))
def calc_segmentation_mse_consistency(input, target):
loss = calc_segmentation_consistency(output=input, reference=target, divergence_types=[
'mse'], divergence_weights=[1.0], class_weights=None, mask=None)
return loss
def calc_segmentation_kl_consistency(input, target):
loss = calc_segmentation_consistency(output=input, reference=target, divergence_types=[
'kl'], divergence_weights=[1.0], class_weights=None, mask=None)
return loss
def contour_loss(input, target, size_average=True, use_gpu=True,ignore_background=True,one_hot_target=True,mask=None):
def contour_loss(input, target, use_gpu=True, ignore_background=True, one_hot_target=True, mask=None):
'''
calc the contour loss across object boundaries (WITHOUT background class)
:param input: NDArray. N*num_classes*H*W : pixelwise probs. for each class e.g. the softmax output from a neural network
:param target: ground truth labels (NHW) or one-hot ground truth maps N*C*H*W
:param size_average: batch mean
:param use_gpu:boolean. default: True, use GPU.
:param ignore_background:boolean, ignore the background class. default: True
:param one_hot_target: boolean. if true, will first convert the target from NHW to NCHW. Default: True.
:return:
'''
n,num_classes,h,w = input.size(0),input.size(1),input.size(2),input.size(3)
n, num_classes, h, w = input.size(0), input.size(
1), input.size(2), input.size(3)
if one_hot_target:
onehot_mapper = One_Hot(depth=num_classes, use_gpu=use_gpu)
target = target.long()
onehot_target = onehot_mapper(target).contiguous().view(input.size(0), num_classes, input.size(2), input.size(3))
onehot_target = onehot_mapper(target).contiguous().view(
input.size(0), num_classes, input.size(2), input.size(3))
else:
onehot_target=target
assert onehot_target.size() == input.size(), 'pred size: {} must match target size: {}'.format(str(input.size()),str(onehot_target.size()))
onehot_target = target
assert onehot_target.size() == input.size(), 'pred size: {} must match target size: {}'.format(
str(input.size()), str(onehot_target.size()))
if mask is None:
## apply masks so that only gradients on certain regions will be backpropagated.
# apply masks so that only gradients on certain regions will be backpropagated.
mask = torch.ones_like(input).long().to(input.device)
mask.requires_grad = False
else:
pass
# print ('mask applied')
if ignore_background:
object_classes = num_classes - 1
target_object_maps = onehot_target[:, 1:].float()
input = input[:, 1:]
else:
target_object_maps=onehot_target
target_object_maps = onehot_target
object_classes = num_classes
x_filter = np.array([[1, 0, -1],
......@@ -235,41 +158,44 @@ def contour_loss(input, target, size_average=True, use_gpu=True,ignore_backgroun
for param in conv_x.parameters():
param.requires_grad = False
g_x_pred = conv_x(input)*mask[:,:object_classes]
g_y_pred = conv_y(input)*mask[:,:object_classes]
g_y_truth = conv_y(target_object_maps)*mask[:,:object_classes]
g_x_truth = conv_x(target_object_maps)*mask[:,:object_classes]
## mse loss
loss =torch.nn.MSELoss(reduction='sum')(input=g_x_pred,target=g_x_truth) +torch.nn.MSELoss(reduction='sum')(input=g_y_pred,target=g_y_truth)
loss/= torch.sum(mask[:,0,:,:])
g_x_pred = conv_x(input)*mask[:, :object_classes]
g_y_pred = conv_y(input)*mask[:, :object_classes]
g_y_truth = conv_y(target_object_maps)*mask[:, :object_classes]
g_x_truth = conv_x(target_object_maps)*mask[:, :object_classes]
# mse loss
loss = 0.5*(torch.nn.MSELoss(reduction='mean')(input=g_x_pred, target=g_x_truth) +
torch.nn.MSELoss(reduction='mean')(input=g_y_pred, target=g_y_truth))
return loss
def kl_divergence(reference, pred,mask=None):
def kl_divergence(reference, pred, mask=None, is_gt=False):
'''
calc the kl div distance between two outputs p and q from a network/model: p(y1|x1).p(y2|x2).
:param reference p: directly output from network using origin input without softmax
:param output q: approximate output: directly output from network using perturbed input without softmax
:param is_gt: is onehot maps
:return: kl divergence: DKL(P||Q) = mean(\sum_1 \to C (p^c log (p^c|q^c)))
'''
p=reference
q=pred
p_logit = F.softmax(p, dim=1)
if mask is None:
mask = torch.ones_like(p_logit, device =p_logit.device)
mask.requires_grad=False
cls_plogp = mask*p_logit * F.log_softmax(p, dim=1)
cls_plogq = mask*p_logit * F.log_softmax(q, dim=1)
plogp = torch.sum(cls_plogp,dim=1,keepdim=True)
plogq = torch.sum(cls_plogq,dim=1,keepdim=True)
q = pred
kl_loss = torch.sum(plogp - plogq)
kl_loss/=torch.sum(mask[:,0,:,:])
if mask is None:
mask = torch.ones_like(q, device=q.device)
mask.requires_grad = False
if not is_gt:
p = F.softmax(reference, dim=1)
log_p = F.log_softmax(reference, dim=1)
else:
p = torch.where(reference == 0, 1e-8, 1-1e-8)
log_p = torch.log(p) # avoid NAN when log(0)
cls_plogp = mask*(p * log_p)
cls_plogq = mask*(p * F.log_softmax(q, dim=1))
plogp = torch.sum(cls_plogp, dim=1, keepdim=True)
plogq = torch.sum(cls_plogq, dim=1, keepdim=True)
kl_loss = torch.mean(plogp - plogq)
return kl_loss
class One_Hot(nn.Module):
def __init__(self, depth, use_gpu=True):
super(One_Hot, self).__init__()
......@@ -289,3 +215,59 @@ class One_Hot(nn.Module):
def __repr__(self):
return self.__class__.__name__ + "({})".format(self.depth)
def cross_entropy_2D(input, target, weight=None, size_average=True):
"""[summary]
calc cross entropy loss computed on 2D images
Args:
input ([torch tensor]): [4d logit] in the format of NCHW
target ([torch tensor]): 3D labelmap or 4d logit (before softmax), in the format of NCHW
weight ([type], optional): weights for classes. Defaults to None.
size_average (bool, optional): take the average across the spatial domain. Defaults to True.
Raises:
NotImplementedError: [description]
Returns:
[type]: [description]
"""
n, c, h, w = input.size()
log_p = F.log_softmax(input, dim=1)
log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
if len(target.size()) == 3:
target = target.view(target.numel())
if not weight is None:
# sum(weight) =C, for numerical stability.
weight = weight/weight.sum()*c
loss_vector = F.nll_loss(
log_p, target, weight=weight, reduction="none")
loss = torch.sum(loss_vector)
if size_average:
loss /= (n*h*w)
elif len(target.size()) == 4:
# ce loss=-qlog(p)
reference = target
reference = reference.transpose(1, 2).transpose(
2, 3).contiguous().view(-1, c) # M,C
if weight is None:
plogq = torch.sum(reference * log_p, dim=1)
plogq = torch.sum(plogq)
if size_average:
plogq /= (n*h*w)