Skip to content

Commit

Permalink
Merge pull request #63 from aecelaya/main
Browse files Browse the repository at this point in the history
Add MedNeXt implementation to MIST. Minor improvements to data loading and analyzer tools.
  • Loading branch information
aecelaya authored Nov 13, 2024
2 parents 67576cb + fce4608 commit 93ee56b
Show file tree
Hide file tree
Showing 9 changed files with 889 additions and 59 deletions.
185 changes: 155 additions & 30 deletions mist/analyze_data/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

# MIST imports.
from mist.runtime import utils
from mist.analyze_data import analyzer_constants
from mist.analyze_data import analyzer_constants as constants

# Set up console for rich text.
console = rich.console.Console()
Expand All @@ -34,6 +34,7 @@ def __init__(self, mist_arguments):
self.dataset_information = utils.read_json_file(
self.mist_arguments.data
)
self._check_dataset_information()
self.config = {}
self.file_paths = {
"configuration": (
Expand All @@ -50,6 +51,123 @@ def __init__(self, mist_arguments):
mist_arguments.data, "train"
)

def _check_dataset_information(self):
"""Check if the dataset description file is in the correct format."""
required_fields = [
"task",
"modality",
"train-data",
"mask",
"images",
"labels",
"final_classes",
]
for field in required_fields:
# Check that the required fields are in the JSON file.
if field not in self.dataset_information:
raise KeyError(
f"Dataset description JSON file must contain a "
f"entry '{field}'. There is no '{field}' in the JSON file."
)

# Check that the required fields are not None.
if field is None:
raise ValueError(
f"Dataset description JSON file must contain a '{field}' "
f"entry. There is a None value in the JSON file for "
f"'{field}'."
)

# Check that the train data folder exists and is not empty.
if field == "train-data":
if not os.path.exists(self.dataset_information[field]):
raise FileNotFoundError(
"In the 'train-data' entry, the directory does not "
"exist. No such file or directory: "
f"{self.dataset_information[field]}"
)

if not os.listdir(self.dataset_information[field]):
raise FileNotFoundError(
"In the 'train-data' entry, the directory is empty: "
f"{self.dataset_information[field]}"
)

# Check that the mask entry is a list and not empty.
if field == "mask":
if not isinstance(self.dataset_information[field], list):
raise TypeError(
"The 'mask' entry must be a list of mask names in the "
"dataset description JSON file. Found the following "
f"entry instead: {self.dataset_information[field]}."
)

if not self.dataset_information[field]:
raise ValueError(
"The 'mask' entry is empty. Please provide a list of "
"mask names in the dataset description JSON file."
)

# Check that the images entry is a dictionary and not empty.
if field == "images":
if not isinstance(self.dataset_information[field], dict):
raise TypeError(
"The 'images' entry must be a dictionary of the format "
"{'image_type': [list of image names]} in the dataset "
"description JSON file. Found the following entry "
f"instead: {self.dataset_information[field]}."
)

if not self.dataset_information[field]:
raise ValueError(
"The 'images' entry is empty. Please provide a "
"dictionary of the format "
f"{'image_type': [list of image names]} in the dataset "
"description JSON file."
)

# Check that the labels entry is a list and not empty. Also check
# that zero is an entry in the labels list.
if field == "labels":
if not isinstance(self.dataset_information[field], list):
raise TypeError(
"The 'labels' entry must be a list of labels in the "
"dataset. This list must contain zero as a label. "
"Found the following entry instead: "
f"{self.dataset_information[field]}."
)

if not self.dataset_information[field]:
raise ValueError(
"The 'labels' entry must be a list of labels in the "
"dataset. This list must contain zero as a label. The "
"list is empty."
)

if 0 not in self.dataset_information[field]:
raise ValueError(
"The 'labels' entry must be a list of labels in the "
"dataset. This list must contain zero as a label. No "
"zero label found in the list."
)

# Check that the final classes entry is a dictionary and not empty.
if field == "final_classes":
if not isinstance(self.dataset_information[field], dict):
raise TypeError(
"The 'final_classes' entry must be a dictionary of the "
"format {class_name: [list of labels]}. Found the "
"following entry instead: "
f"{self.dataset_information[field]}."
)

if not self.dataset_information[field]:
raise ValueError(
"The 'final_classes' entry must be a dictionary of the "
"format {class_name: [list of labels]}. The dictionary "
"is empty."
)

def check_crop_fg(self):
"""Check if cropping to foreground reduces image volume by 20%."""
progress = utils.get_progress_bar("Checking FG vol. reduction")
Expand Down Expand Up @@ -105,7 +223,7 @@ def check_crop_fg(self):
)
crop_to_fg = (
np.mean(vol_reduction) >=
analyzer_constants.AnalyzeConstants.MIN_AVERAGE_VOLUME_REDUCTION_FRACTION
constants.AnalyzeConstants.MIN_AVERAGE_VOLUME_REDUCTION_FRACTION
)
return crop_to_fg, cropped_dims

Expand All @@ -129,7 +247,7 @@ def check_nz_ratio(self):

use_nz_mask = (
(1. - np.mean(nz_ratio)) >=
analyzer_constants.AnalyzeConstants.MIN_SPARSITY_FRACTION
constants.AnalyzeConstants.MIN_SPARSITY_FRACTION
)
return use_nz_mask

Expand All @@ -150,7 +268,7 @@ def get_target_spacing(self):
mask = ants.image_read(patient["mask"])
mask = ants.reorient_image2(mask, "RAI")
mask.set_direction(
analyzer_constants.AnalyzeConstants.RAI_ANTS_DIRECTION
constants.AnalyzeConstants.RAI_ANTS_DIRECTION
)

# Get voxel spacing.
Expand All @@ -162,13 +280,13 @@ def get_target_spacing(self):
# If anisotropic, adjust the coarsest resolution to bring ratio down.
if (
np.max(target_spacing) / np.min(target_spacing) >
analyzer_constants.AnalyzeConstants.MAX_DIVIDED_BY_MIN_SPACING_THRESHOLD
constants.AnalyzeConstants.MAX_DIVIDED_BY_MIN_SPACING_THRESHOLD
):
low_res_axis = np.argmax(target_spacing)
target_spacing[low_res_axis] = (
np.percentile(
original_spacings[:, low_res_axis],
analyzer_constants.AnalyzeConstants.ANISOTROPIC_LOW_RESOLUTION_AXIS_PERCENTILE
constants.AnalyzeConstants.ANISOTROPIC_LOW_RESOLUTION_AXIS_PERCENTILE
)
)

Expand Down Expand Up @@ -214,13 +332,13 @@ def check_resampled_dims(self, cropped_dims):
# print to console.
if (
image_memory_size >
analyzer_constants.AnalyzeConstants.MAX_RECOMMENDED_MEMORY_SIZE
constants.AnalyzeConstants.MAX_RECOMMENDED_MEMORY_SIZE
):
print_patient_id = patient["id"]
messages += (
f"[Warning] In {print_patient_id}: Resampled example "
"is larger than the recommended memory size of "
f"{analyzer_constants.AnalyzeConstants.MAX_RECOMMENDED_MEMORY_SIZE/1e9} "
f"{constants.AnalyzeConstants.MAX_RECOMMENDED_MEMORY_SIZE/1e9} "
"GB. Consider coarsening or removing this example.\n"
)

Expand Down Expand Up @@ -253,18 +371,18 @@ def get_ct_normalization_parameters(self):
# You don"t need to use all of the voxels for this.
fg_intensities += (
image[mask != 0]
).tolist()[::analyzer_constants.AnalyzeConstants.CT_GATHER_EVERY_ITH_VOXEL_VALUE] # type: ignore
).tolist()[::constants.AnalyzeConstants.CT_GATHER_EVERY_ITH_VOXEL_VALUE] # type: ignore

global_z_score_mean = np.mean(fg_intensities)
global_z_score_std = np.std(fg_intensities)
global_window_range = [
np.percentile(
fg_intensities,
analyzer_constants.AnalyzeConstants.CT_GLOBAL_CLIP_MIN_PERCENTILE
constants.AnalyzeConstants.CT_GLOBAL_CLIP_MIN_PERCENTILE
),
np.percentile(
fg_intensities,
analyzer_constants.AnalyzeConstants.CT_GLOBAL_CLIP_MAX_PERCENTILE
constants.AnalyzeConstants.CT_GLOBAL_CLIP_MAX_PERCENTILE
),
]

Expand Down Expand Up @@ -364,6 +482,7 @@ def validate_dataset(self):
If any of these checks fail, the patient is excluded from training.
"""
progress = utils.get_progress_bar("Verifying dataset")
dataset_labels_set = set(self.dataset_information["labels"])

bad_data = []
messages = ""
Expand All @@ -373,24 +492,37 @@ def validate_dataset(self):
patient = self.paths_dataframe.iloc[i].to_dict()

# Get list of images, mask, labels in mask, and the header.
image_list = list(patient.values())[2:len(patient)]
mask = ants.image_read(patient["mask"])
mask_labels = set(mask.unique().astype(int))
mask_header = ants.image_header_info(patient["mask"])
try:
image_list = list(patient.values())[2:len(patient)]
mask = ants.image_read(patient["mask"])
mask_labels = set(mask.unique().astype(int))
mask_header = ants.image_header_info(patient["mask"])
image_header = ants.image_header_info(image_list[0])
except RuntimeError as e:
messages += f"In {patient['id']}: {e}\n"
bad_data.append(i)
continue

# Check if labels are correct.
if not mask_labels.issubset(
set(self.dataset_information["labels"])
):
if not mask_labels.issubset(dataset_labels_set):
messages += (
f"In {patient['id']}: Labels in mask do not match those"
f"In {patient['id']}: Labels in mask do not match those"
f" specified in {self.mist_arguments.data}\n"
)
bad_data.append(i)
continue

# Check that the image and mask headers match and that the
# images are 3D.
# Check that the mask is 3D.
if not utils.is_image_3d(mask_header):
messages += (
f"In {patient['id']}: Got 4D mask, make sure all"
"images are 3D\n"
)
bad_data.append(i)
break

# Check that the mask and image headers match and that each
# images is 3D.
for image_path in image_list:
image_header = ants.image_header_info(image_path)
if not utils.compare_headers(
Expand All @@ -411,15 +543,8 @@ def validate_dataset(self):
bad_data.append(i)
break

if not utils.is_image_3d(mask_header):
messages += (
f"In {patient['id']}: Got 4D mask, make sure all"
"images are 3D\n"
)
bad_data.append(i)
break

# Check that all images have the same header information.
# Check that all images have the same header information as
# the first image.
if len(image_list) > 1:
anchor_image = image_list[0]
anchor_header = ants.image_header_info(anchor_image)
Expand Down
Loading

0 comments on commit 93ee56b

Please sign in to comment.