From fce4608afcc15a6cf66768cd7a333095aab36e6e Mon Sep 17 00:00:00 2001 From: aecelaya Date: Tue, 12 Nov 2024 22:04:33 -0600 Subject: [PATCH] Add MedNeXt models to get_model and arguments. --- mist/models/get_model.py | 59 +++++++++++++++++++++++----------------- mist/runtime/args.py | 8 ++++-- 2 files changed, 40 insertions(+), 27 deletions(-) diff --git a/mist/models/get_model.py b/mist/models/get_model.py index ab67a68..0b4779e 100755 --- a/mist/models/get_model.py +++ b/mist/models/get_model.py @@ -1,3 +1,4 @@ +"""Module for creating new models and loading pretrained models.""" import os import json import torch @@ -12,21 +13,12 @@ from mist.models.nnunet import NNUnet from mist.models.attn_unet import MONAIAttnUNet from mist.models.swin_unetr import MONAISwinUNETR - -""" -Available models: - - nnUNet - - U-Net - - FMG-Net - - W-Net - - Attention UNet - - Swin UNETR -""" +from mist.models.mednext_v1 import create_mednext_v1 def get_model(**kwargs): if kwargs["model_name"] == "nnunet": - model = NNUnet( + return NNUnet( kwargs["n_channels"], kwargs["n_classes"], kwargs["pocket"], @@ -37,8 +29,28 @@ def get_model(**kwargs): kwargs["target_spacing"], kwargs["use_res_block"] ) - elif kwargs["model_name"] == "unet": - model = UNet( + if kwargs["model_name"] == "mednext-v1-small": + return create_mednext_v1.create_mednext_v1_small( + kwargs["n_channels"], + kwargs["n_classes"], + ) + if kwargs["model_name"] == "mednext-v1-base": + return create_mednext_v1.create_mednext_v1_base( + kwargs["n_channels"], + kwargs["n_classes"], + ) + if kwargs["model_name"] == "mednext-v1-medium": + return create_mednext_v1.create_mednext_v1_medium( + kwargs["n_channels"], + kwargs["n_classes"], + ) + if kwargs["model_name"] == "mednext-v1-large": + return create_mednext_v1.create_mednext_v1_large( + kwargs["n_channels"], + kwargs["n_classes"], + ) + if kwargs["model_name"] == "unet": + return UNet( kwargs["n_channels"], kwargs["n_classes"], kwargs["patch_size"], @@ -48,8 +60,8 @@ def get_model(**kwargs): kwargs["deep_supervision_heads"], kwargs["vae_reg"] ) - elif kwargs["model_name"] == "fmgnet": - model = MGNet( + if kwargs["model_name"] == "fmgnet": + return MGNet( "fmgnet", kwargs["n_channels"], kwargs["n_classes"], @@ -59,8 +71,8 @@ def get_model(**kwargs): kwargs["deep_supervision_heads"], kwargs["vae_reg"] ) - elif kwargs["model_name"] == "wnet": - model = MGNet( + if kwargs["model_name"] == "wnet": + return MGNet( "wnet", kwargs["n_channels"], kwargs["n_classes"], @@ -70,23 +82,20 @@ def get_model(**kwargs): kwargs["deep_supervision_heads"], kwargs["vae_reg"] ) - elif kwargs["model_name"] == "attn_unet": - model = MONAIAttnUNet( + if kwargs["model_name"] == "attn-unet": + return MONAIAttnUNet( kwargs["n_classes"], kwargs["n_channels"], kwargs["pocket"], kwargs["patch_size"] ) - elif kwargs["model_name"] == "unetr": - model = MONAISwinUNETR( + if kwargs["model_name"] == "swin-unetr": + return MONAISwinUNETR( kwargs["n_classes"], kwargs["n_channels"], kwargs["patch_size"] ) - else: - raise ValueError("Invalid model name") - - return model + raise ValueError("Invalid model name") def load_model_from_config(weights_path, model_config_path): diff --git a/mist/runtime/args.py b/mist/runtime/args.py index f9494df..222a40f 100755 --- a/mist/runtime/args.py +++ b/mist/runtime/args.py @@ -281,11 +281,15 @@ def get_main_args(): default="nnunet", choices=[ "nnunet", + "mednext-v1-small", + "mednext-v1-base", + "mednext-v1-medium", + "mednext-v1-large", "unet", "fmgnet", "wnet", - "attn_unet", - "unetr", + "attn-unet", + "swin-unetr", "pretrained" ], help="Pick which network architecture to use"