Skip to content

Commit

Permalink
Add instructions and script on how to use the ventral rootlets model (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
valosekj authored May 23, 2024
1 parent cb1dfeb commit 97a6370
Show file tree
Hide file tree
Showing 3 changed files with 297 additions and 0 deletions.
69 changes: 69 additions & 0 deletions packaging_ventral_rootlets/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
## Getting started

⚠️ This README provides instructions on how to use the model for **_ventral_** and dorsal rootlets.
Please note that this model is still under development and is not yet available in the Spinal Cord Toolbox (SCT).

⚠️ For the stable model for dorsal rootlets only, use SCT v6.2 or higher (please refer to this [README](..%2FREADME.md)).

### Dependencies

- [Spinal Cord Toolbox (SCT)](https://spinalcordtoolbox.com/user_section/installation.html)
- [conda](https://conda.io/projects/conda/en/latest/user-guide/install/index.html)
- Python

### Step 1: Cloning the Repository

Open a terminal and clone the repository using the following command:

```
git clone https://github.com/ivadomed/model-spinal-rootlets
```

### Step 2: Setting up the Environment

The following commands show how to set up the environment.
Note that the documentation assumes that the user has `conda` installed on their system.
Instructions on installing `conda` can be found [here](https://conda.io/projects/conda/en/latest/user-guide/install/index.html).

1. Create a conda environment with the following command:
```
conda create -n venv_nnunet python=3.9
```

2. Activate the environment with the following command:
```
conda activate venv_nnunet
```

3. Install the required packages with the following command:
```
cd model-spinal-rootlets
pip install -r packaging_ventral_rootlets/requirements.txt
```

### Step 3: Getting the Predictions

ℹ️ To temporarily suppress warnings raised by the nnUNet, you can run the following three commands in the same terminal session as the above command:

```bash
export nnUNet_raw="${HOME}/nnUNet_raw"
export nnUNet_preprocessed="${HOME}/nnUNet_preprocessed"
export nnUNet_results="${HOME}/nnUNet_results"
```

To segment a single image using the trained model, run the following command from the terminal.

This assumes that the latest model has been downloaded (https://github.com/ivadomed/model-spinal-rootlets/releases/download/r20240523/model-spinal-rootlets_ventral_D106_r20240523.zip)
and unzipped (`unzip model-spinal-rootlets_ventral_D106_r20240523.zip`).

```bash
python packaging_ventral_rootlets/run_inference_single_subject.py -i <INPUT> -o <OUTPUT> -path-model <PATH_TO_MODEL_FOLDER>
```

For example:

```bash
python packaging_ventral_rootlets/run_inference_single_subject.py -i sub-001_T2w.nii.gz -o sub-001_T2w_label-rootlets_dseg.nii.gz -path-model ~/Downloads/model-spinal-rootlets_ventral_D106_r20240523 -fold all
```

ℹ️ The script also supports getting segmentations on a GPU. To do so, simply add the flag `--use-gpu` at the end of the above commands. By default, the inference is run on the CPU. It is useful to note that obtaining the predictions from the GPU is significantly faster than the CPU.
4 changes: 4 additions & 0 deletions packaging_ventral_rootlets/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
numpy
nibabel
nnunetv2==2.2.1
torch==2.0.1
224 changes: 224 additions & 0 deletions packaging_ventral_rootlets/run_inference_single_subject.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
"""
This script is used to run inference on a single subject using a nnUNetV2 model.
Note: conda environment with nnUNetV2 is required to run this script.
For details how to install nnUNetV2, see:
https://github.com/ivadomed/utilities/blob/main/quick_start_guides/nnU-Net_quick_start_guide.md#installation
Author: Jan Valosek
Example:
python run_inference_single_subject.py
-i sub-001_T2w.nii.gz
-o sub-001_T2w_label-rootlet.nii.gz
-path-model <PATH_TO_MODEL_FOLDER>
-tile-step-size 0.5
-fold 1
"""


import os
import shutil
import subprocess
import argparse
import datetime

import torch
import glob
import time
import tempfile

from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
from batchgenerators.utilities.file_and_folder_operations import join


def get_parser():
# parse command line arguments
parser = argparse.ArgumentParser(description='Segment an image using nnUNet model.')
parser.add_argument('-i', help='Input image to segment. Example: sub-001_T2w.nii.gz', required=True)
parser.add_argument('-o', help='Output filename. Example: sub-001_T2w_label-rootlet.nii.gz', required=True)
parser.add_argument('-path-model', help='Path to the model folder. This folder should contain individual '
'folders like fold_0, fold_1, etc. and dataset.json, '
'dataset_fingerprint.json and plans.json files.', required=True, type=str)
parser.add_argument('-use-gpu', action='store_true', default=False,
help='Use GPU for inference. Default: False')
parser.add_argument('-fold', type=str, required=True,
help='Fold(s) to use for inference. Example(s): 2 (single fold), 2,3 (multiple folds), '
'all (fold_all).', choices=['0', '1', '2', '3', '4', 'all'])
parser.add_argument('-use-best-checkpoint', action='store_true', default=False,
help='Use the best checkpoint (instead of the final checkpoint) for prediction. '
'NOTE: nnUNet by default uses the final checkpoint. Default: False')
parser.add_argument('-tile-step-size', default=0.5, type=float,
help='Tile step size defining the overlap between images patches during inference. '
'Default: 0.5 '
'NOTE: changing it from 0.5 to 0.9 makes inference faster but there is a small drop in '
'performance.')

return parser


def get_orientation(file):
"""
Get the original orientation of an image
:param file: path to the image
:return: orig_orientation: original orientation of the image, e.g. LPI
"""

# Fetch the original orientation from the output of sct_image
sct_command = "sct_image -i {} -header | grep -E qform_[xyz] | awk '{{printf \"%s\", substr($2, 1, 1)}}'".format(
file)
orig_orientation = subprocess.check_output(sct_command, shell=True).decode('utf-8')
return orig_orientation


def tmp_create():
"""
Create temporary folder and return its path
"""
prefix = f"sciseg_prediction_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_"
tmpdir = tempfile.mkdtemp(prefix=prefix)
print(f"Creating temporary folder ({tmpdir})")
return tmpdir


def splitext(fname):
"""
Split a fname (folder/file + ext) into a folder/file and extension.
Note: for .nii.gz the extension is understandably .nii.gz, not .gz
(``os.path.splitext()`` would want to do the latter, hence the special case).
Taken (shamelessly) from: https://github.com/spinalcordtoolbox/manual-correction/blob/main/utils.py
"""
dir, filename = os.path.split(fname)
for special_ext in ['.nii.gz', '.tar.gz']:
if filename.endswith(special_ext):
stem, ext = filename[:-len(special_ext)], special_ext
return os.path.join(dir, stem), ext
# If no special case, behaves like the regular splitext
stem, ext = os.path.splitext(filename)
return os.path.join(dir, stem), ext


def add_suffix(fname, suffix):
"""
Add suffix between end of file name and extension. Taken (shamelessly) from:
https://github.com/spinalcordtoolbox/manual-correction/blob/main/utils.py
:param fname: absolute or relative file name. Example: t2.nii.gz
:param suffix: suffix. Example: _mean
:return: file name with suffix. Example: t2_mean.nii
Examples:
- add_suffix(t2.nii, _mean) -> t2_mean.nii
- add_suffix(t2.nii.gz, a) -> t2a.nii.gz
"""
stem, ext = splitext(fname)
return os.path.join(stem + suffix + ext)


def main():
parser = get_parser()
args = parser.parse_args()

fname_file = args.i
fname_file_out = args.o
print(f'\nFound {fname_file} file.')

# Create temporary directory in the temp to store the reoriented images
tmpdir = tmp_create()
# Copy the file to the temporary directory using shutil.copyfile
fname_file_tmp = os.path.join(tmpdir, os.path.basename(fname_file))
shutil.copyfile(fname_file, fname_file_tmp)
print(f'Copied {fname_file} to {fname_file_tmp}')

# Get the original orientation of the image, for example LPI
orig_orientation = get_orientation(fname_file_tmp)

# Reorient the image to LPI orientation if not already in LPI
if orig_orientation != 'LPI':
print(f'Original orientation: {orig_orientation}')
print(f'Reorienting to LPI orientation...')
# reorient the image to LPI using SCT
os.system('sct_image -i {} -setorient LPI -o {}'.format(fname_file_tmp, fname_file_tmp))

# NOTE: for individual images, the _0000 suffix is not needed.
# BUT, the images should be in a list of lists
fname_file_tmp_list = [[fname_file_tmp]]

# Use fold_all (all train/val subjects were used for training) or specific fold(s)
folds_avail = 'all' if args.fold == 'all' else [int(f) for f in args.fold.split(',')]
print(f'Using fold(s): {folds_avail}')

# Create directory for nnUNet prediction
tmpdir_nnunet = os.path.join(tmpdir, 'nnUNet_prediction')
fname_prediction = os.path.join(tmpdir_nnunet, os.path.basename(add_suffix(fname_file_tmp, '_pred')))
os.mkdir(tmpdir_nnunet)

# Run nnUNet prediction
print('Starting inference...it may take a few minutes...\n')
start = time.time()
# directly call the predict function
predictor = nnUNetPredictor(
tile_step_size=args.tile_step_size, # changing it from 0.5 to 0.9 makes inference faster
use_gaussian=True, # applies gaussian noise and gaussian blur
use_mirroring=False, # test time augmentation by mirroring on all axes
perform_everything_on_gpu=True if args.use_gpu else False,
device=torch.device('cuda') if args.use_gpu else torch.device('cpu'),
verbose_preprocessing=False,
allow_tqdm=True
)

print('Running inference on device: {}'.format(predictor.device))

# initializes the network architecture, loads the checkpoint
predictor.initialize_from_trained_model_folder(
join(args.path_model),
use_folds=folds_avail,
checkpoint_name='checkpoint_final.pth' if not args.use_best_checkpoint else 'checkpoint_best.pth',
)
print('Model loaded successfully. Fetching data...')

# NOTE: for individual files, the image should be in a list of lists
predictor.predict_from_files(
list_of_lists_or_source_folder=fname_file_tmp_list,
output_folder_or_list_of_truncated_output_files=tmpdir_nnunet,
save_probabilities=False,
overwrite=True,
num_processes_preprocessing=4,
num_processes_segmentation_export=4,
folder_with_segs_from_prev_stage=None,
num_parts=1,
part_id=0
)

end = time.time()

print('Inference done.')
total_time = end - start
print('Total inference time: {} minute(s) {} seconds\n'.format(int(total_time // 60), int(round(total_time % 60))))

# Copy .nii.gz file from tmpdir_nnunet to tmpdir
pred_file = glob.glob(os.path.join(tmpdir_nnunet, '*.nii.gz'))[0]
shutil.copyfile(pred_file, fname_prediction)
print(f'Copied {pred_file} to {fname_prediction}')

# Reorient the image back to original orientation
# skip if already in LPI
if orig_orientation != 'LPI':
print(f'Reorienting to original orientation {orig_orientation}...')
# reorient the image to the original orientation using SCT
os.system('sct_image -i {} -setorient {} -o {}'.format(fname_prediction, orig_orientation, fname_prediction))

# Copy level-specific (i.e., non-binary) segmentation
shutil.copyfile(fname_prediction, fname_file_out)
print(f'Copied {fname_prediction} to {fname_file_out}')

print('Deleting the temporary folder...')
# Delete the temporary folder
shutil.rmtree(tmpdir)

print('-' * 50)
print(f"Input file: {fname_file}")
print(f"Rootlet segmentation: {fname_file_out}")
print('-' * 50)


if __name__ == '__main__':
main()

0 comments on commit 97a6370

Please sign in to comment.