Skip to content

Commit

Permalink
Override checkpointing hooks to exclude backbone while saving checkpo…
Browse files Browse the repository at this point in the history
…ints (#724)
  • Loading branch information
nkaenzig authored Dec 9, 2024
1 parent d0f5a03 commit 5c1b376
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 3 deletions.
20 changes: 19 additions & 1 deletion src/eva/core/models/modules/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from eva.core.metrics import structs as metrics_lib
from eva.core.models.modules import module
from eva.core.models.modules.typings import INPUT_BATCH, MODEL_TYPE
from eva.core.models.modules.utils import batch_postprocess, grad
from eva.core.models.modules.utils import batch_postprocess, grad, submodule_state_dict
from eva.core.utils import parser


Expand All @@ -32,6 +32,7 @@ def __init__(
lr_scheduler: LRSchedulerCallable = lr_scheduler.ConstantLR,
metrics: metrics_lib.MetricsSchema | None = None,
postprocess: batch_postprocess.BatchPostProcess | None = None,
save_head_only: bool = True,
) -> None:
"""Initializes the neural net head module.
Expand All @@ -48,6 +49,8 @@ def __init__(
postprocess: A list of helper functions to apply after the
loss and before the metrics calculation to the model
predictions and targets.
save_head_only: Whether to save only the head during checkpointing. If False,
will also save the backbone (not recommended when frozen).
"""
super().__init__(metrics=metrics, postprocess=postprocess)

Expand All @@ -56,6 +59,7 @@ def __init__(
self.backbone = backbone
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.save_head_only = save_head_only

@override
def configure_model(self) -> Any:
Expand All @@ -72,6 +76,20 @@ def configure_optimizers(self) -> Any:
lr_scheduler = self.lr_scheduler(optimizer)
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

@override
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
if self.save_head_only:
checkpoint["state_dict"] = submodule_state_dict(checkpoint["state_dict"], "head")
super().on_save_checkpoint(checkpoint)

@override
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
if self.save_head_only and self.backbone is not None:
checkpoint["state_dict"].update(
{f"backbone.{k}": v for k, v in self.backbone.state_dict().items()}
)
super().on_load_checkpoint(checkpoint)

@override
def forward(self, tensor: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
features = tensor if self.backbone is None else self.backbone(tensor)
Expand Down
3 changes: 2 additions & 1 deletion src/eva/core/models/modules/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@

from eva.core.models.modules.utils import grad
from eva.core.models.modules.utils.batch_postprocess import BatchPostProcess
from eva.core.models.modules.utils.checkpoint import submodule_state_dict

__all__ = ["grad", "BatchPostProcess"]
__all__ = ["grad", "BatchPostProcess", "submodule_state_dict"]
21 changes: 21 additions & 0 deletions src/eva/core/models/modules/utils/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Checkpointing related utilities and helper functions."""

from typing import Any, Dict


def submodule_state_dict(state_dict: Dict[str, Any], submodule_key: str) -> Dict[str, Any]:
"""Get the state dict of a specific submodule.
Args:
state_dict: The state dict to extract the submodule from.
submodule_key: The key of the submodule to extract.
Returns:
The subset of the state dict corresponding to the specified submodule.
"""
submodule_key = submodule_key if submodule_key.endswith(".") else submodule_key + "."
return {
module: weights
for module, weights in state_dict.items()
if module.startswith(submodule_key)
}
20 changes: 19 additions & 1 deletion src/eva/vision/models/modules/semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from eva.core.metrics import structs as metrics_lib
from eva.core.models.modules import module
from eva.core.models.modules.typings import INPUT_BATCH, INPUT_TENSOR_BATCH
from eva.core.models.modules.utils import batch_postprocess, grad
from eva.core.models.modules.utils import batch_postprocess, grad, submodule_state_dict
from eva.core.utils import parser
from eva.vision.models.networks import decoders
from eva.vision.models.networks.decoders.segmentation.typings import DecoderInputs
Expand All @@ -31,6 +31,7 @@ def __init__(
lr_scheduler: LRSchedulerCallable = lr_scheduler.ConstantLR,
metrics: metrics_lib.MetricsSchema | None = None,
postprocess: batch_postprocess.BatchPostProcess | None = None,
save_decoder_only: bool = True,
) -> None:
"""Initializes the neural net head module.
Expand All @@ -49,6 +50,8 @@ def __init__(
postprocess: A list of helper functions to apply after the
loss and before the metrics calculation to the model
predictions and targets.
save_decoder_only: Whether to save only the decoder during checkpointing. If False,
will also save the encoder (not recommended when frozen).
"""
super().__init__(metrics=metrics, postprocess=postprocess)

Expand All @@ -58,6 +61,7 @@ def __init__(
self.lr_multiplier_encoder = lr_multiplier_encoder
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.save_decoder_only = save_decoder_only

@override
def configure_model(self) -> None:
Expand All @@ -83,6 +87,20 @@ def configure_optimizers(self) -> Any:
lr_scheduler = self.lr_scheduler(optimizer)
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

@override
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
if self.save_decoder_only:
checkpoint["state_dict"] = submodule_state_dict(checkpoint["state_dict"], "decoder")
super().on_save_checkpoint(checkpoint)

@override
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
if self.save_decoder_only and self.encoder is not None:
checkpoint["state_dict"].update(
{f"encoder.{k}": v for k, v in self.encoder.state_dict().items()} # type: ignore
)
super().on_load_checkpoint(checkpoint)

@override
def forward(
self,
Expand Down

0 comments on commit 5c1b376

Please sign in to comment.