Skip to content
Snippets Groups Projects
MLI-CW-2.ipynb 717 KiB
Newer Older
Ben Glocker's avatar
Ben Glocker committed
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CO416 - Machine Learning for  Imaging"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Coursework 2 - Age regression from brain MRI"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Predicting age from a brain MRI scan can have diagnostic value for a number of diseases that cause structural changes and damage to the brain. Discrepancy between the predicted, biological age and the real, chronological age of a patient might indicate the presence of disease and abnormal changes to the brain. For this we need an accurate predictor of brain age which may be learned from a set of healthy reference subjects.\n",
    "The objective for the coursework is to implement two different supervised learning approaches for age regression from brain MRI. Data from 600 healthy subjects will be provided. Each approach will require a processing pipeline with different components that you will need to implement using methods that were discussed in the lectures and tutorials. There are dedicated sections in the Jupyter notebook for each approach which contain some detailed instructions, hints and notes.\n",
    "\n",
    "You may find useful ideas and implementations in the tutorial notebooks. Make sure to add documentation to your code. Markers will find it easier to understand your reasoning when sufficiently detailed comments are provided in your implementations.\n",
    "\n",
    "#### Read the descriptions and provided code cells carefully and look out for the cells marked with 'TASK'."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Getting started and familiarise ourselves with the data\n",
    "\n",
    "The following cells provide some helper functions to load the data, and provide some overview and visualisation of the statistics over the population of 600 subjects. Let's start by loading the meta data, that is the data containing information about the subject IDs, their age, and gender."
   ]
  },
  {
   "cell_type": "code",
Ubuntu's avatar
Ubuntu committed
   "execution_count": 2,
Ben Glocker's avatar
Ben Glocker committed
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ID</th>\n",
       "      <th>age</th>\n",
       "      <th>gender_code</th>\n",
       "      <th>gender_text</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>CC110033</td>\n",
       "      <td>24</td>\n",
       "      <td>1</td>\n",
       "      <td>MALE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>CC110037</td>\n",
       "      <td>18</td>\n",
       "      <td>1</td>\n",
       "      <td>MALE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>CC110045</td>\n",
       "      <td>24</td>\n",
       "      <td>2</td>\n",
       "      <td>FEMALE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>CC110056</td>\n",
       "      <td>22</td>\n",
       "      <td>2</td>\n",
       "      <td>FEMALE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>CC110062</td>\n",
       "      <td>20</td>\n",
       "      <td>1</td>\n",
       "      <td>MALE</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         ID  age  gender_code gender_text\n",
       "0  CC110033   24            1        MALE\n",
       "1  CC110037   18            1        MALE\n",
       "2  CC110045   24            2      FEMALE\n",
       "3  CC110056   22            2      FEMALE\n",
       "4  CC110062   20            1        MALE"
      ]
     },
Ubuntu's avatar
Ubuntu committed
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
Ben Glocker's avatar
Ben Glocker committed
   "source": [
    "# Read the meta data using pandas\n",
    "import pandas as pd\n",
    "\n",
    "data_dir = \"./data/brain/\"\n",
    "\n",
    "meta_data = pd.read_csv(data_dir + 'meta/clean_participant_data.csv')\n",
    "meta_data.head() # show the first five data entries"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's have a look at some population statistics."
   ]
  },
  {
   "cell_type": "code",
Ubuntu's avatar
Ubuntu committed
   "execution_count": 3,
Ben Glocker's avatar
Ben Glocker committed
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
Ben Glocker's avatar
Ben Glocker committed
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "sns.catplot(x=\"gender_text\", data=meta_data, kind=\"count\")\n",
    "plt.title('Gender distribution')\n",
    "plt.xlabel('Gender')\n",
    "plt.show()\n",
    "\n",
    "sns.distplot(meta_data['age'], bins=[10,20,30,40,50,60,70,80,90])\n",
    "plt.title('Age distribution')\n",
    "plt.xlabel('Age')\n",
    "plt.show()\n",
    "\n",
    "plt.scatter(range(len(meta_data['age'])),meta_data['age'], marker='.')\n",
    "plt.grid()\n",
    "plt.xlabel('Subject')\n",
    "plt.ylabel('Age')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set up a simple medical image viewer and import SimpleITK"
   ]
  },
  {
   "cell_type": "code",
Ubuntu's avatar
Ubuntu committed
   "execution_count": 4,
Ben Glocker's avatar
Ben Glocker committed
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import SimpleITK as sitk\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from ipywidgets import interact, fixed\n",
    "from IPython.display import display\n",
    "\n",
    "from utils.image_viewer import display_image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Imaging data\n",
    "\n",
    "Let's check out the imaging data that is available for each subject. This cell also shows how to retrieve data given a particular subject ID from the meta data."
   ]
  },
  {
   "cell_type": "code",
Ubuntu's avatar
Ubuntu committed
   "execution_count": 5,
Ben Glocker's avatar
Ben Glocker committed
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Imaging data of subject CC110033 with age 24\n",
      "\n",
      "MR Image (used in part A)\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 720x288 with 3 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Brain mask (used in part A)\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 720x288 with 3 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Spatially normalised grey matter maps (used in part B)\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 720x288 with 3 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
Ben Glocker's avatar
Ben Glocker committed
   "source": [
    "import glob\n",
    "\n",
    "# Subject with index 0\n",
    "ID = meta_data['ID'][0]\n",
    "age = meta_data['age'][0]\n",
    "\n",
    "# Data folders\n",
    "image_dir = data_dir + 'images/'\n",
    "image_filenames = glob.glob(image_dir + '*.nii.gz')\n",
    "\n",
    "mask_dir = data_dir + 'masks/'\n",
    "mask_filenames = glob.glob(mask_dir + '*.nii.gz')\n",
    "\n",
    "greymatter_dir = data_dir + 'greymatter/'\n",
    "greymatter_filenames = glob.glob(greymatter_dir + '*.nii.gz')\n",
    "\n",
    "\n",
    "image_filename = [f for f in image_filenames if ID in f][0]\n",
    "img = sitk.ReadImage(image_filename)\n",
    "\n",
    "mask_filename = [f for f in mask_filenames if ID in f][0]\n",
    "msk = sitk.ReadImage(mask_filename)\n",
    "\n",
    "greymatter_filename = [f for f in greymatter_filenames if ID in f][0]\n",
    "gm = sitk.ReadImage(greymatter_filename)\n",
    "\n",
    "print('Imaging data of subject ' + ID + ' with age ' + str(age))\n",
    "\n",
    "print('\\nMR Image (used in part A)')\n",
    "display_image(img, window=400, level=200)\n",
    "\n",
    "print('Brain mask (used in part A)')\n",
    "display_image(msk)\n",
    "\n",
    "print('Spatially normalised grey matter maps (used in part B)')\n",
    "display_image(gm)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part A: Volume-based regression using brain structure segmentation\n",
    "\n",
    "The first approach aims to regress the age of a subject using the volumes of brain tissues as features. The structures include grey matter (GM), white matter (WM), and cerebrospinal fluid (CSF). It is known that with increasing age the ventricles enlarge (filled with CSF), while it is assumed that grey and white matter volume may decrease over time. However, as overall brain volume varies across individuals, taking the absolute volumes of tissues might not be predictive. Instead, relative volumes need to be computed as the ratios between each tissue volume and overall brain volume. To this end, a four-class (GM, WM, CSF, and background) brain segmentation needs to be implemented and applied to the 600 brain scans. Brain masks are provided which have been generated with a state-of-the-art neuroimaging brain extraction tool.\n",
    "\n",
    "Different regression techniques should be explored, and it might be beneficial to investigate what the best set of features is for this task. Are all volume features equally useful, or is it even better to combine some of them and create new features. How does a simple linear regression perform compared to a model with higher order polynomials? Do you need regularisation? How about other regression methods such as regression trees or neural networks? The accuracy of different methods should be evaluated using two-fold cross-validation, and average age prediction accuracy should be compared and reported appropriately.\n",
    "\n",
    "*Note:* For part A, only the MR images and the brain masks should be used from the imaging data. The spatially normalised grey matter maps are used in part B only. If you struggle with task A-1, you can continue with A-2 using the provided reference segmentations in subfolder `segs_refs`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TASK A-1: Brain tissue segmentation\n",
    "\n",
    "Implement a CNN model for brain tissue segmentation which can provide segmentations of GM, WM, and CSF. For this task (and only for this task), we provide a separate dataset of 52 subjects which are split into 47 images for training and 5 for validation. The template code below has the data handling and main training routines already implemented, so you can focus on implementing a suitable CNN model. A simple model is provided, but this won't perform very well.\n",
    "\n",
    "Once your model is trained and you are happy with the results on the validation data you should apply it to the 600 test images. We provide reference segmentations in a subfolder `segs_refs` for all subjects. Calculate Dice similarity coefficients per tissue when comparing your predicted segmentations for the 600 test images to the reference segmentations. Summarise the statistics of the 600 Dice scores for each tissue class in [box-and-whisker-plots](https://matplotlib.org/api/_as_gen/matplotlib.pyplot.boxplot.html).\n",
    "\n",
    "*Note:* Implementing a full-fledged machine learning pipeline with training and testing procedures in Jupyter notebooks is a bit cumbersome and a pain to debug. Also, running bigger training tasks can be unstable. The code below should work as is on your VM. However, if you want to get a bit more serious about implementing an advanced CNN approach for image segmentation, you may want to move code into separate Python scripts and run them from the terminal."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Imports"
   ]
  },
  {
   "cell_type": "code",
Ubuntu's avatar
Ubuntu committed
   "execution_count": 6,
Ben Glocker's avatar
Ben Glocker committed
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from utils.data_helper import ImageSegmentationDataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Check that the GPU is up and running"
   ]
  },
  {
   "cell_type": "code",
Ubuntu's avatar
Ubuntu committed
   "execution_count": 7,
Ben Glocker's avatar
Ben Glocker committed
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Device: cuda:0\n",
      "GPU: Tesla K80\n"
     ]
    }
   ],
Ben Glocker's avatar
Ben Glocker committed
   "source": [
    "cuda_dev = '0' #GPU device 0 (can be changed if multiple GPUs are available)\n",
Ben Glocker's avatar
Ben Glocker committed
    "\n",
    "use_cuda = torch.cuda.is_available()\n",
    "device = torch.device(\"cuda:\" + cuda_dev if use_cuda else \"cpu\")\n",
    "\n",
    "print('Device: ' + str(device))\n",
    "if use_cuda:\n",
    "    print('GPU: ' + str(torch.cuda.get_device_name(int(cuda_dev))))        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Config and hyper-parameters\n",
    "\n",
    "Here we set some default hyper-parameters and a starting configuration for the image resolution and others.\n",
    "\n",
    "**This needs to be revisited to optimise these values. In particular, you may want to run your final model on higher resolution images.**"
   ]
  },
  {
   "cell_type": "code",
Ubuntu's avatar
Ubuntu committed
   "execution_count": 176,
Ben Glocker's avatar
Ben Glocker committed
   "metadata": {},
   "outputs": [],
   "source": [
    "rnd_seed = 42 #fixed random seed\n",
    "\n",
    "img_size = [64, 64, 64]\n",
    "img_spacing = [3, 3, 3]\n",
    "\n",
Ben Glocker's avatar
Ben Glocker committed
    "learning_rate = 0.001\n",
    "batch_size = 2\n",
    "val_interval = 10\n",
    "\n",
    "num_classes = 4\n",
    "\n",
    "out_dir = './output'\n",
    "\n",
    "# Create output directory\n",
    "if not os.path.exists(out_dir):\n",
    "    os.makedirs(out_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Loading and pre-processing of training and validation data\n",
    "\n",
    "We apply some standard pre-processing on the data such as intensity normalization (zero mean unit variance) and downsampling according to the configuration above.\n",
    "\n",
    "**We provide a 'debug' csv file pointing to just a few images for training. Replace this with the full training dataset when you train your full model.**"
   ]
  },
  {
   "cell_type": "code",
Ubuntu's avatar
Ubuntu committed
   "execution_count": 177,
Ben Glocker's avatar
Ben Glocker committed
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LOADING TRAINING DATA...\n",
Ubuntu's avatar
Ubuntu committed
      "+ reading image msub-CC110319_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC110319.nii.gz\n",
      "+ reading mask sub-CC110319_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC120208_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC120208.nii.gz\n",
      "+ reading mask sub-CC120208_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC120462_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC120462.nii.gz\n",
      "+ reading mask sub-CC120462_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC121144_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC121144.nii.gz\n",
      "+ reading mask sub-CC121144_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC122405_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC122405.nii.gz\n",
      "+ reading mask sub-CC122405_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC210422_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC210422.nii.gz\n",
      "+ reading mask sub-CC210422_T1w_rigid_to_mni_brain_mask.nii.gz\n",
Ubuntu's avatar
Ubuntu committed
      "+ reading image msub-CC220203_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC220203.nii.gz\n",
      "+ reading mask sub-CC220203_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC220518_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC220518.nii.gz\n",
      "+ reading mask sub-CC220518_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC221220_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC221220.nii.gz\n",
      "+ reading mask sub-CC221220_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC221595_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC221595.nii.gz\n",
      "+ reading mask sub-CC221595_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC222120_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC222120.nii.gz\n",
      "+ reading mask sub-CC222120_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC222956_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC222956.nii.gz\n",
      "+ reading mask sub-CC222956_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC310203_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC310203.nii.gz\n",
      "+ reading mask sub-CC310203_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC310407_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC310407.nii.gz\n",
      "+ reading mask sub-CC310407_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC320089_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC320089.nii.gz\n",
      "+ reading mask sub-CC320089_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC320336_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC320336.nii.gz\n",
      "+ reading mask sub-CC320336_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC320574_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC320574.nii.gz\n",
      "+ reading mask sub-CC320574_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC321069_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC321069.nii.gz\n",
      "+ reading mask sub-CC321069_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC321428_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC321428.nii.gz\n",
      "+ reading mask sub-CC321428_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC321899_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC321899.nii.gz\n",
      "+ reading mask sub-CC321899_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC410113_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC410113.nii.gz\n",
      "+ reading mask sub-CC410113_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC410243_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC410243.nii.gz\n",
      "+ reading mask sub-CC410243_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC410432_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC410432.nii.gz\n",
      "+ reading mask sub-CC410432_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC420137_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC420137.nii.gz\n",
      "+ reading mask sub-CC420137_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC420202_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC420202.nii.gz\n",
      "+ reading mask sub-CC420202_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC420286_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC420286.nii.gz\n",
      "+ reading mask sub-CC420286_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC420888_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC420888.nii.gz\n",
      "+ reading mask sub-CC420888_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC510226_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC510226.nii.gz\n",
      "+ reading mask sub-CC510226_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC510329_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC510329.nii.gz\n",
      "+ reading mask sub-CC510329_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC510474_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC510474.nii.gz\n",
      "+ reading mask sub-CC510474_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC520002_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC520002.nii.gz\n",
      "+ reading mask sub-CC520002_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC520134_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC520134.nii.gz\n",
      "+ reading mask sub-CC520134_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC520253_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC520253.nii.gz\n",
      "+ reading mask sub-CC520253_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC520503_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC520503.nii.gz\n",
      "+ reading mask sub-CC520503_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC520775_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC520775.nii.gz\n",
      "+ reading mask sub-CC520775_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC610288_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC610288.nii.gz\n",
      "+ reading mask sub-CC610288_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC610575_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC610575.nii.gz\n",
      "+ reading mask sub-CC610575_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC620073_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC620073.nii.gz\n",
      "+ reading mask sub-CC620073_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC620262_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC620262.nii.gz\n",
      "+ reading mask sub-CC620262_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC620444_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC620444.nii.gz\n",
      "+ reading mask sub-CC620444_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC620557_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC620557.nii.gz\n",
      "+ reading mask sub-CC620557_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC620821_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC620821.nii.gz\n",
      "+ reading mask sub-CC620821_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC621642_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC621642.nii.gz\n",
      "+ reading mask sub-CC621642_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC710416_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC710416.nii.gz\n",
      "+ reading mask sub-CC710416_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC720103_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC720103.nii.gz\n",
      "+ reading mask sub-CC720103_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC720511_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC720511.nii.gz\n",
      "+ reading mask sub-CC720511_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC721291_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC721291.nii.gz\n",
      "+ reading mask sub-CC721291_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "\n",
      "LOADING VALIDATION DATA...\n",
      "+ reading image msub-CC220901_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC220901.nii.gz\n",
      "+ reading mask sub-CC220901_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC320698_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC320698.nii.gz\n",
      "+ reading mask sub-CC320698_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC420454_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC420454.nii.gz\n",
      "+ reading mask sub-CC420454_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC610058_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC610058.nii.gz\n",
      "+ reading mask sub-CC610058_T1w_rigid_to_mni_brain_mask.nii.gz\n",
      "+ reading image msub-CC710679_T1w_rigid_to_mni.nii.gz\n",
      "+ reading segmentation CC710679.nii.gz\n",
      "+ reading mask sub-CC710679_T1w_rigid_to_mni_brain_mask.nii.gz\n"
     ]
    }
   ],
Ben Glocker's avatar
Ben Glocker committed
   "source": [
    "# USE THIS FOR TRAINING ON ALL 47 SUBJECTS\n",
Ubuntu's avatar
Ubuntu committed
    "train_data = data_dir + 'train/csv/train.csv'\n",
Ben Glocker's avatar
Ben Glocker committed
    "\n",
    "# USE THIS FOR DEBUGGING WITH JUST 2 SUBJECTS\n",
Ubuntu's avatar
Ubuntu committed
    "#train_data = data_dir + 'train/csv/train_debug.csv'\n",
Ben Glocker's avatar
Ben Glocker committed
    "\n",
    "val_data = data_dir + 'train/csv/val.csv'\n",
    "\n",
    "print('LOADING TRAINING DATA...')\n",
    "dataset_train = ImageSegmentationDataset(train_data, img_spacing, img_size)\n",
    "dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "print('\\nLOADING VALIDATION DATA...')\n",
    "dataset_val = ImageSegmentationDataset(val_data, img_spacing, img_size)\n",
    "dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Visualise training example\n",
    "\n",
    "Just to check how a training image looks like after pre-processing."
   ]
  },
  {
   "cell_type": "code",
Ubuntu's avatar
Ubuntu committed
   "execution_count": 178,
Ben Glocker's avatar
Ben Glocker committed
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
Ubuntu's avatar
Ubuntu committed
      "Image: msub-CC110319_T1w_rigid_to_mni.nii.gz\n"
Ubuntu's avatar
Ubuntu committed
      "image/png": "\n",
      "text/plain": [
       "<Figure size 720x288 with 3 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
Ubuntu's avatar
Ubuntu committed
      "Segmentation: CC110319.nii.gz\n"
Ubuntu's avatar
Ubuntu committed
      "image/png": "\n",
      "text/plain": [
       "<Figure size 720x288 with 3 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mask\n"
     ]
    },
    {
     "data": {
Ubuntu's avatar
Ubuntu committed
      "image/png": "\n",
      "text/plain": [
       "<Figure size 720x288 with 3 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
Ben Glocker's avatar
Ben Glocker committed
   "source": [
    "sample = dataset_train.get_sample(0)\n",
    "img_name = dataset_train.get_img_name(0)\n",
    "seg_name = dataset_train.get_seg_name(0)\n",
    "print('Image: ' + img_name)\n",
    "display_image(sample['img'], window=5, level=0)\n",
    "print('Segmentation: ' + seg_name)\n",
    "display_image(sitk.LabelToRGB(sample['seg']))\n",
    "print('Mask')\n",
    "display_image(sample['msk'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### The Model\n",
    "\n",
    "This is the **key part of task A-1** where you have to design a suitable CNN model for brain segmentation. The simple model provided below works to some degree (it let's you run through the upcoming cells), but it will not perform very well. Use what you learned in the lectures to come up with a good architecture. Start with a simple, shallow model and only increase complexity (e.g., number of layers) if needed."
   ]
  },
  {
   "cell_type": "code",
Ubuntu's avatar
Ubuntu committed
   "execution_count": 179,
Ben Glocker's avatar
Ben Glocker committed
   "metadata": {},
   "outputs": [],
   "source": [
    "class SimpleNet3D(nn.Module):\n",
    "\n",
    "    def __init__(self, num_classes):\n",
    "        super(SimpleNet3D, self).__init__()\n",
    "        self.conv1 = nn.Conv3d(1, 16, kernel_size=3, padding=1)\n",
    "        self.conv2 = nn.Conv3d(16, 32, kernel_size=3, padding=1)\n",
    "        self.conv3 = nn.Conv3d(32, 16, kernel_size=3, padding=1)\n",
    "        self.conv4 = nn.Conv3d(16, num_classes, kernel_size=3, padding=1)\n",
Ben Glocker's avatar
Ben Glocker committed
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.conv1(x))\n",
    "        x = F.relu(self.conv2(x))\n",
    "        x = F.relu(self.conv3(x))\n",
    "        x = self.conv4(x)\n",
    "        return F.softmax(x, dim=1)"
   ]
  },
Ubuntu's avatar
Ubuntu committed
  {
   "cell_type": "code",
   "execution_count": 180,
   "metadata": {},
   "outputs": [],
   "source": [
    "def passthrough(x, **kwargs):\n",
    "    return x\n",
    "\n",
    "def ELUCons(elu, nchan):\n",
    "    if elu:\n",
    "        return nn.ELU(inplace=True)\n",
    "    else:\n",
    "        return nn.PReLU(nchan)\n",
    "\n",
    "# normalization between sub-volumes is necessary\n",
    "# for good performance\n",
    "class ContBatchNorm3d(nn.modules.batchnorm._BatchNorm):\n",
    "    def _check_input_dim(self, x):\n",
    "        if x.dim() != 5:\n",
    "            raise ValueError('expected 5D input (got {}D input)'\n",
    "                             .format(x.dim()))\n",
    "        #super(ContBatchNorm3d, self)._check_input_dim(x)\n",
    "\n",
    "    def forward(self, x):\n",
    "        self._check_input_dim(x)\n",
    "        return F.batch_norm(\n",
    "            x, self.running_mean, self.running_var, self.weight, self.bias,\n",
    "            True, self.momentum, self.eps)\n",
    "\n",
    "\n",
    "class LUConv(nn.Module):\n",
    "    def __init__(self, nchan, elu):\n",
    "        super(LUConv, self).__init__()\n",
    "        self.relu1 = ELUCons(elu, nchan)\n",
    "        self.conv1 = nn.Conv3d(nchan, nchan, kernel_size=5, padding=2)\n",
    "        self.bn1 = ContBatchNorm3d(nchan)\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.relu1(self.bn1(self.conv1(x)))\n",
    "        return out\n",
    "\n",
    "\n",
    "def _make_nConv(nchan, depth, elu):\n",
    "    layers = []\n",
    "    for _ in range(depth):\n",
    "        layers.append(LUConv(nchan, elu))\n",
    "    return nn.Sequential(*layers)\n",
    "\n",
    "\n",
    "class InputTransition(nn.Module):\n",
    "    def __init__(self, outChans, elu):\n",
    "        super(InputTransition, self).__init__()\n",
    "        self.conv1 = nn.Conv3d(1, 16, kernel_size=5, padding=2)\n",
    "        self.bn1 = ContBatchNorm3d(16)\n",
    "        self.relu1 = ELUCons(elu, 16)\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.conv1(x)\n",
    "        out = self.bn1(out)\n",
    "        # split input in to 16 channels\n",
    "        x16 = torch.cat((x, x, x, x, x, x, x, x,\n",
    "                         x, x, x, x, x, x, x, x), 1)\n",
    "        #print(out.shape)\n",
    "        #print(x16.shape)\n",
    "        out = torch.add(out, x16)\n",
    "        #print(out.shape)\n",
    "        out = self.relu1(out)\n",
    "        return out\n",
    "\n",
    "\n",
    "class DownTransition(nn.Module):\n",
    "    def __init__(self, inChans, nConvs, elu, dropout=False):\n",
    "        super(DownTransition, self).__init__()\n",
    "        outChans = 2*inChans\n",
    "        self.down_conv = nn.Conv3d(inChans, outChans, kernel_size=2, stride=2)\n",
    "        self.bn1 = ContBatchNorm3d(outChans)\n",
    "        self.do1 = passthrough\n",
    "        self.relu1 = ELUCons(elu, outChans)\n",
    "        self.relu2 = ELUCons(elu, outChans)\n",
    "        if dropout:\n",
    "            self.do1 = nn.Dropout3d()\n",
    "        self.ops = _make_nConv(outChans, nConvs, elu)\n",
    "\n",
    "    def forward(self, x):\n",
    "        down = self.relu1(self.bn1(self.down_conv(x)))\n",
    "        out = self.do1(down)\n",
    "        out = self.ops(out)\n",
    "        out = self.relu2(torch.add(out, down))\n",
    "        return out\n",
    "\n",
    "\n",
    "class UpTransition(nn.Module):\n",
    "    def __init__(self, inChans, outChans, nConvs, elu, dropout=False):\n",
    "        super(UpTransition, self).__init__()\n",
    "        self.up_conv = nn.ConvTranspose3d(inChans, outChans // 2, kernel_size=2, stride=2)\n",
    "        self.bn1 = ContBatchNorm3d(outChans // 2)\n",
    "        self.do1 = passthrough\n",
    "        self.do2 = nn.Dropout3d()\n",
    "        self.relu1 = ELUCons(elu, outChans // 2)\n",
    "        self.relu2 = ELUCons(elu, outChans)\n",
    "        if dropout:\n",
    "            self.do1 = nn.Dropout3d()\n",
    "        self.ops = _make_nConv(outChans, nConvs, elu)\n",
    "\n",
    "    def forward(self, x, skipx):\n",
    "        out = self.do1(x)\n",
    "        skipxdo = self.do2(skipx)\n",
    "        out = self.relu1(self.bn1(self.up_conv(out)))\n",
    "        xcat = torch.cat((out, skipxdo), 1)\n",
    "        out = self.ops(xcat)\n",
    "        out = self.relu2(torch.add(out, xcat))\n",
    "        return out\n",
    "\n",
    "\n",
    "class OutputTransition(nn.Module):\n",
    "    def __init__(self, inChans, elu, num_classes):\n",
    "        super(OutputTransition, self).__init__()\n",
    "        self.conv1 = nn.Conv3d(inChans, num_classes, kernel_size=5, padding=2)\n",
    "        self.bn1 = ContBatchNorm3d(num_classes)\n",
    "        self.conv2 = nn.Conv3d(num_classes, num_classes, kernel_size=1)\n",
    "        self.relu1 = ELUCons(elu, num_classes)\n",
    "        self.softmax = F.softmax\n",
    "\n",
    "    def forward(self, x):\n",
    "        # convolve 32 down to 2 channels\n",
    "        out = self.relu1(self.bn1(self.conv1(x)))\n",
    "        out = self.conv2(out)\n",
    "\n",
    "        # make channels the last axis\n",
    "        # out = out.permute(0, 2, 3, 4, 1).contiguous()\n",
    "        # flatten\n",
    "#         print(out[0][3])\n",
    "        # out = out.view(out.numel() // 2, 2)\n",
    "        out = self.softmax(out, dim=1)\n",
    "#         print(out.shape)\n",
    "#         print(out[0][3])\n",
    "        # treat channel 0 as the predicted output\n",
    "        return out\n",
    "\n",
    "\n",
    "class SimpleNet3D(nn.Module):\n",
    "    # the number of convolutions in each layer corresponds\n",
    "    # to what is in the actual prototxt, not the intent\n",
    "    def __init__(self, num_classes, elu=True):\n",
    "        super(SimpleNet3D, self).__init__()\n",
    "        self.in_tr = InputTransition(16, elu)\n",
    "        self.down_tr32 = DownTransition(16, 1, elu)\n",
    "        self.down_tr64 = DownTransition(32, 2, elu)\n",
    "        self.down_tr128 = DownTransition(64, 3, elu, dropout=True)\n",
    "        self.down_tr256 = DownTransition(128, 2, elu, dropout=True)\n",
    "        self.up_tr256 = UpTransition(256, 256, 2, elu, dropout=True)\n",
    "        self.up_tr128 = UpTransition(256, 128, 2, elu, dropout=True)\n",
    "        self.up_tr64 = UpTransition(128, 64, 1, elu)\n",
    "        self.up_tr32 = UpTransition(64, 32, 1, elu)\n",
    "        self.out_tr = OutputTransition(32, elu, num_classes)\n",
    "\n",
    "    def forward(self, x):\n",
    "        out16 = self.in_tr(x)\n",
    "        out32 = self.down_tr32(out16)\n",
    "        out64 = self.down_tr64(out32)\n",
    "        out128 = self.down_tr128(out64)\n",
    "        out256 = self.down_tr256(out128)\n",
    "        out = self.up_tr256(out256, out128)\n",
    "        out = self.up_tr128(out, out64)\n",
    "        out = self.up_tr64(out, out32)\n",
    "        out = self.up_tr32(out, out16)\n",
    "        out = self.out_tr(out)\n",
    "        return out"
   ]
  },
Ben Glocker's avatar
Ben Glocker committed
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### TRAINING\n",
    "\n",
    "Below is an implementation of a full training procedure including a loop for intermediate evaluation of the model on the validation data. Feel free to modify this procedure. For example, in addition to the loss you may want to monitor precision, recall and Dice scores (or others)."
   ]
  },
  {
   "cell_type": "code",
Ubuntu's avatar
Ubuntu committed
   "execution_count": 181,
Ben Glocker's avatar
Ben Glocker committed
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "START TRAINING...\n",
Ubuntu's avatar
Ubuntu committed
      "+ TRAINING \tEpoch: 1 \tLoss: 1.382734\n",
      "--------------------------------------------------\n",
Ubuntu's avatar
Ubuntu committed
      "+ VALIDATE \tEpoch: 1 \tLoss: 1.378221\n"
Ubuntu's avatar
Ubuntu committed
      "image/png": "\n",
      "text/plain": [
       "<Figure size 720x288 with 3 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------------------------\n",
Ubuntu's avatar
Ubuntu committed
      "+ TRAINING \tEpoch: 2 \tLoss: 1.363415\n",
      "+ TRAINING \tEpoch: 3 \tLoss: 1.326473\n",
      "+ TRAINING \tEpoch: 4 \tLoss: 1.306049\n"
Ubuntu's avatar
Ubuntu committed
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-181-ec987b20196a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     19\u001b[0m     \u001b[0;31m# Training\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     20\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_samples\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataloader_train\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m         \u001b[0mimg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mseg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch_samples\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'img'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_samples\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'seg'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     22\u001b[0m         \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     23\u001b[0m         \u001b[0mprd\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
Ben Glocker's avatar
Ben Glocker committed
   "source": [
    "model_dir = os.path.join(out_dir, 'model')\n",
    "if not os.path.exists(model_dir):\n",
    "    os.makedirs(model_dir)\n",
    "\n",
    "torch.manual_seed(rnd_seed) #fix random seed\n",
    "\n",
Ubuntu's avatar
Ubuntu committed
    "model = SimpleNet3D(num_classes).to(device)\n",
Ben Glocker's avatar
Ben Glocker committed
    "model.train()\n",
    "    \n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
    "\n",
    "loss_train_log = []\n",
    "loss_val_log = []\n",
    "epoch_val_log = []\n",
    "    \n",