-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add instructions and script on how to use the ventral rootlets model (#…
…47)
- Loading branch information
Showing
3 changed files
with
297 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
## Getting started | ||
|
||
⚠️ This README provides instructions on how to use the model for **_ventral_** and dorsal rootlets. | ||
Please note that this model is still under development and is not yet available in the Spinal Cord Toolbox (SCT). | ||
|
||
⚠️ For the stable model for dorsal rootlets only, use SCT v6.2 or higher (please refer to this [README](..%2FREADME.md)). | ||
|
||
### Dependencies | ||
|
||
- [Spinal Cord Toolbox (SCT)](https://spinalcordtoolbox.com/user_section/installation.html) | ||
- [conda](https://conda.io/projects/conda/en/latest/user-guide/install/index.html) | ||
- Python | ||
|
||
### Step 1: Cloning the Repository | ||
|
||
Open a terminal and clone the repository using the following command: | ||
|
||
``` | ||
git clone https://github.com/ivadomed/model-spinal-rootlets | ||
``` | ||
|
||
### Step 2: Setting up the Environment | ||
|
||
The following commands show how to set up the environment. | ||
Note that the documentation assumes that the user has `conda` installed on their system. | ||
Instructions on installing `conda` can be found [here](https://conda.io/projects/conda/en/latest/user-guide/install/index.html). | ||
|
||
1. Create a conda environment with the following command: | ||
``` | ||
conda create -n venv_nnunet python=3.9 | ||
``` | ||
|
||
2. Activate the environment with the following command: | ||
``` | ||
conda activate venv_nnunet | ||
``` | ||
|
||
3. Install the required packages with the following command: | ||
``` | ||
cd model-spinal-rootlets | ||
pip install -r packaging_ventral_rootlets/requirements.txt | ||
``` | ||
|
||
### Step 3: Getting the Predictions | ||
|
||
ℹ️ To temporarily suppress warnings raised by the nnUNet, you can run the following three commands in the same terminal session as the above command: | ||
|
||
```bash | ||
export nnUNet_raw="${HOME}/nnUNet_raw" | ||
export nnUNet_preprocessed="${HOME}/nnUNet_preprocessed" | ||
export nnUNet_results="${HOME}/nnUNet_results" | ||
``` | ||
|
||
To segment a single image using the trained model, run the following command from the terminal. | ||
|
||
This assumes that the latest model has been downloaded (https://github.com/ivadomed/model-spinal-rootlets/releases/download/r20240523/model-spinal-rootlets_ventral_D106_r20240523.zip) | ||
and unzipped (`unzip model-spinal-rootlets_ventral_D106_r20240523.zip`). | ||
|
||
```bash | ||
python packaging_ventral_rootlets/run_inference_single_subject.py -i <INPUT> -o <OUTPUT> -path-model <PATH_TO_MODEL_FOLDER> | ||
``` | ||
|
||
For example: | ||
|
||
```bash | ||
python packaging_ventral_rootlets/run_inference_single_subject.py -i sub-001_T2w.nii.gz -o sub-001_T2w_label-rootlets_dseg.nii.gz -path-model ~/Downloads/model-spinal-rootlets_ventral_D106_r20240523 -fold all | ||
``` | ||
|
||
ℹ️ The script also supports getting segmentations on a GPU. To do so, simply add the flag `--use-gpu` at the end of the above commands. By default, the inference is run on the CPU. It is useful to note that obtaining the predictions from the GPU is significantly faster than the CPU. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
numpy | ||
nibabel | ||
nnunetv2==2.2.1 | ||
torch==2.0.1 |
224 changes: 224 additions & 0 deletions
224
packaging_ventral_rootlets/run_inference_single_subject.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,224 @@ | ||
""" | ||
This script is used to run inference on a single subject using a nnUNetV2 model. | ||
Note: conda environment with nnUNetV2 is required to run this script. | ||
For details how to install nnUNetV2, see: | ||
https://github.com/ivadomed/utilities/blob/main/quick_start_guides/nnU-Net_quick_start_guide.md#installation | ||
Author: Jan Valosek | ||
Example: | ||
python run_inference_single_subject.py | ||
-i sub-001_T2w.nii.gz | ||
-o sub-001_T2w_label-rootlet.nii.gz | ||
-path-model <PATH_TO_MODEL_FOLDER> | ||
-tile-step-size 0.5 | ||
-fold 1 | ||
""" | ||
|
||
|
||
import os | ||
import shutil | ||
import subprocess | ||
import argparse | ||
import datetime | ||
|
||
import torch | ||
import glob | ||
import time | ||
import tempfile | ||
|
||
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor | ||
from batchgenerators.utilities.file_and_folder_operations import join | ||
|
||
|
||
def get_parser(): | ||
# parse command line arguments | ||
parser = argparse.ArgumentParser(description='Segment an image using nnUNet model.') | ||
parser.add_argument('-i', help='Input image to segment. Example: sub-001_T2w.nii.gz', required=True) | ||
parser.add_argument('-o', help='Output filename. Example: sub-001_T2w_label-rootlet.nii.gz', required=True) | ||
parser.add_argument('-path-model', help='Path to the model folder. This folder should contain individual ' | ||
'folders like fold_0, fold_1, etc. and dataset.json, ' | ||
'dataset_fingerprint.json and plans.json files.', required=True, type=str) | ||
parser.add_argument('-use-gpu', action='store_true', default=False, | ||
help='Use GPU for inference. Default: False') | ||
parser.add_argument('-fold', type=str, required=True, | ||
help='Fold(s) to use for inference. Example(s): 2 (single fold), 2,3 (multiple folds), ' | ||
'all (fold_all).', choices=['0', '1', '2', '3', '4', 'all']) | ||
parser.add_argument('-use-best-checkpoint', action='store_true', default=False, | ||
help='Use the best checkpoint (instead of the final checkpoint) for prediction. ' | ||
'NOTE: nnUNet by default uses the final checkpoint. Default: False') | ||
parser.add_argument('-tile-step-size', default=0.5, type=float, | ||
help='Tile step size defining the overlap between images patches during inference. ' | ||
'Default: 0.5 ' | ||
'NOTE: changing it from 0.5 to 0.9 makes inference faster but there is a small drop in ' | ||
'performance.') | ||
|
||
return parser | ||
|
||
|
||
def get_orientation(file): | ||
""" | ||
Get the original orientation of an image | ||
:param file: path to the image | ||
:return: orig_orientation: original orientation of the image, e.g. LPI | ||
""" | ||
|
||
# Fetch the original orientation from the output of sct_image | ||
sct_command = "sct_image -i {} -header | grep -E qform_[xyz] | awk '{{printf \"%s\", substr($2, 1, 1)}}'".format( | ||
file) | ||
orig_orientation = subprocess.check_output(sct_command, shell=True).decode('utf-8') | ||
return orig_orientation | ||
|
||
|
||
def tmp_create(): | ||
""" | ||
Create temporary folder and return its path | ||
""" | ||
prefix = f"sciseg_prediction_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_" | ||
tmpdir = tempfile.mkdtemp(prefix=prefix) | ||
print(f"Creating temporary folder ({tmpdir})") | ||
return tmpdir | ||
|
||
|
||
def splitext(fname): | ||
""" | ||
Split a fname (folder/file + ext) into a folder/file and extension. | ||
Note: for .nii.gz the extension is understandably .nii.gz, not .gz | ||
(``os.path.splitext()`` would want to do the latter, hence the special case). | ||
Taken (shamelessly) from: https://github.com/spinalcordtoolbox/manual-correction/blob/main/utils.py | ||
""" | ||
dir, filename = os.path.split(fname) | ||
for special_ext in ['.nii.gz', '.tar.gz']: | ||
if filename.endswith(special_ext): | ||
stem, ext = filename[:-len(special_ext)], special_ext | ||
return os.path.join(dir, stem), ext | ||
# If no special case, behaves like the regular splitext | ||
stem, ext = os.path.splitext(filename) | ||
return os.path.join(dir, stem), ext | ||
|
||
|
||
def add_suffix(fname, suffix): | ||
""" | ||
Add suffix between end of file name and extension. Taken (shamelessly) from: | ||
https://github.com/spinalcordtoolbox/manual-correction/blob/main/utils.py | ||
:param fname: absolute or relative file name. Example: t2.nii.gz | ||
:param suffix: suffix. Example: _mean | ||
:return: file name with suffix. Example: t2_mean.nii | ||
Examples: | ||
- add_suffix(t2.nii, _mean) -> t2_mean.nii | ||
- add_suffix(t2.nii.gz, a) -> t2a.nii.gz | ||
""" | ||
stem, ext = splitext(fname) | ||
return os.path.join(stem + suffix + ext) | ||
|
||
|
||
def main(): | ||
parser = get_parser() | ||
args = parser.parse_args() | ||
|
||
fname_file = args.i | ||
fname_file_out = args.o | ||
print(f'\nFound {fname_file} file.') | ||
|
||
# Create temporary directory in the temp to store the reoriented images | ||
tmpdir = tmp_create() | ||
# Copy the file to the temporary directory using shutil.copyfile | ||
fname_file_tmp = os.path.join(tmpdir, os.path.basename(fname_file)) | ||
shutil.copyfile(fname_file, fname_file_tmp) | ||
print(f'Copied {fname_file} to {fname_file_tmp}') | ||
|
||
# Get the original orientation of the image, for example LPI | ||
orig_orientation = get_orientation(fname_file_tmp) | ||
|
||
# Reorient the image to LPI orientation if not already in LPI | ||
if orig_orientation != 'LPI': | ||
print(f'Original orientation: {orig_orientation}') | ||
print(f'Reorienting to LPI orientation...') | ||
# reorient the image to LPI using SCT | ||
os.system('sct_image -i {} -setorient LPI -o {}'.format(fname_file_tmp, fname_file_tmp)) | ||
|
||
# NOTE: for individual images, the _0000 suffix is not needed. | ||
# BUT, the images should be in a list of lists | ||
fname_file_tmp_list = [[fname_file_tmp]] | ||
|
||
# Use fold_all (all train/val subjects were used for training) or specific fold(s) | ||
folds_avail = 'all' if args.fold == 'all' else [int(f) for f in args.fold.split(',')] | ||
print(f'Using fold(s): {folds_avail}') | ||
|
||
# Create directory for nnUNet prediction | ||
tmpdir_nnunet = os.path.join(tmpdir, 'nnUNet_prediction') | ||
fname_prediction = os.path.join(tmpdir_nnunet, os.path.basename(add_suffix(fname_file_tmp, '_pred'))) | ||
os.mkdir(tmpdir_nnunet) | ||
|
||
# Run nnUNet prediction | ||
print('Starting inference...it may take a few minutes...\n') | ||
start = time.time() | ||
# directly call the predict function | ||
predictor = nnUNetPredictor( | ||
tile_step_size=args.tile_step_size, # changing it from 0.5 to 0.9 makes inference faster | ||
use_gaussian=True, # applies gaussian noise and gaussian blur | ||
use_mirroring=False, # test time augmentation by mirroring on all axes | ||
perform_everything_on_gpu=True if args.use_gpu else False, | ||
device=torch.device('cuda') if args.use_gpu else torch.device('cpu'), | ||
verbose_preprocessing=False, | ||
allow_tqdm=True | ||
) | ||
|
||
print('Running inference on device: {}'.format(predictor.device)) | ||
|
||
# initializes the network architecture, loads the checkpoint | ||
predictor.initialize_from_trained_model_folder( | ||
join(args.path_model), | ||
use_folds=folds_avail, | ||
checkpoint_name='checkpoint_final.pth' if not args.use_best_checkpoint else 'checkpoint_best.pth', | ||
) | ||
print('Model loaded successfully. Fetching data...') | ||
|
||
# NOTE: for individual files, the image should be in a list of lists | ||
predictor.predict_from_files( | ||
list_of_lists_or_source_folder=fname_file_tmp_list, | ||
output_folder_or_list_of_truncated_output_files=tmpdir_nnunet, | ||
save_probabilities=False, | ||
overwrite=True, | ||
num_processes_preprocessing=4, | ||
num_processes_segmentation_export=4, | ||
folder_with_segs_from_prev_stage=None, | ||
num_parts=1, | ||
part_id=0 | ||
) | ||
|
||
end = time.time() | ||
|
||
print('Inference done.') | ||
total_time = end - start | ||
print('Total inference time: {} minute(s) {} seconds\n'.format(int(total_time // 60), int(round(total_time % 60)))) | ||
|
||
# Copy .nii.gz file from tmpdir_nnunet to tmpdir | ||
pred_file = glob.glob(os.path.join(tmpdir_nnunet, '*.nii.gz'))[0] | ||
shutil.copyfile(pred_file, fname_prediction) | ||
print(f'Copied {pred_file} to {fname_prediction}') | ||
|
||
# Reorient the image back to original orientation | ||
# skip if already in LPI | ||
if orig_orientation != 'LPI': | ||
print(f'Reorienting to original orientation {orig_orientation}...') | ||
# reorient the image to the original orientation using SCT | ||
os.system('sct_image -i {} -setorient {} -o {}'.format(fname_prediction, orig_orientation, fname_prediction)) | ||
|
||
# Copy level-specific (i.e., non-binary) segmentation | ||
shutil.copyfile(fname_prediction, fname_file_out) | ||
print(f'Copied {fname_prediction} to {fname_file_out}') | ||
|
||
print('Deleting the temporary folder...') | ||
# Delete the temporary folder | ||
shutil.rmtree(tmpdir) | ||
|
||
print('-' * 50) | ||
print(f"Input file: {fname_file}") | ||
print(f"Rootlet segmentation: {fname_file_out}") | ||
print('-' * 50) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |