Skip to content

Commit

Permalink
Mark model input as dynamically sized (#105)
Browse files Browse the repository at this point in the history
I tested this out with the 7B and it does not hurt throughput.
  • Loading branch information
epwalsh authored Nov 17, 2024
1 parent 776e235 commit 57b38ad
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 2 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Fixed

- (Optimization) Mark model input sizes as dynamic for `torch.compile()` to avoid recompile during evals or variable-sequence / batch size training. This doesn't seem to hurt throughput.

## [v1.6.3](https://github.com/allenai/OLMo-core/releases/tag/v1.6.3) - 2024-11-15

### Added
Expand Down
11 changes: 10 additions & 1 deletion src/olmo_core/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
fused_cross_entropy_loss,
)
from ..optim import SkipStepOptimizer
from ..utils import cuda_sync_debug_mode, move_to_device
from ..utils import cuda_sync_debug_mode, mark_dynamic, move_to_device
from .callbacks import (
Callback,
CheckpointerCallback,
Expand Down Expand Up @@ -880,6 +880,15 @@ def model_forward(self, micro_batch: Dict[str, Any]) -> torch.Tensor:
Run a forward pass on a micro-batch, returning the logits.
"""
with self._model_forward_context():
# NOTE: Input sizes might be dynamic, e.g. when training with variable sequence lengths
# or during an eval loop, so we mark them as dynamic for torch.compile up-front to avoid
# recompiling later.
# In theory this could harm performance a bit when input sizes are actually static
# but so far I haven't noticed any dip in throughput with the models I've tested.
mark_dynamic(micro_batch["input_ids"], (0, 1))
if "doc_lens" in micro_batch:
mark_dynamic(micro_batch["doc_lens"], (0, 1))

# shape: (batch_size, seq_len, vocab_size)
logits = self.model(
input_ids=micro_batch["input_ids"],
Expand Down
9 changes: 8 additions & 1 deletion src/olmo_core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from itertools import cycle, islice
from queue import Queue
from threading import Thread
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union

import rich
import torch
Expand Down Expand Up @@ -111,6 +111,13 @@ def move_to_device(o: T, device: torch.device, non_blocking: Optional[bool] = No
return o


def mark_dynamic(x: torch.Tensor, dim: Union[int, Sequence[int]]):
"""
Mark a tensor as having dynamic sizes for ``torch.compile()``.
"""
torch._dynamo.mark_dynamic(x, dim)


def get_default_device() -> torch.device:
"""
Get the default device.
Expand Down

0 comments on commit 57b38ad

Please sign in to comment.