Commit 73fa4709 authored by cc215's avatar cc215 💬
Browse files

fix bug and add mc dropout prediction

parent d8808541
......@@ -10,13 +10,13 @@ import gc
from model.init_weight import init_weights
from model.unet import UNet
from model.model_utils import makeVariable
from common_utils.loss import cross_entropy_2D
from common_utils.metrics import runningScore
from common_utils.save import save_list_results_as_png
class SegmentationModel(nn.Module):
def __init__(self, network_type, in_channels=1, num_classes=2,
encoder_dropout=None,
decoder_dropout=None, use_gpu=True, lr=0.001, resume_path=None,
):
......@@ -39,8 +39,7 @@ class SegmentationModel(nn.Module):
self.num_classes = num_classes
self.lr = lr
self.in_channels = in_channels
self.encoder_dropout = encoder_dropout if isinstance(encoder_dropout,float) else None
self.decoder_dropout = decoder_dropout if isinstance(encoder_dropout,float) else None
self.decoder_dropout = decoder_dropout if isinstance(decoder_dropout,float) else None
self.model = self.get_network_from_model_library(self.network_type)
assert not self.model is None, 'cannot find the model given the specified name'
......@@ -137,8 +136,43 @@ class SegmentationModel(nn.Module):
with torch.no_grad():
output = self.model.forward(input)
probs = torch.softmax(output,dim=1)
torch.cuda.empty_cache()
return probs
def MC_predict(self,input, n_times=5,decoder_dropout=0.1, disable_bn=False):
assert n_times>=1
## use MC dropout to get ensembled prediction
## enable dropout
if self.decoder_dropout is None or self.decoder_dropout!=decoder_dropout:
self.decoder_dropout = decoder_dropout
self.model = self.get_network_from_model_library(self.network_type)
self.init_model(self.network_type)
if self.use_gpu:
self.model.cuda()
self.model.train()
## fix batch norm
if not disable_bn:
for module in self.model.modules():
# print(module)
if isinstance(module, nn.BatchNorm2d):
if hasattr(module, 'weight'):
module.weight.requires_grad_(False)
if hasattr(module, 'bias'):
module.bias.requires_grad_(False)
module.eval()
# mc sampling
probs_list=[]
for i in range(n_times):
image = input.detach()
output = self.model.forward(image)
probs_i= torch.softmax(output,dim=1)
probs_list.append(probs_i)
torch.cuda.empty_cache()
mean_probs = sum(probs_list)/len(probs_list)
return mean_probs,probs_list
def evaluate(self, input, targets_npy):
'''
evaluate the model performance
......
......@@ -75,11 +75,11 @@ def predict(sequence_name, root_dir, image_format_name,
else:
model.load_state_dict(torch.load(model_path))
# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
device = torch.device("cuda:{}".format(args.gpu) if (torch.cuda.is_available()) else "cpu")
if device.type == 'cuda':
model.cuda()
torch.cuda.set_device(0)
torch.cuda.set_device(args.gpu)
model.eval()
testset = CARDIAC_Predict_DATASET(root_dir, image_format_name=image_format_name,
......
......@@ -16,7 +16,7 @@ from common_utils.basic_operations import transform2tensor
def predict(model_path, input_image_path,
save_pred_path, batch_size=4, crop_size=256,if_resample=True,if_z_score=False):
save_pred_path=None, batch_size=4, crop_size=256,if_resample=True,if_z_score=False, gpu_id=0, mc_dropout=0,decoder_dropout_rate=0.1):
'''
:param model_path: path to the saved model parameters
......@@ -26,21 +26,25 @@ def predict(model_path, input_image_path,
:param crop_size: the size for image ROI cropping, need to be divided by 16.
:param if_resample:if resamping image to a uniform pixel-spacing before predition
:param if_z_score: if rescale images to have zero mean and std deviation, by default, we rescale it to 0-1.
:return:
:return:prediction: numpy array in 3D format NHW
'''
print('<------Loading model-------->')
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
device = torch.device("cuda:{}".format(gpu_id) if (torch.cuda.is_available()) else "cpu")
if device.type == 'cuda':
use_gpu =True
print ('use gpu')
else:
use_gpu=False
num_classes = 4
model = SegmentationModel(network_type='UNet_64',in_channels=1, num_classes=num_classes,use_gpu=use_gpu,resume_path=model_path)
decoder_dropout =None if mc_dropout ==0 else decoder_dropout_rate
model = SegmentationModel(network_type='UNet_64',in_channels=1, num_classes=num_classes,use_gpu=use_gpu, decoder_dropout=decoder_dropout,
resume_path=model_path)
model.eval()
print('<------Loading data-------->')
### read image ##
temp_image = sitk.ReadImage(input_image_path)
original_shape=sitk.GetArrayFromImage(temp_image).shape
original_im_arr=sitk.GetArrayFromImage(temp_image)
original_shape=original_im_arr.shape
temp_image = sitk.Cast(sitk.RescaleIntensity(temp_image), sitk.sitkFloat32)
origin_spacing = temp_image.GetSpacing()
......@@ -53,8 +57,8 @@ def predict(model_path, input_image_path,
new_image = temp_image
## new image shape
aft_resample_shape = sitk.GetArrayFromImage(new_image).shape
npy_data = sitk.GetArrayFromImage(new_image).astype(float)
aft_resample_shape = npy_data.shape
soft_prediction = np.zeros((aft_resample_shape[0], num_classes, aft_resample_shape[1], aft_resample_shape[2])) ##N*4*H*W
print('<------Batchwise prediction-------->')
......@@ -79,7 +83,9 @@ def predict(model_path, input_image_path,
input = Variable(input_tensor)
### predict every batch
batch_output = model(input)
if mc_dropout>0:
batch_output, batch_output_list= model.MC_predict(input,n_times=mc_dropout,decoder_dropout=decoder_dropout_rate)
else:batch_output = model.predict(input)
batch_output_npy = batch_output.data.cpu().numpy() ##batch_size*n_cls*H'*W'
temp = reversecroppad(batch_output_npy.squeeze()) ##batch_size*n_cls*H*W
soft_prediction[i * batch_size:(i + 1) * batch_size] = temp
......@@ -107,11 +113,13 @@ def predict(model_path, input_image_path,
if len(predict_result.shape) < len(original_shape):
predict_result = np.reshape(predict_result,original_shape)
print('Saving segmentation to {}'.format(save_pred_path))
post_im = sitk.GetImageFromArray(predict_result)
ref_im = temp_image
post_im.CopyInformation(ref_im)
sitk.WriteImage(post_im, save_pred_path, True)
if save_pred_path is not None:
print('Saving segmentation to {}'.format(save_pred_path))
post_im = sitk.GetImageFromArray(predict_result)
ref_im = temp_image
post_im.CopyInformation(ref_im)
sitk.WriteImage(post_im, save_pred_path, True)
return model,original_im_arr,predict_result
......@@ -126,10 +134,15 @@ if __name__ == '__main__':
parser.add_argument('-c','--crop_size', type=int, default=192, help="crop images to save memory")
parser.add_argument('-z','--z_score', action="store_true", default=False,help="normalize the images to have zero mean and std deviation.")
parser.add_argument('-g','--gpu', default=0,help='select GPU by masking shell environment variable CUDA_VISIBLE_DEVICES')
parser.add_argument('-d','--mc_dropout', default=0,help='if >0, it will apply MC dropout for d times, by default the dropout rate=0.1')
args = parser.parse_args()
### GPU CONFIG
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
gpu_id = args.gpu
predict(args.model_path, args.input_image_path,
args.output_segmentation_path, batch_size=args.batch_size, crop_size=args.crop_size,if_resample=True,if_z_score=args.z_score)
\ No newline at end of file
args.output_segmentation_path, batch_size=args.batch_size,
crop_size=args.crop_size,if_resample=True,
if_z_score=args.z_score,gpu_id=gpu_id,
mc_dropout=int(args.mc_dropout))
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment