Commit c74d1f3a authored by cc215's avatar cc215 💬
Browse files

updaye readme

parent 295b870b
......@@ -31,7 +31,32 @@ author: Chen Chen (cc215@ic.ac.uk)
- LVOT segmentation
- `python predict.py --sequence LVOT`
Results will be saved under `test_results` by default
Results will be saved under `test_results` by default.
## Model performance across intra-domain and public out-of-domain data
- For cardiac segmentation on short-axis images, we provide the model trained on UKBB (~4000 training subjects), with extensive data augmentations. This model achieves extradinary performance across various unseen challenge datasets, [ACDC](), [M&Ms](). We report the segmentation performance in terms of Dice score.
| | UKBB test (600) | | | ACDC (100) | | | M&Ms (150) | | |
|- |:-: |:-: |:-: |:-: |:-: |:-: |:-: |:-: |:-: |
| configurations | LV | MYO | RV | LV | MYO | RV | LV | MYO | RV |
| bs=1, roi size = 256, z_score | 0.9383 | 0.8780 | 0.8979 | 0.8940 | 0.8034 | 0.8237 | 0.8862 | 0.7889 | 0.8168 |
| bs=1, roi size = 192, z_score | 0.9371 | 0.8775 | 0.8962 | 0.8891 | 0.7981 | 0.8103 | 0.8725 | 0.7716 | 0.7954 |
| bs=-1, roi size = 192, z_score | 0.9139 | 0.8388 | 0.8779 | 0.8790 | 0.7818 | 0.8069 | 0.8679 | 0.7675 | 0.7926 |
The model params are stored in `./checkpoints/Unet_LVSA_trained_from_UKBB.pkl`
For optimal performance at deployment, please set it to instance normalization (--batch_size 1) and crop images to (--roi_size 256).
e.g.
- If we want to predict test images from UKBB, we can simply ran
`python predict.py --sequence LVSA \
--model_path './checkpoints/Unet_LVSA_trained_from_UKBB.pkl' \
--root_dir '/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/test' \
--image_format 'sa_{}.nii.gz' \
--save_folder_path "./result/predict/Unet_LVSA_trained_from_UKBB/UKBB_test" \
--save_name_format 'pred_{}.nii.gz' \
--roi_size 256 \
--batch_size 1 \
--z_score
`
- Predictions for different frames (e.g. ED,ES) from a list of patients {pid} will be saved under the project dir `./result/predict/Unet_LVSA_trained_from_UKBB/UKBB_test/LVSA/{pid}/pred_{frame}.nii.gz`
## Customize your need
- python predict.py --sequence {sequence_name} --root_dir {root_dir} --image_format {image_format_name} --save_folder_path {dir to save the result} --save_name_format {name of nifti file}
- In this demo, our test image paths are structured like:
......
......@@ -6,13 +6,13 @@
"validate_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/validation",
"image_format_name": "sa_{frame}.nii.gz",
"label_format_name": "label_sa_{frame}.nii.gz",
"data_aug_policy" :"UKBB_advancedv2",
"data_aug_policy" :"UKBB_advancedv3",
"if_resample": true,
"new_spacing": [1.25, 1.25, 10],
"keep_z_spacing": true,
"image_size":[224,224,1],
"label_size":[224,224],
"pad_size": [192,192,1],
"image_size":[256,256,1],
"label_size":[256,256],
"pad_size": [256,256,1],
"crop_size" :[192,192,1],
"num_classes": 4,
"use_cache": false,
......@@ -22,7 +22,7 @@
},
"segmentation_model": {
"network_type": "UNet_64",
"network_type": "IN_UNet_64",
"num_classes": 4,
"resume_path":"./checkpoints/Unet_LVSA_trained_from_UKBB.pkl",
"lr": 0.00001,
......@@ -31,7 +31,7 @@
"max_iteration": 500000,
"batch_size": 20,
"use_gpu": true,
"decoder_dropout": 0.2
"decoder_dropout": 0.1
},
"adversarial_augmentation":
......
......@@ -22,11 +22,11 @@
},
"segmentation_model": {
"network_type": "UNet_64",
"network_type": "IN_UNet_64",
"num_classes": 4,
"resume_path":"./checkpoints/Unet_LVSA_trained_from_UKBB.pkl",
"lr": 0.00001,
"optimizer_name": "SGD",
"optimizer_name": "adam",
"n_epochs": 1000,
"max_iteration": 500000,
"batch_size": 20,
......
{ "name": "Composite",
"data": {
"dataset_name":"UKBB" ,
"readable_frames": ["ED", "ES"],
"train_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/train",
"validate_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/validation",
"image_format_name": "sa_{frame}.nii.gz",
"label_format_name": "label_sa_{frame}.nii.gz",
"data_aug_policy" :"UKBB_advanced_z_score",
"if_resample": true,
"new_spacing": [1.25, 1.25, 10],
"keep_z_spacing": true,
"image_size":[256,256,1],
"label_size":[256,256],
"pad_size": [256,256,1],
"crop_size" :[192,192,1],
"num_classes": 4,
"use_cache": false,
"myocardium_only": false,
"ignore_black_slices": true
},
"segmentation_model": {
"network_type": "IN_UNet_64",
"num_classes": 4,
"resume_path":"./checkpoints/Unet_LVSA_trained_from_UKBB.pkl",
"lr": 0.00001,
"optimizer_name": "adam",
"n_epochs": 1000,
"max_iteration": 500000,
"batch_size": 20,
"use_gpu": true,
"decoder_dropout": 0.1
},
"adversarial_augmentation":
{
"transformation_type":"composite"
}
,
"output":
{
"save_epoch_every_num_epochs":1,
"save_dir":"./result/Composite"
}
}
\ No newline at end of file
{ "name": "Composite",
"data": {
"dataset_name":"UKBB" ,
"readable_frames": ["ED", "ES"],
"train_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/train",
"validate_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/validation",
"image_format_name": "sa_{frame}.nii.gz",
"label_format_name": "label_sa_{frame}.nii.gz",
"data_aug_policy" :"UKBB_advancedv3",
"if_resample": true,
"new_spacing": [1.25, 1.25, 10],
"keep_z_spacing": true,
"image_size":[256,256,1],
"label_size":[256,256],
"pad_size": [256,256,1],
"crop_size" :[192,192,1],
"num_classes": 4,
"use_cache": false,
"myocardium_only": false,
"ignore_black_slices": true
},
"segmentation_model": {
"network_type": "IN_UNet_64",
"num_classes": 4,
"resume_path":"./checkpoints/Unet_LVSA_trained_from_UKBB.pkl",
"lr": 0.00001,
"optimizer_name": "adam",
"n_epochs": 50,
"max_iteration": 500000,
"batch_size": 20,
"use_gpu": true,
"decoder_dropout": 0.1
},
"adversarial_augmentation":
{
"transformation_type":"composite"
}
,
"output":
{
"save_epoch_every_num_epochs":1,
"save_dir":"./result/Composite"
}
}
\ No newline at end of file
......@@ -241,7 +241,7 @@ class MyNormalizeMedicPercentile(object):
def __init__(self,
min_val=0.0,
max_val=1.0,
perc_threshold=(1.0, 98.0),
perc_threshold=(1.0, 99.0),
norm_flag=True):
"""
Normalize a tensor between a min and max value
......@@ -310,7 +310,7 @@ class MyNormalizeMedic(object):
# scale the intensity values to be unit norm
std_val = np.std(_input.numpy().flatten())
if np.abs(std_val)<1e-6:
if np.abs(std_val)<1e-20:
_input = _input
else:_input = _input.div(float(std_val))
......
......@@ -194,14 +194,13 @@ class CardiacUKBBDataset(BaseSegDataset):
label_path = os.path.join(root_dir, *[str(pid), self.label_format_name.format(frame=frame)])
if os.path.exists(img_path) and os.path.exists(label_path):
ndarray = sitk.GetArrayFromImage(sitk.ReadImage(img_path))
if self.ignore_black_slices:
nlabel = sitk.GetArrayFromImage(sitk.ReadImage(label_path))
num_slices = ndarray.shape[0]
for cnt in range(num_slices):
if self.ignore_black_slices:
label=nlabel[cnt,:,:]
if np.abs(np.sum(label)-0)<1e-6:
img_slice_data=ndarray[cnt,:,:]
img_slice_data -= np.mean(img_slice_data)
if np.abs(np.sum(img_slice_data)-0)<1e-6:
## ignore black images
continue
index2img_path_dict[cur_ind] = img_path
index2label_path_dict[cur_ind] = label_path
......@@ -322,7 +321,8 @@ if __name__ == '__main__':
pad_size = (256, 256, 1)
crop_size = (192, 192, 1)
tr = Transformations(data_aug_policy_name='UKBB_advanced_z_score', pad_size=pad_size, crop_size=crop_size).get_transformation()
dataset = CardiacUKBBDataset(debug=True,transform=tr['train'], no_aug_transform=tr['validate'],formalized_label_dict={0: 'BG', 1: 'LV',2:'MYO',3:'RV'})
dataset = CardiacUKBBDataset(debug=True,transform=tr['train'], if_resample=True,
no_aug_transform=tr['validate'],formalized_label_dict={0: 'BG', 1: 'LV',2:'MYO',3:'RV'})
train_loader = DataLoader(dataset=dataset, num_workers=0, batch_size=1, shuffle=True, drop_last=True)
for i, item in enumerate(train_loader):
......
......@@ -131,7 +131,7 @@ class CARDIAC_Predict_DATASET(data.Dataset):
def get_name(self):
print('dataset loader')
def transform2tensor(self,cPader, img_slice):
def transform2tensor(self,cPader, img_slice,eps=1e-20):
'''
transform npy data to torch tensor
......@@ -146,19 +146,35 @@ class CARDIAC_Predict_DATASET(data.Dataset):
## normalize data
new_img_slice = new_img_slice * 1.0 ##N*H*W
new_input_mean = np.mean(new_img_slice, axis=(1, 2), keepdims=True)
if self.if_z_score:
print ('z score')
new_img_slice -= new_input_mean
new_std = np.std(new_img_slice, axis=(1, 2), keepdims=True)
if new_img_slice.shape[0]>1:
new_std[abs(new_std-0.)<1e-6]=1
new_std[abs(new_std-0.)<eps]=1
else:
if abs(new_std-0)<1e-3: new_std=1
new_img_slice /= (new_std)
if abs(new_std)<eps: new_std=1
new_img_slice /= new_std
else:
print ('0-1 rescaling')
min_val = np.min(new_img_slice,axis=(1, 2), keepdims=True)
max_val = np.max(new_img_slice,axis=(1, 2), keepdims=True)
new_img_slice =(new_img_slice-min_val)/(max_val-min_val+1e-10)
if new_img_slice.shape[0]>1:
min_val, max_val = np.percentile(new_img_slice, (1,99))
new_img_slice[new_img_slice>max_val]=max_val
new_img_slice[new_img_slice<min_val]=min_val
new_img_slice =(new_img_slice-min_val)/(max_val-min_val+eps)
else:
for i in range(new_img_slice.shape[0]):
a_slice = new_img_slice[i]
min_val, max_val = np.percentile(a_slice, (1,99))
a_slice[a_slice>max_val]=max_val
a_slice[a_slice<min_val]=min_val
a_slice =(a_slice-min_val)/(max_val-min_val+eps)
new_img_slice[i] = a_slice
new_img_slice = new_img_slice[:, np.newaxis, :, :]
......
......@@ -115,16 +115,16 @@ class Transformations:
MyRandomFlip(h=config['flip_flag'][0], v=config['flip_flag'][1],
p=config['flip_flag'][2]),
## intensity transform
RandomBrightnessFluctuation(p=config['intensity_prob'],flag=[True, False]),
MyElasticTransform(is_labelmap=[False, True], p_thresh=config['elastic_prob']),
ts.RandomAffine(rotation_range=config['rotate_val'],
translation_range=config['shift_val'],
shear_range=config['shear_val'],
zoom_range=config['scale_val'], interp=('bilinear', 'nearest')),
ts.RandomCrop(size=self.crop_size),
MyNormalizeMedicPercentile(norm_flag=(True, False),min_val=0,max_val=1,perc_threshold=(0,100)),
RandomBrightnessFluctuation(p=config['intensity_prob'],flag=[True, False]),
MyNormalizeMedicPercentile(norm_flag=(True, False),min_val=0,max_val=1,perc_threshold=(1,99)),
ts.TypeCast(['float', 'long'])
])
......@@ -133,7 +133,7 @@ class Transformations:
ts.ChannelsFirst(),
ts.TypeCast(['float', 'float']),
MySpecialCrop(size=self.crop_size, crop_type=0),
MyNormalizeMedicPercentile(norm_flag=(True, False),min_val=0,max_val=1,perc_threshold=(0,100)),
MyNormalizeMedicPercentile(norm_flag=(True, False),min_val=0,max_val=1,perc_threshold=(1,99)),
ts.TypeCast(['float', 'long'])
])
aug_valid_transform = ts.Compose([ts.PadNumpy(size=self.pad_size),
......@@ -153,7 +153,7 @@ class Transformations:
zoom_range=config['scale_val'], interp=('bilinear', 'nearest')),
MySpecialCrop(size=self.crop_size,crop_type=0),
MyNormalizeMedicPercentile(norm_flag=(True, False),min_val=0,max_val=1,perc_threshold=(0,100)),
MyNormalizeMedicPercentile(norm_flag=(True, False),min_val=0,max_val=1,perc_threshold=(1,99)),
ts.TypeCast(['float', 'long'])
])
......
## olds model
## UKBB
roi_size=256
method_name='baseline_256'
model_path ='/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/checkpoints/Unet_LVSA_trained_from_UKBB.pkl'
python predict.py --sequence LVSA \
--model_path $model_path \
--root_dir '/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/test' \
--image_format 'sa_{}.nii.gz' \
--roi_size $roi_size \
--batch_size 256 \
--save_folder_path "/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/predict/${method_name}/UKBB_test" \
--save_name_format 'pred_{}.nii.gz' --z_score
python predict.py --sequence LVSA \
--model_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/checkpoints/Unet_LVSA_trained_from_UKBB.pkl' \
--root_dir '/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/test' \
--image_format 'sa_{}.nii.gz' \
--roi_size 256 \
--batch_size 1 \
--save_folder_path "/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/predict/baseline_192_IN/UKBB_test" \
--save_name_format 'pred_{}.nii.gz' --z_score
# acdc
python predict.py --sequence LVSA --model_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/checkpoints/Unet_LVSA_trained_from_UKBB.pkl' \
--root_dir '/vol/biomedic3/cc215/data/ACDC/bias_corrected_and_normalized/patient_wise/' \
--image_format '{}_img.nrrd' \
--roi_size 192 \
--batch_size -1 \
--z_score \
--save_folder_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/predict/baseline_BN/ACDC_all' \
--save_name_format 'pred_{}.nrrd'
## mm
python predict.py --sequence LVSA \
--model_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/checkpoints/Unet_LVSA_trained_from_UKBB.pkl' \
--root_dir '/vol/biomedic3/cc215/data/cardiac_MMSeg_challenge/Training-corrected/Labeled' \
--batch_size -1 --z_score \
--image_format 'sa_{}.nrrd' \
--roi_size 192 \
--save_folder_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/predict/baseline_192/MM' \
--save_name_format 'pred_{}.nrrd'
## b
## baseline model
## UKBB
python predict.py --sequence LVSA --model_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/baseline_SGD/best/checkpoints/UNet_64$SAX$_Segmentation.pth' --root_dir '/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/test' \
--image_format 'sa_{}.nii.gz' --roi_size 192 --batch_size -1 --save_folder_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/predict/baseline_SGD/UKBB_test' --save_name_format 'pred_{}.nii.gz' --gpu 1
--image_format 'sa_{}.nii.gz' --roi_size 256 --batch_size -1 --save_folder_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/predict/baseline_SGD/UKBB_test' --save_name_format 'pred_{}.nii.gz' --gpu 1
# acdc
python predict.py --sequence LVSA --model_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/baseline_SGD/best/checkpoints/UNet_64$SAX$_Segmentation.pth' --root_dir '/vol/biomedic3/cc215/data/ACDC/bias_corrected_and_normalized/patient_wise/' \
......@@ -40,9 +90,123 @@ python predict.py --sequence LVSA --model_path '/vol/bitbucket/cc215/Projects/Ca
## mm
python predict.py --sequence LVSA --model_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/composite_train_SGD/best/checkpoints/UNet_64$SAX$_Segmentation.pth' \
--root_dir '/vol/biomedic3/cc215/data/cardiac_MMSeg_challenge/Training-corrected/Labeled' --batch_size -1 \
--image_format 'sa_{}.nrrd' --roi_size 192 --save_folder_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/predict/composite_train_SGD/MM' --save_name_format 'pred_{}.nrrd'
--image_format 'sa_{}.nrrd' --roi_size 192 \
--save_folder_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/predict/composite_train_SGD/MM' --save_name_format 'pred_{}.nrrd'
## baseline_Adam_z_score
python predict.py --sequence LVSA --model_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/baseline_Adam_z_score/best/checkpoints/UNet_64$SAX$_Segmentation.pth' --root_dir '/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/test' \
--image_format 'sa_{}.nii.gz' --roi_size 256 --batch_size -1 --z_score --save_folder_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/predict/baseline_Adam_z_score_256/UKBB_test' --save_name_format 'pred_{}.nii.gz' --gpu 1 --z_score
python predict.py --sequence LVSA --model_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/baseline_Adam_z_score/best/checkpoints/UNet_64$SAX$_Segmentation.pth' --root_dir '/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/test' \
--image_format 'sa_{}.nii.gz' --roi_size 192 --batch_size -1 --z_score --save_folder_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/predict/baseline_Adam_z_score/UKBB_test' --save_name_format 'pred_{}.nii.gz' --gpu 1
--image_format 'sa_{}.nii.gz' --roi_size -1 --batch_size -1 --z_score --save_folder_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/predict/baseline_Adam_z_score_192/UKBB_test' --save_name_format 'pred_{}.nii.gz' --gpu 1 --z_score
## baseline with IN_UNet_64
##UKBB
python predict.py --sequence LVSA --model_arch 'IN_UNet_64' \
--model_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/baseline_Adam_z_score_IN_UNet_64/best/checkpoints/IN_UNet_64$SAX$_Segmentation.pth' \
--root_dir '/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/test' \
--image_format 'sa_{}.nrrd' --roi_size 192 --batch_size 1 --z_score --save_folder_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/predict/baseline_Adam_z_score_IN_UNet_64/UKBB_test' --save_name_format 'pred_{}.nii.gz' --gpu 0
## ACDC
python predict.py --sequence LVSA --model_path './result/baseline_Adam_z_score_IN_UNet_64/best/checkpoints/IN_UNet_64$SAX$_Segmentation.pth' \
--root_dir '/vol/biomedic3/cc215/data/ACDC/bias_corrected_and_normalized/patient_wise/' \
--image_format '{}_img.nrrd' \
--roi_size 192 --batch_size 1 --z_score \
--save_folder_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/predict/baseline_Adam_z_score_IN_UNet_64/ACDC_all' \
--save_name_format 'pred_{}.nrrd'
## MM
python predict.py --sequence LVSA --model_path './result/baseline_Adam_z_score_IN_UNet_64/best/checkpoints/IN_UNet_64$SAX$_Segmentation.pth' \
--root_dir '/vol/biomedic3/cc215/data/cardiac_MMSeg_challenge/Training-corrected/Labeled' \
--image_format 'sa_{}.nrrd' \
--roi_size -1 --batch_size 1 --z_score \
--save_folder_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/predict/baseline_Adam_z_score_IN_UNet_64/MM' \
--save_name_format 'pred_{}.nrrd'
## adv compose with IN_UNet_64 (no z score)
python predict.py --sequence LVSA --model_arch 'IN_UNet_64' \
--model_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/composite_train_Adam_finetune_aug_v3_IN_UNet_64/best/checkpoints/IN_UNet_64$SAX$_Segmentation.pth' \
--root_dir '/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/test' \
--image_format 'sa_{}.nii.gz' --roi_size 192 --batch_size 1 \
--save_folder_path '/vol/bitbucket/cc215/Projcts/Cardiac_Multi_View_Segmentation/result/predict/composite_train_Adam_finetune_aug_v3_IN_UNet_64/UKBB_test' \
--save_name_format 'pred_{}.nii.gz' --gpu 0
## ACDC
python predict.py --sequence LVSA --model_arch 'IN_UNet_64' \
--model_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/composite_train_Adam_finetune_aug_v3_IN_UNet_64/best/checkpoints/IN_UNet_64$SAX$_Segmentation.pth' \
--root_dir '/vol/biomedic3/cc215/data/ACDC/bias_corrected_and_normalized/patient_wise/' \
--image_format '{}_img.nrrd' \
--roi_size 192 --batch_size 1 \
--save_folder_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/predict/composite_train_Adam_finetune_aug_v3_IN_UNet_64/ACDC_all' \
--save_name_format 'pred_{}.nrrd'
## MM
python predict.py --sequence LVSA --model_arch 'IN_UNet_64' \
--model_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/composite_train_Adam_finetune_aug_v3_IN_UNet_64/best/checkpoints/IN_UNet_64$SAX$_Segmentation.pth' \
--root_dir '/vol/biomedic3/cc215/data/cardiac_MMSeg_challenge/Training-corrected/Labeled' \
--image_format 'sa_{}.nrrd' \
--roi_size 192 --batch_size 1 \
--save_folder_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/predict/composite_train_Adam_finetune_aug_v3_IN_UNet_64/MM' \
--save_name_format 'pred_{}.nrrd'
python predict.py --sequence LVSA --model_arch 'IN_UNet_64' \
--model_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/baseline_Adam_v3_IN_UNet_64/best/checkpoints/IN_UNet_64$SAX$_Segmentation.pth' \
--root_dir '/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/test' \
--image_format 'sa_{}.nii.gz' --roi_size 192 --batch_size 1 \
--save_folder_path '/vol/bitbucket/cc215/Projcts/Cardiac_Multi_View_Segmentation/result/predict/baseline_Adam_v3_IN_UNet_64/UKBB_test' \
--save_name_format 'pred_{}.nii.gz' --gpu 1
## ACDC
python predict.py --sequence LVSA --model_arch 'IN_UNet_64' \
--model_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/baseline_Adam_v3_IN_UNet_64/best/checkpoints/IN_UNet_64$SAX$_Segmentation.pth' \
--root_dir '/vol/biomedic3/cc215/data/ACDC/bias_corrected_and_normalized/patient_wise/' \
--image_format '{}_img.nrrd' \
--roi_size 192 --batch_size 1 \
--save_folder_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/predict/baseline_Adam_v3_IN_UNet_64/ACDC_all' \
--save_name_format 'pred_{}.nrrd'
## MM
python predict.py --sequence LVSA --model_arch 'IN_UNet_64' \
--model_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/baseline_Adam_v3_IN_UNet_64/best/checkpoints/IN_UNet_64$SAX$_Segmentation.pth' \
--root_dir '/vol/biomedic3/cc215/data/cardiac_MMSeg_challenge/Training-corrected/Labeled' \
--image_format 'sa_{}.nrrd' \
--roi_size 192 --batch_size 1 \
--save_folder_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/predict/baseline_Adam_v3_IN_UNet_64/MM' \
--save_name_format 'pred_{}.nrrd'
## bias train SGD
python predict.py --sequence LVSA --model_arch 'UNet_64' \
--model_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/bias_train_SGD/best/checkpoints/UNet_64$SAX$_Segmentation.pth' \
--root_dir '/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/test' \
--image_format 'sa_{}.nii.gz' --roi_size 192 --batch_size 1 \
--save_folder_path '/vol/bitbucket/cc215/Projcts/Cardiac_Multi_View_Segmentation/result/predict/bias_train_SGD/UKBB_test' \
--save_name_format 'pred_{}.nii.gz' --gpu 1
python predict.py --sequence LVSA --model_arch 'UNet_64' \
--model_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/bias_train_SGD/best/checkpoints/UNet_64$SAX$_Segmentation.pth' \
--root_dir '/vol/biomedic3/cc215/data/cardiac_MMSeg_challenge/Training-corrected/Labeled' \
--image_format 'sa_{}.nrrd' \
--roi_size 192 --batch_size 1 \
--save_folder_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/predict/bias_train_SGD/MM' \
--save_name_format 'pred_{}.nrrd'
python predict.py --sequence LVSA --model_arch 'UNet_64' \
--model_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/bias_train_SGD/best/checkpoints/UNet_64$SAX$_Segmentation.pth' \
--root_dir '/vol/biomedic3/cc215/data/ACDC/bias_corrected_and_normalized/patient_wise/' \
--image_format '{}_img.nrrd' \
--roi_size 192 --batch_size 1 \
--save_folder_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/predict/bias_train_SGD/ACDC_all' \
--save_name_format 'pred_{}.nrrd'
## baseline v2
python predict.py --sequence LVSA --model_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/checkpoints/Unet_LVSA_trained_from_UKBB.pkl' --root_dir '/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/test' \
--image_format 'sa_{}.nii.gz' --roi_size 256 --save_folder_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/predict/baseline_v2/UKBB_test' --save_name_format 'pred_{}.nii.gz' --z_score
python predict.py --sequence LVSA --model_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/checkpoints/Unet_LVSA_best.pkl' --root_dir '/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/test' \
--image_format 'sa_{}.nii.gz' --roi_size 256 --batch_size -1 --save_folder_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/predict/baseline_v2/UKBB_test' --save_name_format 'pred_{}.nii.gz' --z_score
## new model
......@@ -18,5 +18,5 @@ python predict.py --sequence LVSA --model_path '/vol/bitbucket/cc215/Projects/Ca
## bias_train_SGD
python predict.py --sequence LVSA --model_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/bias_train_SGD/best/checkpoints/UNet_64$SAX$_Segmentation.pth' --root_dir '/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/test' \
--image_format 'sa_{}.nii.gz' --roi_size 192 --batch_size -1 --save_folder_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/predict/bias_train_SGD/UKBB_test' --save_name_format 'pred_{}.nii.gz' --gpu 1
--image_format 'sa_{}.nii.gz' --roi_size 192 --batch_size 1 --save_folder_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/predict/bias_train_SGD/UKBB_test' --save_name_format 'pred_{}.nii.gz' --gpu 1
......@@ -17,3 +17,10 @@ python train.py --json_config_path 'configs/composite_independent_train_SGD.json
python train.py --json_config_path 'configs/baseline_SGD_z_score.json' --log --gpu 1
python train.py --json_config_path 'configs/baseline_Adam_z_score.json' --log --gpu 0
## lr =0.00001, finetuning z_score
python train.py --json_config_path 'configs/baseline_Adam_z_score_IN_UNet_64.json' --log --gpu 0
## lr =0.00001, finetuning min max with 1-99 percentile
python train.py --json_config_path 'configs/composite_train_Adam_finetune_aug_v3_IN_UNet_64.json' --log --gpu 1 --adv_training
## lr =0.00001, finetuning z_score
python train.py --json_config_path 'configs/baseline_Adam_v3_IN_UNet_64.json' --log --gpu 0
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
......@@ -9,7 +9,7 @@ import torch.optim as optim
import gc
from model.init_weight import init_weights
from model.unet import UNet
from model.unet import UNet,INUnet
from model.model_utils import makeVariable
from common_utils.loss import cross_entropy_2D
from common_utils.metrics import runningScore
......@@ -43,7 +43,6 @@ class SegmentationModel(nn.Module):
self.decoder_dropout = decoder_dropout if isinstance(decoder_dropout,float) else None
self.model = self.get_network_from_model_library(self.network_type)
assert not self.model is None, 'cannot find the model given the specified name'
## print number of paramters
self.resume_path=resume_path
self.init_model(network_type)
......@@ -61,13 +60,20 @@ class SegmentationModel(nn.Module):
def get_network_from_model_library(self, network_type):
model = None
model_candidates = ['UNet_64','IN_UNet_64']
assert network_type in model_candidates, 'currently, we only support network types: {}, but found {}'.format(str(model_candidates),network_type)
if network_type =='UNet_64':
model = UNet(input_channel=self.in_channels, num_classes=self.num_classes, feature_scale=1,
norm=nn.BatchNorm2d,
dropout=self.decoder_dropout)
print ('init {}'.format(network_type))
elif network_type =='IN_UNet_64':
model = INUnet(input_channel=self.in_channels, num_classes=self.num_classes, feature_scale=1,
norm=nn.BatchNorm2d,
dropout=self.decoder_dropout)
print ('init {}'.format(network_type))
else:
print ('currently, we only support network types: [UNet_64], but found {}'.format(network_type))
raise NotImplementedError
return model
......@@ -76,15 +82,23 @@ class SegmentationModel(nn.Module):
resume_path= self.resume_path
init_weights(self.model, init_type='kaiming')
print('init ', network_type)
# print('init ', network_type)
if not resume_path is None:
if not resume_path == '':
assert os.path.exists(resume_path), 'path: {} must exist'.format(resume_path)
if '.pkl' in resume_path:
self.model.load_state_dict(torch.load(resume_path)['model_state'], strict=True)
try:
self.model.load_state_dict(torch.load(resume_path)['model_state'], strict=True)
except:
print ('fail to load, loose the constraint')
self.model.load_state_dict(torch.load(resume_path)['model_state'], strict=False)
print ('load params from ',resume_path)
elif '.pth' in resume_path:
self.model.load_state_dict(torch.load(resume_path), strict=True)
try:
self.model.load_state_dict(torch.load(resume_path), strict=True)