Commit 96ab1196 authored by cc215's avatar cc215 💬
Browse files

update readme

parent 824bb396
......@@ -127,9 +127,10 @@ e.g.
| Testing config <br> (batch_size=1, roi =256) | UKBB test (600) | | | ACDC (100) | | | M\&Ms (150) | | |
|:-: |:-: |:-: |:-: |:-: |:-: |:-: |:-: |:-: |:-: |
| model: UNet_64 | LV | MYO | RV | LV | MYO | RV | LV | MYO | RV |
| Unet_LVSA_trained_from_UKBB.pkl | 0.9383 | 0.8780 | 0.8979 | 0.8940 | 0.8034 | 0.8237 | 0.8862 | 0.7889 | 0.8168 |
| UNet_LVSA_Adv_Compose_(epochs=20).pth | 0.9360 | 0.8726 | 0.8966 | 0.8984 | 0.7973 | 0.8440 | 0.8873 | 0.7859 | 0.8343 |
| baseline | 0.9383 | 0.8780 | 0.8979 | 0.8940 | 0.8034 | 0.8237 | 0.8862 | 0.7889 | 0.8168 |
| Finetune w. random DA | 0.9378 | 0.8768 | 0.8975 | 0.8884 | 0.7981 | 0.8295 | 0.8846 | 0.7893 | 0.8158 |
| Finetune w. random DA + adv Bias:<br>UNet_LVSA_Adv_bias_(epochs=20).pth | 0.9326 | 0.8722 | 0.8973 | 0.8809 | 0.7912 | 0.8395 | 0.8794 | 0.7812 | 0.8228 |
| Finetune w. random DA + adv Composite DA:<br><br>UNet_LVSA_Adv_Compose_(epochs=20).pth | 0.9360 | 0.8726 | 0.8966 | 0.8984 | 0.7973 | 0.8440 | 0.8873 | 0.7859 | 0.8343 |
- To deploy the model for segmentation, please run the following command to test first:
- run `source ./demo_scripts/predict_test.sh`
......
......@@ -93,3 +93,57 @@ def transform2tensor(cPader, img_slice,if_z_score=False):
##transform to tensor
new_image_tensor = torch.from_numpy(new_img_slice).float()
return new_image_tensor
def construct_input(segmentation,image=None,num_classes=None,temperature =1.0,apply_softmax=True, is_labelmap=False, smooth_label=False,shuffle=False,use_gpu=True):
"""
concat image and segmentation toghether to form an input to an external assessor
Args:
image ([4d float tensor]): a of batch of images N(Ch)HW, Ch is the image channel
segmentation ([4d float tensor] or 3d label map): corresponding segmentation map NCHW or 3 one hotmap NHW
shuffle (bool, optional): if true, it will shuffle the input image and segmentation before concat. Defaults to False.
"""
assert (apply_softmax and is_labelmap) is False
if not is_labelmap:
batch_size, h,w = segmentation.size(0),segmentation.size(2),segmentation.size(3)
else:
batch_size, h,w = segmentation.size(0),segmentation.size(1),segmentation.size(2)
device = torch.device('cuda') if use_gpu else torch.device('cpu')
if not is_labelmap:
if apply_softmax:
assert len(segmentation.size())==4
segmentation = segmentation/temperature
softmax_predict = torch.softmax(segmentation,dim=1)
segmentation =softmax_predict
else:
## make onehot maps
assert num_classes is not None, 'please specify num_classes'
flatten_y = segmentation.view(batch_size*h*w, 1)
y_onehot = torch.zeros(batch_size*h*w, num_classes,dtype = torch.float32,device=device)
y_onehot.scatter_(1, flatten_y, 1)
y_onehot =y_onehot.view(batch_size,h,w, num_classes)
y_onehot = y_onehot.permute(0,3,1,2)
y_onehot.requires_grad=False
if smooth_label:
## add noise to labels
smooth_factor =torch.rand(1, device=device)*0.2
y_onehot[y_onehot==1] = 1-smooth_factor
y_onehot[y_onehot==0] = smooth_factor/(num_classes-1)
segmentation = y_onehot
if shuffle and image is not None:
## shuffle images in a batch, such that the segmentations do not match anymore.
image = shuffle_tensor(image)
if image is not None:
tuple = torch.cat([segmentation,image],dim=1)
return tuple
else:
return segmentation
\ No newline at end of file
......@@ -6,7 +6,7 @@ import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Variable
def cross_entropy_2D(input, target, weight=None, size_average=True):
def cross_entropy_2D(input, target, weight=None, size_average=True,mask=None):
"""[summary]
calc cross entropy loss computed on 2D images
Args:
......@@ -24,39 +24,82 @@ def cross_entropy_2D(input, target, weight=None, size_average=True):
n, c, h, w = input.size()
log_p = F.log_softmax(input, dim=1)
log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
if mask is None:
mask = torch.ones_like(log_p,device = log_p.device) ##
mask =mask.view(-1,c)
mask_region_size = torch.sum(mask[:,0])
if len(target.size())==3:
target = target.view(target.numel())
if not weight is None:
## sum(weight) =C, for numerical stability.
weight = torch.softmax(weight,dim=0)*c
loss = F.nll_loss(log_p, target, weight=weight, reduction='sum')
loss_vector = F.nll_loss(log_p, target, weight=weight, reduce=False)
loss_vector = loss_vector*mask[:,0]
loss = torch.sum(loss_vector)
if size_average:
loss /= float(target.numel())
loss /= float(mask_region_size) ## /N*H'*W'
elif len(target.size())==4:
## ce loss=-qlog(p)
reference= F.softmax(target, dim=1) #M,C
reference = reference.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) #M,C
if weight is None:
plogq = torch.sum(reference *log_p, dim=1)
plogq = torch.sum(reference *log_p*mask, dim=1)
plogq = torch.sum(plogq)
if size_average:
plogq/= float(n*h*w)
plogq/= float(mask_region_size)
else:
## sum(weight) =C
weight = torch.softmax(weight,dim=0)*c
plogq_class_wise =reference *log_p
plogq_class_wise =reference *log_p*mask
plogq_sum_class=0.
for i in range(c):
plogq_sum_class+=torch.sum(plogq_class_wise[:,i]*weight[i])
plogq = plogq_sum_class
if size_average:
plogq/= float(n*h*w)
plogq/= float(mask_region_size) # only average loss on the mask entries with value =1
loss=-1*plogq
else:
raise NotImplementedError
return loss
class SoftDiceLoss(nn.Module):
### Dice loss: code is from https://github.com/ozan-oktay/Attention-Gated-Networks/blob/master/models/layers/loss
# .py
def __init__(self, n_classes, use_gpu=True,squared_union=False):
super(SoftDiceLoss, self).__init__()
self.one_hot_encoder = One_Hot(n_classes, use_gpu).forward
self.n_classes = n_classes
self.squared_union =squared_union
def forward(self, input, target, weight=None):
smooth =0.01
batch_size = input.size(0)
input = F.softmax(input, dim=1).view(batch_size, self.n_classes, -1)
if len(target.size())==3:
target = self.one_hot_encoder(target).contiguous().view(batch_size, self.n_classes, -1)
elif len(target.size())==4 and target.size(1) ==input.size(1):
target = F.softmax(target, dim=1).view(batch_size, self.n_classes, -1)
target = target.view(batch_size, self.n_classes, -1)
else:
print ( 'the shapes for input and target do not match, input:{} target:{}'.format(str(input.size())),str(target.size()))
raise ValueError
inter = torch.sum(input * target, 2)
if self.squared_union:
##2pq/(|p|^2+|q|^2)
union = torch.sum(input**2, 2) + torch.sum(target**2, 2)
else:
##2pq/(|p|+|q|)
union = torch.sum(input, 2) + torch.sum(target, 2)
score = torch.sum((2.0 * inter+smooth) / (union+smooth))
score = 1.0 - score / (float(batch_size) * float(self.n_classes))
return score
def calc_segmentation_mse_consistency(input, target):
loss = calc_segmentation_consistency(output=input,reference=target,divergence_types=['mse'],divergence_weights=[1.0],class_weights=None,mask=None)
return loss
......@@ -66,7 +109,7 @@ def calc_segmentation_kl_consistency(input, target):
def calc_segmentation_consistency(output, reference,divergence_types=['kl','contour'],
divergence_weights=[1.0,0.5],scales=[0],
divergence_weights=[1.0,0.5],
mask=None):
"""
measuring the difference between two predictions (network logits before softmax)
......@@ -83,21 +126,15 @@ def calc_segmentation_consistency(output, reference,divergence_types=['kl','cont
loss (tensor float):
"""
dist = 0.
num_classes = reference.size(1)
num_classes = output.size(1)
reference = reference.detach()
if mask is None:
## apply masks so that only gradients on certain regions will be backpropagated.
mask = torch.ones_like(output).float().to(reference.device)
for scale in scales:
if scale>0:
output_reference = torch.nn.AvgPool2d(2 ** scale)(reference)
output_new = torch.nn.AvgPool2d(2 ** scale)(output)
else:
output_reference = reference
output_new = output
for divergence_type, d_weight in zip(divergence_types, divergence_weights):
output_reference = reference
output_new = output
for divergence_type, d_weight in zip(divergence_types, divergence_weights):
loss = 0.
if divergence_type=='kl':
'''
......@@ -105,12 +142,12 @@ def calc_segmentation_consistency(output, reference,divergence_types=['kl','cont
'''
loss = kl_divergence(pred=output_new,reference=output_reference.detach(),mask=mask)
elif divergence_type =='ce':
loss = cross_entropy_2D(input=output_new*mask,target=mask*output_reference.detach())
loss = cross_entropy_2D(input=output_new,target=output_reference.detach(),mask=mask)
elif divergence_type =='mse':
target_pred = torch.softmax(output_reference, dim=1)
input_pred = torch.softmax(output_new, dim=1)
loss = torch.nn.MSELoss(reduction='mean')(target = target_pred*mask, input = input_pred*mask)
loss = torch.nn.MSELoss(reduction='sum')(target = target_pred*mask, input = input_pred*mask)
loss = loss/torch.sum(mask[:,0])
elif divergence_type == 'contour': ## contour-based loss
target_pred = torch.softmax(output_reference, dim=1)
input_pred = torch.softmax(output_new, dim=1)
......@@ -126,9 +163,8 @@ def calc_segmentation_consistency(output, reference,divergence_types=['kl','cont
# print ('{}:{}'.format(divergence_type,loss.item()))
dist += 2 ** scale*(d_weight * loss)
return dist / (1.0 * len(scales))
dist += (d_weight * loss)
return dist
......
......@@ -6,6 +6,21 @@ from medpy.metric.binary import dc
from common_utils.measure import hd, hd_2D_stack, asd, volumesimilarity
import pandas as pd
from IPython.display import display, HTML
import scipy.stats as stats
def p_value_test(reference_df, test_df, attributes=['LV_Dice','MYO_Dice','RV_Dice']):
'''
given a reference_df and test_df for a specific dataset
return p-value dict for each attribute
'''
p_value_dict={}
for aclass in attributes:
ttest,lv_pval = stats.ttest_rel(test_df[aclass], reference_df[aclass])
p_value_dict[aclass]='{0:.4f}'.format(lv_pval)
# print(p_value_dict)
return p_value_dict
class runningScore(object):
......
{ "name": "Composite",
"data": {
"dataset_name":"ACDC" ,
"readable_frames": [ "ES"],
"train_dir": "/vol/biomedic3/cc215/data/ACDC/semi_supervised_learning/labelled",
"validate_dir": "/vol/biomedic3/cc215/data/ACDC/semi_supervised_learning/validate",
"image_format_name": "{frame}_img.nrrd",
"label_format_name": "{frame}_seg.nrrd",
"data_aug_policy" :"UKBB_advancedv4",
"if_resample": true,
"new_spacing": [1.36719, 1.36719, -1],
"keep_z_spacing": true,
"image_size":[224,224,1],
"label_size":[224,224],
"pad_size": [192,192,1],
"crop_size" :[192,192,1],
"num_classes": 4,
"use_cache": true,
"myocardium_only": false,
"ignore_black_slices": true
},
"segmentation_model": {
"network_type": "UNet_16",
"num_classes": 4,
"resume_path":"",
"lr": 0.0001,
"optimizer_name": "adam",
"n_epochs": 600,
"max_iteration": 500000,
"batch_size": 20,
"use_gpu": true,
"decoder_dropout":false
},
"adversarial_augmentation":
{
"transformation_type":"bias",
"divergence_types":["mse","contour"],
"divergence_weights":[1,0.5],
"optimization_mode": "chain",
"n_iter":1,
"power_iteration":true
}
,
"output":
{
"save_epoch_every_num_epochs":1,
"save_dir":"./result/Composite"
}
}
\ No newline at end of file
{ "name": "Composite",
"data": {
"dataset_name":"ACDC" ,
"readable_frames": [ "ES"],
"train_dir": "/vol/biomedic3/cc215/data/ACDC/semi_supervised_learning/labelled",
"validate_dir": "/vol/biomedic3/cc215/data/ACDC/semi_supervised_learning/validate",
"image_format_name": "{frame}_img.nrrd",
"label_format_name": "{frame}_seg.nrrd",
"data_aug_policy" :"UKBB_advancedv4",
"if_resample": true,
"new_spacing": [1.36719, 1.36719, -1],
"keep_z_spacing": true,
"image_size":[224,224,1],
"label_size":[224,224],
"pad_size": [192,192,1],
"crop_size" :[192,192,1],
"num_classes": 4,
"use_cache": true,
"myocardium_only": false,
"ignore_black_slices": true
},
"segmentation_model": {
"network_type": "UNet_16",
"num_classes": 4,
"resume_path":"",
"lr": 0.0001,
"optimizer_name": "adam",
"n_epochs": 600,
"max_iteration": 500000,
"batch_size": 20,
"use_gpu": true,
"decoder_dropout":false
},
"adversarial_augmentation":
{
"transformation_type":"bias",
"divergence_types":["ce"],
"divergence_weights":[1],
"optimization_mode": "chain",
"n_iter":1,
"power_iteration":true
}
,
"output":
{
"save_epoch_every_num_epochs":1,
"save_dir":"./result/Composite"
}
}
\ No newline at end of file
{ "name": "Composite",
"data": {
"dataset_name":"ACDC" ,
"readable_frames": [ "ES"],
"train_dir": "/vol/biomedic3/cc215/data/ACDC/semi_supervised_learning/labelled",
"validate_dir": "/vol/biomedic3/cc215/data/ACDC/semi_supervised_learning/validate",
"image_format_name": "{frame}_img.nrrd",
"label_format_name": "{frame}_seg.nrrd",
"data_aug_policy" :"UKBB_advancedv4",
"if_resample": true,
"new_spacing": [1.36719, 1.36719, -1],
"keep_z_spacing": true,
"image_size":[224,224,1],
"label_size":[224,224],
"pad_size": [192,192,1],
"crop_size" :[192,192,1],
"num_classes": 4,
"use_cache": true,
"myocardium_only": false,
"ignore_black_slices": true
},
"segmentation_model": {
"network_type": "UNet_16",
"num_classes": 4,
"resume_path":"",
"lr": 0.0001,
"optimizer_name": "adam",
"n_epochs": 600,
"max_iteration": 500000,
"batch_size": 20,
"use_gpu": true,
"decoder_dropout":false
},
"adversarial_augmentation":
{
"transformation_type":"composite",
"divergence_types":["mse","contour"],
"divergence_weights":[1,0.5],
"optimization_mode": "independent",
"n_iter":1,
"power_iteration":true
}
,
"output":
{
"save_epoch_every_num_epochs":1,
"save_dir":"./result/Composite"
}
}
\ No newline at end of file
......@@ -36,8 +36,10 @@
"adversarial_augmentation":
{
"transformation_type":"composite"
"transformation_type":"composite",
"divergence_types":["mse","contour"],
"divergence_weights":[1,0.5],
"n_iter":1
}
,
"output":
......
{ "name": "Composite",
"data": {
"dataset_name":"UKBB" ,
"readable_frames": ["ED", "ES"],
"train_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/train",
"validate_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/validation",
"image_format_name": "sa_{frame}.nii.gz",
"label_format_name": "label_sa_{frame}.nii.gz",
"data_aug_policy" :"UKBB_advancedv4",
"if_resample": true,
"new_spacing": [1.25, 1.25, 10],
"keep_z_spacing": true,
"image_size":[224,224,1],
"label_size":[224,224],
"pad_size": [192,192,1],
"crop_size" :[192,192,1],
"num_classes": 4,
"use_cache": false,
"myocardium_only": false,
"ignore_black_slices": true
},
"segmentation_model": {
"network_type": "UNet_64",
"num_classes": 4,
"resume_path":"./checkpoints/Unet_LVSA_trained_from_UKBB.pkl",
"lr": 0.00001,
"optimizer_name": "adam",
"n_epochs": 50,
"max_iteration": 500000,
"batch_size": 20,
"use_gpu": true,
"decoder_dropout": 0.1
},
"adversarial_augmentation":
{
"transformation_type":"composite",
"divergence_types":["mse","contour"],
"divergence_weights":[1,0.5],
"optimization_mode": "independent",
"n_iter":1,
"power_iteration":true,
"random_select":true
}
,
"output":
{
"save_epoch_every_num_epochs":1,
"save_dir":"./result/Composite"
}
}
\ No newline at end of file
{ "name": "Composite",
"data": {
"dataset_name":"UKBB" ,
"readable_frames": ["ED", "ES"],
"train_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/train",
"validate_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/validation",
"image_format_name": "sa_{frame}.nii.gz",
"label_format_name": "label_sa_{frame}.nii.gz",
"data_aug_policy" :"UKBB_advancedv4",
"if_resample": true,
"new_spacing": [1.25, 1.25, 10],
"keep_z_spacing": true,
"image_size":[224,224,1],
"label_size":[224,224],
"pad_size": [192,192,1],
"crop_size" :[192,192,1],
"num_classes": 4,
"use_cache": false,
"myocardium_only": false,
"ignore_black_slices": true
},
"segmentation_model": {
"network_type": "UNet_64",
"num_classes": 4,
"resume_path":"./checkpoints/Unet_LVSA_trained_from_UKBB.pkl",
"lr": 0.00001,
"optimizer_name": "adam",
"n_epochs": 50,
"max_iteration": 500000,
"batch_size": 20,
"use_gpu": true,
"decoder_dropout": 0.1
},
"adversarial_augmentation":
{
"transformation_type":"composite",
"divergence_types":["mse","contour"],
"divergence_weights":[1,0.5],
"optimization_mode": "chain",
"n_iter":1,
"power_iteration":false
}
,
"output":
{
"save_epoch_every_num_epochs":1,
"save_dir":"./result/Composite"
}
}
\ No newline at end of file
{ "name": "Composite",
"data": {
"dataset_name":"UKBB" ,
"readable_frames": ["ED", "ES"],
"train_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/train",
"validate_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/validation",
"image_format_name": "sa_{frame}.nii.gz",
"label_format_name": "label_sa_{frame}.nii.gz",
"data_aug_policy" :"UKBB_advancedv4",
"if_resample": true,
"new_spacing": [1.25, 1.25, 10],
"keep_z_spacing": true,
"image_size":[224,224,1],
"label_size":[224,224],
"pad_size": [192,192,1],
"crop_size" :[192,192,1],
"num_classes": 4,
"use_cache": false,
"myocardium_only": false,
"ignore_black_slices": true
},
"segmentation_model": {
"network_type": "UNet_64",
"num_classes": 4,
"resume_path":"./checkpoints/Unet_LVSA_trained_from_UKBB.pkl",
"lr": 0.00001,
"optimizer_name": "adam",
"n_epochs": 50,
"max_iteration": 500000,
"batch_size": 20,
"use_gpu": true,
"decoder_dropout": 0.1
},
"adversarial_augmentation":
{
"transformation_type":"composite",