Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Override checkpointing hooks to exclude backbone while saving checkpoints #724

Merged
merged 3 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion src/eva/core/models/modules/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from eva.core.metrics import structs as metrics_lib
from eva.core.models.modules import module
from eva.core.models.modules.typings import INPUT_BATCH, MODEL_TYPE
from eva.core.models.modules.utils import batch_postprocess, grad
from eva.core.models.modules.utils import batch_postprocess, grad, submodule_state_dict
from eva.core.utils import parser


Expand All @@ -32,6 +32,7 @@ def __init__(
lr_scheduler: LRSchedulerCallable = lr_scheduler.ConstantLR,
metrics: metrics_lib.MetricsSchema | None = None,
postprocess: batch_postprocess.BatchPostProcess | None = None,
save_head_only: bool = True,
nkaenzig marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""Initializes the neural net head module.

Expand All @@ -48,6 +49,8 @@ def __init__(
postprocess: A list of helper functions to apply after the
loss and before the metrics calculation to the model
predictions and targets.
save_head_only: Whether to save only the head during checkpointing. If False,
will also save the backbone (not recommended when frozen).
"""
super().__init__(metrics=metrics, postprocess=postprocess)

Expand All @@ -56,6 +59,7 @@ def __init__(
self.backbone = backbone
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.save_head_only = save_head_only

@override
def configure_model(self) -> Any:
Expand All @@ -72,6 +76,20 @@ def configure_optimizers(self) -> Any:
lr_scheduler = self.lr_scheduler(optimizer)
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

@override
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
if self.save_head_only:
checkpoint["state_dict"] = submodule_state_dict(checkpoint["state_dict"], "head")
super().on_save_checkpoint(checkpoint)

@override
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
if self.save_head_only and self.backbone is not None:
checkpoint["state_dict"].update(
{f"backbone.{k}": v for k, v in self.backbone.state_dict().items()}
)
super().on_load_checkpoint(checkpoint)

@override
def forward(self, tensor: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
features = tensor if self.backbone is None else self.backbone(tensor)
Expand Down
3 changes: 2 additions & 1 deletion src/eva/core/models/modules/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@

from eva.core.models.modules.utils import grad
from eva.core.models.modules.utils.batch_postprocess import BatchPostProcess
from eva.core.models.modules.utils.checkpoint import submodule_state_dict

__all__ = ["grad", "BatchPostProcess"]
__all__ = ["grad", "BatchPostProcess", "submodule_state_dict"]
21 changes: 21 additions & 0 deletions src/eva/core/models/modules/utils/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Checkpointing related utilities and helper functions."""

from typing import Any, Dict


def submodule_state_dict(state_dict: Dict[str, Any], submodule_key: str) -> Dict[str, Any]:
"""Get the state dict of a specific submodule.

Args:
state_dict: The state dict to extract the submodule from.
submodule_key: The key of the submodule to extract.

Returns:
The subset of the state dict corresponding to the specified submodule.
"""
submodule_key = submodule_key if submodule_key.endswith(".") else submodule_key + "."
return {
module: weights
for module, weights in state_dict.items()
if module.startswith(submodule_key)
}
20 changes: 19 additions & 1 deletion src/eva/vision/models/modules/semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from eva.core.metrics import structs as metrics_lib
from eva.core.models.modules import module
from eva.core.models.modules.typings import INPUT_BATCH, INPUT_TENSOR_BATCH
from eva.core.models.modules.utils import batch_postprocess, grad
from eva.core.models.modules.utils import batch_postprocess, grad, submodule_state_dict
from eva.core.utils import parser
from eva.vision.models.networks import decoders
from eva.vision.models.networks.decoders.segmentation.typings import DecoderInputs
Expand All @@ -31,6 +31,7 @@ def __init__(
lr_scheduler: LRSchedulerCallable = lr_scheduler.ConstantLR,
metrics: metrics_lib.MetricsSchema | None = None,
postprocess: batch_postprocess.BatchPostProcess | None = None,
save_decoder_only: bool = True,
) -> None:
"""Initializes the neural net head module.

Expand All @@ -49,6 +50,8 @@ def __init__(
postprocess: A list of helper functions to apply after the
loss and before the metrics calculation to the model
predictions and targets.
save_decoder_only: Whether to save only the decoder during checkpointing. If False,
will also save the encoder (not recommended when frozen).
"""
super().__init__(metrics=metrics, postprocess=postprocess)

Expand All @@ -58,6 +61,7 @@ def __init__(
self.lr_multiplier_encoder = lr_multiplier_encoder
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.save_decoder_only = save_decoder_only

@override
def configure_model(self) -> None:
Expand All @@ -83,6 +87,20 @@ def configure_optimizers(self) -> Any:
lr_scheduler = self.lr_scheduler(optimizer)
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

@override
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
if self.save_decoder_only:
checkpoint["state_dict"] = submodule_state_dict(checkpoint["state_dict"], "decoder")
super().on_save_checkpoint(checkpoint)

@override
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
if self.save_decoder_only and self.encoder is not None:
checkpoint["state_dict"].update(
{f"encoder.{k}": v for k, v in self.encoder.state_dict().items()} # type: ignore
)
super().on_load_checkpoint(checkpoint)

@override
def forward(
self,
Expand Down
Loading