diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 54ee19536..fffddac81 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -56,6 +56,7 @@ from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils.callbacks import ( EvalFirstStepCallback, + GCCallback, GPUStatsCallback, LossWatchDogCallback, SaveAxolotlConfigtoWandBCallback, @@ -1452,6 +1453,8 @@ def get_callbacks(self): if self.cfg.loss_watchdog_threshold is not None: callbacks.append(LossWatchDogCallback(self.cfg)) + if self.cfg.gc_steps: + callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps)) callbacks.append(SaveModelCallback()) return callbacks diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 641c9b162..f1b459b6b 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations +import gc import logging import math import os @@ -842,3 +843,17 @@ def on_train_end( # pylint: disable=unused-argument ): control.should_save = True return control + + +class GCCallback(TrainerCallback): + """Callback to garbage collect torch cache""" + + def __init__(self, gc_steps=None): + self.gc_steps = gc_steps + + def on_step_end( + self, args, state, control, **kwargs # pylint: disable=unused-argument + ): + if state.global_step % self.gc_steps == 0: + torch.cuda.empty_cache() + gc.collect() diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 5ddf04811..c704be800 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -666,6 +666,8 @@ class Config: loss_watchdog_threshold: Optional[float] = None loss_watchdog_patience: Optional[int] = None + gc_steps: Optional[int] = None + bf16: Optional[Union[Literal["auto"], bool]] = "auto" fp16: Optional[bool] = None bfloat16: Optional[bool] = None # for non-AMP cases