Skip to content

Commit

Permalink
Merge pull request #43 from aecelaya/main
Browse files Browse the repository at this point in the history
Bug fixes for VAE regularization and pretrained models.
  • Loading branch information
aecelaya authored Oct 1, 2024
2 parents aeb260a + 35cf50d commit 26aa9bb
Show file tree
Hide file tree
Showing 9 changed files with 802 additions and 385 deletions.
2 changes: 1 addition & 1 deletion mist/models/get_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def load_model_from_config(weights_path, model_config_path):
model = get_model(**model_config)

# Trick for loading DDP model
state_dict = torch.load(weights_path)
state_dict = torch.load(weights_path, weights_only=True)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
# remove 'module.' of DataParallel/DistributedDataParallel
Expand Down
4 changes: 2 additions & 2 deletions mist/models/mgnets.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,8 @@ def __init__(self,
# VAE regularization
if self.vae_reg:
self.normal_dist = torch.distributions.Normal(0, 1)
self.normal_dist.loc = self.normal_dist.loc # .cuda()
self.normal_dist.scale = self.normal_dist.scale # .cuda()
self.normal_dist.loc = self.normal_dist.loc.cuda()
self.normal_dist.scale = self.normal_dist.scale.cuda()

self.mu = nn.Linear(self.out_channels * (len(self.spikes) + 1), self.latent_dim)
self.sigma = nn.Linear(self.out_channels * (len(self.spikes) + 1), self.latent_dim)
Expand Down
2 changes: 1 addition & 1 deletion mist/models/nnunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def get_upsamples(self):

def get_upsamples_vae(self):
inp, out = self.filters[1:][::-1], self.filters[:-1][::-1]
inp[0] = 1
inp[0] = self.in_channels
strides, kernel_size = self.strides[1:][::-1], self.kernel_size[1:][::-1]
upsample_kernel_size = self.upsample_kernel_size[::-1]
return self.get_module_list(
Expand Down
4 changes: 2 additions & 2 deletions mist/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def __init__(self,
# VAE Regularization
if self.vae_reg:
self.normal_dist = torch.distributions.Normal(0, 1)
self.normal_dist.loc = self.normal_dist.loc # .cuda()
self.normal_dist.scale = self.normal_dist.scale # .cuda()
self.normal_dist.loc = self.normal_dist.loc.cuda()
self.normal_dist.scale = self.normal_dist.scale.cuda()

self.global_maxpool = GlobalMaxPooling3D()
self.mu = nn.Linear(self.channels[0][0], self.latent_dim)
Expand Down
13 changes: 13 additions & 0 deletions mist/runtime/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Exceptions for MIST."""


class InsufficientValidationSetError(Exception):
"""Raised if validation set size is smaller than the number of GPUs."""

def __init__(self, val_size: int, world_size: int) -> None:
self.message = (
f"Validation set size of {val_size} is too small for {world_size} "
"GPUs. Please increase the validation set size or reduce the "
"number of GPUs."
)
super().__init__(self.message)
File renamed without changes.
Loading

0 comments on commit 26aa9bb

Please sign in to comment.