diff --git a/llmfoundry/optim/no_op.py b/llmfoundry/optim/no_op.py index f435917b36..416363c261 100644 --- a/llmfoundry/optim/no_op.py +++ b/llmfoundry/optim/no_op.py @@ -1,13 +1,25 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Iterable, Optional + import torch -from typing import Iterable, Any, Optional, Callable + class NoOp(torch.optim.Optimizer): + """Optimizer that performs no optimization steps.""" + def __init__( self, params: Iterable[torch.Tensor], ): + """Initialize NoOp optimizer. + + Args: + params (Iterable[torch.Tensor]): Model parameters for the optimizer. + """ # LR schedulers expect param groups to have LR. Unused. - defaults = {"lr": 0.0} + defaults = {'lr': 0.0} super().__init__(params, defaults) def __setstate__(self, state: dict[str, dict[Any, Any]]) -> None: @@ -29,4 +41,4 @@ def step(self, closure: Optional[Callable] = None): with torch.enable_grad(): loss = closure() - return loss \ No newline at end of file + return loss diff --git a/tests/optim/test_no_op.py b/tests/optim/test_no_op.py index eb1c2fb704..27766d6eaf 100644 --- a/tests/optim/test_no_op.py +++ b/tests/optim/test_no_op.py @@ -1,13 +1,15 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -import torch +import copy from typing import Callable + +import torch +from composer.trainer import Trainer from torch.utils.data import DataLoader + from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM -from composer.trainer import Trainer from llmfoundry.utils.builders import build_optimizer -import copy def test_no_op_does_nothing( @@ -20,7 +22,7 @@ def test_no_op_does_nothing( loss_fn='torch_crossentropy', attn_config={ 'attn_impl': 'torch', - } + }, ) # Build NoOp optimizer @@ -38,6 +40,9 @@ def test_no_op_does_nothing( trainer.fit() # Check that the model has not changed - for ((orig_name, orig_param), (new_name, new_param)) in zip(orig_model.named_parameters(), model.named_parameters()): + for ( + (orig_name, orig_param), + (new_name, new_param), + ) in zip(orig_model.named_parameters(), model.named_parameters()): print(f'Checking {orig_name} and {new_name}') assert torch.equal(orig_param, new_param)