Skip to content

Commit

Permalink
yo
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 committed Sep 30, 2024
1 parent 311f92b commit 81e98b2
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 8 deletions.
18 changes: 15 additions & 3 deletions llmfoundry/optim/no_op.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,25 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Callable, Iterable, Optional

import torch
from typing import Iterable, Any, Optional, Callable


class NoOp(torch.optim.Optimizer):
"""Optimizer that performs no optimization steps."""

def __init__(
self,
params: Iterable[torch.Tensor],
):
"""Initialize NoOp optimizer.
Args:
params (Iterable[torch.Tensor]): Model parameters for the optimizer.
"""
# LR schedulers expect param groups to have LR. Unused.
defaults = {"lr": 0.0}
defaults = {'lr': 0.0}
super().__init__(params, defaults)

def __setstate__(self, state: dict[str, dict[Any, Any]]) -> None:
Expand All @@ -29,4 +41,4 @@ def step(self, closure: Optional[Callable] = None):
with torch.enable_grad():
loss = closure()

return loss
return loss
15 changes: 10 additions & 5 deletions tests/optim/test_no_op.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import torch
import copy
from typing import Callable

import torch
from composer.trainer import Trainer
from torch.utils.data import DataLoader

from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM
from composer.trainer import Trainer
from llmfoundry.utils.builders import build_optimizer
import copy


def test_no_op_does_nothing(
Expand All @@ -20,7 +22,7 @@ def test_no_op_does_nothing(
loss_fn='torch_crossentropy',
attn_config={
'attn_impl': 'torch',
}
},
)

# Build NoOp optimizer
Expand All @@ -38,6 +40,9 @@ def test_no_op_does_nothing(
trainer.fit()

# Check that the model has not changed
for ((orig_name, orig_param), (new_name, new_param)) in zip(orig_model.named_parameters(), model.named_parameters()):
for (
(orig_name, orig_param),
(new_name, new_param),
) in zip(orig_model.named_parameters(), model.named_parameters()):
print(f'Checking {orig_name} and {new_name}')
assert torch.equal(orig_param, new_param)

0 comments on commit 81e98b2

Please sign in to comment.