diff --git a/CHANGELOG.md b/CHANGELOG.md index 0dc35036..57be3aa9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Fixed + +- (Optimization) Mark model input sizes as dynamic for `torch.compile()` to avoid recompile during evals or variable-sequence / batch size training. This doesn't seem to hurt throughput. + ## [v1.6.3](https://github.com/allenai/OLMo-core/releases/tag/v1.6.3) - 2024-11-15 ### Added diff --git a/src/olmo_core/train/trainer.py b/src/olmo_core/train/trainer.py index f0d32441..82309025 100644 --- a/src/olmo_core/train/trainer.py +++ b/src/olmo_core/train/trainer.py @@ -48,7 +48,7 @@ fused_cross_entropy_loss, ) from ..optim import SkipStepOptimizer -from ..utils import cuda_sync_debug_mode, move_to_device +from ..utils import cuda_sync_debug_mode, mark_dynamic, move_to_device from .callbacks import ( Callback, CheckpointerCallback, @@ -880,6 +880,15 @@ def model_forward(self, micro_batch: Dict[str, Any]) -> torch.Tensor: Run a forward pass on a micro-batch, returning the logits. """ with self._model_forward_context(): + # NOTE: Input sizes might be dynamic, e.g. when training with variable sequence lengths + # or during an eval loop, so we mark them as dynamic for torch.compile up-front to avoid + # recompiling later. + # In theory this could harm performance a bit when input sizes are actually static + # but so far I haven't noticed any dip in throughput with the models I've tested. + mark_dynamic(micro_batch["input_ids"], (0, 1)) + if "doc_lens" in micro_batch: + mark_dynamic(micro_batch["doc_lens"], (0, 1)) + # shape: (batch_size, seq_len, vocab_size) logits = self.model( input_ids=micro_batch["input_ids"], diff --git a/src/olmo_core/utils.py b/src/olmo_core/utils.py index 90cdfd64..64ba1193 100644 --- a/src/olmo_core/utils.py +++ b/src/olmo_core/utils.py @@ -12,7 +12,7 @@ from itertools import cycle, islice from queue import Queue from threading import Thread -from typing import Any, Callable, Dict, List, Optional, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union import rich import torch @@ -111,6 +111,13 @@ def move_to_device(o: T, device: torch.device, non_blocking: Optional[bool] = No return o +def mark_dynamic(x: torch.Tensor, dim: Union[int, Sequence[int]]): + """ + Mark a tensor as having dynamic sizes for ``torch.compile()``. + """ + torch._dynamo.mark_dynamic(x, dim) + + def get_default_device() -> torch.device: """ Get the default device.