Commit 91bad048 authored by cc215's avatar cc215 💬
Browse files

add training code for kings

parent 9a41b011
......@@ -8,3 +8,4 @@ __pycache__/
# C extensions
*.so
test_results/
result/
\ No newline at end of file
......@@ -14,7 +14,7 @@ author: Chen Chen (cc215@ic.ac.uk)
## Environment
- Python 3.5
- Pytorch 1.0
- CUDA(cuda 9.0)
- CUDA(cuda 10.0)
## Dependencies
- see requirements.txt
......@@ -53,3 +53,45 @@ Results will be saved under `test_results` by default
## Training
- run `pip install -r requirements.txt`
- pandas==0.22.0
- matplotlib==2.2.2
- nipy==0.4.2
- MedPy==0.3.0
- scipy==1.0.1
- tqdm==4.23.0
- numpy==1.14.2
- SimpleITK==1.1.0
- scikit-image
- tensorboardX==1.4
- and then install an adapted version of torch sample via : `pip install git+https://github.com/ozan-oktay/torchsample/`
- test environment:
- run `python train.py `
- open config file `configs/basic_opt.json`, change dataset configuration:
- "train_dir": training dataset directory
- "validate_dir": validation dataset directory
- "readable_frames": list of cardiac frames to be trained. e.g. ["ED","ES"]
- "image_format_name": the file name of image data, e.g. "sa_{frame}.nii.gz" for loading sa_ED.nii.gz and sa_ES.nii.gz
- "label_format_name": the file name of label data, e.g. "label_sa_{frame}.nii.gz" for loading label_sa_ED.nii.gz and label_sa_ES.nii.gz
- run `python train.py --json_config_path {config_file_path}`
- e.g. `python train.py --json_config_path configs/basic_opt.json`
## Finetuning
- open config file (`configs/basic_opt.json`), change model resume path and adjust learning rate to be 0.0001 or 0.00001:
- "resume_path":"./checkpoints/Unet_LVSA_best.pkl"
- "lr": 0.0001
## Output
- By default, all models and internal outputs will be stored under `result`
- The best model can be found under this dir, e.g. 'result/best/checkpoints/UNet_64$SAX$_Segmentation.pth'
## Advanced
- you can change data augmentation strategy by changing the name of "data_aug_policy" in the config file.
- For details about the data augmentation strategy, please refer to :'dataset_loader/mytransform.py'
# Created by cc215 at 27/12/19
# Enter feature description here
# Enter scenario name here
# Enter steps here
def switch_kv_in_dict(mydict):
switched_dict = {y: x for x, y in mydict.items()}
return switched_dict
\ No newline at end of file
from medpy.metric.binary import dc
import SimpleITK as sitk
import os
import numpy as np
import pandas as pd
from common_utils.measure import hd,hd_2D_stack
def compute_score(img_pred, img_gt,measure_vol=False,voxel_spacing=(1.0,1.0,1.0),classes=[1,2,3]):
'''3D measurements.'''
##lv=1,myo=2,rv=3
# img_pred=img_pred[1:img_pred.shape[0]-2] ## exclude apical/basal slices.
# img_gt=img_gt[1:img_gt.shape[0]-2]
n, h, w = img_gt.shape
res = []
for c in classes:
# Copy the gt image to not alterate the input
gt_c_i = np.copy(img_gt)
gt_c_i[gt_c_i != c] = 0
# Copy the pred image to not alterate the input
pred_c_i = np.copy(img_pred)
pred_c_i[pred_c_i != c] = 0
# Clip the value to compute the volumes
gt_c_i = np.clip(gt_c_i, 0, 1)
pred_c_i = np.clip(pred_c_i, 0, 1)
# Compute the Dice
if np.sum(gt_c_i) == 0:
print('zero gt')
if np.sum(pred_c_i) == 0:
print('zero pred')
dice = dc(gt_c_i, pred_c_i)
#hd_value = hd(gt_c_i, pred_c_i, voxelspacing=voxel_spacing, connectivity=2) ##connectivity=2 for 8-neighborhood
hd_value=hd_2D_stack(pred_c_i.reshape(n,h,w),gt_c_i.reshape(n,h,w),pixelspacing=voxel_spacing[1],connectivity=1)
if measure_vol:
## add hd,error_volume_LV,mass
# Compute volume
volpred = pred_c_i.sum() * np.prod(voxel_spacing) / 1000.
volgt = gt_c_i.sum() * np.prod(voxel_spacing) / 1000.
volerror = np.abs(volpred - volgt)
if c==2:
##1.05*volume over myo.
volerror=volerror*1.05
## myo mass:
res += [dice,hd_value,volpred,volgt, volerror]
else:
res += [dice,hd_value]
return res
def evaluate_patient_wise(root_dir,label_format_name,pred_format_name,frames=['ED','ES'],dataset='ACDC',measure_vol=False):
result = []
for p_id in sorted(os.listdir(root_dir)):
print(p_id)
patient_dir=os.path.join(root_dir,p_id)
if not os.path.isdir(patient_dir): pass
for frame in frames:
if frame=='ED':
if p_id in ['10AM02216','14DN01375','14EB01736', '10MW00126', '10WP00714', '14DW01572','12DH01153','14DN01375','14EC03291']:
print('ignore:', p_id)
continue
if frame=='ES':
if p_id in ['14EB01736','14DN01375','12DS00630','12DH01153','10MW00126', '10WP00714', '14DW01572','12DH01153','14DN01375','14EC03291']:
print('ignore:', p_id)
continue
pred_path=os.path.join(patient_dir,pred_format_name.format(frame))
gt_path=os.path.join(patient_dir,label_format_name.format(frame))
if not os.path.exists(pred_path) or not os.path.exists(gt_path):
continue
pred=sitk.GetArrayFromImage(sitk.ReadImage(pred_path))
gt_im=sitk.ReadImage(gt_path)
spacing=gt_im.GetSpacing()
spacing=spacing[::-1]
spacing=np.array(spacing)
print ('spacing:',spacing)
gt=sitk.GetArrayFromImage(gt_im)
## transfer GT if it is different from UKBB labeling protocol
if dataset=='UKBB' or dataset=='UCL' or 'LVSC' in dataset:
pass
elif dataset=='ACDC':
gt=(gt==3)*1+(gt==2)*2+(gt==1)*3
elif dataset=='ACDC_ACDC':
gt=(gt==3)*1+(gt==2)*2+(gt==1)*3
pred=(pred==3)*1+(pred==2)*2+(pred==1)*3
elif dataset == 'HB':
gt = (gt == 1) * 1 + (gt == 2) * 2 + (gt >= 3) * 3
elif dataset == 'RVSC':
gt = (gt == 1) * 3 + (gt == 2) * 0
elif dataset == 'Carlo_Pathology_LVSA':
gt = (gt == 4) * 3 + (gt == 2) * 2+(gt == 1) * 1
else:
raise NotImplementedError
temp=[]
temp+=[str(p_id)]
temp+=[str(frame)]
res_1=compute_score (pred,gt,measure_vol=measure_vol,voxel_spacing=spacing,)
temp+=res_1
print (temp)
result.append(temp)
return result
def measure_prediction_result(result,save_path=None, header = ['patient_id', 'frame', 'lv_dice_score','lv_hd','myo_dice_score','myo_hd', 'rv_dice_score','rv_hd']):
import pandas as pd
import time
df = pd.DataFrame(result, columns=header)
print(save_path)
if not save_path is None:
new_save_path=save_path + "_{}.csv".format(
time.strftime("%Y%m%d_%H%M%S"))
df.to_csv(new_save_path, index=False)
print (df.describe())
return df,new_save_path
def run_statistic( root_dir, dataset,header,save_analysis_dir,measure_vol=False):
# from dataset.cardiac_dataset import CARDIAC_DATASET
_, _, label_format_name, _=CARDIAC_DATASET.get_dataset_config(dataset)
label_format_name=label_format_name.split('/')[-1]
print (label_format_name)
pred_format_name = 'seg_sa_{}.nii.gz'
ED_result = evaluate_patient_wise(root_dir, label_format_name, pred_format_name, frames=['ED'], dataset=dataset,measure_vol=measure_vol)
ES_result = evaluate_patient_wise(root_dir, label_format_name, pred_format_name, frames=['ES'], dataset=dataset,measure_vol=measure_vol)
model_name=root_dir.split('/')[-1]
if not os.path.exists('/vol/medic01/users/cc215/data/DA/experiments/'):
os.mkdir('/vol/medic01/users/cc215/data/DA/experiments/')
if not os.path.exists(save_analysis_dir):
os.makedirs(save_analysis_dir)
#ed_result_path='/vol/medic01/users/cc215/data/DA/experiments/'+dataset+'_'+model_name+'_ED'
#es_result_path='/vol/medic01/users/cc215/data/DA/experiments/'+dataset+'_'+model_name+'_ES'
df1,ed_result_path= measure_prediction_result(ED_result, save_path=os.path.join(save_analysis_dir,dataset+'_'+model_name+'_ED'),
header=header)
df2,es_result_path = measure_prediction_result(ES_result, save_path=os.path.join(save_analysis_dir,dataset+'_'+model_name+'_ES'),
header=header)
df = pd.DataFrame(ED_result + ES_result, columns=header)
def print_mean_std(df):
info='{:.3f} ({:.3f}) '.format(df["lv_dice_score"].mean(), df["lv_dice_score"].std())\
+ ',{:.3f} ({:.3f}) '.format(df["myo_dice_score"].mean(),
df["myo_dice_score"].std()) + ',{:.3f} ({:.3f}) '.format(
df["rv_dice_score"].mean(), df["rv_dice_score"].std())
print (info)
return info
def print_hd_mean_std(df):
info='{:.3f} ({:.3f}) '.format(df["lv_hd"].mean(), df["lv_hd"].std()) \
+ ',{:.3f} ({:.3f}) '.format(df["myo_hd"].mean(),
df["myo_hd"].std()) + ',{:.3f} ({:.3f}) '.format(
df["rv_hd"].mean(), df["rv_hd"].std())
print (info)
return info
print('pred_path:', root_dir)
print('pred_dataset:', dataset)
print('==DICE==')
print('ED/ES/Overall: LV , MYO , RV ')
txt_path=os.path.join(save_analysis_dir,'dice_hd.txt')
file=open(txt_path,'w')
ed_dice_result=print_mean_std(df1)
es_dice_result=print_mean_std(df2)
total_dice_result= print_mean_std(df)
dice=[save_analysis_dir,'\n','dice\n','ED,',ed_dice_result,'\n','ES,',es_dice_result,'\n','total,',total_dice_result,'\n']
file.writelines(dice)
print('==HD==')
print('ED/ES/Overall: LV , MYO , RV ')
ed_hd_result=print_hd_mean_std(df1)
es_hd_result=print_hd_mean_std(df2)
total_hd_result=print_hd_mean_std(df)
hd=['hd \n','ED,',ed_hd_result,'\n','ES,',es_hd_result,'\n','total,',total_hd_result,'\n']
file.writelines(hd)
file.close()
print ('save ED result csv to:',ed_result_path)
print ('save ES result csv to:',es_result_path)
if __name__=='__main__':
# header = ['patient_id', 'frame', 'lv_dice_score', 'myo_dice_score', 'rv_dice_score']
# root_dir = '/vol/medic01/users/cc215/data/ACDC_2017/UNetresample' # /vol/medic01/users/cc215/data/ACDC_2017/SA_UNetresample/'#'/vol/medic01/users/cc215/data/ACDC_2017/UNetresample/'##'/vol/medic01/users/cc215/data/ACDC_2017/UNET_ACDC_temp_advresample' #UNET_UKBBresample/'#'/vol/medic01/users/cc215/data/ACDC_2017/UNET_ACDCresample'#'#'/vol/medic01/users/cc215/data/Biobank/UKBB_Unet'
# dataset = 'ACDC'
# run_statistic(root_dir,dataset,header)
measure_vol=True
if not measure_vol is True:
header = ['patient_id', 'frame', 'lv_dice_score','lv_hd','myo_dice_score','myo_hd', 'rv_dice_score','rv_hd']
else:
header = ['patient_id', 'frame', 'lv_dice_score', 'lv_hd','lv_vol','lv_vol_gt','lv_vol_error','myo_dice_score', 'myo_hd','myo_vol','myo_vol_gt','myo_vol_error', 'rv_dice_score', 'rv_hd','rv_vol',
'rv_vol_gt','rv_vol_error']
# root_dir = '/vol/medic01/users/cc215/data/DA/experiments_UKBB2UKBB/predict/UNetpredict_testresample_new'
# #root_dir='/vol/medic01/users/cc215/data/ACDC_2017/UNET_ACDCresample' #'/vol/medic01/users/cc215/data/ACDC_2017/UNET_ACDC_temp_advresample' #UNET_UKBBresample/'#'/vol/medic01/users/cc215/data/ACDC_2017/UNET_ACDCresample'#'#'/vol/medic01/users/cc215/data/Biobank/UKBB_Unet'
# dataset = 'UKBB'
pred_results={
'UKBB2_ACDC_test':'/vol/medic01/users/cc215/data/DA/experiments_UKBB2ACDC/predict/UNetpredict_testresample_new',
'UKBB2_ACDC_all':'/vol/medic01/users/cc215/data/DA/experiments_UKBB2ACDC/predict/UNetpredict_allresample_new/',
'UKBB2_UCL_test':'/vol/medic01/users/cc215/data/DA/experiments_UKBB2UCL/predict/UNetpredict_testresample_new',
'UKBB2_UCL_all':'/vol/medic01/users/cc215/data/DA/experiments_UKBB2UCL/predict/UNetpredict_allresample_new',
'ACDC2_ACDC_test':'/vol/medic01/users/cc215/data/DA/experiments_ACDC2ACDC/predict/UNetpredict_testresample_new',
'UCL2_UCL_test':'/vol/medic01/users/cc215/data/DA/experiments_UCL2UCL/predict/UNetpredict_testresample_new',
'UKBB2_UKBB_test':'/vol/medic01/users/cc215/data/DA/experiments_UKBB2UKBB/predict/UNetpredict_testresample_new',
}
for k,v in pred_results.items():
root_dir=v
dataset=k.split('_')[1]
run_statistic(root_dir, dataset, header,measure_vol=measure_vol,save_analysis_dir='/vol/medic01/users/cc215/new_Drop/Dropbox/PhD_2018/cardiac_data_augmentation/paper_DA/data_analysis/all_metrics/'+k)
This diff is collapsed.
# Created by cc215 at 27/12/19
# Enter feature description here
# Enter scenario name here
# Enter steps here
import json
import logging
import os
import shutil
import torch
import matplotlib.pyplot as plt
class Params():
"""Class that loads hyperparameters from a json file.
Example:
```
params = Params(json_path)
print(params.learning_rate)
params.learning_rate = 0.5 # change the value of learning_rate in params
```
"""
def __init__(self, json_path):
with open(json_path) as f:
params = json.load(f)
self.__dict__.update(params)
def save(self, json_path):
with open(json_path, 'w') as f:
json.dump(self.__dict__, f, indent=4)
def update(self, json_path):
"""Loads parameters from json file"""
with open(json_path) as f:
params = json.load(f)
self.__dict__.update(params)
@property
def dict(self):
"""Gives dict-like access to Params instance by `params.dict['learning_rate']"""
return self.__dict__
def save_dict_to_json(d, json_path):
"""Saves dict of floats in json file
Args:
d: (dict) of float-castable values (np.float, int, float, etc.)
json_path: (string) path to json file
"""
with open(json_path, 'w') as f:
# We need to convert the values to float for json (it doesn't accept np.array, np.float, )
d = {k: float(v) for k, v in d.items()}
json.dump(d, f, indent=4)
def plot_training_results(model_dir, plot_history):
"""
Plot training results (procedure) during training.
Args:
plot_history: (dict) a dictionary containing historical values of what
we want to plot
"""
# tr_losses = plot_history['train_loss']
# val_losses = plot_history['val_loss']
# te_losses = plot_history['test_loss']
# tr_accs = plot_history['train_acc']
val_accs = plot_history['val_acc']
te_accs = plot_history['test_acc']
# plt.figure(0)
# plt.plot(list(range(len(tr_losses))), tr_losses, label='train_loss')
# plt.plot(list(range(len(val_losses))), val_losses, label='val_loss')
# plt.plot(list(range(len(te_losses))), te_losses, label='test_loss')
# plt.title('Loss trend')
# plt.xlabel('episode')
# plt.ylabel('ce loss')
# plt.legend()
# plt.savefig(os.path.join(model_dir, 'loss_trend'), dpi=200)
# plt.clf()
plt.figure(1)
# plt.plot(list(range(len(tr_accs))), tr_accs, label='train_acc')
plt.plot(list(range(len(val_accs))), val_accs, label='val_acc')
plt.plot(list(range(len(te_accs))), te_accs, label='test_acc')
plt.title('Accuracy trend')
plt.xlabel('iter / 1000')
plt.ylabel('accuracy')
plt.legend()
plt.savefig(os.path.join(model_dir, 'accuracy_trend'), dpi=200)
plt.clf()
if __name__ =='__main__':
params = Params('/vol/medic01/users/cc215/Dropbox/projects/DeformADA/configs/gat_loss.json')
print (params.dict)
\ No newline at end of file
import torch
import os
def resume_model_from_file(file_path):
start_epoch=1
optimizer_state=None
state_dict=None
checkpoint=None
assert os.path.isfile(file_path)
if '.pkl' in file_path:
print("Loading models and optimizer from checkpoint '{}'".format(file_path))
checkpoint = torch.load(file_path)
for k,v in checkpoint.items():
if k=='model_state':
state_dict=checkpoint['model_state']
if k=='optimizer_state':
optimizer_state=checkpoint['optimizer_state']
if k=='epoch':
start_epoch = int(checkpoint['epoch'])
print("Loaded checkpoint '{}' (epoch {})"
.format(file_path, checkpoint['epoch']))
elif '.pth' in file_path:
print("Loading models and optimizer from checkpoint '{}'".format(file_path))
state_dict = torch.load(file_path)
start_epoch=int(file_path.split('.')[0].split('_')[-1]) ##restore training procedure.
else:
raise NotImplementedError
return {'start_epoch':start_epoch,
'optimizer_state':optimizer_state,
'state_dict':state_dict,
'checkpoint':checkpoint
}
def restoreOmega(path,model,optimizer=None):
checkpoint = resume_model_from_file(file_path=path)
state_dict = checkpoint['state_dict']
start_epoch = checkpoint['start_epoch']
model.load_state_dict(state_dict, strict=False)
optimizer_state = checkpoint['optimizer_state']
if not (optimizer_state is None) and (not optimizer is None):
try:
optimizer.load_state_dict(optimizer_state)
except:
pass
return model,optimizer,start_epoch
def save_model_to_file(model_name,model, epoch, optimizer,save_path):
state_dict= model.module.state_dict() if isinstance(model,torch.nn.DataParallel) else model.state_dict()
state = {'model_name': model_name,
'epoch': epoch + 1,
'model_state': state_dict,
'optimizer_state': optimizer.state_dict()
}
torch.save(state, save_path)
def gen_overlay(img,attention):
import numpy as np
import cv2
'''
2D
:param image: 2D
:param attention: 2D
:return: 2D
'''
img = img[:, :, np.newaxis]
# img=(img*255.0)
img = np.repeat(img, axis=2, repeats=3)
height, width, _ = img.shape
img=(img-img.min())/(img.max()-img.min())*255
img=img.astype(np.uint8)
#attention= np.expand_dims(attention, axis=2)
# attention= np.repeat(attention, axis=2, repeats=3)
attention=(attention-attention.min())/(attention.max()-attention.min())*255
attention=attention.astype(np.uint8)
heatmap = cv2.applyColorMap(attention,cv2.COLORMAP_JET)
cam=cv2.addWeighted(img, 1, heatmap, 1, 0)
# cam = heatmap + np.float32(img)
#cam = cam / np.max(cam)
return cam#np.uint8(cam*255)
def save_npy2image(data,file_dir,name):
if not os.path.exists(file_dir):
os.makedirs(file_dir)
filepath=os.path.join(file_dir,name+'.png')
import scipy.misc
scipy.misc.imsave(filepath, data)
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Variable
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
def cross_entropy_2D(input, target, weight=None, size_average=True):
n, c, h, w = input.size()
log_p = F.log_softmax(input, dim=1)
log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
target = target.view(target.numel())
loss = F.nll_loss(log_p, target, weight=weight, size_average=False)
if size_average:
loss /= float(target.numel()+1e-10)
return loss
\ No newline at end of file
This diff is collapsed.
# Adapted from score written by wkentaro
# https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py
import numpy as np
from medpy.metric.binary import dc
from common_utils.measure import hd, hd_2D_stack, asd, volumesimilarity
import pandas as pd
from IPython.display import display, HTML
class runningScore(object):
def __init__(self, n_classes):
self.n_classes = n_classes
self.confusion_matrix = np.zeros((n_classes, n_classes))
def _fast_hist(self, label_true, label_pred, n_class):
mask = (label_true >= 0) & (label_true < n_class)
hist = np.bincount(
n_class * label_true[mask].astype(int) +
label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class)
return hist
def update(self, label_trues, label_preds):
for lt, lp in zip(label_trues, label_preds):
self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes)
def get_scores(self):
"""Returns accuracy score evaluation result.
- overall accuracy
- mean accuracy
- mean IU
- fwavacc
"""
hist = self.confusion_matrix
acc = np.diag(hist).sum() / hist.sum()
acc_cls = np.diag(hist) / hist.sum(axis=1)
acc_cls = np.nanmean(acc_cls)
iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
mean_iu = np.nanmean(iu)
freq = hist.sum(axis=1) / hist.sum()
fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
cls_iu = dict(zip(range(self.n_classes), iu))
return {'Overall Acc: \t': acc,
'Mean Acc : \t': acc_cls,
'FreqW Acc : \t': fwavacc,
'Mean IoU : \t': mean_iu, }, cls_iu
def reset(self):
self.confusion_matrix = np.zeros((self.n_classes, self.n_classes))
class runningCustomScore(object):
def __init__(self, n_classes, add_hd=False):
self.n_classes = n_classes
assert self.n_classes <= 2, 'only support binary segmentation for now'
self.confusion_matrix = np.zeros((n_classes, n_classes))
self.dice_score = []
self.hd_score = []
self.add_hd = add_hd
def _fast_hist(self, label_true, label_pred, n_class):
mask = (label_true >= 0) & (label_true < n_class)
hist = np.bincount(
n_class * label_true[mask].astype(int) +
label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class)
return hist
def update(self, label_trues, label_preds, voxel_spacing=None):
for lt, lp in zip(label_trues, label_preds):
self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes)
# Clip the value to compute the volumes
gt = np.clip(label_trues, 0, 1)
pred = np.clip(label_preds, 0, 1)
self.dice_score.append(dc(result=pred, reference=gt))
if self.add_hd:
assert voxel_spacing is not None, 'please define voxel '
if np.sum(gt) > 0 and np.sum(pred) > 0:
print(voxel_spacing)
self.hd_score.append(hd(result=pred, reference=gt, voxelspacing=voxel_spacing, connectivity=1))
def get_scores(self):
"""Returns accuracy score evaluation result.
- overall accuracy
- mean accuracy
- mean IU
- fwavacc
"""
hist = self.confusion_matrix
acc = np.diag(hist).sum() / hist.sum()
acc_cls = np.diag(hist) / hist.sum(axis=1)