Skip to content

Commit

Permalink
Add MedNeXt models to get_model and arguments.
Browse files Browse the repository at this point in the history
  • Loading branch information
aecelaya committed Nov 13, 2024
1 parent 5263a01 commit fce4608
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 27 deletions.
59 changes: 34 additions & 25 deletions mist/models/get_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Module for creating new models and loading pretrained models."""
import os
import json
import torch
Expand All @@ -12,21 +13,12 @@
from mist.models.nnunet import NNUnet
from mist.models.attn_unet import MONAIAttnUNet
from mist.models.swin_unetr import MONAISwinUNETR

"""
Available models:
- nnUNet
- U-Net
- FMG-Net
- W-Net
- Attention UNet
- Swin UNETR
"""
from mist.models.mednext_v1 import create_mednext_v1


def get_model(**kwargs):
if kwargs["model_name"] == "nnunet":
model = NNUnet(
return NNUnet(
kwargs["n_channels"],
kwargs["n_classes"],
kwargs["pocket"],
Expand All @@ -37,8 +29,28 @@ def get_model(**kwargs):
kwargs["target_spacing"],
kwargs["use_res_block"]
)
elif kwargs["model_name"] == "unet":
model = UNet(
if kwargs["model_name"] == "mednext-v1-small":
return create_mednext_v1.create_mednext_v1_small(
kwargs["n_channels"],
kwargs["n_classes"],
)
if kwargs["model_name"] == "mednext-v1-base":
return create_mednext_v1.create_mednext_v1_base(
kwargs["n_channels"],
kwargs["n_classes"],
)
if kwargs["model_name"] == "mednext-v1-medium":
return create_mednext_v1.create_mednext_v1_medium(
kwargs["n_channels"],
kwargs["n_classes"],
)
if kwargs["model_name"] == "mednext-v1-large":
return create_mednext_v1.create_mednext_v1_large(
kwargs["n_channels"],
kwargs["n_classes"],
)
if kwargs["model_name"] == "unet":
return UNet(
kwargs["n_channels"],
kwargs["n_classes"],
kwargs["patch_size"],
Expand All @@ -48,8 +60,8 @@ def get_model(**kwargs):
kwargs["deep_supervision_heads"],
kwargs["vae_reg"]
)
elif kwargs["model_name"] == "fmgnet":
model = MGNet(
if kwargs["model_name"] == "fmgnet":
return MGNet(
"fmgnet",
kwargs["n_channels"],
kwargs["n_classes"],
Expand All @@ -59,8 +71,8 @@ def get_model(**kwargs):
kwargs["deep_supervision_heads"],
kwargs["vae_reg"]
)
elif kwargs["model_name"] == "wnet":
model = MGNet(
if kwargs["model_name"] == "wnet":
return MGNet(
"wnet",
kwargs["n_channels"],
kwargs["n_classes"],
Expand All @@ -70,23 +82,20 @@ def get_model(**kwargs):
kwargs["deep_supervision_heads"],
kwargs["vae_reg"]
)
elif kwargs["model_name"] == "attn_unet":
model = MONAIAttnUNet(
if kwargs["model_name"] == "attn-unet":
return MONAIAttnUNet(
kwargs["n_classes"],
kwargs["n_channels"],
kwargs["pocket"],
kwargs["patch_size"]
)
elif kwargs["model_name"] == "unetr":
model = MONAISwinUNETR(
if kwargs["model_name"] == "swin-unetr":
return MONAISwinUNETR(
kwargs["n_classes"],
kwargs["n_channels"],
kwargs["patch_size"]
)
else:
raise ValueError("Invalid model name")

return model
raise ValueError("Invalid model name")


def load_model_from_config(weights_path, model_config_path):
Expand Down
8 changes: 6 additions & 2 deletions mist/runtime/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,15 @@ def get_main_args():
default="nnunet",
choices=[
"nnunet",
"mednext-v1-small",
"mednext-v1-base",
"mednext-v1-medium",
"mednext-v1-large",
"unet",
"fmgnet",
"wnet",
"attn_unet",
"unetr",
"attn-unet",
"swin-unetr",
"pretrained"
],
help="Pick which network architecture to use"
Expand Down

0 comments on commit fce4608

Please sign in to comment.