diff --git a/mist/analyze_data/analyze.py b/mist/analyze_data/analyze.py index 1c96f1a..f3afaf3 100755 --- a/mist/analyze_data/analyze.py +++ b/mist/analyze_data/analyze.py @@ -112,11 +112,20 @@ def compute_class_weights(self): """Compute class weights on original data.""" # Either compute class weights or use user provided weights. + # Check that number of class weights matches the number of labels if + # provided. + n_labels = len(self.dataset_information["labels"]) + if ( + self.mist_arguments.class_weights and + len(self.mist_arguments.class_weights) != n_labels + ): + raise ValueError( + "Number of class weights must match number of labels." + ) + if self.mist_arguments.class_weights is None: # Initialize class weights if not provided. - class_weights = [ - 0. for i in range(len(self.dataset_information["labels"])) - ] + class_weights = [0. for i in range(n_labels)] progress = utils.get_progress_bar("Computing class weights") with progress as pb: @@ -176,13 +185,19 @@ def get_target_spacing(self): for i in pb.track(range(len(self.paths_dataframe))): patient = self.paths_dataframe.iloc[i].to_dict() - # Read mask image. This is faster to load. - spacing = ants.image_header_info(patient["mask"])["spacing"] + # Reorient masks to RAI to collect target spacing. We do this + # to make sure that all of the axes in the spacings match up. + # We load the masks because they are smaller and faster to load. + mask = ants.image_read(patient["mask"]) + mask = ants.reorient_image2(mask, "RAI") + mask.set_direction( + analyzer_constants.AnalyzeConstants.RAI_ANTS_DIRECTION + ) # Get voxel spacing. - original_spacings[i, :] = spacing + original_spacings[i, :] = mask.spacing - # Initialize target spacing + # Initialize target spacing. target_spacing = list(np.median(original_spacings, axis=0)) # If anisotropic, adjust the coarsest resolution to bring ratio down. diff --git a/mist/analyze_data/analyzer_constants.py b/mist/analyze_data/analyzer_constants.py index 993b936..c49be37 100644 --- a/mist/analyze_data/analyzer_constants.py +++ b/mist/analyze_data/analyzer_constants.py @@ -1,11 +1,15 @@ """Constants for the Analyzer class.""" import dataclasses +import numpy as np @dataclasses.dataclass(frozen=True) class AnalyzeConstants: """Dataclass for constants used in the analyze_data module.""" + # RAI orientation direction for ANTs. + RAI_ANTS_DIRECTION = np.eye(3) + # Maximum memory in bytes for each image and mask pair. MAX_MEMORY_PER_IMAGE_MASK_PAIR_BYTES = 2e9 diff --git a/mist/inference/main_inference.py b/mist/inference/main_inference.py index fba5e9b..408db09 100755 --- a/mist/inference/main_inference.py +++ b/mist/inference/main_inference.py @@ -87,10 +87,10 @@ def get_sw_prediction( def back_to_original_space( - pred: npt.NDArray[Any], - og_ants_img: ants.core.ants_image.ANTsImage, - config: Dict[str, Any], - fg_bbox: Optional[Dict[str, Any]], + prediction_npy: npt.NDArray[Any], + original_image_ants: ants.core.ants_image.ANTsImage, + mist_configuration: Dict[str, Any], + fg_bounding_box: Optional[Dict[str, Any]], ) -> ants.core.ants_image.ANTsImage: """Place prediction back into original image space. @@ -101,52 +101,60 @@ def back_to_original_space( header to the prediction's header. Args: - pred: The prediction to place back into the original image space. This - should be a numpy array. - og_ants_img: The original ANTs image. - config: The configuration dictionary. - fg_bbox: The foreground bounding box. + prediction_npy: The prediction to place back into the original image + space. This should be a numpy array. + original_image_ants: The original ANTs image. + mist_configuration: The configuration dictionary. + fg_bounding_box: The foreground bounding box. Returns: pred: The prediction in the original image space. This will be an ANTs image. """ # Convert prediction to ANTs image. - pred = ants.from_numpy(data=pred, spacing=config["target_spacing"]) + prediction_ants = ants.from_numpy( + data=prediction_npy, + spacing=mist_configuration["target_spacing"] + ) # Reorient prediction. - pred = ants.reorient_image2(pred, ants.get_orientation(og_ants_img)) - pred.set_direction(og_ants_img.direction) + prediction_ants = ants.reorient_image2( + prediction_ants, + ants.get_orientation(original_image_ants) + ) + prediction_ants.set_direction(original_image_ants.direction) # Enforce size for cropped images. - if fg_bbox is not None: + if fg_bounding_box is not None: # If we have a foreground bounding box, use that to determine the size. new_size = [ - fg_bbox["x_end"] - fg_bbox["x_start"] + 1, - fg_bbox["y_end"] - fg_bbox["y_start"] + 1, - fg_bbox["z_end"] - fg_bbox["z_start"] + 1, + fg_bounding_box["x_end"] - fg_bounding_box["x_start"] + 1, + fg_bounding_box["y_end"] - fg_bounding_box["y_start"] + 1, + fg_bounding_box["z_end"] - fg_bounding_box["z_start"] + 1, ] else: # Otherwise, use the original image size. - new_size = og_ants_img.shape + new_size = original_image_ants.shape # Resample prediction to original image space. - pred = preprocess.resample_mask( - pred, - labels=list(range(len(config["labels"]))), - target_spacing=og_ants_img.spacing, + prediction_ants = preprocess.resample_mask( + prediction_ants, + labels=list(range(len(mist_configuration["labels"]))), + target_spacing=original_image_ants.spacing, new_size=np.array(new_size, dtype="int").tolist(), ) # Appropriately pad back to original size if necessary. - if fg_bbox is not None: - pred = utils.decrop_from_fg(pred, fg_bbox) + if fg_bounding_box is not None: + prediction_ants = utils.decrop_from_fg(prediction_ants, fg_bounding_box) # Copy header from original image onto the prediction so they match. This # will take care of other details in the header like the origin and the # image bounding box. - pred = og_ants_img.new_image_like(pred.numpy()) - return pred + prediction_ants = original_image_ants.new_image_like( + prediction_ants.numpy() + ) + return prediction_ants def predict_single_example( @@ -341,11 +349,11 @@ def check_test_time_input( # Convert input to pandas dataframe if isinstance(patients, pd.DataFrame): return patients - if '.csv' in patients: + if '.csv' in patients and isinstance(patients, str): return pd.read_csv(patients) if isinstance(patients, dict): return utils.convert_dict_to_df(patients) - if '.json' in patients: + if '.json' in patients and isinstance(patients, str): patients = utils.read_json_file(patients) return utils.convert_dict_to_df(patients) raise ValueError(f"Received invalid input format: {type(patients)}") diff --git a/mist/preprocess_data/preprocess.py b/mist/preprocess_data/preprocess.py index 07ee196..6e058a1 100755 --- a/mist/preprocess_data/preprocess.py +++ b/mist/preprocess_data/preprocess.py @@ -1,7 +1,7 @@ """Preprocessing functions for medical images and masks.""" import os import argparse -from typing import Dict, List, Tuple, Any, Optional +from typing import Dict, List, Tuple, Any, Optional, Union import ants import numpy as np @@ -321,7 +321,7 @@ def preprocess_example( fg_bbox: Optional[Dict[str, int]]=None, use_dtm: bool=False, normalize_dtm: bool=False, -) -> Dict[str, npt.NDArray[Any]]: +) -> Dict[str, Union[npt.NDArray[Any], Dict[str, int], None]]: """Preprocessing function for a single example. Args: @@ -439,7 +439,7 @@ def preprocess_example( def convert_nifti_to_numpy( image_list: List[str], mask: Optional[str]=None, - ) -> Dict[str, Any]: + ) -> Dict[str, Union[npt.NDArray[Any], None]]: """Convert NIfTI images to numpy arrays. Args: