Skip to content

Commit

Permalink
Merge pull request #65 from aecelaya/main
Browse files Browse the repository at this point in the history
Add deep supervision and pocket option to all MedNeXt models.
  • Loading branch information
aecelaya authored Nov 15, 2024
2 parents 66aa661 + 12fdd41 commit 08ed855
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 43 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ Please see our Read the Docs page [**here**](https://mist-medical.readthedocs.io
## What's New
* November 2024 - MedNeXt models (small, base, medium, and large) added to MIST.
These models can be called with ```--model mednext-v1-<small, base, medium, large>```.
We're still working on getting deep supervision to work with MedNeXt.
* October 2024 - MIST takes 3rd place in BraTS 2024 adult glioma challenge @ MICCAI 2024!
* August 2024 - Added clDice as an available loss function.
* April 2024 - The Read the Docs page is up!
Expand Down
8 changes: 8 additions & 0 deletions mist/models/get_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,29 @@ def get_model(**kwargs):
return create_mednext_v1.create_mednext_v1_small(
kwargs["n_channels"],
kwargs["n_classes"],
kwargs["deep_supervision"],
kwargs["pocket"],
)
if kwargs["model_name"] == "mednext-v1-base":
return create_mednext_v1.create_mednext_v1_base(
kwargs["n_channels"],
kwargs["n_classes"],
kwargs["deep_supervision"],
kwargs["pocket"],
)
if kwargs["model_name"] == "mednext-v1-medium":
return create_mednext_v1.create_mednext_v1_medium(
kwargs["n_channels"],
kwargs["n_classes"],
kwargs["deep_supervision"],
kwargs["pocket"],
)
if kwargs["model_name"] == "mednext-v1-large":
return create_mednext_v1.create_mednext_v1_large(
kwargs["n_channels"],
kwargs["n_classes"],
kwargs["deep_supervision"],
kwargs["pocket"],
)
if kwargs["model_name"] == "unet":
return UNet(
Expand Down
25 changes: 16 additions & 9 deletions mist/models/mednext_v1/create_mednext_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
def create_mednext_v1_small(
num_input_channels: int,
num_classes: int,
deep_supervision: bool=False,
pocket: bool=False,
kernel_size: int=3,
ds: bool=False,
) -> MedNeXt:
"""Creates the small-sized version of the MedNeXt V1 model.
Expand All @@ -25,7 +26,8 @@ def create_mednext_v1_small(
n_classes=num_classes,
exp_r=2,
kernel_size=kernel_size,
deep_supervision=ds,
deep_supervision=deep_supervision,
pocket=pocket,
do_res=True,
do_res_up_down=True,
block_counts=[2,2,2,2,2,2,2,2,2],
Expand All @@ -35,8 +37,9 @@ def create_mednext_v1_small(
def create_mednext_v1_base(
num_input_channels: int,
num_classes: int,
deep_supervision: bool=False,
pocket: bool=False,
kernel_size: int=3,
ds: bool=False,
) -> MedNeXt:
"""Creates the baseline version of the MedNeXt V1 model.
Expand All @@ -55,7 +58,8 @@ def create_mednext_v1_base(
n_classes = num_classes,
exp_r=[2,3,4,4,4,4,4,3,2],
kernel_size=kernel_size,
deep_supervision=ds,
deep_supervision=deep_supervision,
pocket=pocket,
do_res=True,
do_res_up_down = True,
block_counts = [2,2,2,2,2,2,2,2,2],
Expand All @@ -65,8 +69,9 @@ def create_mednext_v1_base(
def create_mednext_v1_medium(
num_input_channels: int,
num_classes: int,
deep_supervision: bool=False,
pocket: bool=False,
kernel_size: int=3,
ds: bool=False,
) -> MedNeXt:
"""Creates the medium-sized version of the MedNeXt V1 model.
Expand All @@ -85,7 +90,8 @@ def create_mednext_v1_medium(
n_classes=num_classes,
exp_r=[2,3,4,4,4,4,4,3,2],
kernel_size=kernel_size,
deep_supervision=ds,
deep_supervision=deep_supervision,
pocket=pocket,
do_res=True,
do_res_up_down = True,
block_counts = [3,4,4,4,4,4,4,4,3],
Expand All @@ -95,8 +101,9 @@ def create_mednext_v1_medium(
def create_mednext_v1_large(
num_input_channels: int,
num_classes: int,
deep_supervision: bool=False,
pocket: bool=False,
kernel_size: int=3,
ds: bool=False,
) -> MedNeXt:
"""Creates the large-sized version of the MedNeXt V1 model.
Expand All @@ -115,9 +122,9 @@ def create_mednext_v1_large(
n_classes=num_classes,
exp_r=[3,4,8,8,8,8,8,4,3],
kernel_size=kernel_size,
deep_supervision=ds,
deep_supervision=deep_supervision,
pocket=pocket,
do_res=True,
do_res_up_down = True,
block_counts = [3,4,8,8,8,8,8,4,3],
)

73 changes: 40 additions & 33 deletions mist/models/mednext_v1/mednext_v1.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""MIST-compatible MedNeXt model."""
from typing import List, Union, Optional
import torch.nn as nn
import torch.nn.functional as F

from mist.models.mednext_v1 import blocks

Expand All @@ -15,6 +16,7 @@ def __init__(self,
enc_kernel_size: Optional[int]=None,
dec_kernel_size: Optional[int]=None,
deep_supervision: bool=False,
pocket: bool=False,
do_res: bool=False,
do_res_up_down: bool=False,
block_counts: list=[2,2,2,2,2,2,2,2,2],
Expand Down Expand Up @@ -57,7 +59,7 @@ def __init__(self,

self.down_0 = blocks.MedNeXtDownBlock(
in_channels=n_channels,
out_channels=2*n_channels,
out_channels=n_channels if pocket else 2*n_channels,
exp_r=exp_r[1],
kernel_size=enc_kernel_size,
do_res=do_res_up_down,
Expand All @@ -67,8 +69,8 @@ def __init__(self,

self.enc_block_1 = nn.Sequential(*[
blocks.MedNeXtBlock(
in_channels=n_channels*2,
out_channels=n_channels*2,
in_channels=n_channels if pocket else n_channels*2,
out_channels=n_channels if pocket else n_channels*2,
exp_r=exp_r[1],
kernel_size=enc_kernel_size,
do_res=do_res,
Expand All @@ -80,8 +82,8 @@ def __init__(self,
)

self.down_1 = blocks.MedNeXtDownBlock(
in_channels=2*n_channels,
out_channels=4*n_channels,
in_channels=n_channels if pocket else 2*n_channels,
out_channels=n_channels if pocket else 4*n_channels,
exp_r=exp_r[2],
kernel_size=enc_kernel_size,
do_res=do_res_up_down,
Expand All @@ -92,8 +94,8 @@ def __init__(self,

self.enc_block_2 = nn.Sequential(*[
blocks.MedNeXtBlock(
in_channels=n_channels*4,
out_channels=n_channels*4,
in_channels=n_channels if pocket else n_channels*4,
out_channels=n_channels if pocket else n_channels*4,
exp_r=exp_r[2],
kernel_size=enc_kernel_size,
do_res=do_res,
Expand All @@ -105,8 +107,8 @@ def __init__(self,
)

self.down_2 = blocks.MedNeXtDownBlock(
in_channels=4*n_channels,
out_channels=8*n_channels,
in_channels=n_channels if pocket else 4*n_channels,
out_channels=n_channels if pocket else 8*n_channels,
exp_r=exp_r[3],
kernel_size=enc_kernel_size,
do_res=do_res_up_down,
Expand All @@ -117,8 +119,8 @@ def __init__(self,

self.enc_block_3 = nn.Sequential(*[
blocks.MedNeXtBlock(
in_channels=n_channels*8,
out_channels=n_channels*8,
in_channels=n_channels if pocket else n_channels*8,
out_channels=n_channels if pocket else n_channels*8,
exp_r=exp_r[3],
kernel_size=enc_kernel_size,
do_res=do_res,
Expand All @@ -130,8 +132,8 @@ def __init__(self,
)

self.down_3 = blocks.MedNeXtDownBlock(
in_channels=8*n_channels,
out_channels=16*n_channels,
in_channels=n_channels if pocket else 8*n_channels,
out_channels=n_channels if pocket else 16*n_channels,
exp_r=exp_r[4],
kernel_size=enc_kernel_size,
do_res=do_res_up_down,
Expand All @@ -142,8 +144,8 @@ def __init__(self,

self.bottleneck = nn.Sequential(*[
blocks.MedNeXtBlock(
in_channels=n_channels*16,
out_channels=n_channels*16,
in_channels=n_channels if pocket else n_channels*16,
out_channels=n_channels if pocket else n_channels*16,
exp_r=exp_r[4],
kernel_size=dec_kernel_size,
do_res=do_res,
Expand All @@ -155,8 +157,8 @@ def __init__(self,
)

self.up_3 = blocks.MedNeXtUpBlock(
in_channels=16*n_channels,
out_channels=8*n_channels,
in_channels=n_channels if pocket else 16*n_channels,
out_channels=n_channels if pocket else 8*n_channels,
exp_r=exp_r[5],
kernel_size=dec_kernel_size,
do_res=do_res_up_down,
Expand All @@ -167,8 +169,8 @@ def __init__(self,

self.dec_block_3 = nn.Sequential(*[
blocks.MedNeXtBlock(
in_channels=n_channels*8,
out_channels=n_channels*8,
in_channels=n_channels if pocket else n_channels*8,
out_channels=n_channels if pocket else n_channels*8,
exp_r=exp_r[5],
kernel_size=dec_kernel_size,
do_res=do_res,
Expand All @@ -180,8 +182,8 @@ def __init__(self,
)

self.up_2 = blocks.MedNeXtUpBlock(
in_channels=8*n_channels,
out_channels=4*n_channels,
in_channels=n_channels if pocket else 8*n_channels,
out_channels=n_channels if pocket else 4*n_channels,
exp_r=exp_r[6],
kernel_size=dec_kernel_size,
do_res=do_res_up_down,
Expand All @@ -192,8 +194,8 @@ def __init__(self,

self.dec_block_2 = nn.Sequential(*[
blocks.MedNeXtBlock(
in_channels=n_channels*4,
out_channels=n_channels*4,
in_channels=n_channels if pocket else n_channels*4,
out_channels=n_channels if pocket else n_channels*4,
exp_r=exp_r[6],
kernel_size=dec_kernel_size,
do_res=do_res,
Expand All @@ -205,8 +207,8 @@ def __init__(self,
)

self.up_1 = blocks.MedNeXtUpBlock(
in_channels=4*n_channels,
out_channels=2*n_channels,
in_channels=n_channels if pocket else 4*n_channels,
out_channels=n_channels if pocket else 2*n_channels,
exp_r=exp_r[7],
kernel_size=dec_kernel_size,
do_res=do_res_up_down,
Expand All @@ -217,8 +219,8 @@ def __init__(self,

self.dec_block_1 = nn.Sequential(*[
blocks.MedNeXtBlock(
in_channels=n_channels*2,
out_channels=n_channels*2,
in_channels=n_channels if pocket else n_channels*2,
out_channels=n_channels if pocket else n_channels*2,
exp_r=exp_r[7],
kernel_size=dec_kernel_size,
do_res=do_res,
Expand All @@ -230,7 +232,7 @@ def __init__(self,
)

self.up_0 = blocks.MedNeXtUpBlock(
in_channels=2*n_channels,
in_channels=n_channels if pocket else 2*n_channels,
out_channels=n_channels,
exp_r=exp_r[8],
kernel_size=dec_kernel_size,
Expand All @@ -257,10 +259,10 @@ def __init__(self,
self.out_0 = blocks.OutBlock(in_channels=n_channels, n_classes=n_classes, dim=dim)

if deep_supervision:
self.out_1 = blocks.OutBlock(in_channels=n_channels*2, n_classes=n_classes, dim=dim)
self.out_2 = blocks.OutBlock(in_channels=n_channels*4, n_classes=n_classes, dim=dim)
self.out_3 = blocks.OutBlock(in_channels=n_channels*8, n_classes=n_classes, dim=dim)
self.out_4 = blocks.OutBlock(in_channels=n_channels*16, n_classes=n_classes, dim=dim)
self.out_1 = blocks.OutBlock(in_channels=n_channels if pocket else n_channels*2, n_classes=n_classes, dim=dim)
self.out_2 = blocks.OutBlock(in_channels=n_channels if pocket else n_channels*4, n_classes=n_classes, dim=dim)
self.out_3 = blocks.OutBlock(in_channels=n_channels if pocket else n_channels*8, n_classes=n_classes, dim=dim)
self.out_4 = blocks.OutBlock(in_channels=n_channels if pocket else n_channels*16, n_classes=n_classes, dim=dim)

self.block_counts = block_counts

Expand Down Expand Up @@ -315,7 +317,12 @@ def forward(self, x):
output["prediction"] = x

if self.do_ds:
output["deep_supervision"] = [x_ds_1, x_ds_2, x_ds_3, x_ds_4]
patch_size = x.shape[2:]
x_ds_1 = F.interpolate(x_ds_1, size=patch_size)
x_ds_2 = F.interpolate(x_ds_2, size=patch_size)
x_ds_3 = F.interpolate(x_ds_3, size=patch_size)
x_ds_4 = F.interpolate(x_ds_4, size=patch_size)
output["deep_supervision"] = (x_ds_1, x_ds_2, x_ds_3, x_ds_4)
else:
output = x

Expand Down

0 comments on commit 08ed855

Please sign in to comment.