Skip to content

Commit

Permalink
Merge pull request #45 from aecelaya/main
Browse files Browse the repository at this point in the history
Improved target spacing computation and error handling.
  • Loading branch information
aecelaya authored Oct 8, 2024
2 parents d11cfe9 + 9771ff5 commit 7b47e03
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 37 deletions.
29 changes: 22 additions & 7 deletions mist/analyze_data/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,20 @@ def compute_class_weights(self):
"""Compute class weights on original data."""

# Either compute class weights or use user provided weights.
# Check that number of class weights matches the number of labels if
# provided.
n_labels = len(self.dataset_information["labels"])
if (
self.mist_arguments.class_weights and
len(self.mist_arguments.class_weights) != n_labels
):
raise ValueError(
"Number of class weights must match number of labels."
)

if self.mist_arguments.class_weights is None:
# Initialize class weights if not provided.
class_weights = [
0. for i in range(len(self.dataset_information["labels"]))
]
class_weights = [0. for i in range(n_labels)]
progress = utils.get_progress_bar("Computing class weights")

with progress as pb:
Expand Down Expand Up @@ -176,13 +185,19 @@ def get_target_spacing(self):
for i in pb.track(range(len(self.paths_dataframe))):
patient = self.paths_dataframe.iloc[i].to_dict()

# Read mask image. This is faster to load.
spacing = ants.image_header_info(patient["mask"])["spacing"]
# Reorient masks to RAI to collect target spacing. We do this
# to make sure that all of the axes in the spacings match up.
# We load the masks because they are smaller and faster to load.
mask = ants.image_read(patient["mask"])
mask = ants.reorient_image2(mask, "RAI")
mask.set_direction(
analyzer_constants.AnalyzeConstants.RAI_ANTS_DIRECTION
)

# Get voxel spacing.
original_spacings[i, :] = spacing
original_spacings[i, :] = mask.spacing

# Initialize target spacing
# Initialize target spacing.
target_spacing = list(np.median(original_spacings, axis=0))

# If anisotropic, adjust the coarsest resolution to bring ratio down.
Expand Down
4 changes: 4 additions & 0 deletions mist/analyze_data/analyzer_constants.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
"""Constants for the Analyzer class."""

import dataclasses
import numpy as np

@dataclasses.dataclass(frozen=True)
class AnalyzeConstants:
"""Dataclass for constants used in the analyze_data module."""

# RAI orientation direction for ANTs.
RAI_ANTS_DIRECTION = np.eye(3)

# Maximum memory in bytes for each image and mask pair.
MAX_MEMORY_PER_IMAGE_MASK_PAIR_BYTES = 2e9

Expand Down
62 changes: 35 additions & 27 deletions mist/inference/main_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ def get_sw_prediction(


def back_to_original_space(
pred: npt.NDArray[Any],
og_ants_img: ants.core.ants_image.ANTsImage,
config: Dict[str, Any],
fg_bbox: Optional[Dict[str, Any]],
prediction_npy: npt.NDArray[Any],
original_image_ants: ants.core.ants_image.ANTsImage,
mist_configuration: Dict[str, Any],
fg_bounding_box: Optional[Dict[str, Any]],
) -> ants.core.ants_image.ANTsImage:
"""Place prediction back into original image space.
Expand All @@ -101,52 +101,60 @@ def back_to_original_space(
header to the prediction's header.
Args:
pred: The prediction to place back into the original image space. This
should be a numpy array.
og_ants_img: The original ANTs image.
config: The configuration dictionary.
fg_bbox: The foreground bounding box.
prediction_npy: The prediction to place back into the original image
space. This should be a numpy array.
original_image_ants: The original ANTs image.
mist_configuration: The configuration dictionary.
fg_bounding_box: The foreground bounding box.
Returns:
pred: The prediction in the original image space. This will be an ANTs
image.
"""
# Convert prediction to ANTs image.
pred = ants.from_numpy(data=pred, spacing=config["target_spacing"])
prediction_ants = ants.from_numpy(
data=prediction_npy,
spacing=mist_configuration["target_spacing"]
)

# Reorient prediction.
pred = ants.reorient_image2(pred, ants.get_orientation(og_ants_img))
pred.set_direction(og_ants_img.direction)
prediction_ants = ants.reorient_image2(
prediction_ants,
ants.get_orientation(original_image_ants)
)
prediction_ants.set_direction(original_image_ants.direction)

# Enforce size for cropped images.
if fg_bbox is not None:
if fg_bounding_box is not None:
# If we have a foreground bounding box, use that to determine the size.
new_size = [
fg_bbox["x_end"] - fg_bbox["x_start"] + 1,
fg_bbox["y_end"] - fg_bbox["y_start"] + 1,
fg_bbox["z_end"] - fg_bbox["z_start"] + 1,
fg_bounding_box["x_end"] - fg_bounding_box["x_start"] + 1,
fg_bounding_box["y_end"] - fg_bounding_box["y_start"] + 1,
fg_bounding_box["z_end"] - fg_bounding_box["z_start"] + 1,
]
else:
# Otherwise, use the original image size.
new_size = og_ants_img.shape
new_size = original_image_ants.shape

# Resample prediction to original image space.
pred = preprocess.resample_mask(
pred,
labels=list(range(len(config["labels"]))),
target_spacing=og_ants_img.spacing,
prediction_ants = preprocess.resample_mask(
prediction_ants,
labels=list(range(len(mist_configuration["labels"]))),
target_spacing=original_image_ants.spacing,
new_size=np.array(new_size, dtype="int").tolist(),
)

# Appropriately pad back to original size if necessary.
if fg_bbox is not None:
pred = utils.decrop_from_fg(pred, fg_bbox)
if fg_bounding_box is not None:
prediction_ants = utils.decrop_from_fg(prediction_ants, fg_bounding_box)

# Copy header from original image onto the prediction so they match. This
# will take care of other details in the header like the origin and the
# image bounding box.
pred = og_ants_img.new_image_like(pred.numpy())
return pred
prediction_ants = original_image_ants.new_image_like(
prediction_ants.numpy()
)
return prediction_ants


def predict_single_example(
Expand Down Expand Up @@ -341,11 +349,11 @@ def check_test_time_input(
# Convert input to pandas dataframe
if isinstance(patients, pd.DataFrame):
return patients
if '.csv' in patients:
if '.csv' in patients and isinstance(patients, str):
return pd.read_csv(patients)
if isinstance(patients, dict):
return utils.convert_dict_to_df(patients)
if '.json' in patients:
if '.json' in patients and isinstance(patients, str):
patients = utils.read_json_file(patients)
return utils.convert_dict_to_df(patients)
raise ValueError(f"Received invalid input format: {type(patients)}")
Expand Down
6 changes: 3 additions & 3 deletions mist/preprocess_data/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Preprocessing functions for medical images and masks."""
import os
import argparse
from typing import Dict, List, Tuple, Any, Optional
from typing import Dict, List, Tuple, Any, Optional, Union

import ants
import numpy as np
Expand Down Expand Up @@ -321,7 +321,7 @@ def preprocess_example(
fg_bbox: Optional[Dict[str, int]]=None,
use_dtm: bool=False,
normalize_dtm: bool=False,
) -> Dict[str, npt.NDArray[Any]]:
) -> Dict[str, Union[npt.NDArray[Any], Dict[str, int], None]]:
"""Preprocessing function for a single example.
Args:
Expand Down Expand Up @@ -439,7 +439,7 @@ def preprocess_example(
def convert_nifti_to_numpy(
image_list: List[str],
mask: Optional[str]=None,
) -> Dict[str, Any]:
) -> Dict[str, Union[npt.NDArray[Any], None]]:
"""Convert NIfTI images to numpy arrays.
Args:
Expand Down

0 comments on commit 7b47e03

Please sign in to comment.