Commit d4d9021c authored by cc215's avatar cc215 💬
Browse files

update model

parent b58b4c62
......@@ -122,11 +122,19 @@ e.g.
## Model update (2021.3.9):
- A model trained on UKBB data (SAX slices) with adversarial data augmentation (adversarial noise, adversarial bias field, adversarial morphological deformation, and adversarial affine transformation) is available.
- This model is expected with improved robustness on cross-domain data (especially for images affected by bias field)
- This model is expected with improved robustness especially for right ventricle segmentation on cross-domain data. See below test results on intra domain test set (UKBB) and *unseen* cross domain sets ACDC and M\&Ms.
| Testing config <br> (batch_size=1, roi =256) | UKBB test (600) | | | ACDC (100) | | | M\&Ms (150) | | |
|:-: |:-: |:-: |:-: |:-: |:-: |:-: |:-: |:-: |:-: |
| model: UNet_64 | LV | MYO | RV | LV | MYO | RV | LV | MYO | RV |
| Unet_LVSA_trained_from_UKBB.pkl | 0.9383 | 0.8780 | 0.8979 | 0.8940 | 0.8034 | 0.8237 | 0.8862 | 0.7889 | 0.8168 |
| UNet_LVSA_Adv_Compose_(epochs=20).pth | 0.9360 | 0.8726 | 0.8966 | 0.8984 | 0.7973 | 0.8440 | 0.8873 | 0.7859 | 0.8343 |
- To deploy the model for segmentation, please run the following command to test first:
- run `source ./demo_scripts/predict_test.sh`
- this script will perform the following steps:
- 1. load images from disk 'test_data/' and load model from `./checkpoints/UNet_LVSA_Adv_Compose.pth`
- 1. load images from disk 'test_data/' and load model from `.checkpoints/UNet_LVSA_Adv_Compose_(epochs=20).pth`
- 2. perform image resampling to have a uniform pixel spacing 1.25 x 1.25 mm
- 3. central crop images to 192x192
- 4. intensity normalizaton to [0,1]
......@@ -138,13 +146,14 @@ e.g.
- before use, please run the following command to test first:
run `source ./demo_scripts/predict_single.sh`
- then you can modify the command to process your own data (XXX.nii.gz) and a segmentation mask will be saved at 'YYY.nii.gz',
- run `python predict_single_LVSA.py -m './checkpoints/UNet_LVSA_Adv_Compose.pth' -i 'XXX.nii.gz' -o 'YYY.nii.gz' -c 192 -g 0 -b 8`
- run `python predict_single_LVSA.py -m '.checkpoints/UNet_LVSA_Adv_Compose_(epochs=20).pth' -i 'XXX.nii.gz' -o 'YYY.nii.gz' -c 256 -g 0 -b 1`
* notes:
- m: model path
- i: input image path
- o: output path for prediction
- c: crop image to save memory, you can change it to any size as long as it can be divided by 16, and the targeted structures are still within the image region after cropping
- c: The size for cropping image to save memory, you can change it to any size as long as it can be divided by 16, and the targeted structures are still within the image region after cropping. When set to -1, it will crop each image to its largest rectangle, where height and width are 16x. Default: 256.
- g: int, gpu id
- b: int, batch size (>=1)
- z: boolean, If it is set to true, min-max intensity normalization will be used to prepocess images which maps intensity to 0-1 range. By default, this is deactivated. We found std normalization yields better cross-domain segmentation performance compared to min-max rescaling.
- b: int, batch size (>=1). For optimal performance, we found that performing segmentation with instance normalization (b=1) is more robust compared to the one with batch normalization (b>1)> However, it will slow down the inference speed due to the slice-by-slice prediction scheme.
......@@ -21,21 +21,45 @@ def unit_normalize(d):
# print(torch.norm(d.view(d.size(0), -1), dim=1))
return d
def intensity_norm_fn(intensity_norm_type):
if intensity_norm_type =='min_max':
return rescale_intensity
elif intensity_norm_type =='z_score':
return z_score_intensity
else:
raise ValueError
def rescale_intensity(data,new_min=0,new_max=1):
def rescale_intensity(data,new_min=0,new_max=1,eps=1e-20):
'''
rescale pytorch batch data
:param data: N*1*H*W
:return: data with intensity ranging from 0 to 1
'''
bs, c , h, w = data.size(0),data.size(1),data.size(2), data.size(3)
data = data.view(bs, -1)
data = data.view(bs*c, -1)
old_max = torch.max(data, dim=1, keepdim=True).values
old_min = torch.min(data, dim=1, keepdim=True).values
new_data = (data - old_min) / (old_max - old_min + 1e-6)*(new_max-new_min)+new_min
new_data = (data - old_min) / (old_max - old_min + eps)*(new_max-new_min)+new_min
new_data = new_data.view(bs, c, h, w)
return new_data
def z_score_intensity(data):
'''
rescale pytorch batch data
:param data: N*c*H*W
:return: data with intensity with zero mean dnd 1 std.
'''
bs, c , h, w = data.size(0),data.size(1),data.size(2), data.size(3)
data = data.view(bs*c, -1)
mean = torch.mean(data, dim=1, keepdim=True)
data_dmean = data-mean.detach()
std = torch.std(data_dmean, dim=1, keepdim=True)
std = std.detach()
std[abs(std)==0]=1
new_data = (data_dmean)/(std)
new_data = new_data.view(bs, c, h, w)
return new_data
def transform2tensor(cPader, img_slice,if_z_score=False):
......
......@@ -7,17 +7,56 @@ import torch.nn as nn
from torch.autograd import Variable
def cross_entropy_2D(input, target, weight=None, size_average=True):
"""[summary]
calc cross entropy loss computed on 2D images
Args:
input ([torch tensor]): [4d logit] in the format of NCHW
target ([torch tensor]): 3D labelmap or 4d logit (before softmax), in the format of NCHW
weight ([type], optional): weights for classes. Defaults to None.
size_average (bool, optional): take the average across the spatial domain. Defaults to True.
Raises:
NotImplementedError: [description]
Returns:
[type]: [description]
"""
n, c, h, w = input.size()
log_p = F.log_softmax(input, dim=1)
log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
target = target.view(target.numel())
loss = F.nll_loss(log_p, target, weight=weight, size_average=False)
if size_average:
loss /= float(target.numel())
if len(target.size())==3:
target = target.view(target.numel())
if not weight is None:
## sum(weight) =C, for numerical stability.
weight = torch.softmax(weight,dim=0)*c
loss = F.nll_loss(log_p, target, weight=weight, reduction='sum')
if size_average:
loss /= float(target.numel())
elif len(target.size())==4:
## ce loss=-qlog(p)
reference= F.softmax(target, dim=1) #M,C
reference = reference.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) #M,C
if weight is None:
plogq = torch.sum(reference *log_p, dim=1)
plogq = torch.sum(plogq)
if size_average:
plogq/= float(n*h*w)
else:
## sum(weight) =C
weight = torch.softmax(weight,dim=0)*c
plogq_class_wise =reference *log_p
plogq_sum_class=0.
for i in range(c):
plogq_sum_class+=torch.sum(plogq_class_wise[:,i]*weight[i])
plogq = plogq_sum_class
if size_average:
plogq/= float(n*h*w)
loss=-1*plogq
else:
raise NotImplementedError
return loss
def calc_segmentation_mse_consistency(input, target):
loss = calc_segmentation_consistency(output=input,reference=target,divergence_types=['mse'],divergence_weights=[1.0],class_weights=None,mask=None)
return loss
......@@ -65,6 +104,8 @@ def calc_segmentation_consistency(output, reference,divergence_types=['kl','cont
standard kl loss
'''
loss = kl_divergence(pred=output_new,reference=output_reference.detach(),mask=mask)
elif divergence_type =='ce':
loss = cross_entropy_2D(input=output_new*mask,target=mask*output_reference.detach())
elif divergence_type =='mse':
target_pred = torch.softmax(output_reference, dim=1)
input_pred = torch.softmax(output_new, dim=1)
......@@ -73,10 +114,12 @@ def calc_segmentation_consistency(output, reference,divergence_types=['kl','cont
elif divergence_type == 'contour': ## contour-based loss
target_pred = torch.softmax(output_reference, dim=1)
input_pred = torch.softmax(output_new, dim=1)
cnt = 0
for i in range(1,num_classes):
cnt +=1
loss += contour_loss(input=input_pred[:,[i],], target=(target_pred[:,[i]]).detach(), ignore_background=False,mask=mask,
one_hot_target=False)
loss = loss/(num_classes-1)
# if cnt>0:loss/=cnt
else:
raise NotImplementedError
......@@ -102,7 +145,7 @@ def contour_loss(input, target, size_average=True, use_gpu=True,ignore_backgroun
:param one_hot_target: boolean. if true, will first convert the target from NHW to NCHW. Default: True.
:return:
'''
num_classes = input.size(1)
n,num_classes,h,w = input.size(0),input.size(1),input.size(2),input.size(3)
if one_hot_target:
onehot_mapper = One_Hot(depth=num_classes, use_gpu=use_gpu)
target = target.long()
......@@ -162,8 +205,8 @@ def contour_loss(input, target, size_average=True, use_gpu=True,ignore_backgroun
g_x_truth = conv_x(target_object_maps)*mask[:,:object_classes]
## mse loss
loss =torch.nn.MSELoss(reduction='mean')(input=g_x_pred,target=g_x_truth) +torch.nn.MSELoss(reduction='mean')(input=g_y_pred,target=g_y_truth)
loss =torch.nn.MSELoss(reduction='sum')(input=g_x_pred,target=g_x_truth) +torch.nn.MSELoss(reduction='sum')(input=g_y_pred,target=g_y_truth)
loss/= torch.sum(mask[:,0,:,:])
return loss
......@@ -185,7 +228,8 @@ def kl_divergence(reference, pred,mask=None):
plogp = torch.sum(cls_plogp,dim=1,keepdim=True)
plogq = torch.sum(cls_plogq,dim=1,keepdim=True)
kl_loss = torch.mean(plogp - plogq)
kl_loss = torch.sum(plogp - plogq)
kl_loss/=torch.sum(mask[:,0,:,:])
return kl_loss
......
......@@ -6,13 +6,13 @@
"validate_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/validation",
"image_format_name": "sa_{frame}.nii.gz",
"label_format_name": "label_sa_{frame}.nii.gz",
"data_aug_policy" :"UKBB_advancedv3",
"data_aug_policy" :"UKBB_advancedv4",
"if_resample": true,
"new_spacing": [1.25, 1.25, 10],
"keep_z_spacing": true,
"image_size":[256,256,1],
"label_size":[256,256],
"pad_size": [256,256,1],
"image_size":[224,224,1],
"label_size":[224,224],
"pad_size": [192,192,1],
"crop_size" :[192,192,1],
"num_classes": 4,
"use_cache": false,
......@@ -26,8 +26,8 @@
"num_classes": 4,
"resume_path":"./checkpoints/Unet_LVSA_trained_from_UKBB.pkl",
"lr": 0.00001,
"optimizer_name": "SGD",
"n_epochs": 1000,
"optimizer_name": "adam",
"n_epochs": 50,
"max_iteration": 500000,
"batch_size": 20,
"use_gpu": true,
......
......@@ -6,13 +6,13 @@
"validate_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/validation",
"image_format_name": "sa_{frame}.nii.gz",
"label_format_name": "label_sa_{frame}.nii.gz",
"data_aug_policy" :"UKBB_advanced_z_score",
"data_aug_policy" :"UKBB_advancedv4",
"if_resample": true,
"new_spacing": [1.25, 1.25, 10],
"keep_z_spacing": true,
"image_size":[256,256,1],
"label_size":[256,256],
"pad_size": [256,256,1],
"image_size":[224,224,1],
"label_size":[224,224],
"pad_size": [192,192,1],
"crop_size" :[192,192,1],
"num_classes": 4,
"use_cache": false,
......@@ -27,7 +27,7 @@
"resume_path":"./checkpoints/Unet_LVSA_trained_from_UKBB.pkl",
"lr": 0.00001,
"optimizer_name": "adam",
"n_epochs": 1000,
"n_epochs": 50,
"max_iteration": 500000,
"batch_size": 20,
"use_gpu": true,
......@@ -36,7 +36,7 @@
"adversarial_augmentation":
{
"transformation_type":"composite"
"transformation_type":"bias"
}
,
"output":
......
......@@ -6,7 +6,7 @@
"validate_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/validation",
"image_format_name": "sa_{frame}.nii.gz",
"label_format_name": "label_sa_{frame}.nii.gz",
"data_aug_policy" :"UKBB_advancedv2",
"data_aug_policy" :"UKBB_advancedv4",
"if_resample": true,
"new_spacing": [1.25, 1.25, 10],
"keep_z_spacing": true,
......@@ -27,16 +27,17 @@
"resume_path":"./checkpoints/Unet_LVSA_trained_from_UKBB.pkl",
"lr": 0.00001,
"optimizer_name": "adam",
"n_epochs": 1000,
"n_epochs": 50,
"max_iteration": 500000,
"batch_size": 20,
"use_gpu": true,
"decoder_dropout": 0.2
"decoder_dropout": 0.1
},
"adversarial_augmentation":
{
"transformation_type":"composite"
}
,
"output":
......
......@@ -6,13 +6,13 @@
"validate_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/validation",
"image_format_name": "sa_{frame}.nii.gz",
"label_format_name": "label_sa_{frame}.nii.gz",
"data_aug_policy" :"UKBB_advancedv3",
"data_aug_policy" :"UKBB_advancedv4",
"if_resample": true,
"new_spacing": [1.25, 1.25, 10],
"keep_z_spacing": true,
"image_size":[256,256,1],
"label_size":[256,256],
"pad_size": [256,256,1],
"image_size":[224,224,1],
"label_size":[224,224],
"pad_size": [192,192,1],
"crop_size" :[192,192,1],
"num_classes": 4,
"use_cache": false,
......@@ -26,8 +26,8 @@
"num_classes": 4,
"resume_path":"./checkpoints/Unet_LVSA_trained_from_UKBB.pkl",
"lr": 0.00001,
"optimizer_name": "SGD",
"n_epochs": 1000,
"optimizer_name": "adam",
"n_epochs": 50,
"max_iteration": 500000,
"batch_size": 20,
"use_gpu": true,
......@@ -37,7 +37,9 @@
"adversarial_augmentation":
{
"transformation_type":"composite",
"optimization_mode":"independent"
"divergence_types":["kl","contour"],
"divergence_weights":[1,1]
}
,
"output":
......
{ "name": "Composite",
"data": {
"dataset_name":"UKBB" ,
"readable_frames": ["ED", "ES"],
"train_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/train",
"validate_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/validation",
"image_format_name": "sa_{frame}.nii.gz",
"label_format_name": "label_sa_{frame}.nii.gz",
"data_aug_policy" :"UKBB_advancedv3",
"if_resample": true,
"new_spacing": [1.25, 1.25, 10],
"keep_z_spacing": true,
"image_size":[256,256,1],
"label_size":[256,256],
"pad_size": [256,256,1],
"crop_size" :[192,192,1],
"num_classes": 4,
"use_cache": false,
"myocardium_only": false,
"ignore_black_slices": true
},
"segmentation_model": {
"network_type": "IN_UNet_64",
"num_classes": 4,
"resume_path":"./checkpoints/Unet_LVSA_trained_from_UKBB.pkl",
"lr": 0.00001,
"optimizer_name": "adam",
"n_epochs": 1000,
"max_iteration": 500000,
"batch_size": 20,
"use_gpu": true,
"decoder_dropout": 0.1
},
"adversarial_augmentation":
{
"transformation_type":"composite"
}
,
"output":
{
"save_epoch_every_num_epochs":1,
"save_dir":"./result/Composite"
}
}
\ No newline at end of file
{ "name": "Composite",
"data": {
"dataset_name":"UKBB" ,
"readable_frames": ["ED", "ES"],
"train_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/train",
"validate_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/validation",
"image_format_name": "sa_{frame}.nii.gz",
"label_format_name": "label_sa_{frame}.nii.gz",
"data_aug_policy" :"UKBB_advanced_z_score",
"if_resample": true,
"new_spacing": [1.25, 1.25, 10],
"keep_z_spacing": true,
"image_size":[256,256,1],
"label_size":[256,256],
"pad_size": [256,256,1],
"crop_size" :[192,192,1],
"num_classes": 4,
"use_cache": false,
"myocardium_only": false,
"ignore_black_slices": true
},
"segmentation_model": {
"network_type": "IN_UNet_64",
"num_classes": 4,
"resume_path":"./checkpoints/Unet_LVSA_trained_from_UKBB.pkl",
"lr": 0.00001,
"optimizer_name": "adam",
"n_epochs": 1000,
"max_iteration": 500000,
"batch_size": 20,
"use_gpu": true,
"decoder_dropout": 0.1
},
"adversarial_augmentation":
{
"transformation_type":"composite"
}
,
"output":
{
"save_epoch_every_num_epochs":1,
"save_dir":"./result/Composite"
}
}
\ No newline at end of file
{ "name": "Composite",
"data": {
"dataset_name":"UKBB" ,
"readable_frames": ["ED", "ES"],
"train_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/train",
"validate_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/validation",
"image_format_name": "sa_{frame}.nii.gz",
"label_format_name": "label_sa_{frame}.nii.gz",
"data_aug_policy" :"UKBB_advanced_z_score",
"if_resample": true,
"new_spacing": [1.25, 1.25, 10],
"keep_z_spacing": true,
"image_size":[256,256,1],
"label_size":[256,256],
"pad_size": [256,256,1],
"crop_size" :[192,192,1],
"num_classes": 4,
"use_cache": false,
"myocardium_only": false,
"ignore_black_slices": true
},
"segmentation_model": {
"network_type": "IN_UNet_64",
"num_classes": 4,
"resume_path":"./checkpoints/Unet_LVSA_trained_from_UKBB.pkl",
"lr": 0.00001,
"optimizer_name": "adam",
"n_epochs": 1000,
"max_iteration": 500000,
"batch_size": 20,
"use_gpu": true,
"decoder_dropout": 0.1
},
"adversarial_augmentation":
{
"transformation_type":"composite"
}
,
"output":
{
"save_epoch_every_num_epochs":1,
"save_dir":"./result/Composite"
}
}
\ No newline at end of file
{ "name": "Composite",
"data": {
"dataset_name":"UKBB" ,
"readable_frames": ["ED", "ES"],
"train_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/train",
"validate_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/validation",
"image_format_name": "sa_{frame}.nii.gz",
"label_format_name": "label_sa_{frame}.nii.gz",
"data_aug_policy" :"UKBB_advancedv3",
"if_resample": true,
"new_spacing": [1.25, 1.25, 10],
"keep_z_spacing": true,
"image_size":[256,256,1],
"label_size":[256,256],
"pad_size": [256,256,1],
"crop_size" :[192,192,1],
"num_classes": 4,
"use_cache": false,
"myocardium_only": false,
"ignore_black_slices": true
},
"segmentation_model": {
"network_type": "UNet_64",
"num_classes": 4,
"resume_path":"./checkpoints/Unet_LVSA_trained_from_UKBB.pkl",
"lr": 0.00001,
"optimizer_name": "SGD",
"n_epochs": 1000,
"max_iteration": 500000,
"batch_size": 20,
"use_gpu": true,
"decoder_dropout": 0.1
},
"adversarial_augmentation":
{
"transformation_type":"composite"
}
,
"output":
{
"save_epoch_every_num_epochs":1,
"save_dir":"./result/Composite"
}
}
\ No newline at end of file
......@@ -7,7 +7,6 @@
"image_format_name": "sa_{frame}.nii.gz",
"label_format_name": "label_sa_{frame}.nii.gz",
"data_aug_policy" :"UKBB_affine_elastic",
"if_resample": true,
"new_spacing": [1.25, 1.25, 10],
"keep_z_spacing": true,
......
{ "name": "Bias",
"data": {
"dataset_name":"UKBB" ,
"readable_frames": ["ED", "ES"],
"train_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/train",
"validate_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/validation",
"image_format_name": "sa_{frame}.nii.gz",
"label_format_name": "label_sa_{frame}.nii.gz",
"data_aug_policy" :"UKBB_advanced",
"if_resample": true,
"new_spacing": [1.25, 1.25, 10],
"keep_z_spacing": true,
"image_size":[224,224,1],
"label_size":[224,224],
"pad_size": [192,192,1],
"crop_size" :[192,192,1],
"num_classes": 4,
"use_cache": false,
"myocardium_only": false,
"ignore_black_slices": true
},
"segmentation_model": {
"network_type": "UNet_64",
"num_classes": 4,
"resume_path":"./checkpoints/Unet_LVSA_best.pkl",
"lr": 0.00001,
"n_epochs": 50,
"max_iteration": 500000,
"batch_size": 20,
"use_gpu": true,
"decoder_dropout": true
},
"adversarial_augmentation":
{
"transformation_type":"bias"
}
,
"output":
{
"save_epoch_every_num_epochs":1,
"save_dir":"./result/Bias"
}
}
\ No newline at end of file