Commit 36f6a2d5 authored by cc215's avatar cc215 💬
Browse files

update adv chain information

parent 84efc713
......@@ -3,158 +3,160 @@
author: Chen Chen (cc215@ic.ac.uk)
## Features
- LV/MYO/RV segmentation on short axis view (LVSA):
- LV: left ventricle cavity, MYO: myocardium of left ventricle, RV: right ventricle
- LV: left ventricle cavity, MYO: myocardium of left ventricle, RV: right ventricle
- MYO segmentation on the cardiac long axis views
- 4CH: 4 chamber view
- VLA: Vertical Long Axis
- LVOT: Left Ventricular Outflow Tract
- 4CH: 4 chamber view
- VLA: Vertical Long Axis
- LVOT: Left Ventricular Outflow Tract
## Environment
- Python 3.5
- Pytorch 1.6 (*please upgrade your pytorch to the latest version, otherwise it may raise error when loading weights from the saved checkpoints.*)
- Pytorch 1.6 (_please upgrade your pytorch to the latest version, otherwise it may raise error when loading weights from the saved checkpoints._)
- CUDA(cuda 10.0)
## Dependencies
- see requirements.txt
- install them by running `pip install -r requirements.txt`
- see requirements.txt
- install them by running `pip install -r requirements.txt`
## Test segmentation
- For a quick start, we provide test samples under `test_data` subdir
- LVSA segmentation
- `python predict.py --sequence LVSA`
- `python predict.py --sequence LVSA`
- VLA segmentation
- `python predict.py --sequence VLA`
- `python predict.py --sequence VLA`
- 4CH segmentation
- `python predict.py --sequence 4CH`
- `python predict.py --sequence 4CH`
- LVOT segmentation
- `python predict.py --sequence LVOT`
- `python predict.py --sequence LVOT`
Results will be saved under `test_results` by default.
## Model performance across intra-domain and public out-of-domain data
- For cardiac segmentation on short-axis images, we provide the model trained on UKBB (~4000 training subjects), with extensive data augmentations. This model achieves extradinary performance across various unseen challenge datasets, [ACDC](https://www.creatis.insa-lyon.fr/Challenge/acdc/databases.html), [M&Ms](https://www.ub.edu/mnms/). We report the segmentation performance on these datasets in terms of Dice score.
| | UKBB test (600 subjects, 1200 frames: ED+ES) | | | ACDC (100 subjects, 200 frames: ED+ES) | | | M&Ms (150 subjects, 300 frames: ED+ES) | | |
|- |:-: |:-: |:-: |:-: |:-: |:-: |:-: |:-: |:-: |
| configurations | LV | MYO | RV | LV | MYO | RV | LV | MYO | RV |
| batch_size = 1, roi size = 256, z_score | 0.9383 | 0.8780 | 0.8979 | 0.8940 | 0.8034 | 0.8237 | 0.8862 | 0.7889 | 0.8168 |
| batch_size = 1, roi size = 192, z_score | 0.9371 | 0.8775 | 0.8962 | 0.8891 | 0.7981 | 0.8103 | 0.8725 | 0.7716 | 0.7954 |
| batch_size = -1 (use the whole volume as batch), roi size = 192, z_score | 0.9139 | 0.8388 | 0.8779 | 0.8790 | 0.7818 | 0.8069 | 0.8679 | 0.7675 | 0.7926 |
- For cardiac segmentation on short-axis images, we provide the model trained on UKBB (~4000 training subjects), with extensive data augmentations. This model achieves extradinary performance across various unseen challenge datasets, [ACDC](https://www.creatis.insa-lyon.fr/Challenge/acdc/databases.html), [M&Ms](https://www.ub.edu/mnms/). We report the segmentation performance on these datasets in terms of Dice score.
| | UKBB test (600 subjects, 1200 frames: ED+ES) | | | ACDC (100 subjects, 200 frames: ED+ES) | | | M&Ms (150 subjects, 300 frames: ED+ES) | | |
| ------------------------------------------------------------------------ | :------------------------------------------: | :----: | :----: | :------------------------------------: | :----: | :----: | :------------------------------------: | :----: | :----: |
| configurations | LV | MYO | RV | LV | MYO | RV | LV | MYO | RV |
| batch_size = 1, roi size = 256, z_score | 0.9383 | 0.8780 | 0.8979 | 0.8940 | 0.8034 | 0.8237 | 0.8862 | 0.7889 | 0.8168 |
| batch_size = 1, roi size = 192, z_score | 0.9371 | 0.8775 | 0.8962 | 0.8891 | 0.7981 | 0.8103 | 0.8725 | 0.7716 | 0.7954 |
| batch_size = -1 (use the whole volume as batch), roi size = 192, z_score | 0.9139 | 0.8388 | 0.8779 | 0.8790 | 0.7818 | 0.8069 | 0.8679 | 0.7675 | 0.7926 |
The model params are stored in `./checkpoints/Unet_LVSA_trained_from_UKBB.pkl`
For optimal performance at deployment, please set it to instance normalization (--batch_size 1) and crop images to (--roi_size 256).
e.g.
For optimal performance at deployment, please set it to instance normalization (--batch_size 1) and crop images to (--roi_size 256).
e.g.
- If we want to predict test images from UKBB, we can simply run the following command:
`python predict.py --sequence LVSA \
--model_path './checkpoints/Unet_LVSA_trained_from_UKBB.pkl' \
--root_dir '/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/test' \
--image_format 'sa_{}.nii.gz' \
--save_folder_path "./result/predict/Unet_LVSA_trained_from_UKBB/UKBB_test" \
--save_name_format 'pred_{}.nii.gz' \
--roi_size 256 \
--batch_size 1 \
--z_score
`
`python predict.py --sequence LVSA \ --model_path './checkpoints/Unet_LVSA_trained_from_UKBB.pkl' \ --root_dir '/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/test' \ --image_format 'sa_{}.nii.gz' \ --save_folder_path "./result/predict/Unet_LVSA_trained_from_UKBB/UKBB_test" \ --save_name_format 'pred_{}.nii.gz' \ --roi_size 256 \ --batch_size 1 \ --z_score `
- Predictions for different frames (e.g. ED,ES) from a list of patients {pid} will be saved under the project dir `./result/predict/Unet_LVSA_trained_from_UKBB/UKBB_test/LVSA/{pid}/pred_{frame}.nii.gz`
## Customize your need
- python predict.py --sequence {sequence_name} --root_dir {root_dir} --image_format {image_format_name} --save_folder_path {dir to save the result} --save_name_format {name of nifti file}
- In this demo, our test image paths are structured like:
- test_data
- patientX01
- LVSA
- lvsa_img_ED.nii.gz
- lvsa_img_ES.nii.gz
- 4CH
- 4CH_img_ED.nii.gz
- LVOT
- LVOT_img_ED.nii.gz
- VLA
- VLA_img_ED.nii.gz
- test_data
- patientX01
- LVSA
- lvsa_img_ED.nii.gz
- lvsa_img_ES.nii.gz
- 4CH
- 4CH_img_ED.nii.gz
- LVOT
- LVOT_img_ED.nii.gz
- VLA
- VLA_img_ED.nii.gz
- e.g.
- predict LVSA: `python predict.py --sequence LVSA --root_dir 'test_data/' --image_format 'LVSA/LVSA_img_{}.nii.gz' --save_folder_path 'test_results/' --save_name_format 'seg_sa_{}.nii.gz' `
- please read predict.py for avanced settings (batch size, image resampling).
- predict LVSA: `python predict.py --sequence LVSA --root_dir 'test_data/' --image_format 'LVSA/LVSA_img_{}.nii.gz' --save_folder_path 'test_results/' --save_name_format 'seg_sa_{}.nii.gz' `
- please read predict.py for avanced settings (batch size, image resampling).
## Training
## Training
- run `pip install -r requirements.txt`
- pandas==0.22.0
- matplotlib==2.2.2
- nipy==0.4.2
- MedPy==0.3.0
- scipy==1.0.1
- tqdm==4.23.0
- numpy==1.14.2
- SimpleITK==1.1.0
- scikit-image
- tensorboardX==1.4
- pandas==0.22.0
- matplotlib==2.2.2
- nipy==0.4.2
- MedPy==0.3.0
- scipy==1.0.1
- tqdm==4.23.0
- numpy==1.14.2
- SimpleITK==1.1.0
- scikit-image
- tensorboardX==1.4
- and then install an adapted version of torch sample via : `pip install git+https://github.com/ozan-oktay/torchsample/`
- test environment:
- run `python train.py `
- run `python train.py `
- open config file `configs/basic_opt.json`, change dataset configuration:
- "train_dir": training dataset directory
- "validate_dir": validation dataset directory
- "readable_frames": list of cardiac frames to be trained. e.g. ["ED","ES"]
- "image_format_name": the file name of image data, e.g. "sa_{frame}.nii.gz" for loading sa_ED.nii.gz and sa_ES.nii.gz
- "label_format_name": the file name of label data, e.g. "label_sa_{frame}.nii.gz" for loading label_sa_ED.nii.gz and label_sa_ES.nii.gz
- "train_dir": training dataset directory
- "validate_dir": validation dataset directory
- "readable_frames": list of cardiac frames to be trained. e.g. ["ED","ES"]
- "image*format_name": the file name of image data, e.g. "sa*{frame}.nii.gz" for loading sa_ED.nii.gz and sa_ES.nii.gz
- "label*format_name": the file name of label data, e.g. "label_sa*{frame}.nii.gz" for loading label_sa_ED.nii.gz and label_sa_ES.nii.gz
- run `python train.py --json_config_path {config_file_path}`
- e.g. `python train.py --json_config_path configs/basic_opt.json`
## Finetuning
- open config file (`configs/basic_opt.json`), change model resume path and adjust learning rate to be 0.0001 or 0.00001:
- "resume_path":"./checkpoints/Unet_LVSA_trained_from_UKBB.pkl" ## model trained on UKBB dataset
- "resume_path":"./checkpoints/Unet_LVSA_trained_from_UKBB.pkl" ## model trained on UKBB dataset
- "lr": 0.0001
## Output
- By default, all models and internal outputs will be stored under `result`
- The best model can be found under this dir, e.g. 'result/best/checkpoints/UNet_64$SAX$_Segmentation.pth'
- By default, all models and internal outputs will be stored under `result`
- The best model can be found under this dir, e.g. 'result/best/checkpoints/UNet_64$SAX$\_Segmentation.pth'
## Advanced
- you can change data augmentation strategy by changing the name of "data_aug_policy" in the config file.
- For details about the data augmentation strategy, please refer to :'dataset_loader/mytransform.py'
## Model update (2021.3):
- A model trained on UKBB data (SAX slices) with adversarial data augmentation (adversarial noise, adversarial bias field, adversarial morphological deformation, and adversarial affine transformation) is available.
- This model is expected with improved robustness especially for right ventricle segmentation on cross-domain data. See below test results on intra domain test set (UKBB) and *unseen* cross domain sets ACDC and M\&Ms.
| Testing config <br> (batch_size=1, roi =256) | UKBB test (600) | | | ACDC (100) | | | M\&Ms (150) | | |
|:-: |:-: |:-: |:-: |:-: |:-: |:-: |:-: |:-: |:-: |
| model: UNet_64 | LV | MYO | RV | LV | MYO | RV | LV | MYO | RV |
| baseline | 0.9383 | 0.8780 | 0.8979 | 0.8940 | 0.8034 | 0.8237 | 0.8862 | 0.7889 | 0.8168 |
| Finetune w. random DA | 0.9378 | 0.8768 | 0.8975 | 0.8884 | 0.7981 | 0.8295 | 0.8846 | 0.7893 | 0.8158 |
| Finetune w. random DA + adv Bias:<br>UNet_LVSA_Adv_bias_(epochs=20).pth | 0.9326 | 0.8722 | 0.8973 | 0.8809 | 0.7912 | 0.8395 | 0.8794 | 0.7812 | 0.8228 |
| Finetune w. random DA + adv Composite DA:<br><br>UNet_LVSA_Adv_Compose_(epochs=20).pth | 0.9360 | 0.8726 | 0.8966 | 0.8984 | 0.7973 | 0.8440 | 0.8873 | 0.7859 | 0.8343 |
- A model trained on UKBB data (SAX slices) with adversarial data augmentation (adversarial noise, adversarial bias field, adversarial morphological deformation, and adversarial affine transformation) is available.
- This model is expected with improved robustness especially for right ventricle segmentation on cross-domain data. See below test results on intra domain test set (UKBB) and _unseen_ cross domain sets ACDC and M\&Ms.
| Testing config <br> (batch_size=1, roi =256) | UKBB test (600) | | | ACDC (100) | | | M\&Ms (150) | | |
| :------------------------------------------: | :------------------------------------------------------------------------------------: | :----: | :----: | :--------: | :----: | :----: | :---------: | :----: | :----: | ------ | --- | --- |
| model: UNet_64 | LV | MYO | RV | LV | MYO | RV | LV | MYO | RV |
| baseline | 0.9383 | 0.8780 | 0.8979 | 0.8940 | 0.8034 | 0.8237 | 0.8862 | 0.7889 | 0.8168 |
| <!-- | Finetune w. random DA | 0.9378 | 0.8768 | 0.8975 | 0.8884 | 0.7981 | 0.8295 | 0.8846 | 0.7893 | 0.8158 |
| <!-- | Finetune w. random DA + adv Bias:<br>UNet*LVSA_Adv_bias*(epochs=20).pth | 0.9326 | 0.8722 | 0.8973 | 0.8809 | 0.7912 | 0.8395 | 0.8794 | 0.7812 | 0.8228 | --> |
| <!-- | Finetune w. random DA + adv Composite DA:<br><br>UNet*LVSA_Adv_Compose*(epochs=20).pth | 0.9360 | 0.8726 | 0.8966 | 0.8984 | 0.7973 | 0.8440 | 0.8873 | 0.7859 | 0.8343 | --> | --> |
| Finetune w. random DA + adv chain:<br><br> (checkpoints/LVSA/Unet_adv_chain.pth) | 0.9360 | 0.8732 | 0.8965 | 0.9060 | 0.8087 | 0.8404 | 0.8929 | 0.7987 | 0.8245 |
- To deploy the model for segmentation, please run the following command to test first:
- run `source ./demo_scripts/predict_test.sh`
- run `source ./demo_scripts/predict_test.sh`
- this script will perform the following steps:
- 1. load images from disk 'test_data/' and load model from `.checkpoints/UNet_LVSA_Adv_Compose_(epochs=20).pth`
- 2. perform image resampling to have a uniform pixel spacing 1.25 x 1.25 mm
- 3. central crop images to the size of 256x256
- 4. standard intensity normalizaton so that intensity distribution of each slice has zero mean, 1 std.
- 5. predict the segmentation map
- 6. recover the image size and resample the prediction back to its original image space.
- 7. save the predicted segmentation maps for `test_data/patient_id/LVSA/LVSA_img_{}.nii.gz` to `test_results/LVSA/patient_id/Adv_Compose_pred_ED.nii.gz`
- we also provide a script to predict a single image each time.
- before use, please run the following command to test first:
run `source ./demo_scripts/predict_single.sh`
- then you can modify the command to process your own data (XXX.nii.gz) and a segmentation mask will be saved at 'YYY.nii.gz',
- run `python predict_single_LVSA.py -m '.checkpoints/UNet_LVSA_Adv_Compose_(epochs=20).pth' -i 'XXX.nii.gz' -o 'YYY.nii.gz' -c 256 -g 0 -b 1`
* notes:
- m: model path
- i: input image path
- o: output path for prediction
- c: The size for cropping image to save memory, you can change it to any size as long as it can be divided by 16, and the targeted structures are still within the image region after cropping. When set to -1, it will crop each image to its largest rectangle, where height and width are 16x. Default: 256.
- g: int, gpu id
- z: boolean, If it is set to true, min-max intensity normalization will be used to prepocess images which maps intensity to 0-1 range. By default, this is deactivated. We found std normalization yields better cross-domain segmentation performance compared to min-max rescaling.
- b: int, batch size (>=1). For optimal performance, we found that performing segmentation with instance normalization (b=1) is more robust compared to the one with batch normalization (b>1)> However, it will slow down the inference speed due to the slice-by-slice prediction scheme.
- 1. load images from disk 'test*data/' and load model from `.checkpoints/UNet_LVSA_Adv_Compose*(epochs=20).pth`
- 2. perform image resampling to have a uniform pixel spacing 1.25 x 1.25 mm
- 3. central crop images to the size of 256x256
- 4. standard intensity normalizaton so that intensity distribution of each slice has zero mean, 1 std.
- 5. predict the segmentation map
- 6. recover the image size and resample the prediction back to its original image space.
- 7. save the predicted segmentation maps for `test_data/patient_id/LVSA/LVSA_img_{}.nii.gz` to `test_results/LVSA/patient_id/Adv_Compose_pred_ED.nii.gz`
- we also provide a script to predict a single image each time.
- before use, please run the following command to test first:
run `source ./demo_scripts/predict_single.sh`
- then you can modify the command to process your own data (XXX.nii.gz) and a segmentation mask will be saved at 'YYY.nii.gz',
- run `python predict_single_LVSA.py -m '.checkpoints/UNet_LVSA_Adv_Compose_(epochs=20).pth' -i 'XXX.nii.gz' -o 'YYY.nii.gz' -c 256 -g 0 -b 1`
* notes:
- m: model path
- i: input image path
- o: output path for prediction
- c: The size for cropping image to save memory, you can change it to any size as long as it can be divided by 16, and the targeted structures are still within the image region after cropping. When set to -1, it will crop each image to its largest rectangle, where height and width are 16x. Default: 256.
- g: int, gpu id
- z: boolean, If it is set to true, min-max intensity normalization will be used to prepocess images which maps intensity to 0-1 range. By default, this is deactivated. We found std normalization yields better cross-domain segmentation performance compared to min-max rescaling.
- b: int, batch size (>=1). For optimal performance, we found that performing segmentation with instance normalization (b=1) is more robust compared to the one with batch normalization (b>1)> However, it will slow down the inference speed due to the slice-by-slice prediction scheme.
{ "name": "Composite",
"data": {
"dataset_name":"UKBB" ,
"readable_frames": ["ED", "ES"],
"train_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/train",
"validate_dir": "/vol/medic02/users/wbai/data/cardiac_atlas/UKBB_2964/sa/validation",
"image_format_name": "sa_{frame}.nii.gz",
"label_format_name": "label_sa_{frame}.nii.gz",
"data_aug_policy" :"UKBB_advancedv4",
"if_resample": true,
"new_spacing": [1.25, 1.25, 10],
"keep_z_spacing": true,
"image_size":[224,224,1],
"label_size":[224,224],
"pad_size": [192,192,1],
"crop_size" :[192,192,1],
"num_classes": 4,
"use_cache": false,
"myocardium_only": false,
"ignore_black_slices": true
},
"segmentation_model": {
"network_type": "UNet_64",
"num_classes": 4,
"resume_path":"./checkpoints/Unet_LVSA_trained_from_UKBB.pkl",
"lr": 0.00001,
"optimizer_name": "adam",
"n_epochs": 50,
"max_iteration": 500000,
"batch_size": 20,
"use_gpu": true,
"decoder_dropout": 0.1
},
"adversarial_augmentation":
{
"policy_name": "advbias",
"transformation_type":"composite",
"divergence_types":["mse","contour"],
"divergence_weights":[1,0.5],
"n_iter":1
}
,
"output":
{
"save_epoch_every_num_epochs":1,
"save_dir":"./result/Composite"
}
}
\ No newline at end of file
import os
import numpy as np
import SimpleITK as sitk
import logging
from torch.utils import data
from dataset_loader.utils import resample_by_spacing
import torch
from common_utils.basic_operations import rescale_intensity
from dataset_loader.utils import resample_by_spacing
class CARDIAC_Predict_DATASET(data.Dataset):
def __init__(self,
......@@ -15,7 +17,7 @@ class CARDIAC_Predict_DATASET(data.Dataset):
if_resample=True,
new_spacing=[1.25, 1.25, 10],
keep_z_spacing=True,
if_z_score =False,
if_z_score=False,
):
'''
......@@ -33,7 +35,6 @@ class CARDIAC_Predict_DATASET(data.Dataset):
self.readable_frames = readable_frames
dataset_dir = root_dir
self.dataset_dir = dataset_dir
p_list, p_path_list = self.get_p_list(self.dataset_dir)
self.patient_list = p_list
......@@ -43,10 +44,10 @@ class CARDIAC_Predict_DATASET(data.Dataset):
self.image_format_name = image_format_name
self.new_spacing = new_spacing
self.if_z_score =if_z_score
self.if_z_score = if_z_score
self.if_resample = if_resample
self.pid = 0
self.not_found = [] ##record all missing data path
self.not_found = [] # record all missing data path
self.keep_z_spacing = keep_z_spacing
def get_p_list(self, dir):
......@@ -79,8 +80,9 @@ class CARDIAC_Predict_DATASET(data.Dataset):
patient_data = {}
if not self.pid == len(self.patient_path_list):
for frame in self.readable_frames:
temp_image_path = os.path.join(self.patient_path_list[self.pid], self.image_format_name.format(frame))
## check path exists
temp_image_path = os.path.join(
self.patient_path_list[self.pid], self.image_format_name.format(frame))
# check path exists
print('try to read {}'.format(temp_image_path))
if not os.path.exists(temp_image_path):
print('not found, ignore it')
......@@ -90,30 +92,31 @@ class CARDIAC_Predict_DATASET(data.Dataset):
### read image ##
temp_image = sitk.ReadImage(temp_image_path)
temp_image = sitk.Cast(sitk.RescaleIntensity(temp_image), sitk.sitkFloat32)
temp_image = sitk.Cast(sitk.RescaleIntensity(
temp_image), sitk.sitkFloat32)
temp_image_arr = sitk.GetArrayFromImage(temp_image)
original_shape=temp_image_arr.shape
##convert to float format
original_shape = temp_image_arr.shape
# convert to float format
origin_spacing = temp_image.GetSpacing()
if self.if_resample:
## image resampling
# image resampling
new_image = resample_by_spacing(im=temp_image, new_spacing=self.new_spacing,
interpolator=sitk.sitkLinear,
keep_z_spacing=self.keep_z_spacing)
else:
new_image = temp_image
## new image shape
# new image shape
data = sitk.GetArrayFromImage(new_image).astype(float)
patient_id = self.patient_path_list[self.pid].split('/')[-1]
self.aft_resample_shape = data.shape
## save frame data
patient_data[frame] = {'image': data, ##npy data
'origin_itk_image': temp_image, ##original data
# save frame data
patient_data[frame] = {'image': data, # npy data
'origin_itk_image': temp_image, # original data
'temp_image_path': temp_image_path,
'new_spacing': self.new_spacing,
'original_shape':original_shape,
'original_shape': original_shape,
'aft_resample_shape': self.aft_resample_shape,
'origin_spacing': origin_spacing,
'after_resampled_image': new_image,
......@@ -131,8 +134,7 @@ class CARDIAC_Predict_DATASET(data.Dataset):
def get_name(self):
print('dataset loader')
def transform2tensor(self,cPader, img_slice,eps=1e-20):
def transform2tensor(self, cPader, img_slice, eps=1e-20):
'''
transform npy data to torch tensor
:param cPader:pad image to be divided by 16
......@@ -143,47 +145,45 @@ class CARDIAC_Predict_DATASET(data.Dataset):
###
new_img_slice = cPader(img_slice)
## normalize data
new_img_slice = new_img_slice * 1.0 ##N*H*W
# normalize data
new_img_slice = new_img_slice * 1.0 # N*H*W
new_input_mean = np.mean(new_img_slice, axis=(1, 2), keepdims=True)
if self.if_z_score:
print ('z score')
logging.info('z score')
new_img_slice -= new_input_mean
new_std = np.std(new_img_slice, axis=(1, 2), keepdims=True)
if new_img_slice.shape[0]>1:
new_std[abs(new_std-0.)<eps]=1
else:
if abs(new_std)<eps: new_std=1
if new_img_slice.shape[0] > 1:
new_std[abs(new_std-0.) < eps] = 1
else:
if abs(new_std) < eps:
new_std = 1
new_img_slice /= new_std
else:
print ('0-1 rescaling')
if new_img_slice.shape[0]>1:
min_val, max_val = np.percentile(new_img_slice, (1,99))
new_img_slice[new_img_slice>max_val]=max_val
new_img_slice[new_img_slice<min_val]=min_val
new_img_slice =(new_img_slice-min_val)/(max_val-min_val+eps)
logging.info('0-1 rescaling')
if new_img_slice.shape[0] > 1:
min_val, max_val = np.percentile(new_img_slice, (1, 99))
new_img_slice[new_img_slice > max_val] = max_val
new_img_slice[new_img_slice < min_val] = min_val
new_img_slice = (new_img_slice-min_val)/(max_val-min_val+eps)
else:
for i in range(new_img_slice.shape[0]):
a_slice = new_img_slice[i]
min_val, max_val = np.percentile(a_slice, (1,99))
a_slice[a_slice>max_val]=max_val
a_slice[a_slice<min_val]=min_val
a_slice =(a_slice-min_val)/(max_val-min_val+eps)
min_val, max_val = np.percentile(a_slice, (1, 99))
a_slice[a_slice > max_val] = max_val
a_slice[a_slice < min_val] = min_val
a_slice = (a_slice-min_val)/(max_val-min_val+eps)
new_img_slice[i] = a_slice
new_img_slice = new_img_slice[:, np.newaxis, :, :]
##transform to tensor
# transform to tensor
new_image_tensor = torch.from_numpy(new_img_slice).float()
return new_image_tensor
if __name__ == '__main__':
import torch
......@@ -223,4 +223,3 @@ if __name__ == '__main__':
fail_cases.append(str(data['patient_id']) + frame)
continue
n += 1
......@@ -91,7 +91,6 @@ python predict.py --sequence LVSA --model_arch 'UNet_64' \
--save_name_format 'pred_{}.nrrd'
# baseline_Adam_finetune_v4_composite_50_chain_mse_random_random_select
python predict.py --sequence LVSA --model_arch 'UNet_64' \
--model_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/baseline_Adam_finetune_v4_composite_50_chain_mse_random_random_select/best/checkpoints/UNet_64$SAX$_Segmentation.pth' \
--root_dir '/vol/biomedic3/cc215/data/ACDC/bias_corrected_and_normalized/patient_wise/' \
......@@ -99,3 +98,14 @@ python predict.py --sequence LVSA --model_arch 'UNet_64' \
--roi_size 256 --batch_size 1 \
--save_folder_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/predict/baseline_Adam_finetune_v4_composite_50_chain_mse_random_random_select/ACDC_all' \
--save_name_format 'pred_{}.nrrd'
baseline_Adam_finetune_v4_composite_50_chain_mse_adv_no_power
python predict.py --sequence LVSA --model_arch 'UNet_64' \
--model_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/baseline_Adam_finetune_v4_composite_50_chain_mse_adv_no_power/best/checkpoints/UNet_64$SAX$_Segmentation.pth' \
--root_dir '/vol/biomedic3/cc215/data/ACDC/bias_corrected_and_normalized/patient_wise/' \
--image_format '{}_img.nrrd' \
--roi_size 256 --batch_size 1 \
--save_folder_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/predict/baseline_Adam_finetune_v4_composite_50_chain_mse_adv_no_power/ACDC_all' \
--save_name_format 'pred_{}.nrrd'
\ No newline at end of file
......@@ -75,3 +75,10 @@ python predict.py --sequence LVSA --model_arch 'UNet_64' \
--save_folder_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/predict/baseline_Adam_finetune_v4_composite_50_chain_mse_adv_random_select/MM' \
--save_name_format 'pred_{}.nrrd' --gpu 1
python predict.py --sequence LVSA --model_arch 'UNet_64' \
--model_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/baseline_Adam_finetune_v4_composite_50_chain_mse_adv_no_power/best/checkpoints/UNet_64$SAX$_Segmentation.pth' \
--root_dir '/vol/biomedic3/cc215/data/cardiac_MMSeg_challenge/Training-corrected/Labeled' \
--image_format 'sa_{}.nrrd' \
--roi_size 256 --batch_size 1 \
--save_folder_path '/vol/bitbucket/cc215/Projects/Cardiac_Multi_View_Segmentation/result/predict/baseline_Adam_finetune_v4_composite_50_chain_mse_adv_no_power/MM' \
--save_name_format 'pred_{}.nrrd' --gpu 1
......@@ -2,5 +2,4 @@
# taskset -c 4,5,6,7 python train.py --json_config_path 'configs/ACDC/supervised/baseline.json' --log --intensity_norm_type z_score --gpu 1
taskset -c 0,1,2,3 python train.py --json_config_path 'configs/ACDC/supervised/adv_bias.json' --log --intensity_norm_type z_score --gpu 1 --adv_training
# taskset -c 8,9,10,11 python train.py --json_config_path 'configs/ACDC/supervised/adv_bias_ce.json' --log --intensity_norm_type z_score --gpu 0 --adv_training
baseline_Adam_finetune_v4_composite_50_chain_mse_adv_no_power.json
taskset -c 0,1,2,3 python train.py --json_config_path 'configs/ACDC/supervised/adv_bias.json' --log --intensity_norm_type z_score --gpu 1 --adv_training
......@@ -46,3 +46,7 @@ taskset -c 24,25,26,27 python train.py --json_config_path 'configs/baseline_Adam
taskset -c 24,25,26,27 python train.py --json_config_path 'configs/baseline_Adam_finetune_v4_composite_50_chain_mse_random_random_select.json' --log --intensity_norm_type z_score --adv_training --gpu 6
taskset -c 12,13,14,15 python train.py --json_config_path 'configs/baseline_Adam_finetune_v4_composite_50_chain_mse_adv_no_power.json' --log --intensity_norm_type z_score --adv_training --gpu 4
taskset -c 12,13,14,15 python train.py --json_config_path 'configs/baseline_Adam_finetune_v4_composite_50_bias_mse_adv_no_power.json' --log --intensity_norm_type z_score --adv_training --gpu 4
This diff is collapsed.
This diff is collapsed.
......@@ -14,55 +14,99 @@ def get_default_augmentor(
data_size,
divergence_types=['mse', 'contour'], # you can also change it to 'kl'.
divergence_weights=[1.0, 0.5],
policy_name='advchain',
debug=False,
use_gpu=True):
'''
return a data augmentor and a list of flags indicating the component of the data augmentation
e.g [1,1,1,1]->[bias,noise,morph,affine]
'''
augmentor_bias = AdvBias(
config_dict={'epsilon': 0.3,
'control_point_spacing': [data_size[2]//2, data_size[3]//2],
'downscale': 2,
'data_size': data_size,
'interpolation_order': 3,
'init_mode': 'random',
'space': 'log'}, debug=debug, use_gpu=use_gpu)
augmentor_noise = AdvNoise(config_dict={'epsilon': 1,
'xi': 1e-6,
'data_size': data_size},
debug=debug)
if policy_name == 'advchain':
augmentor_bias = AdvBias(
config_dict={'epsilon': 0.3,
'control_point_spacing': [data_size[2]//2, data_size[3]//2],
'downscale': 2,
'data_size': data_size,
'interpolation_order': 3,
'init_mode': 'random',
'space': 'log'}, debug=debug, use_gpu=use_gpu)
augmentor_affine = AdvAffine(config_dict={
'rot': 15/180,
'scale_x': 0.2,
'scale_y': 0.2,
'shift_x': 0.1,
'shift_y': 0.1,
'data_size': data_size,
'forward_interp': 'bilinear',
'backward_interp': 'bilinear'},
debug=debug, use_gpu=use_gpu)
augmentor_morph = AdvMorph(
config_dict={'epsilon': 1.5,
'data_size': data_size,
'vector_size': [data_size[2]//8, data_size[3]//8],
'interpolator_mode': 'bilinear'
},
debug=debug, use_gpu=use_gpu)
augmentor_noise = AdvNoise(config_dict={'epsilon': 1,
'xi': 1e-6,
'data_size': data_size},
debug=debug)
transformation_family = [augmentor_affine,
augmentor_noise, augmentor_bias, augmentor_morph]
[one_chain] = random_chain(transformation_family)
augmentor_affine = AdvAffine(config_dict={
'rot': 15/180,
'scale_x': 0.2,
'scale_y': 0.2,
'shift_x': 0.1,
'shift_y': 0.1,
'data_size': data_size,