Commit 9eabccd2 authored by cc215's avatar cc215 💬
Browse files

initial commit

parents
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="TestRunnerService">
<option name="PROJECT_TEST_RUNNER" value="Unittests" />
</component>
</module>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="JavaScriptSettings">
<option name="languageLevel" value="ES6" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.5 (pytorch3)" project-jdk-type="Python SDK" />
</project>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/Cardiac_Carlo.iml" filepath="$PROJECT_DIR$/.idea/Cardiac_Carlo.iml" />
</modules>
</component>
</project>
\ No newline at end of file
This diff is collapsed.
# Cardiac MRI Image Multi-view Segmentation
author: Chen Chen (cc215@ic.ac.uk)
## Features
- LV/MYO/RV segmentation on short axis view (LVSA):
- LV: left ventricle cavity, MYO: myocardium of left ventricle, RV: right ventricle
- MYO segmentation on the cardiac long axis views
- 4CH: 4 chamber view
- VLA: Vertical Long Axis
- LVOT: Left Ventricular Outflow Tract
## Environment
- Python 3.5
- Pytorch 1.0
- CUDA(cuda 9.0)
## Dependencies
- see requirements.txt
- install them by running `pip install -r requirements.txt`
## Test segmentation
- LVSA segmentation
- `python predict.py --sequence LVSA`
- VLA segmentation
- `python predict.py --sequence VLA`
- 4CH segmentation
- `python predict.py --sequence 4CH`
- LVOT segmentation
- `python predict.py --sequence LVOT`
## Customize your need
- please read predict.py for avanced settings.
import os
import numpy as np
import SimpleITK as sitk
from torch.utils import data
import torch
from dataset.utils import resample_by_spacing
class CARDIAC_Predict_DATASET(data.Dataset):
def __init__(self,
root_dir,
image_format_name,
readable_frames=['ED', 'ES'],
if_resample=True,
new_spacing=[1.25, 1.25, 10],
keep_z_spacing=True,
):
'''
:param root_dir: test folder
:param image_format_name: image name format, e.g.'LVOT/LVOT_img_{}.nii.gz' {} denotes ED or ES frame.
:param split: the subdir of test set, usually is train/test
:param readable_frames:
:param if_resample: Resample all image to same pixel spacing
:param new_spacing: new spacing [x,y,z]
:param keep_z_spacing: if do scaling across z axis, Default: False.
'''
super(CARDIAC_Predict_DATASET, self).__init__()
self.readable_frames = readable_frames
dataset_dir = root_dir
self.dataset_dir = dataset_dir
p_list, p_path_list = self.get_p_list(self.dataset_dir)
self.patient_list = p_list
self.patient_path_list = p_path_list
self.data_size = len(self.patient_path_list)
print('Number of images: {} '.format(self.data_size))
self.image_format_name = image_format_name
self.new_spacing = new_spacing
self.if_resample = if_resample
self.pid = 0
self.not_found = [] ##record all missing data path
self.keep_z_spacing = keep_z_spacing
def get_p_list(self, dir):
'''
get patient path
:param dir:
:return: patient_id list, patient _path_list
'''
p_list = []
path_list = []
for pid in sorted(os.listdir(dir)):
p_path = os.path.join(dir, pid)
if os.path.exists(p_path):
p_list.append(pid)
path_list.append(p_path)
# report the number of images in the dataset
return p_list, path_list
def get_size(self):
return len(self.patient_list)
def reset_pid(self):
self.pid = 0
def get_next_patient(self):
'''
:return: ED and ES frame of one patient data
'''
patient_data = {}
if not self.pid == len(self.patient_path_list):
for frame in self.readable_frames:
temp_image_path = os.path.join(self.patient_path_list[self.pid], self.image_format_name.format(frame))
## check path exists
print('try to read {}'.format(temp_image_path))
if not os.path.exists(temp_image_path):
print('not found, ignore it')
continue
print('load success')
### read image ##
temp_image = sitk.ReadImage(temp_image_path)
original_shape=sitk.GetArrayFromImage(temp_image).shape
##convert to float format
temp_image = sitk.Cast(sitk.RescaleIntensity(temp_image), sitk.sitkFloat32)
origin_spacing = temp_image.GetSpacing()
if self.if_resample:
## image resampling
new_image = resample_by_spacing(im=temp_image, new_spacing=self.new_spacing,
interpolator=sitk.sitkLinear,
keep_z_spacing=self.keep_z_spacing)
else:
new_image = temp_image
## new image shape
self.aft_resample_shape = sitk.GetArrayFromImage(new_image).shape
data = sitk.GetArrayFromImage(new_image).astype(float)
patient_id = self.patient_path_list[self.pid].split('/')[-1]
## save frame data
patient_data[frame] = {'image': data, ##npy data
'origin_itk_image': temp_image, ##original data
'temp_image_path': temp_image_path,
'new_spacing': self.new_spacing,
'original_shape':original_shape,
'aft_resample_shape': self.aft_resample_shape,
'origin_spacing': origin_spacing,
'after_resampled_image': new_image,
'patient_id': patient_id
}
patient_data['patient_id'] = self.patient_list[self.pid]
self.pid += 1
return patient_data
def __len__(self):
return self.data_size
def get_name(self):
print('dataset loader')
@staticmethod
def transform2tensor(cPader, img_slice):
'''
transform npy data to torch tensor
:param cPader:pad image to be divided by 16
:param img_slice: N*H*W
:param label_slice: N*H*W
:return: N*1*H*W
'''
###
new_img_slice = cPader(img_slice)
## normalize data
new_img_slice = new_img_slice * 1.0 ##N*H*W
new_input_mean = np.mean(new_img_slice, axis=(1, 2), keepdims=True)
new_img_slice -= new_input_mean
new_std = np.std(new_img_slice, axis=(1, 2), keepdims=True)
new_img_slice /= (new_std)
new_img_slice = new_img_slice[:, np.newaxis, :, :]
##transform to tensor
new_image_tensor = torch.from_numpy(new_img_slice).float()
return new_image_tensor
if __name__ == '__main__':
import torch
import matplotlib.pyplot as plt
root_dir = '/vol/medic01/users/cc215/data/Carlo_Pathology/'
image_format_name = 'LVOT/LVOT_img_{}.nii.gz'
readable_frames = ['ED', 'ES']
n_classes = 2
testset = CARDIAC_Predict_DATASET(root_dir=root_dir, if_resample=True,
split='train',
image_format_name=image_format_name,
readable_frames=readable_frames,
)
torch.cuda.set_device(0)
# train_iter=itertools.iter(train_loader)
n = 1
fail_cases = []
for i in range(testset.get_size()):
data = testset.get_next_patient()
for frame in readable_frames:
try:
print(data['ED']['aft_resample_shape']) # 10*224*224*1
# print(data['ES']['aft_resample_shape']) # 10*224*224*1
# print(data['ED']['image'].shape) # 10*224*224*1
#
# plt.figure(figsize=(30,30))
if n == 1:
prev = data['ED']['image'][0, :, :]
plt.title(data['patient_id'])
plt.imshow(data['ED']['image'][0, :, :])
plt.show()
plt.colorbar()
except:
fail_cases.append(str(data['patient_id']) + frame)
continue
n += 1
This diff is collapsed.
from torch.nn import init
def weights_init_normal(m):
classname = m.__class__.__name__
#print(classname)
if classname.find('Conv') != -1:
init.normal(m.weight.data, 0.0, 0.02)
elif classname.find('Linear') != -1:
init.normal(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
init.normal(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
def weights_init_xavier(m):
classname = m.__class__.__name__
#print(classname)
if classname.find('Conv') != -1:
init.xavier_normal(m.weight.data, gain=1)
elif classname.find('Linear') != -1:
init.xavier_normal(m.weight.data, gain=1)
elif classname.find('BatchNorm') != -1:
init.normal(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
def weights_init_kaiming(m):
classname = m.__class__.__name__
#print(classname)
if classname.find('Conv') != -1:
init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
elif classname.find('Linear') != -1:
init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
elif classname.find('BatchNorm') != -1:
init.normal(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
def weights_init_orthogonal(m):
classname = m.__class__.__name__
#print(classname)
if classname.find('Conv') != -1:
init.orthogonal(m.weight.data, gain=1)
elif classname.find('Linear') != -1:
init.orthogonal(m.weight.data, gain=1)
elif classname.find('BatchNorm') != -1:
init.normal(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
def init_weights(net, init_type='normal'):
#print('initialization method [%s]' % init_type)
if init_type == 'normal':
net.apply(weights_init_normal)
elif init_type == 'xavier':
net.apply(weights_init_xavier)
elif init_type == 'kaiming':
net.apply(weights_init_kaiming)
elif init_type == 'orthogonal':
net.apply(weights_init_orthogonal)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
\ No newline at end of file
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
from torch.optim import lr_scheduler
import math
def get_scheduler(optimizer, lr_policy,lr_decay_iters=5,epoch_count=None,niter=None,niter_decay=None):
print('lr_policy = [{}]'.format(lr_policy))
if lr_policy == 'lambda':
def lambda_rule(epoch):
lr_l = 1.0 - max(0, epoch + 1 + epoch_count - niter) / float(niter_decay + 1)
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif lr_policy == 'step':
scheduler = lr_scheduler.StepLR(optimizer, step_size=lr_decay_iters, gamma=0.5)
elif lr_policy == 'step2':
scheduler = lr_scheduler.StepLR(optimizer, step_size=lr_decay_iters, gamma=0.1)
elif lr_policy == 'plateau':
print('schedular=plateau')
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, threshold=0.01, patience=5)
elif lr_policy == 'plateau2':
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
elif lr_policy == 'step_warmstart':
def lambda_rule(epoch):
#print(epoch)
if epoch < 5:
lr_l = 0.1
elif 5 <= epoch < 100:
lr_l = 1
elif 100 <= epoch < 200:
lr_l = 0.1
elif 200 <= epoch:
lr_l = 0.01
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif lr_policy == 'step_warmstart2':
def lambda_rule(epoch):
#print(epoch)
if epoch < 5:
lr_l = 0.1
elif 5 <= epoch < 50:
lr_l = 1
elif 50 <= epoch < 100:
lr_l = 0.1
elif 100 <= epoch:
lr_l = 0.01
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
else:
return NotImplementedError('learning rate policy [%s] is not implemented', lr_policy)
return scheduler
def spatial_pyramid_pool(previous_conv, batch_size, previous_conv_size, out_bin_sizes):
'''
ref: Spatial Pyramid Pooling in Deep ConvolutionalNetworks for Visual Recognition
previous_conv: a tensor vector of previous convolution layer
num_sample: an int number of image in the batch
previous_conv_size: an int vector [height, width] of the matrix features size of previous convolution layer
out_pool_size: a int vector of expected output size of max pooling layer
returns: a tensor vector with shape [1 x n] is the concentration of multi-level pooling
'''
# print(previous_conv.size())
for i in range(0, len(out_bin_sizes)):
print(previous_conv_size)
#assert previous_conv_size[0] % out_bin_sizes[i]==0, 'please make sure feature size can be devided by bins'
h_wid = int(math.ceil(previous_conv_size[0] / out_bin_sizes[i]))
w_wid = int(math.ceil(previous_conv_size[1] / out_bin_sizes[i]))
# h_stride = int(math.floor(previous_conv_size[0] / out_bin_sizes[i]))
# w_stride = int(math.floor(previous_conv_size[1] / out_bin_sizes[i]))
h_pad = (h_wid * out_bin_sizes[i] - previous_conv_size[0] + 1) // 2
w_pad = (w_wid * out_bin_sizes[i] - previous_conv_size[1] + 1) // 2
maxpool = nn.MaxPool2d(kernel_size=(h_wid, w_wid), stride=(h_wid, w_wid),padding=(h_pad,w_pad))
x = maxpool(previous_conv)
if (i == 0):
spp = x.view(batch_size, -1)
#print("spp size:",spp.size())
else:
# print("size:",spp.size())
spp = torch.cat((spp, x.view(batch_size, -1)), dim=1)
# print("spp size:",spp.size())
return spp
'''
https://discuss.pytorch.org/t/solved-reverse-gradients-in-backward-pass/3589/4
'''
class GradientReversalFunction(torch.autograd.Function):
def __init__(self, Lambda):
super(GradientReversalFunction, self).__init__()
self.Lambda = Lambda
def forward(self, input):
return input.view_as(input)
def backward(self, grad_output):
# Multiply gradient by -self.Lambda
return self.Lambda * grad_output.neg()
class GradientReversalLayer(nn.Module):
def __init__(self, Lambda, use_cuda=False):
super(GradientReversalLayer, self).__init__()
self.Lambda = Lambda
if use_cuda:
self.cuda()
def forward(self, input):
return GradientReversalFunction(self.Lambda)(input)
def change_lambda(self, Lambda):
self.Lambda = Lambda
def calc_gradient_penalty(netD, lamda,real_data, fake_data,gpu=0):
from torch import autograd
# print ("real_data: ", real_data.size())
batch_size=real_data.size(0)
alpha = torch.rand(batch_size, 1)
alpha = alpha.expand(batch_size, int(real_data.nelement()/batch_size)).contiguous().view(batch_size, real_data.size(1), real_data.size(2), real_data.size(3))
if gpu is not None:
alpha = alpha.cuda(gpu)
interpolates = alpha * real_data + ((1 - alpha) * fake_data)
if gpu is not None:
interpolates = interpolates.cuda(gpu)
interpolates = autograd.Variable(interpolates, requires_grad=True)
disc_interpolates = netD(interpolates)
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size()).cuda(gpu) if gpu is not None else torch.ones(
disc_interpolates.size()),
create_graph=True, retain_graph=True, only_inputs=True)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lamda
return gradient_penalty
def encode(label_map,n_classes):
'''
input label as tensor
return onehot label
:param label: batch_size*1*target_h*target_w
:return:label:batch_size*n_classes*target_h*target_w
'''
# create one-hot vector for label map
label_map=label_map[:,None,:,:]
size = label_map.size()
print (size)
oneHot_size = (size[0],n_classes, size[2], size[3])
input_label = torch.zeros(torch.Size(oneHot_size)).float().cuda()
label_map=Variable(label_map)
input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
return input_label
def gram_matrix_2D(y):
'''
give torch 4d tensor, calculate Gram Matrix
:param y:
:return:
'''
(b, ch, h, w) = y.size()
features = y.view(b, ch, w * h)
features_t = features.transpose(1, 2)
gram = features.bmm(features_t) / (ch * h * w)
return gram
def adjust_learning_rate(optimizer, lr):
"""Sets the learning rate to a fixed number"""
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def get_scheduler(optimizer, lr_policy,lr_decay_iters=5,epoch_count=None,niter=None,niter_decay=None):
print('lr_policy = [{}]'.format(lr_policy))
if lr_policy == 'lambda':
def lambda_rule(epoch):
lr_l = 1.0 - max(0, epoch + 1 + epoch_count - niter) / float(niter_decay + 1)
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif lr_policy == 'step':
scheduler = lr_scheduler.StepLR(optimizer, step_size=lr_decay_iters, gamma=0.5)
elif lr_policy == 'step2':
scheduler = lr_scheduler.StepLR(optimizer, step_size=lr_decay_iters, gamma=0.1)
elif lr_policy == 'plateau':
print('schedular=plateau')
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, threshold=0.01, patience=5)
elif lr_policy == 'plateau2':
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
elif lr_policy == 'step_warmstart':
def lambda_rule(epoch):
#print(epoch)
if epoch < 5:
lr_l = 0.1
elif 5 <= epoch < 100:
lr_l = 1
elif 100 <= epoch < 200:
lr_l = 0.1
elif 200 <= epoch:
lr_l = 0.01
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif lr_policy == 'step_warmstart2':
def lambda_rule(epoch):
#print(epoch)
if epoch < 5:
lr_l = 0.1
elif 5 <= epoch < 50:
lr_l = 1
elif 50 <= epoch < 100:
lr_l = 0.1
elif 100 <= epoch:
lr_l = 0.01
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
else:
return NotImplementedError('learning rate policy [%s] is not implemented', lr_policy)
return scheduler
def cal_cls_acc(pred,gt):
'''
input tensor
:param pred: network output N*n_classes
:param gt: ground_truth N [labels_id]
:return: float acc
'''
pred_class = pred.data.max(1)[1].cpu()
sum = gt.cpu().eq(pred_class).sum()
count = gt.size(0)
return sum, count
def cal_statistic_loss(featuremaps):
batch_size = featuremaps[0].size(0)
# print('bn:', batch_size)
style_loss = 0.
for f in featuremaps:
level_1 = f # batch*feature_stats
loss = 0.
std_erro=0.
mean_erro=0.
for i in range(batch_size):
instance_f=f[i] #f_n*h*w
instance_f_view = instance_f.view(1,instance_f.size(0), instance_f.size(1)*instance_f.size(2))
target_std = torch.std(instance_f_view, 2, unbiased=False).view(-1)
target_mean = torch.mean(instance_f_view, 2, keepdim=False).view(-1)