Commit 4cbe93f3 authored by cc215's avatar cc215 💬
Browse files

add super model

parent 91bad048
......@@ -8,4 +8,6 @@ __pycache__/
# C extensions
*.so
test_results/
result/
\ No newline at end of file
result/
log/
runs/
\ No newline at end of file
{
"python.pythonPath": "/vol/biomedic2/cc215/anaconda2/envs/maskrcnn_benchmark/bin/python"
}
\ No newline at end of file
......@@ -95,3 +95,28 @@ Results will be saved under `test_results` by default
- you can change data augmentation strategy by changing the name of "data_aug_policy" in the config file.
- For details about the data augmentation strategy, please refer to :'dataset_loader/mytransform.py'
## Model update (2021.3.9):
- A model trained on UKBB data (SAX slices) with adversarial data augmentation is available.
- This model is expected with improved robustness (especially for images with bias field)
- To deploy the model for segmentation, please run the following command to test first:
- run `source ./demo_scripts/predict_test.sh`
- this script will perform the following steps:
- 1. load images from disk 'test_data/' and load model from `./checkpoints/UNet_LVSA_Adv_Compose.pth`
- 2. perform image resampling to have a uniform pixel spacing 1.25 x 1.25 mm
- 3. central crop images to 192x192
- 4. intensity normalizaton to [0,1]
- 5. predict the segmentation map
- 6. recover the image size and resample the prediction back to its original image space.
- 7. save the predicted segmentation maps for `test_data/patient_id/LVSA/LVSA_img_{}.nii.gz` to `test_results/LVSA/patient_id/Adv_Compose_pred_ED.nii.gz`
- we also provide a script to process a single image each time.
- to use, please run the following command to test first:
run `source ./demo_scripts/predict_single.sh`
- then you can modify the command to process your own data (XXX.nii.gz) and a segmentation mask will be saved at 'YYY.nii.gz'
- `python predict_single_LVSA.py -m './checkpoints/UNet_LVSA_Adv_Compose.pth' -i 'XXX.nii.gz' -o 'YYY.nii.gz' -c 192 -g 0 -b 8`
* other commands:
- c: crop image to save memory, you can change it to any size as long as it can be divided by 16, and your segmented objects is still within the cropped image
- g: gpu id
- b: batch size (>=1)
......@@ -2,7 +2,70 @@
# Enter feature description here
# Enter scenario name here
# Enter steps here
import torch
import numpy as np
def switch_kv_in_dict(mydict):
switched_dict = {y: x for x, y in mydict.items()}
return switched_dict
\ No newline at end of file
return switched_dict
def unit_normalize(d):
d_abs_max = torch.max(
torch.abs(d.view(d.size(0), -1)), 1, keepdim=True)[0].view(
d.size(0), 1, 1, 1)
# print(d_abs_max.size())
d /= (1e-20 + d_abs_max) ## d' =d/d_max
d /= torch.sqrt(1e-6 + torch.sum(
torch.pow(d, 2.0), tuple(range(1, len(d.size()))), keepdim=True)) ##d'/sqrt(d'^2)
# print(torch.norm(d.view(d.size(0), -1), dim=1))
return d
def rescale_intensity(data,new_min=0,new_max=1):
'''
rescale pytorch batch data
:param data: N*1*H*W
:return: data with intensity ranging from 0 to 1
'''
bs, c , h, w = data.size(0),data.size(1),data.size(2), data.size(3)
data = data.reshape(bs, -1)
old_max = torch.max(data, dim=1, keepdim=True).values
old_min = torch.min(data, dim=1, keepdim=True).values
new_data = (data - old_min) / (old_max - old_min + 1e-6)*(new_max-new_min)+new_min
new_data = new_data.reshape(bs, c, h, w)
return new_data
def transform2tensor(cPader, img_slice,if_z_score=False):
'''
transform npy data to torch tensor
:param cPader:pad image to be divided by 16
:param img_slices: npy N*H*W
:param label_slices:npy N*H*W
:return: N*1*H*W
'''
###
new_img_slice = cPader(img_slice)
## normalize data
new_img_slice = new_img_slice * 1.0 ##N*H*W
new_input_mean = np.mean(new_img_slice, axis=(1, 2), keepdims=True)
if if_z_score:
new_img_slice -= new_input_mean
new_std = np.std(new_img_slice, axis=(1, 2), keepdims=True)
if abs(new_std-0)<1e-3: new_std=1
new_img_slice /= (new_std)
else:
##print ('0-1 rescaling')
min_val = np.min(new_img_slice,axis=(1, 2), keepdims=True)
max_val = np.max(new_img_slice,axis=(1, 2), keepdims=True)
new_img_slice =(new_img_slice-min_val)/(max_val-min_val+1e-10)
new_img_slice = new_img_slice[:, np.newaxis, :, :]
##transform to tensor
new_image_tensor = torch.from_numpy(new_img_slice).float()
return new_image_tensor
import pickle
import os
def save_dict(mydict, file_path):
f = open(file_path,"wb")
pickle.dump(mydict,f)
def load_dict(file_path):
with open(file_path,"rb") as f:
data = pickle.load(f)
return data
def check_dir(dir_path, create=False):
'''
check the existence of a dir, when create is True, will create the dir if it does not exist.
dir_path: str.
create: bool
return:
exists (1) or not (-1)
'''
if os.path.exists(dir_path):
return 1
else:
if create:
os.makedirs(dir_path)
return -1
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Variable
def cross_entropy_2D(input, target, weight=None, size_average=True):
......@@ -16,4 +14,198 @@ def cross_entropy_2D(input, target, weight=None, size_average=True):
loss = F.nll_loss(log_p, target, weight=weight, size_average=False)
if size_average:
loss /= float(target.numel()+1e-10)
return loss
\ No newline at end of file
return loss
def calc_segmentation_mse_consistency(input, target):
loss = calc_segmentation_consistency(output=input,reference=target,divergence_types=['mse'],divergence_weights=[1.0],class_weights=None,mask=None)
return loss
def calc_segmentation_kl_consistency(input, target):
loss = calc_segmentation_consistency(output=input,reference=target,divergence_types=['kl'],divergence_weights=[1.0],class_weights=None,mask=None)
return loss
def calc_segmentation_consistency(output, reference,divergence_types=['kl','contour'],
divergence_weights=[1.0,0.5],scales=[0],
mask=None):
"""
measuring the difference between two predictions (network logits before softmax)
Args:
output (torch tensor 4d): network predicts: NCHW (after perturbation)
reference (torch tensor 4d): network references: NCHW (before perturbation)
divergence_types (list, string): specify loss types. Defaults to ['kl','contour'].
divergence_weights (list, float): specify coefficients for each loss above. Defaults to [1.0,0.5].
scales (list of int): specify a list of downsampling rates so that losses will be calculated on different scales. Defaults to [0].
mask ([tensor], 0-1 onehotmap): [N*1*H*W]. No losses on the elements with mask=0. Defaults to None.
Raises:
NotImplementedError: when loss name is not in ['kl','mse','contour']
Returns:
loss (tensor float):
"""
dist = 0.
num_classes = reference.size(1)
reference = reference.detach()
if mask is None:
## apply masks so that only gradients on certain regions will be backpropagated.
mask = torch.ones_like(output).float().to(reference.device)
for scale in scales:
if scale>0:
output_reference = torch.nn.AvgPool2d(2 ** scale)(reference)
output_new = torch.nn.AvgPool2d(2 ** scale)(output)
else:
output_reference = reference
output_new = output
for divergence_type, d_weight in zip(divergence_types, divergence_weights):
loss = 0.
if divergence_type=='kl':
'''
standard kl loss
'''
loss = kl_divergence(pred=output_new,reference=output_reference.detach(),mask=mask)
elif divergence_type =='mse':
target_pred = torch.softmax(output_reference, dim=1)
input_pred = torch.softmax(output_new, dim=1)
loss = torch.nn.MSELoss(reduction='mean')(target = target_pred*mask, input = input_pred*mask)
elif divergence_type == 'contour': ## contour-based loss
target_pred = torch.softmax(output_reference, dim=1)
input_pred = torch.softmax(output_new, dim=1)
for i in range(1,num_classes):
loss += contour_loss(input=input_pred[:,[i],], target=(target_pred[:,[i]]).detach(), ignore_background=False,mask=mask,
one_hot_target=False)
loss = loss/(num_classes-1)
else:
raise NotImplementedError
# print ('{}:{}'.format(divergence_type,loss.item()))
dist += 2 ** scale*(d_weight * loss)
return dist / (1.0 * len(scales))
def contour_loss(input, target, size_average=True, use_gpu=True,ignore_background=True,one_hot_target=True,mask=None):
'''
calc the contour loss across object boundaries (WITHOUT background class)
:param input: NDArray. N*num_classes*H*W : pixelwise probs. for each class e.g. the softmax output from a neural network
:param target: ground truth labels (NHW) or one-hot ground truth maps N*C*H*W
:param size_average: batch mean
:param use_gpu:boolean. default: True, use GPU.
:param ignore_background:boolean, ignore the background class. default: True
:param one_hot_target: boolean. if true, will first convert the target from NHW to NCHW. Default: True.
:return:
'''
num_classes = input.size(1)
if one_hot_target:
onehot_mapper = One_Hot(depth=num_classes, use_gpu=use_gpu)
target = target.long()
onehot_target = onehot_mapper(target).contiguous().view(input.size(0), num_classes, input.size(2), input.size(3))
else:
onehot_target=target
assert onehot_target.size() == input.size(), 'pred size: {} must match target size: {}'.format(str(input.size()),str(onehot_target.size()))
if mask is None:
## apply masks so that only gradients on certain regions will be backpropagated.
mask = torch.ones_like(input).long().to(input.device)
mask.requires_grad = False
else:
pass
# print ('mask applied')
if ignore_background:
object_classes = num_classes - 1
target_object_maps = onehot_target[:, 1:].float()
input = input[:, 1:]
else:
target_object_maps=onehot_target
object_classes = num_classes
x_filter = np.array([[1, 0, -1],
[2, 0, -2],
[1, 0, -1]]).reshape(1, 1, 3, 3)
x_filter = np.repeat(x_filter, axis=1, repeats=object_classes)
x_filter = np.repeat(x_filter, axis=0, repeats=object_classes)
conv_x = nn.Conv2d(in_channels=object_classes, out_channels=object_classes, kernel_size=3, stride=1, padding=1,
dilation=1, bias=False)
conv_x.weight = nn.Parameter(torch.from_numpy(x_filter).float())
y_filter = np.array([[1, 2, 1],
[0, 0, 0],
[-1, -2, -1]]).reshape(1, 1, 3, 3)
y_filter = np.repeat(y_filter, axis=1, repeats=object_classes)
y_filter = np.repeat(y_filter, axis=0, repeats=object_classes)
conv_y = nn.Conv2d(in_channels=object_classes, out_channels=object_classes, kernel_size=3, stride=1, padding=1,
bias=False)
conv_y.weight = nn.Parameter(torch.from_numpy(y_filter).float())
if use_gpu:
conv_y = conv_y.cuda()
conv_x = conv_x.cuda()
for param in conv_y.parameters():
param.requires_grad = False
for param in conv_x.parameters():
param.requires_grad = False
g_x_pred = conv_x(input)*mask[:,:object_classes]
g_y_pred = conv_y(input)*mask[:,:object_classes]
g_y_truth = conv_y(target_object_maps)*mask[:,:object_classes]
g_x_truth = conv_x(target_object_maps)*mask[:,:object_classes]
## mse loss
loss =torch.nn.MSELoss(reduction='mean')(input=g_x_pred,target=g_x_truth) +torch.nn.MSELoss(reduction='mean')(input=g_y_pred,target=g_y_truth)
return loss
def kl_divergence(reference, pred,mask=None):
'''
calc the kl div distance between two outputs p and q from a network/model: p(y1|x1).p(y2|x2).
:param reference p: directly output from network using origin input without softmax
:param output q: approximate output: directly output from network using perturbed input without softmax
:return: kl divergence: DKL(P||Q) = mean(\sum_1 \to C (p^c log (p^c|q^c)))
'''
p=reference
q=pred
p_logit = F.softmax(p, dim=1)
if mask is None:
mask = torch.ones_like(p_logit, device =p_logit.device)
mask.requires_grad=False
cls_plogp = mask*p_logit * F.log_softmax(p, dim=1)
cls_plogq = mask*p_logit * F.log_softmax(q, dim=1)
plogp = torch.sum(cls_plogp,dim=1,keepdim=True)
plogq = torch.sum(cls_plogq,dim=1,keepdim=True)
kl_loss = torch.mean(plogp - plogq)
return kl_loss
class One_Hot(nn.Module):
def __init__(self, depth, use_gpu=True):
super(One_Hot, self).__init__()
self.depth = depth
if use_gpu:
self.ones = torch.sparse.torch.eye(depth).cuda()
else:
self.ones = torch.sparse.torch.eye(depth)
def forward(self, X_in):
n_dim = X_in.dim()
output_size = X_in.size() + torch.Size([self.depth])
num_element = X_in.numel()
X_in = X_in.data.long().view(num_element)
out = Variable(self.ones.index_select(0, X_in)).view(output_size)
return out.permute(0, -1, *range(1, n_dim)).squeeze(dim=2).float()
def __repr__(self):
return self.__class__.__name__ + "({})".format(self.depth)
{ "name": "Bias",
"data": {
"dataset_name":"UKBB" ,
"readable_frames": ["ED", "ES"],
"train_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/train",
"validate_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/validation",
"image_format_name": "sa_{frame}.nii.gz",
"label_format_name": "label_sa_{frame}.nii.gz",
"data_aug_policy" :"UKBB_advanced",
"if_resample": true,
"new_spacing": [1.25, 1.25, 10],
"keep_z_spacing": true,
"image_size":[224,224,1],
"label_size":[224,224],
"pad_size": [192,192,1],
"crop_size" :[192,192,1],
"num_classes": 4,
"use_cache": false,
"myocardium_only": false,
"ignore_black_slices": true
},
"segmentation_model": {
"network_type": "UNet_64",
"num_classes": 4,
"resume_path":"./checkpoints/Unet_LVSA_best.pkl",
"lr": 0.00001,
"n_epochs": 50,
"max_iteration": 500000,
"batch_size": 20,
"use_gpu": true,
"decoder_dropout": true
},
"adversarial_augmentation":
{
"transformation_type":"bias"
}
,
"output":
{
"save_epoch_every_num_epochs":1,
"save_dir":"./result/Bias"
}
}
\ No newline at end of file
{ "name": "Bias",
"data": {
"dataset_name":"UKBB" ,
"readable_frames": ["ED", "ES"],
"train_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/train",
"validate_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/validation",
"image_format_name": "sa_{frame}.nii.gz",
"label_format_name": "label_sa_{frame}.nii.gz",
"data_aug_policy" :"UKBB_advancedv2",
"if_resample": true,
"new_spacing": [1.25, 1.25, 10],
"keep_z_spacing": true,
"image_size":[224,224,1],
"label_size":[224,224],
"pad_size": [192,192,1],
"crop_size" :[192,192,1],
"num_classes": 4,
"use_cache": false,
"myocardium_only": false,
"ignore_black_slices": true
},
"segmentation_model": {
"network_type": "UNet_64",
"num_classes": 4,
"resume_path":"./checkpoints/Unet_LVSA_best.pkl",
"lr": 0.00001,
"n_epochs": 50,
"max_iteration": 500000,
"batch_size": 20,
"use_gpu": true,
"decoder_dropout": 0.2
},
"adversarial_augmentation":
{
"transformation_type":"bias"
}
,
"output":
{
"save_epoch_every_num_epochs":1,
"save_dir":"./result/Bias"
}
}
\ No newline at end of file
{ "name": "Bias",
"data": {
"dataset_name":"UKBB" ,
"readable_frames": ["ED", "ES"],
"train_dir": "demo_dataset/train",
"validate_dir": "demo_dataset/validate",
"image_format_name": "sa_{frame}.nii.gz",
"label_format_name": "label_sa_{frame}.nii.gz",
"data_aug_policy" :"UKBB_affine_elastic",
"if_resample": true,
"new_spacing": [1.25, 1.25, 10],
"keep_z_spacing": true,
"image_size":[224,224,1],
"label_size":[224,224],
"pad_size": [192,192,1],
"crop_size" :[192,192,1],
"num_classes": 4,
"use_cache": true,
"myocardium_only": false,
"ignore_black_slices": true
},
"segmentation_model": {
"network_type": "UNet_64",
"num_classes": 4,
"resume_path":"./checkpoints/Unet_LVSA_best.pkl",
"lr": 0.00001,
"n_epochs": 1000,
"max_iteration": 10000,
"batch_size": 20,
"use_gpu": true,
"decoder_dropout": true
},
"adversarial_augmentation":
{
"transformation_type":"bias"
}
,
"output":
{
"save_epoch_every_num_epochs":50,
"save_dir":"./result/Bias"
}
}
\ No newline at end of file
{ "name": "Composite",
"data": {
"dataset_name":"UKBB" ,
"readable_frames": ["ED", "ES"],
"train_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/train",
"validate_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/validation",
"image_format_name": "sa_{frame}.nii.gz",
"label_format_name": "label_sa_{frame}.nii.gz",
"data_aug_policy" :"UKBB_advanced",
"if_resample": true,
"new_spacing": [1.25, 1.25, 10],
"keep_z_spacing": true,
"image_size":[224,224,1],
"label_size":[224,224],
"pad_size": [192,192,1],
"crop_size" :[192,192,1],
"num_classes": 4,
"use_cache": false,
"myocardium_only": false,
"ignore_black_slices": true
},
"segmentation_model": {
"network_type": "UNet_64",
"num_classes": 4,
"resume_path":"./checkpoints/Unet_LVSA_best.pkl",
"lr": 0.00001,
"n_epochs": 50,
"max_iteration": 500000,
"batch_size": 20,
"use_gpu": true,
"decoder_dropout": true
},
"adversarial_augmentation":
{
"transformation_type":"composite"
}
,
"output":
{
"save_epoch_every_num_epochs":1,
"save_dir":"./result/Composite"
}
}
\ No newline at end of file
{ "name": "Composite",
"data": {
"dataset_name":"UKBB" ,
"readable_frames": ["ED", "ES"],
"train_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/train",
"validate_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/validation",
"image_format_name": "sa_{frame}.nii.gz",
"label_format_name": "label_sa_{frame}.nii.gz",
"data_aug_policy" :"UKBB_advancedv2",
"if_resample": true,
"new_spacing": [1.25, 1.25, 10],
"keep_z_spacing": true,
"image_size":[224,224,1],
"label_size":[224,224],
"pad_size": [192,192,1],
"crop_size" :[192,192,1],
"num_classes": 4,
"use_cache": false,
"myocardium_only": false,
"ignore_black_slices": true
},
"segmentation_model": {
"network_type": "UNet_64",
"num_classes": 4,
"resume_path":"./checkpoints/Unet_LVSA_best.pkl",
"lr": 0.00001,
"n_epochs": 50,
"max_iteration": 500000,
"batch_size": 20,
"use_gpu": true,
"decoder_dropout": 0.2
},
"adversarial_augmentation":
{
"transformation_type":"composite"
}
,
"output":
{
"save_epoch_every_num_epochs":1,
"save_dir":"./result/Composite"
}
}
\ No newline at end of file