From a7493647f0bd398e820e6a51c98ff1bcfc2fdc1b Mon Sep 17 00:00:00 2001 From: aecelaya Date: Mon, 4 Nov 2024 11:08:02 -0600 Subject: [PATCH 1/5] Minor bug fix for training data loader labels. --- mist/runtime/run.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/mist/runtime/run.py b/mist/runtime/run.py index e9ec0b9..5caf1c5 100755 --- a/mist/runtime/run.py +++ b/mist/runtime/run.py @@ -326,13 +326,20 @@ def train(self, rank: int, world_size: int) -> None: val_steps = len(val_images) // world_size # Get training data loader. + # The training labels are different from what's specified in the + # dataset description. The preprocessed masks have labels 0,1,...,N. + # We exclude the background label (0) from the training labels and + # pass the rest in as the labels for the training data loader. + training_labels = list(range(len( + self.data_structures["mist_configuration"]["labels"] + )))[1:] train_loader = dali_loader.get_training_dataset( imgs=train_images, lbls=train_labels, dtms=train_dtms, batch_size=self.mist_arguments.batch_size // world_size, oversampling=self.mist_arguments.oversampling, - labels=self.data_structures["mist_configuration"]["labels"][1:], + labels=training_labels, patch_size=( self.data_structures["mist_configuration"]["patch_size"] ), From 4cd229c58edf4cec6cd74b0bff06a0aed40ec2a5 Mon Sep 17 00:00:00 2001 From: aecelaya Date: Mon, 4 Nov 2024 11:08:38 -0600 Subject: [PATCH 2/5] Update version to 0.1.5-beta. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7423d86..436268e 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "mist-medical" -version = "0.1.4-beta" +version = "0.1.5-beta" requires-python = ">= 3.8" description = "MIST is a simple, fully automated framework for 3D medical imaging segmentation." readme = "README.md" From a379c506460232c642bd768e34ba8e98cf91278e Mon Sep 17 00:00:00 2001 From: aecelaya Date: Tue, 12 Nov 2024 17:45:23 -0600 Subject: [PATCH 3/5] Add _check_dataset_information method to ensure that the dataset description file is correct. --- mist/analyze_data/analyze.py | 185 +++++++++++++++++++++++++++++------ 1 file changed, 155 insertions(+), 30 deletions(-) diff --git a/mist/analyze_data/analyze.py b/mist/analyze_data/analyze.py index 36f2878..525c9bd 100755 --- a/mist/analyze_data/analyze.py +++ b/mist/analyze_data/analyze.py @@ -12,7 +12,7 @@ # MIST imports. from mist.runtime import utils -from mist.analyze_data import analyzer_constants +from mist.analyze_data import analyzer_constants as constants # Set up console for rich text. console = rich.console.Console() @@ -34,6 +34,7 @@ def __init__(self, mist_arguments): self.dataset_information = utils.read_json_file( self.mist_arguments.data ) + self._check_dataset_information() self.config = {} self.file_paths = { "configuration": ( @@ -50,6 +51,123 @@ def __init__(self, mist_arguments): mist_arguments.data, "train" ) + def _check_dataset_information(self): + """Check if the dataset description file is in the correct format.""" + required_fields = [ + "task", + "modality", + "train-data", + "mask", + "images", + "labels", + "final_classes", + ] + for field in required_fields: + # Check that the required fields are in the JSON file. + if field not in self.dataset_information: + raise KeyError( + f"Dataset description JSON file must contain a " + f"entry '{field}'. There is no '{field}' in the JSON file." + ) + + # Check that the required fields are not None. + if field is None: + raise ValueError( + f"Dataset description JSON file must contain a '{field}' " + f"entry. There is a None value in the JSON file for " + f"'{field}'." + ) + + # Check that the train data folder exists and is not empty. + if field == "train-data": + if not os.path.exists(self.dataset_information[field]): + raise FileNotFoundError( + "In the 'train-data' entry, the directory does not " + "exist. No such file or directory: " + f"{self.dataset_information[field]}" + ) + + if not os.listdir(self.dataset_information[field]): + raise FileNotFoundError( + "In the 'train-data' entry, the directory is empty: " + f"{self.dataset_information[field]}" + ) + + # Check that the mask entry is a list and not empty. + if field == "mask": + if not isinstance(self.dataset_information[field], list): + raise TypeError( + "The 'mask' entry must be a list of mask names in the " + "dataset description JSON file. Found the following " + f"entry instead: {self.dataset_information[field]}." + ) + + if not self.dataset_information[field]: + raise ValueError( + "The 'mask' entry is empty. Please provide a list of " + "mask names in the dataset description JSON file." + ) + + # Check that the images entry is a dictionary and not empty. + if field == "images": + if not isinstance(self.dataset_information[field], dict): + raise TypeError( + "The 'images' entry must be a dictionary of the format " + "{'image_type': [list of image names]} in the dataset " + "description JSON file. Found the following entry " + f"instead: {self.dataset_information[field]}." + ) + + if not self.dataset_information[field]: + raise ValueError( + "The 'images' entry is empty. Please provide a " + "dictionary of the format " + f"{'image_type': [list of image names]} in the dataset " + "description JSON file." + ) + + # Check that the labels entry is a list and not empty. Also check + # that zero is an entry in the labels list. + if field == "labels": + if not isinstance(self.dataset_information[field], list): + raise TypeError( + "The 'labels' entry must be a list of labels in the " + "dataset. This list must contain zero as a label. " + "Found the following entry instead: " + f"{self.dataset_information[field]}." + ) + + if not self.dataset_information[field]: + raise ValueError( + "The 'labels' entry must be a list of labels in the " + "dataset. This list must contain zero as a label. The " + "list is empty." + ) + + if 0 not in self.dataset_information[field]: + raise ValueError( + "The 'labels' entry must be a list of labels in the " + "dataset. This list must contain zero as a label. No " + "zero label found in the list." + ) + + # Check that the final classes entry is a dictionary and not empty. + if field == "final_classes": + if not isinstance(self.dataset_information[field], dict): + raise TypeError( + "The 'final_classes' entry must be a dictionary of the " + "format {class_name: [list of labels]}. Found the " + "following entry instead: " + f"{self.dataset_information[field]}." + ) + + if not self.dataset_information[field]: + raise ValueError( + "The 'final_classes' entry must be a dictionary of the " + "format {class_name: [list of labels]}. The dictionary " + "is empty." + ) + def check_crop_fg(self): """Check if cropping to foreground reduces image volume by 20%.""" progress = utils.get_progress_bar("Checking FG vol. reduction") @@ -105,7 +223,7 @@ def check_crop_fg(self): ) crop_to_fg = ( np.mean(vol_reduction) >= - analyzer_constants.AnalyzeConstants.MIN_AVERAGE_VOLUME_REDUCTION_FRACTION + constants.AnalyzeConstants.MIN_AVERAGE_VOLUME_REDUCTION_FRACTION ) return crop_to_fg, cropped_dims @@ -129,7 +247,7 @@ def check_nz_ratio(self): use_nz_mask = ( (1. - np.mean(nz_ratio)) >= - analyzer_constants.AnalyzeConstants.MIN_SPARSITY_FRACTION + constants.AnalyzeConstants.MIN_SPARSITY_FRACTION ) return use_nz_mask @@ -150,7 +268,7 @@ def get_target_spacing(self): mask = ants.image_read(patient["mask"]) mask = ants.reorient_image2(mask, "RAI") mask.set_direction( - analyzer_constants.AnalyzeConstants.RAI_ANTS_DIRECTION + constants.AnalyzeConstants.RAI_ANTS_DIRECTION ) # Get voxel spacing. @@ -162,13 +280,13 @@ def get_target_spacing(self): # If anisotropic, adjust the coarsest resolution to bring ratio down. if ( np.max(target_spacing) / np.min(target_spacing) > - analyzer_constants.AnalyzeConstants.MAX_DIVIDED_BY_MIN_SPACING_THRESHOLD + constants.AnalyzeConstants.MAX_DIVIDED_BY_MIN_SPACING_THRESHOLD ): low_res_axis = np.argmax(target_spacing) target_spacing[low_res_axis] = ( np.percentile( original_spacings[:, low_res_axis], - analyzer_constants.AnalyzeConstants.ANISOTROPIC_LOW_RESOLUTION_AXIS_PERCENTILE + constants.AnalyzeConstants.ANISOTROPIC_LOW_RESOLUTION_AXIS_PERCENTILE ) ) @@ -214,13 +332,13 @@ def check_resampled_dims(self, cropped_dims): # print to console. if ( image_memory_size > - analyzer_constants.AnalyzeConstants.MAX_RECOMMENDED_MEMORY_SIZE + constants.AnalyzeConstants.MAX_RECOMMENDED_MEMORY_SIZE ): print_patient_id = patient["id"] messages += ( f"[Warning] In {print_patient_id}: Resampled example " "is larger than the recommended memory size of " - f"{analyzer_constants.AnalyzeConstants.MAX_RECOMMENDED_MEMORY_SIZE/1e9} " + f"{constants.AnalyzeConstants.MAX_RECOMMENDED_MEMORY_SIZE/1e9} " "GB. Consider coarsening or removing this example.\n" ) @@ -253,18 +371,18 @@ def get_ct_normalization_parameters(self): # You don"t need to use all of the voxels for this. fg_intensities += ( image[mask != 0] - ).tolist()[::analyzer_constants.AnalyzeConstants.CT_GATHER_EVERY_ITH_VOXEL_VALUE] # type: ignore + ).tolist()[::constants.AnalyzeConstants.CT_GATHER_EVERY_ITH_VOXEL_VALUE] # type: ignore global_z_score_mean = np.mean(fg_intensities) global_z_score_std = np.std(fg_intensities) global_window_range = [ np.percentile( fg_intensities, - analyzer_constants.AnalyzeConstants.CT_GLOBAL_CLIP_MIN_PERCENTILE + constants.AnalyzeConstants.CT_GLOBAL_CLIP_MIN_PERCENTILE ), np.percentile( fg_intensities, - analyzer_constants.AnalyzeConstants.CT_GLOBAL_CLIP_MAX_PERCENTILE + constants.AnalyzeConstants.CT_GLOBAL_CLIP_MAX_PERCENTILE ), ] @@ -364,6 +482,7 @@ def validate_dataset(self): If any of these checks fail, the patient is excluded from training. """ progress = utils.get_progress_bar("Verifying dataset") + dataset_labels_set = set(self.dataset_information["labels"]) bad_data = [] messages = "" @@ -373,24 +492,37 @@ def validate_dataset(self): patient = self.paths_dataframe.iloc[i].to_dict() # Get list of images, mask, labels in mask, and the header. - image_list = list(patient.values())[2:len(patient)] - mask = ants.image_read(patient["mask"]) - mask_labels = set(mask.unique().astype(int)) - mask_header = ants.image_header_info(patient["mask"]) + try: + image_list = list(patient.values())[2:len(patient)] + mask = ants.image_read(patient["mask"]) + mask_labels = set(mask.unique().astype(int)) + mask_header = ants.image_header_info(patient["mask"]) + image_header = ants.image_header_info(image_list[0]) + except RuntimeError as e: + messages += f"In {patient['id']}: {e}\n" + bad_data.append(i) + continue # Check if labels are correct. - if not mask_labels.issubset( - set(self.dataset_information["labels"]) - ): + if not mask_labels.issubset(dataset_labels_set): messages += ( - f"In {patient['id']}: Labels in mask do not match those" + f"In {patient['id']}: Labels in mask do not match those" f" specified in {self.mist_arguments.data}\n" ) bad_data.append(i) continue - # Check that the image and mask headers match and that the - # images are 3D. + # Check that the mask is 3D. + if not utils.is_image_3d(mask_header): + messages += ( + f"In {patient['id']}: Got 4D mask, make sure all" + "images are 3D\n" + ) + bad_data.append(i) + break + + # Check that the mask and image headers match and that each + # images is 3D. for image_path in image_list: image_header = ants.image_header_info(image_path) if not utils.compare_headers( @@ -411,15 +543,8 @@ def validate_dataset(self): bad_data.append(i) break - if not utils.is_image_3d(mask_header): - messages += ( - f"In {patient['id']}: Got 4D mask, make sure all" - "images are 3D\n" - ) - bad_data.append(i) - break - - # Check that all images have the same header information. + # Check that all images have the same header information as + # the first image. if len(image_list) > 1: anchor_image = image_list[0] anchor_header = ants.image_header_info(anchor_image) From 5263a01fc61b1cb6b583595fa89b215c51fd1617 Mon Sep 17 00:00:00 2001 From: aecelaya Date: Tue, 12 Nov 2024 22:03:52 -0600 Subject: [PATCH 4/5] Add MedNeXt implementation to MIST. --- mist/models/mednext_v1/__init__.py | 0 mist/models/mednext_v1/blocks.py | 240 +++++++++++++++ mist/models/mednext_v1/create_mednext_v1.py | 123 ++++++++ mist/models/mednext_v1/mednext_v1.py | 322 ++++++++++++++++++++ 4 files changed, 685 insertions(+) create mode 100644 mist/models/mednext_v1/__init__.py create mode 100644 mist/models/mednext_v1/blocks.py create mode 100644 mist/models/mednext_v1/create_mednext_v1.py create mode 100644 mist/models/mednext_v1/mednext_v1.py diff --git a/mist/models/mednext_v1/__init__.py b/mist/models/mednext_v1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mist/models/mednext_v1/blocks.py b/mist/models/mednext_v1/blocks.py new file mode 100644 index 0000000..0c17936 --- /dev/null +++ b/mist/models/mednext_v1/blocks.py @@ -0,0 +1,240 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MedNeXtBlock(nn.Module): + + def __init__(self, + in_channels:int, + out_channels:int, + exp_r:int=4, + kernel_size:int=7, + do_res:int=True, + norm_type:str = 'group', + n_groups:int or None = None, + dim = '3d', + grn = False + ): + + super().__init__() + + self.do_res = do_res + + assert dim in ['2d', '3d'] + self.dim = dim + if self.dim == '2d': + conv = nn.Conv2d + elif self.dim == '3d': + conv = nn.Conv3d + + # First convolution layer with DepthWise Convolutions + self.conv1 = conv( + in_channels = in_channels, + out_channels = in_channels, + kernel_size = kernel_size, + stride = 1, + padding = kernel_size//2, + groups = in_channels if n_groups is None else n_groups, + ) + + # Normalization Layer. GroupNorm is used by default. + if norm_type=='group': + self.norm = nn.GroupNorm( + num_groups=in_channels, + num_channels=in_channels + ) + elif norm_type=='layer': + self.norm = LayerNorm( + normalized_shape=in_channels, + data_format='channels_first' + ) + + # Second convolution (Expansion) layer with Conv3D 1x1x1 + self.conv2 = conv( + in_channels = in_channels, + out_channels = exp_r*in_channels, + kernel_size = 1, + stride = 1, + padding = 0 + ) + + # GeLU activations + self.act = nn.GELU() + + # Third convolution (Compression) layer with Conv3D 1x1x1 + self.conv3 = conv( + in_channels = exp_r*in_channels, + out_channels = out_channels, + kernel_size = 1, + stride = 1, + padding = 0 + ) + + self.grn = grn + if grn: + if dim == '3d': + self.grn_beta = nn.Parameter(torch.zeros(1,exp_r*in_channels,1,1,1), requires_grad=True) + self.grn_gamma = nn.Parameter(torch.zeros(1,exp_r*in_channels,1,1,1), requires_grad=True) + elif dim == '2d': + self.grn_beta = nn.Parameter(torch.zeros(1,exp_r*in_channels,1,1), requires_grad=True) + self.grn_gamma = nn.Parameter(torch.zeros(1,exp_r*in_channels,1,1), requires_grad=True) + + + def forward(self, x, dummy_tensor=None): + + x1 = x + x1 = self.conv1(x1) + x1 = self.act(self.conv2(self.norm(x1))) + if self.grn: + # gamma, beta: learnable affine transform parameters + # X: input of shape (N,C,H,W,D) + if self.dim == '3d': + gx = torch.norm(x1, p=2, dim=(-3, -2, -1), keepdim=True) + elif self.dim == '2d': + gx = torch.norm(x1, p=2, dim=(-2, -1), keepdim=True) + nx = gx / (gx.mean(dim=1, keepdim=True)+1e-6) + x1 = self.grn_gamma * (x1 * nx) + self.grn_beta + x1 + x1 = self.conv3(x1) + if self.do_res: + x1 = x + x1 + return x1 + + +class MedNeXtDownBlock(MedNeXtBlock): + + def __init__(self, in_channels, out_channels, exp_r=4, kernel_size=7, + do_res=False, norm_type = 'group', dim='3d', grn=False): + + super().__init__(in_channels, out_channels, exp_r, kernel_size, + do_res = False, norm_type = norm_type, dim=dim, + grn=grn) + + if dim == '2d': + conv = nn.Conv2d + elif dim == '3d': + conv = nn.Conv3d + self.resample_do_res = do_res + if do_res: + self.res_conv = conv( + in_channels = in_channels, + out_channels = out_channels, + kernel_size = 1, + stride = 2 + ) + + self.conv1 = conv( + in_channels = in_channels, + out_channels = in_channels, + kernel_size = kernel_size, + stride = 2, + padding = kernel_size//2, + groups = in_channels, + ) + + def forward(self, x, dummy_tensor=None): + + x1 = super().forward(x) + + if self.resample_do_res: + res = self.res_conv(x) + x1 = x1 + res + + return x1 + + +class MedNeXtUpBlock(MedNeXtBlock): + + def __init__(self, in_channels, out_channels, exp_r=4, kernel_size=7, + do_res=False, norm_type = 'group', dim='3d', grn = False): + super().__init__(in_channels, out_channels, exp_r, kernel_size, + do_res=False, norm_type = norm_type, dim=dim, + grn=grn) + + self.resample_do_res = do_res + + self.dim = dim + if dim == '2d': + conv = nn.ConvTranspose2d + elif dim == '3d': + conv = nn.ConvTranspose3d + if do_res: + self.res_conv = conv( + in_channels = in_channels, + out_channels = out_channels, + kernel_size = 1, + stride = 2 + ) + + self.conv1 = conv( + in_channels = in_channels, + out_channels = in_channels, + kernel_size = kernel_size, + stride = 2, + padding = kernel_size//2, + groups = in_channels, + ) + + + def forward(self, x, dummy_tensor=None): + + x1 = super().forward(x) + # Asymmetry but necessary to match shape + + if self.dim == '2d': + x1 = torch.nn.functional.pad(x1, (1,0,1,0)) + elif self.dim == '3d': + x1 = torch.nn.functional.pad(x1, (1,0,1,0,1,0)) + + if self.resample_do_res: + res = self.res_conv(x) + if self.dim == '2d': + res = torch.nn.functional.pad(res, (1,0,1,0)) + elif self.dim == '3d': + res = torch.nn.functional.pad(res, (1,0,1,0,1,0)) + x1 = x1 + res + + return x1 + + +class OutBlock(nn.Module): + + def __init__(self, in_channels, n_classes, dim): + super().__init__() + + if dim == '2d': + conv = nn.ConvTranspose2d + elif dim == '3d': + conv = nn.ConvTranspose3d + self.conv_out = conv(in_channels, n_classes, kernel_size=1) + + def forward(self, x, dummy_tensor=None): + return self.conv_out(x) + + +class LayerNorm(nn.Module): + """ LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + def __init__(self, normalized_shape, eps=1e-5, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) # beta + self.bias = nn.Parameter(torch.zeros(normalized_shape)) # gamma + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape, ) + + def forward(self, x, dummy_tensor=False): + if self.data_format == "channels_last": + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None] + return x + diff --git a/mist/models/mednext_v1/create_mednext_v1.py b/mist/models/mednext_v1/create_mednext_v1.py new file mode 100644 index 0000000..b9b9f3b --- /dev/null +++ b/mist/models/mednext_v1/create_mednext_v1.py @@ -0,0 +1,123 @@ +"""MIST-compatible MedNeXt V1 model creation functions.""" +from mist.models.mednext_v1.mednext_v1 import MedNeXt + + +def create_mednext_v1_small( + num_input_channels: int, + num_classes: int, + kernel_size: int=3, + ds: bool=False, +) -> MedNeXt: + """Creates the small-sized version of the MedNeXt V1 model. + + Args: + num_input_channels: Number of input channels. + num_classes: Number of output classes. + kernel_size: Kernel size for convolutional layers. + ds: Whether to use deep supervision. + + Returns: + MedNeXt model. + """ + return MedNeXt( + in_channels=num_input_channels, + n_channels=32, + n_classes=num_classes, + exp_r=2, + kernel_size=kernel_size, + deep_supervision=ds, + do_res=True, + do_res_up_down=True, + block_counts=[2,2,2,2,2,2,2,2,2], + ) + + +def create_mednext_v1_base( + num_input_channels: int, + num_classes: int, + kernel_size: int=3, + ds: bool=False, +) -> MedNeXt: + """Creates the baseline version of the MedNeXt V1 model. + + Args: + num_input_channels: Number of input channels. + num_classes: Number of output classes. + kernel_size: Kernel size for convolutional layers. + ds: Whether to use deep supervision. + + Returns: + MedNeXt model. + """ + return MedNeXt( + in_channels = num_input_channels, + n_channels = 32, + n_classes = num_classes, + exp_r=[2,3,4,4,4,4,4,3,2], + kernel_size=kernel_size, + deep_supervision=ds, + do_res=True, + do_res_up_down = True, + block_counts = [2,2,2,2,2,2,2,2,2], + ) + + +def create_mednext_v1_medium( + num_input_channels: int, + num_classes: int, + kernel_size: int=3, + ds: bool=False, +) -> MedNeXt: + """Creates the medium-sized version of the MedNeXt V1 model. + + Args: + num_input_channels: Number of input channels. + num_classes: Number of output classes. + kernel_size: Kernel size for convolutional layers. + ds: Whether to use deep supervision. + + Returns: + MedNeXt model. + """ + return MedNeXt( + in_channels=num_input_channels, + n_channels=32, + n_classes=num_classes, + exp_r=[2,3,4,4,4,4,4,3,2], + kernel_size=kernel_size, + deep_supervision=ds, + do_res=True, + do_res_up_down = True, + block_counts = [3,4,4,4,4,4,4,4,3], + ) + + +def create_mednext_v1_large( + num_input_channels: int, + num_classes: int, + kernel_size: int=3, + ds: bool=False, +) -> MedNeXt: + """Creates the large-sized version of the MedNeXt V1 model. + + Args: + num_input_channels: Number of input channels. + num_classes: Number of output classes. + kernel_size: Kernel size for convolutional layers. + ds: Whether to use deep supervision. + + Returns: + MedNeXt model. + """ + return MedNeXt( + in_channels=num_input_channels, + n_channels=32, + n_classes=num_classes, + exp_r=[3,4,8,8,8,8,8,4,3], + kernel_size=kernel_size, + deep_supervision=ds, + do_res=True, + do_res_up_down = True, + block_counts = [3,4,8,8,8,8,8,4,3], + ) + diff --git a/mist/models/mednext_v1/mednext_v1.py b/mist/models/mednext_v1/mednext_v1.py new file mode 100644 index 0000000..12a56c2 --- /dev/null +++ b/mist/models/mednext_v1/mednext_v1.py @@ -0,0 +1,322 @@ +"""MIST-compatible MedNeXt model.""" +from typing import List, Union, Optional +import torch.nn as nn + +from mist.models.mednext_v1 import blocks + + +class MedNeXt(nn.Module): + def __init__(self, + in_channels: int, + n_channels: int, + n_classes: int, + exp_r: Union[List[int], int]=4, + kernel_size: int=7, + enc_kernel_size: Optional[int]=None, + dec_kernel_size: Optional[int]=None, + deep_supervision: bool=False, + do_res: bool=False, + do_res_up_down: bool=False, + block_counts: list=[2,2,2,2,2,2,2,2,2], + norm_type: str='group', + dim: str='3d', + grn: bool=False + ): + + super().__init__() + + self.do_ds = deep_supervision + assert dim in ['2d', '3d'] + + if kernel_size is not None: + enc_kernel_size = kernel_size + dec_kernel_size = kernel_size + + if dim == '2d': + conv = nn.Conv2d + elif dim == '3d': + conv = nn.Conv3d + + self.stem = conv(in_channels, n_channels, kernel_size=1) + if isinstance(exp_r, int): + exp_r = [exp_r] * len(block_counts) + + self.enc_block_0 = nn.Sequential(*[ + blocks.MedNeXtBlock( + in_channels=n_channels, + out_channels=n_channels, + exp_r=exp_r[0], + kernel_size=enc_kernel_size, + do_res=do_res, + norm_type=norm_type, + dim=dim, + grn=grn + ) + for i in range(block_counts[0])] + ) + + self.down_0 = blocks.MedNeXtDownBlock( + in_channels=n_channels, + out_channels=2*n_channels, + exp_r=exp_r[1], + kernel_size=enc_kernel_size, + do_res=do_res_up_down, + norm_type=norm_type, + dim=dim + ) + + self.enc_block_1 = nn.Sequential(*[ + blocks.MedNeXtBlock( + in_channels=n_channels*2, + out_channels=n_channels*2, + exp_r=exp_r[1], + kernel_size=enc_kernel_size, + do_res=do_res, + norm_type=norm_type, + dim=dim, + grn=grn + ) + for i in range(block_counts[1])] + ) + + self.down_1 = blocks.MedNeXtDownBlock( + in_channels=2*n_channels, + out_channels=4*n_channels, + exp_r=exp_r[2], + kernel_size=enc_kernel_size, + do_res=do_res_up_down, + norm_type=norm_type, + dim=dim, + grn=grn + ) + + self.enc_block_2 = nn.Sequential(*[ + blocks.MedNeXtBlock( + in_channels=n_channels*4, + out_channels=n_channels*4, + exp_r=exp_r[2], + kernel_size=enc_kernel_size, + do_res=do_res, + norm_type=norm_type, + dim=dim, + grn=grn + ) + for i in range(block_counts[2])] + ) + + self.down_2 = blocks.MedNeXtDownBlock( + in_channels=4*n_channels, + out_channels=8*n_channels, + exp_r=exp_r[3], + kernel_size=enc_kernel_size, + do_res=do_res_up_down, + norm_type=norm_type, + dim=dim, + grn=grn + ) + + self.enc_block_3 = nn.Sequential(*[ + blocks.MedNeXtBlock( + in_channels=n_channels*8, + out_channels=n_channels*8, + exp_r=exp_r[3], + kernel_size=enc_kernel_size, + do_res=do_res, + norm_type=norm_type, + dim=dim, + grn=grn + ) + for i in range(block_counts[3])] + ) + + self.down_3 = blocks.MedNeXtDownBlock( + in_channels=8*n_channels, + out_channels=16*n_channels, + exp_r=exp_r[4], + kernel_size=enc_kernel_size, + do_res=do_res_up_down, + norm_type=norm_type, + dim=dim, + grn=grn + ) + + self.bottleneck = nn.Sequential(*[ + blocks.MedNeXtBlock( + in_channels=n_channels*16, + out_channels=n_channels*16, + exp_r=exp_r[4], + kernel_size=dec_kernel_size, + do_res=do_res, + norm_type=norm_type, + dim=dim, + grn=grn + ) + for i in range(block_counts[4])] + ) + + self.up_3 = blocks.MedNeXtUpBlock( + in_channels=16*n_channels, + out_channels=8*n_channels, + exp_r=exp_r[5], + kernel_size=dec_kernel_size, + do_res=do_res_up_down, + norm_type=norm_type, + dim=dim, + grn=grn + ) + + self.dec_block_3 = nn.Sequential(*[ + blocks.MedNeXtBlock( + in_channels=n_channels*8, + out_channels=n_channels*8, + exp_r=exp_r[5], + kernel_size=dec_kernel_size, + do_res=do_res, + norm_type=norm_type, + dim=dim, + grn=grn + ) + for i in range(block_counts[5])] + ) + + self.up_2 = blocks.MedNeXtUpBlock( + in_channels=8*n_channels, + out_channels=4*n_channels, + exp_r=exp_r[6], + kernel_size=dec_kernel_size, + do_res=do_res_up_down, + norm_type=norm_type, + dim=dim, + grn=grn + ) + + self.dec_block_2 = nn.Sequential(*[ + blocks.MedNeXtBlock( + in_channels=n_channels*4, + out_channels=n_channels*4, + exp_r=exp_r[6], + kernel_size=dec_kernel_size, + do_res=do_res, + norm_type=norm_type, + dim=dim, + grn=grn + ) + for i in range(block_counts[6])] + ) + + self.up_1 = blocks.MedNeXtUpBlock( + in_channels=4*n_channels, + out_channels=2*n_channels, + exp_r=exp_r[7], + kernel_size=dec_kernel_size, + do_res=do_res_up_down, + norm_type=norm_type, + dim=dim, + grn=grn + ) + + self.dec_block_1 = nn.Sequential(*[ + blocks.MedNeXtBlock( + in_channels=n_channels*2, + out_channels=n_channels*2, + exp_r=exp_r[7], + kernel_size=dec_kernel_size, + do_res=do_res, + norm_type=norm_type, + dim=dim, + grn=grn + ) + for i in range(block_counts[7])] + ) + + self.up_0 = blocks.MedNeXtUpBlock( + in_channels=2*n_channels, + out_channels=n_channels, + exp_r=exp_r[8], + kernel_size=dec_kernel_size, + do_res=do_res_up_down, + norm_type=norm_type, + dim=dim, + grn=grn + ) + + self.dec_block_0 = nn.Sequential(*[ + blocks.MedNeXtBlock( + in_channels=n_channels, + out_channels=n_channels, + exp_r=exp_r[8], + kernel_size=dec_kernel_size, + do_res=do_res, + norm_type=norm_type, + dim=dim, + grn=grn + ) + for i in range(block_counts[8])] + ) + + self.out_0 = blocks.OutBlock(in_channels=n_channels, n_classes=n_classes, dim=dim) + + if deep_supervision: + self.out_1 = blocks.OutBlock(in_channels=n_channels*2, n_classes=n_classes, dim=dim) + self.out_2 = blocks.OutBlock(in_channels=n_channels*4, n_classes=n_classes, dim=dim) + self.out_3 = blocks.OutBlock(in_channels=n_channels*8, n_classes=n_classes, dim=dim) + self.out_4 = blocks.OutBlock(in_channels=n_channels*16, n_classes=n_classes, dim=dim) + + self.block_counts = block_counts + + + def forward(self, x): + x = self.stem(x) + x_res_0 = self.enc_block_0(x) + x = self.down_0(x_res_0) + x_res_1 = self.enc_block_1(x) + x = self.down_1(x_res_1) + x_res_2 = self.enc_block_2(x) + x = self.down_2(x_res_2) + x_res_3 = self.enc_block_3(x) + x = self.down_3(x_res_3) + + x = self.bottleneck(x) + if self.do_ds and self.training: + x_ds_4 = self.out_4(x) + + x_up_3 = self.up_3(x) + dec_x = x_res_3 + x_up_3 + x = self.dec_block_3(dec_x) + + if self.do_ds and self.training: + x_ds_3 = self.out_3(x) + del x_res_3, x_up_3 + + x_up_2 = self.up_2(x) + dec_x = x_res_2 + x_up_2 + x = self.dec_block_2(dec_x) + if self.do_ds and self.training: + x_ds_2 = self.out_2(x) + del x_res_2, x_up_2 + + x_up_1 = self.up_1(x) + dec_x = x_res_1 + x_up_1 + x = self.dec_block_1(dec_x) + if self.do_ds and self.training: + x_ds_1 = self.out_1(x) + del x_res_1, x_up_1 + + x_up_0 = self.up_0(x) + dec_x = x_res_0 + x_up_0 + x = self.dec_block_0(dec_x) + del x_res_0, x_up_0, dec_x + + x = self.out_0(x) + + # Make MedNeXt compatible with MIST. + if self.training: + output = {} + output["prediction"] = x + + if self.do_ds: + output["deep_supervision"] = [x_ds_1, x_ds_2, x_ds_3, x_ds_4] + else: + output = x + + return output From fce4608afcc15a6cf66768cd7a333095aab36e6e Mon Sep 17 00:00:00 2001 From: aecelaya Date: Tue, 12 Nov 2024 22:04:33 -0600 Subject: [PATCH 5/5] Add MedNeXt models to get_model and arguments. --- mist/models/get_model.py | 59 +++++++++++++++++++++++----------------- mist/runtime/args.py | 8 ++++-- 2 files changed, 40 insertions(+), 27 deletions(-) diff --git a/mist/models/get_model.py b/mist/models/get_model.py index ab67a68..0b4779e 100755 --- a/mist/models/get_model.py +++ b/mist/models/get_model.py @@ -1,3 +1,4 @@ +"""Module for creating new models and loading pretrained models.""" import os import json import torch @@ -12,21 +13,12 @@ from mist.models.nnunet import NNUnet from mist.models.attn_unet import MONAIAttnUNet from mist.models.swin_unetr import MONAISwinUNETR - -""" -Available models: - - nnUNet - - U-Net - - FMG-Net - - W-Net - - Attention UNet - - Swin UNETR -""" +from mist.models.mednext_v1 import create_mednext_v1 def get_model(**kwargs): if kwargs["model_name"] == "nnunet": - model = NNUnet( + return NNUnet( kwargs["n_channels"], kwargs["n_classes"], kwargs["pocket"], @@ -37,8 +29,28 @@ def get_model(**kwargs): kwargs["target_spacing"], kwargs["use_res_block"] ) - elif kwargs["model_name"] == "unet": - model = UNet( + if kwargs["model_name"] == "mednext-v1-small": + return create_mednext_v1.create_mednext_v1_small( + kwargs["n_channels"], + kwargs["n_classes"], + ) + if kwargs["model_name"] == "mednext-v1-base": + return create_mednext_v1.create_mednext_v1_base( + kwargs["n_channels"], + kwargs["n_classes"], + ) + if kwargs["model_name"] == "mednext-v1-medium": + return create_mednext_v1.create_mednext_v1_medium( + kwargs["n_channels"], + kwargs["n_classes"], + ) + if kwargs["model_name"] == "mednext-v1-large": + return create_mednext_v1.create_mednext_v1_large( + kwargs["n_channels"], + kwargs["n_classes"], + ) + if kwargs["model_name"] == "unet": + return UNet( kwargs["n_channels"], kwargs["n_classes"], kwargs["patch_size"], @@ -48,8 +60,8 @@ def get_model(**kwargs): kwargs["deep_supervision_heads"], kwargs["vae_reg"] ) - elif kwargs["model_name"] == "fmgnet": - model = MGNet( + if kwargs["model_name"] == "fmgnet": + return MGNet( "fmgnet", kwargs["n_channels"], kwargs["n_classes"], @@ -59,8 +71,8 @@ def get_model(**kwargs): kwargs["deep_supervision_heads"], kwargs["vae_reg"] ) - elif kwargs["model_name"] == "wnet": - model = MGNet( + if kwargs["model_name"] == "wnet": + return MGNet( "wnet", kwargs["n_channels"], kwargs["n_classes"], @@ -70,23 +82,20 @@ def get_model(**kwargs): kwargs["deep_supervision_heads"], kwargs["vae_reg"] ) - elif kwargs["model_name"] == "attn_unet": - model = MONAIAttnUNet( + if kwargs["model_name"] == "attn-unet": + return MONAIAttnUNet( kwargs["n_classes"], kwargs["n_channels"], kwargs["pocket"], kwargs["patch_size"] ) - elif kwargs["model_name"] == "unetr": - model = MONAISwinUNETR( + if kwargs["model_name"] == "swin-unetr": + return MONAISwinUNETR( kwargs["n_classes"], kwargs["n_channels"], kwargs["patch_size"] ) - else: - raise ValueError("Invalid model name") - - return model + raise ValueError("Invalid model name") def load_model_from_config(weights_path, model_config_path): diff --git a/mist/runtime/args.py b/mist/runtime/args.py index f9494df..222a40f 100755 --- a/mist/runtime/args.py +++ b/mist/runtime/args.py @@ -281,11 +281,15 @@ def get_main_args(): default="nnunet", choices=[ "nnunet", + "mednext-v1-small", + "mednext-v1-base", + "mednext-v1-medium", + "mednext-v1-large", "unet", "fmgnet", "wnet", - "attn_unet", - "unetr", + "attn-unet", + "swin-unetr", "pretrained" ], help="Pick which network architecture to use"