From 46c9607145670a71f7982b69e839c1a240f3346b Mon Sep 17 00:00:00 2001 From: aecelaya Date: Wed, 25 Sep 2024 22:47:46 +0200 Subject: [PATCH 1/5] Refactor inference to production-quality code. --- mist/inference/main_inference.py | 735 ++++++++++++++++++++----------- 1 file changed, 472 insertions(+), 263 deletions(-) diff --git a/mist/inference/main_inference.py b/mist/inference/main_inference.py index 4c4fe53..7ef886f 100755 --- a/mist/inference/main_inference.py +++ b/mist/inference/main_inference.py @@ -1,174 +1,278 @@ """Inference functions for MIST.""" +import argparse import os -import gc -import json +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import ants -import pandas as pd +import monai import numpy as np +import numpy.typing as npt +import pandas as pd +import rich +import torch -# Rich progres bar -from rich.console import Console -from rich.text import Text +# MIST imports. +from mist.models import get_model +from mist.data_loading import dali_loader +from mist.preprocess_data import preprocess +from mist.runtime import utils +from mist.postprocess_preds import postprocess + + +def get_sw_prediction( + image: torch.Tensor, + model: Callable, + patch_size: Tuple[int, int, int], + overlap: float, + blend_mode: str, + tta: bool +) -> torch.Tensor: + """Get sliding window prediction for a single image. + + This function is used to get the sliding window prediction for a single + image. You can vary the patch size, overlap, blend mode, and add test time + augmentation. The output of the function is the prediction for the image. -from monai.inferers import sliding_window_inference + Note that MIST models do not have a softmax layer, so the output of this + function requires a softmax layer to be applied to it. -import torch -from torch.nn.functional import softmax - -from mist.models.get_model import load_model_from_config - -from mist.data_loading.dali_loader import ( - get_test_dataset -) - -from mist.runtime.utils import ( - read_json_file, - convert_dict_to_df, - get_flip_axes, - create_empty_dir, - get_fg_mask_bbox, - decrop_from_fg, - get_progress_bar, - npy_fix_labels, -) - -from mist.preprocess_data.preprocess import ( - convert_nifti_to_numpy, - preprocess_example, - resample_mask -) - -from mist.postprocess_preds.postprocess import apply_transform - - -def get_sw_prediction(image, model, patch_size, overlap, blend_mode, tta): - # Get model prediction - # Predict on original image - prediction = sliding_window_inference(inputs=image, - roi_size=patch_size, - sw_batch_size=1, - predictor=model, - overlap=overlap, - mode=blend_mode, - device=torch.device("cuda")) - prediction = softmax(prediction, dim=1) - - # Test time augmentation + Args: + image: The image to predict on. + model: The MIST model to use for prediction. + patch_size: The size of the patch to use for prediction. + overlap: The overlap between patches. + blend_mode: The blending mode to use. + tta: Whether to use test time augmentation. + + Returns: + prediction: The prediction for the image. + """ + # Predict on original image using sliding window inference from MONAI. + prediction = monai.inferers.sliding_window_inference( + inputs=image, + roi_size=patch_size, + sw_batch_size=1, + predictor=model, + overlap=overlap, + mode=blend_mode, + device=torch.device("cuda") + ) + prediction = torch.nn.functional.softmax(prediction, dim=1) + + # Test time augmentation. if tta: - flip_axes = get_flip_axes() - for i in range(len(flip_axes)): - axes = flip_axes[i] + flip_axes = utils.get_flip_axes() + for axes in flip_axes: + # Flip image and predict on flipped image. flipped_img = torch.flip(image, dims=axes) - flipped_pred = sliding_window_inference(inputs=flipped_img, - roi_size=patch_size, - sw_batch_size=1, - predictor=model, - overlap=overlap, - mode=blend_mode, - device=torch.device("cuda")) - flipped_pred = softmax(flipped_pred, dim=1) + flipped_pred = monai.inferers.sliding_window_inference( + inputs=flipped_img, + roi_size=patch_size, + sw_batch_size=1, + predictor=model, + overlap=overlap, + mode=blend_mode, + device=torch.device("cuda") + ) + + # Flip prediction back and add to original prediction. + flipped_pred = torch.nn.functional.softmax(flipped_pred, dim=1) prediction += torch.flip(flipped_pred, dims=axes) + # Average predictions. prediction /= (len(flip_axes) + 1.) - return prediction -def back_to_original_space(pred, og_ants_img, config, fg_bbox): +def back_to_original_space( + pred: npt.NDArray[Any], + og_ants_img: ants.core.ants_image.ANTsImage, + config: Dict[str, Any], + fg_bbox: Dict[str, Any], +): + """Place prediction back into original image space. + + All predictions are natively in RAI orientation, possibly cropped to the + foreground, and in the target spacing. This function will place the + prediction back into the original image space by resampling, reorientating, + possibly padding back to the original size, and fixing labels if necessary. + + Args: + pred: The prediction to place back into the original image space. This + should be a numpy array. + og_ants_img: The original ANTs image. + config: The configuration dictionary. + fg_bbox: The foreground bounding box. + + Returns: + pred: The prediction in the original image space. This will be an ANTs + image. + """ + # Convert prediction to ANTs image. pred = ants.from_numpy(data=pred) # Set spacing to target spacing. pred.set_spacing(config["target_spacing"]) - # Resample prediction - # Enforce size for cropped images + # Enforce size for cropped images. if fg_bbox is not None: - new_size = [fg_bbox["x_end"] - fg_bbox["x_start"] + 1, - fg_bbox["y_end"] - fg_bbox["y_start"] + 1, - fg_bbox["z_end"] - fg_bbox["z_start"] + 1] + # If we have a foreground bounding box, use that to determine the size. + new_size = [ + fg_bbox["x_end"] - fg_bbox["x_start"] + 1, + fg_bbox["y_end"] - fg_bbox["y_start"] + 1, + fg_bbox["z_end"] - fg_bbox["z_start"] + 1, + ] else: + # Otherwise, use the original image size. new_size = og_ants_img.shape - # Bug fix for sitk resample - new_size = np.array(new_size, dtype='int').tolist() - - pred = resample_mask(pred, - labels=list(range(len(config["labels"]))), - target_spacing=og_ants_img.spacing, - new_size=new_size) + # Resample prediction to original image space. + pred = preprocess.resample_mask( + pred, + labels=list(range(len(config["labels"]))), + target_spacing=og_ants_img.spacing, + new_size=np.array(new_size, dtype="int").tolist(), + ) - # Return prediction to original image space + # Reorient prediction to original orientation. + # Get the original image orientation. og_orientation = ants.get_orientation(og_ants_img) + + # Need to update this to be more robust. We may need to create a look up + # table that applies the correct series of flips or permutations to get an + # RAI image back to the original orientation. pred = ants.reorient_image2(pred, og_orientation) pred.set_direction(og_ants_img.direction) + + # Set origin to original origin. pred.set_origin(og_ants_img.origin) - # Appropriately pad back to original size + # Appropriately pad back to original size if necessary. if fg_bbox is not None: - pred = decrop_from_fg(pred, fg_bbox) + pred = utils.decrop_from_fg(pred, fg_bbox) - # FIX: Copy header from original image onto the prediction so they match + # Copy header from original image onto the prediction so they match. pred = og_ants_img.new_image_like(pred.numpy()) - return pred -def predict_single_example(torch_img, - og_ants_img, - config, - models, - overlap, - blend_mode, - tta, - output_std, - fg_bbox): +def predict_single_example( + torch_img: torch.Tensor, + og_ants_img: ants.core.ants_image.ANTsImage, + config: Dict[str, Any], + models_list: List[Callable], + fg_bbox: Optional[Dict[str, Any]]=None, + overlap: float=0.5, + blend_mode: str="gaussian", + tta: bool=False, + output_std: bool=False, +) -> Tuple[ + ants.core.ants_image.ANTsImage, + List[ants.core.ants_image.ANTsImage] + ]: + """Predict on a single example. + + This function will predict on a single example using a list of models. The + predictions will be averaged together and placed back into the original + image space. If output_std is True, the standard deviation of the + predictions will also be computed and saved. This function uses sliding + window inference to predict on the image. The patch size is saved in the + config input. Other options include the amount of overlap between patches, + how predictions from different patches are blended together, and whether + test time augmentation is used. Test time augmentation is done by flipping + the image along different axes and averaging the predictions. + + Args: + torch_img: The image to predict on. This should be a torch tensor. + og_ants_img: The original ANTs image. + config: The configuration dictionary. + models: The list of models to use for prediction. + fg_bbox: The foreground bounding box. + overlap: The overlap between patches. + blend_mode: The blending mode to use. + tta: Whether to use test time augmentation. + output_std: Whether to output the standard deviation of the predictions. + + Returns: + pred: The prediction in the original image space. This will be an ANTs + image. + std_images: The standard deviation images for each class. This will be + a list of ANTs images. + """ + # Get the number of classes. n_classes = len(config['labels']) - pred = torch.zeros(1, - n_classes, - torch_img.shape[2], - torch_img.shape[3], - torch_img.shape[4]).to("cuda") - std_images = list() - - for model in models: - sw_prediction = get_sw_prediction(torch_img, - model, - config['patch_size'], - overlap, - blend_mode, - tta) + + # Initialize prediction and standard deviation images. The prediction will + # be of shape (1, n_classes, H, W, D). The final output will be of shape + # (H, W, D). + pred = torch.zeros( + 1, + n_classes, + torch_img.shape[2], + torch_img.shape[3], + torch_img.shape[4], + ).to("cuda") + std_images = [] + + # Get predictions from each model. + for model in models_list: + sw_prediction = get_sw_prediction( + torch_img, + model, + config['patch_size'], + overlap, + blend_mode, + tta + ) pred += sw_prediction - if output_std and len(models) > 1: + + # Save standard deviation images if necessary. Only save the standard + # deviation if the number of models is greater than 1. + if output_std and len(models_list) > 1: std_images.append(sw_prediction) - pred /= len(models) + # Average predictions. + pred /= len(models_list) + + # Move prediction to CPU. pred = pred.to("cpu") + + # Get the class with the highest probability. pred = torch.argmax(pred, dim=1) + + # Remove the batch dimension and convert to float32. pred = torch.squeeze(pred, dim=0) pred = pred.to(torch.float32) + + # Convert prediction to numpy array. pred = pred.numpy() - # Get foreground mask if necessary + # Get foreground mask if necessary. if config["crop_to_fg"] and fg_bbox is None: - fg_bbox = get_fg_mask_bbox(og_ants_img) - - # Place prediction back into original image space - pred = back_to_original_space(pred, - og_ants_img, - config, - fg_bbox) + fg_bbox = utils.get_fg_mask_bbox(og_ants_img) + + # Place prediction back into original image space. + pred = back_to_original_space( + pred, + og_ants_img, + config, + fg_bbox + ) - # Fix labels if necessary + # Fix labels if necessary. In some cases the labels used for training + # may not be the same as the original labels. For example, if we have + # labels [0, 1, 2, 4] in our dataset, we will train using labels + # [0, 1, 2, 3]. In this case, we need to fix the labels in the prediction + # to match the original labels. if list(range(n_classes)) != config["labels"]: pred = pred.numpy() - pred = npy_fix_labels(pred, config["labels"]) + pred = utils.npy_fix_labels(pred, config["labels"]) pred = og_ants_img.new_image_like(data=pred) - # Cast prediction of uint8 format to reduce storage + # Cast prediction of uint8 format to reduce storage. pred = pred.astype("uint8") - # Creates standard deviation images for UQ if called for + # Creates standard deviation images for UQ if called for. if output_std: std_images = torch.stack(std_images, dim=0) std_images = torch.std(std_images, dim=0) @@ -176,96 +280,143 @@ def predict_single_example(torch_img, std_images = torch.squeeze(std_images, dim=0) std_images = std_images.to(torch.float32) std_images = std_images.numpy() - std_images = [back_to_original_space(std_image, - og_ants_img, - config, - fg_bbox) for std_image in std_images] - + std_images = [ + back_to_original_space( + std_image, + og_ants_img, + config, + fg_bbox + ) for std_image in std_images + ] return pred, std_images -def load_test_time_models(models_dir, fast): - n_files = len(os.listdir(models_dir)) - 1 - model_list = [os.path.join(models_dir, "fold_{}.pt".format(i)) for i in range(n_files)] +def load_test_time_models( + models_dir: str, + fast: bool, +) -> List[Callable]: + """Load models for test time inference. + + This function will load the models for test time inference. The models are + loaded from the models directory. The model configuration file is also + loaded. If fast is True, only the first model is loaded. + + Args: + models_dir: The directory where the models are stored. + fast: Whether to only load the first model. + + Returns: + final_model_list: The list of models for test time inference. + """ + n_files = len(utils.listdir_with_no_hidden_files(models_dir)) - 1 + model_paths_list = [ + os.path.join(models_dir, f"fold_{i}.pt") for i in range(n_files) + ] model_config = os.path.join(models_dir, "model_config.json") if fast: - model_list = [model_list[0]] + model_paths_list = [model_paths_list[0]] + + final_model_list = [ + get_model.load_model_from_config( + model_path, model_config + ) for model_path in model_paths_list + ] + return final_model_list + - models = [load_model_from_config(model, model_config) for model in model_list] - return models +def check_test_time_input( + patients: Union[str, pd.DataFrame, Dict[str, Any]], +) -> pd.DataFrame: + """Check the input for test time inference and convert to pandas dataframe. + This function will check the input for test time inference and convert it + to a pandas dataframe. The input can be a pandas dataframe, a csv file, a + json file, or a dictionary. If the input is a dictionary or a json file, it + will be converted to a pandas dataframe. -def check_test_time_input(patients): + Args: + patients: The input for test time inference. This can be a pandas + dataframe, a csv file, a json file, or a dictionary. + + Returns: + patients: The input for test time inference as a pandas dataframe. + + Raises: + ValueError: If the input format is invalid. + """ # Convert input to pandas dataframe if isinstance(patients, pd.DataFrame): return patients - elif '.csv' in patients: + if '.csv' in patients: return pd.read_csv(patients) - elif type(patients) is dict: - return convert_dict_to_df(patients) - elif '.json' in patients: - with open(patients, 'r') as file: - patients = json.load(file) - return convert_dict_to_df(patients) - else: - raise ValueError("Invalid input format for test time") - + if isinstance(patients, dict): + return utils.convert_dict_to_df(patients) + if '.json' in patients: + patients = utils.read_json_file(patients) + return utils.convert_dict_to_df(patients) + raise ValueError(f"Received invalid input format: {type(patients)}") + def test_on_fold( - args, - fold_number -): + args: argparse.Namespace, + fold_number: int, +) -> None: """Run inference on the test set for a fold. + + This function will run inference on the test set for a fold. The predictions + will be saved to the results directory. The predictions will be saved as + nifti files. Args: args: Arguments from MIST arguments. fold_number: The fold number to run inference on. + Returns: - Saves predictions to ./results/predictions/train/raw/ directory. + None. Saves predictions to ./results/predictions/train/raw/ directory. + + Raises: + FileNotFoundError: If the original image is not found. + RuntimeError or ValueError if the prediction fails. """ - # Read config file - config = read_json_file( - os.path.join(args.results, 'config.json') - ) - - # Get dataframe with paths for test images - train_paths_df = pd.read_csv( - os.path.join(args.results, 'train_paths.csv') - ) + # Read config file. + config = utils.read_json_file(os.path.join(args.results, 'config.json')) + + # Get dataframe with paths for test images. + train_paths_df = pd.read_csv(os.path.join(args.results, 'train_paths.csv')) testing_paths_df = train_paths_df.loc[train_paths_df["fold"] == fold_number] - - # Get list of numpy files of preprocessed test images + + # Get list of numpy files of preprocessed test images. test_ids = list(testing_paths_df["id"]) test_images = [ - os.path.join(args.numpy, 'images', f'{patient_id}.npy') for patient_id in test_ids + os.path.join( + args.numpy, 'images', f'{patient_id}.npy' + ) for patient_id in test_ids ] - - # Get bounding box data - fg_bboxes = pd.read_csv( - os.path.join(args.results, 'fg_bboxes.csv') - ) - - # Get DALI loader for streaming preprocessed numpy files - test_dali_loader = get_test_dataset( - test_images, + + # Get bounding box data. + fg_bboxes = pd.read_csv(os.path.join(args.results, 'fg_bboxes.csv')) + + # Get DALI loader for streaming preprocessed numpy files. + test_dali_loader = dali_loader.get_test_dataset( + imgs=test_images, seed=args.seed_val, num_workers=args.num_workers, rank=0, world_size=1 ) - - # Load model - model = load_model_from_config( + + # Load model. + model = get_model.load_model_from_config( os.path.join(args.results, 'models', f'fold_{fold_number}.pt'), os.path.join(args.results, 'models', 'model_config.json') ) model.eval() model.to("cuda") - # Progress bar and error messages - progress_bar = get_progress_bar(f'Testing on fold {fold_number}') - console = Console() + # Progress bar and error messages. + progress_bar = utils.get_progress_bar(f'Testing on fold {fold_number}') + console = rich.console.Console() error_messages = '' # Define output directory @@ -276,163 +427,221 @@ def test_on_fold( 'raw' ) - # Run prediction on all test images + # Run prediction on all test images and save predictions to disk. with torch.no_grad(), progress_bar as pb: for image_index in pb.track(range(len(testing_paths_df))): patient = testing_paths_df.iloc[image_index].to_dict() try: - # Get original patient data + # Get original patient data. image_list = list(patient.values())[3:len(patient)] original_ants_image = ants.image_read(image_list[0]) - # Get preprocessed image from DALI loader + # Get preprocessed image from DALI loader. data = test_dali_loader.next()[0] preprocessed_numpy_image = data['image'] - # Get foreground mask if necessary + # Get foreground mask if necessary. if config["crop_to_fg"]: - fg_bbox = fg_bboxes.loc[fg_bboxes['id'] == patient['id']].iloc[0].to_dict() + fg_bbox = fg_bboxes.loc[ + fg_bboxes['id'] == patient['id'] + ].iloc[0].to_dict() else: fg_bbox = None # Predict with model and put back into original image space prediction, _ = predict_single_example( - preprocessed_numpy_image, - original_ants_image, - config, - [model], - args.sw_overlap, - args.blend_mode, - args.tta, - output_std=False, - fg_bbox=fg_bbox + torch_img=preprocessed_numpy_image, + og_ants_img=original_ants_image, + config=config, + models_list=[model], + fg_bbox=fg_bbox, + overlap=args.sw_overlap, + blend_mode=args.blend_mode, + tta=args.tta, + ) + except (FileNotFoundError, RuntimeError, ValueError) as e: + error_messages += ( + f"[Error] {str(e)}: Prediction failed for {patient['id']}\n" ) - except: - error_messages += f"[Inference Error] Prediction failed for {patient['id']}\n" else: - # Write prediction as .nii.gz file - ants.image_write( - prediction, - os.path.join( - output_directory, - f"{patient['id']}.nii.gz" - ) + # Write prediction as .nii.gz file. + filename = os.path.join( + output_directory, f"{patient['id']}.nii.gz" ) - + ants.image_write(prediction, filename) + if len(error_messages) > 0: - text = Text(error_messages) + text = rich.text.Text(error_messages) console.print(text) - # Clean up - gc.collect() +def test_time_inference( + df: pd.DataFrame, + dest: str, + config_file: str, + models: List[Callable], + overlap: float, + blend_mode: str, + tta: bool, + no_preprocess: bool=False, + output_std: bool=False +) -> None: + """Run test time inference on a dataframe of images. + + This function will run test time inference on a dataframe of images. The + predictions will be saved to the destination directory. The input is a + dataframe of images. The configuration file is used to preprocess the images + before prediction. The models are used to predict on the images. The other + parameters control how the predictions are made. For example, the overlap + between patches, how predictions from different patches are blended, + whether test time augmentation is used, and whether the standard deviation + of the predictions is output. + + Args: + df: The dataframe of images to predict on. + dest: The destination directory to save the predictions. + config_file: The configuration file to use for preprocessing. + models: The list of models to use for prediction. + overlap: The overlap between patches. + blend_mode: The blending mode to use. + tta: Whether to use test time augmentation. + no_preprocess: Whether to skip preprocessing. + output_std: Whether to output the standard deviation of the predictions. + + Returns: + None. Saves predictions to the destination directory + + Raises: + FileNotFoundError: If an inference image cannot be found. + RuntimeError or ValueError if the prediction fails. + """ + # Read configuration file. + config = utils.read_json_file(config_file) -def test_time_inference(df, - dest, - config_file, - models, - overlap, - blend_mode, - tta, - no_preprocess=False, - output_std=False): - config = read_json_file(config_file) + # Create destination directory if it does not exist. + utils.create_empty_dir(dest) - create_empty_dir(dest) + # Set up rich progress bar. + testing_progress = utils.get_progress_bar("Running inference") - # Set up rich progress bar - testing_progress = get_progress_bar("Testing") - console = Console() + # Set up rich console for error messages. + console = rich.console.Console() error_messages = "" + # Get start column index for image paths depending on whether the dataframe + # has certain columns. + if "mask" in df.columns and "fold" in df.columns: + start_column_index = 3 + elif "mask" in df.columns or "fold" in df.columns: + start_column_index = 2 + else: + start_column_index = 1 + # Run prediction on all samples and compute metrics with testing_progress as pb: for ii in pb.track(range(len(df))): patient = df.iloc[ii].to_dict() try: - # Create individual folders for each prediction if output_std is enabled + # Create individual folders for each prediction if output_std is + # enabled. if output_std: - output_std_dest = os.path.join(dest, str(patient['id'])) - create_empty_dir(output_std_dest) + output_std_dest = os.path.join(dest, str(patient["id"])) + utils.create_empty_dir(output_std_dest) else: output_std_dest = dest - if "mask" in df.columns and "fold" in df.columns: - image_list = list(patient.values())[3:] - elif "mask" in df.columns or "fold" in df.columns: - image_list = list(patient.values())[2:] - else: - image_list = list(patient.values())[1:] - + # Get image list from patient dictionary. + image_list = list(patient.values())[start_column_index:] og_ants_img = ants.image_read(image_list[0]) if no_preprocess: - preprocessed_example = convert_nifti_to_numpy(image_list) + preprocessed_example = preprocess.convert_nifti_to_numpy( + image_list + ) else: - preprocessed_example = preprocess_example( + preprocessed_example = preprocess.preprocess_example( config, image_list, ) - # Make image channels first and add batch dimension + # Make image channels first and add batch dimension. torch_img = preprocessed_example["image"] torch_img = np.transpose(torch_img, axes=(3, 0, 1, 2)) torch_img = np.expand_dims(torch_img, axis=0) + # Convert to torch tensor and move to GPU. torch_img = torch.Tensor(torch_img.copy()).to(torch.float32) torch_img = torch_img.to("cuda") + # Run prediction on single example. prediction, std_images = predict_single_example( - torch_img, - og_ants_img, - config, - models, - overlap, - blend_mode, - tta, - output_std, - preprocessed_example["fg_bbox"] + torch_img=torch_img, + og_ants_img=og_ants_img, + config=config, + models_list=models, + fg_bbox=preprocessed_example["fg_bbox"], + overlap=overlap, + blend_mode=blend_mode, + tta=tta, + output_std=output_std, ) - # Apply postprocessing if required + # Apply postprocessing if necessary. transforms = ["remove_small_objects", "top_k_cc", "fill_holes"] for transform in transforms: if len(config[transform]) > 0: for i in range(len(config[transform])): if transform == "remove_small_objects": - transform_kwargs = {"small_object_threshold": config[transform][i][1]} + transform_kwargs = { + "small_object_threshold": ( + config[transform][i][1] + ) + } if transform == "top_k_cc": - transform_kwargs = {"morph_cleanup": config[transform][i][1], - "morph_cleanup_iterations": config[transform][i][2], - "top_k": config[transform][i][3]} + transform_kwargs = { + "morph_cleanup": config[transform][i][1], + "morph_cleanup_iterations": ( + config[transform][i][2] + ), + "top_k": config[transform][i][3] + } if transform == "fill_holes": - transform_kwargs = {"fill_label": config[transform][i][1]} - - prediction = apply_transform(prediction, - transform, - config["labels"], - config[transform][i][0], - transform_kwargs) - - # Write prediction mask to nifti file and save to disk - prediction_filename = '{}.nii.gz'.format(str(patient['id'])) - output = os.path.join(output_std_dest, prediction_filename) - ants.image_write(prediction, output) + transform_kwargs = { + "fill_label": config[transform][i][1] + } + + prediction = postprocess.apply_transform( + prediction, + transform, + config["labels"], + config[transform][i][0], + transform_kwargs + ) + except (FileNotFoundError, RuntimeError, ValueError) as e: + error_messages += ( + f"[Error] {str(e)}: Prediction failed for {patient['id']}\n" + ) + else: + # Write prediction mask to nifti file and save to disk. + ants.image_write( + prediction, + os.path.join(output_std_dest, f"{patient['id']}.nii.gz") + ) - # Write standard deviation image(s) to nifti file and save to disk (only for foreground labels) + # Write standard deviation image(s) to nifti file and save to + # disk (only for foreground labels). if output_std: - for i in range(len(std_images)): + for i, std_image in enumerate(std_images): if config["labels"][i] > 0: - std_image_filename = '{}_std_{}.nii.gz'.format(patient['id'], config['labels'][i]) - output = os.path.join(output_std_dest, std_image_filename) - ants.image_write(std_images[i], output) - - except: - error_messages += f"[Inference Error] Prediction failed for {patient['id']}\n" + std_image_filename = ( + f"{patient['id']}_std_{config['labels'][i]}" + ".nii.gz" + ) + output = os.path.join( + output_std_dest, std_image_filename + ) + ants.image_write(std_image, output) - if len(error_messages) > 0: - text = Text(error_messages) - console.print(text) - - # Clean up - gc.collect() + if len(error_messages) > 0: + text = rich.text.Text(error_messages) + console.print(text) From d787d2d5b339df09a099cb117f96628e265becd2 Mon Sep 17 00:00:00 2001 From: aecelaya Date: Thu, 26 Sep 2024 14:49:34 +0200 Subject: [PATCH 2/5] Readability updates for main_inference.py. --- mist/inference/main_inference.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mist/inference/main_inference.py b/mist/inference/main_inference.py index 7ef886f..c30512e 100755 --- a/mist/inference/main_inference.py +++ b/mist/inference/main_inference.py @@ -21,7 +21,7 @@ def get_sw_prediction( image: torch.Tensor, - model: Callable, + model: Callable[[torch.Tensor], torch.Tensor], patch_size: Tuple[int, int, int], overlap: float, blend_mode: str, @@ -89,7 +89,7 @@ def back_to_original_space( og_ants_img: ants.core.ants_image.ANTsImage, config: Dict[str, Any], fg_bbox: Dict[str, Any], -): +) -> ants.core.ants_image.ANTsImage: """Place prediction back into original image space. All predictions are natively in RAI orientation, possibly cropped to the @@ -160,7 +160,7 @@ def predict_single_example( torch_img: torch.Tensor, og_ants_img: ants.core.ants_image.ANTsImage, config: Dict[str, Any], - models_list: List[Callable], + models_list: List[Callable[[torch.Tensor], torch.Tensor]], fg_bbox: Optional[Dict[str, Any]]=None, overlap: float=0.5, blend_mode: str="gaussian", @@ -294,7 +294,7 @@ def predict_single_example( def load_test_time_models( models_dir: str, fast: bool, -) -> List[Callable]: +) -> List[Callable[[torch.Tensor], torch.Tensor]]: """Load models for test time inference. This function will load the models for test time inference. The models are @@ -479,7 +479,7 @@ def test_time_inference( df: pd.DataFrame, dest: str, config_file: str, - models: List[Callable], + models: List[Callable[[torch.Tensor], torch.Tensor]], overlap: float, blend_mode: str, tta: bool, From 08caa28f817267197116c42b4c3e1644a9fea87c Mon Sep 17 00:00:00 2001 From: aecelaya Date: Thu, 26 Sep 2024 14:53:59 +0200 Subject: [PATCH 3/5] Update docstring in get_sw_prediction in main_inference.py. --- mist/inference/main_inference.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mist/inference/main_inference.py b/mist/inference/main_inference.py index c30512e..8dd7eb1 100755 --- a/mist/inference/main_inference.py +++ b/mist/inference/main_inference.py @@ -33,8 +33,8 @@ def get_sw_prediction( image. You can vary the patch size, overlap, blend mode, and add test time augmentation. The output of the function is the prediction for the image. - Note that MIST models do not have a softmax layer, so the output of this - function requires a softmax layer to be applied to it. + Note that MIST models do not have a softmax layer, so we apply softmax to + the output of the model in this function. Args: image: The image to predict on. @@ -57,6 +57,8 @@ def get_sw_prediction( mode=blend_mode, device=torch.device("cuda") ) + + # Apply softmax to prediction. prediction = torch.nn.functional.softmax(prediction, dim=1) # Test time augmentation. From 13f223b3cbfaa87d7cb3ac81f9126c1177e893b7 Mon Sep 17 00:00:00 2001 From: aecelaya Date: Thu, 26 Sep 2024 15:29:42 +0200 Subject: [PATCH 4/5] Clear extra white space in main_inference.py. --- mist/inference/main_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mist/inference/main_inference.py b/mist/inference/main_inference.py index 8dd7eb1..9888f35 100755 --- a/mist/inference/main_inference.py +++ b/mist/inference/main_inference.py @@ -415,7 +415,7 @@ def test_on_fold( ) model.eval() model.to("cuda") - + # Progress bar and error messages. progress_bar = utils.get_progress_bar(f'Testing on fold {fold_number}') console = rich.console.Console() From 830af4ef9fed361b7a314202c42af518ca00889d Mon Sep 17 00:00:00 2001 From: aecelaya Date: Thu, 26 Sep 2024 15:32:35 +0200 Subject: [PATCH 5/5] Readability updates for preprocessing. --- mist/preprocess_data/preprocess.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mist/preprocess_data/preprocess.py b/mist/preprocess_data/preprocess.py index 021bd28..07ee196 100755 --- a/mist/preprocess_data/preprocess.py +++ b/mist/preprocess_data/preprocess.py @@ -394,7 +394,9 @@ def preprocess_example( # Put mask into standard space mask = ants.reorient_image2(mask, "RAI") - mask.set_direction(np.eye(3)) + mask.set_direction( + preprocessing_constants.PreprocessingConstants.RAI_ANTS_DIRECTION + ) mask = resample_mask( mask, labels=config["labels"],