Skip to content

Commit

Permalink
GC every n steps (#2209)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian authored Dec 21, 2024
1 parent 307cf7c commit 2312caa
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GCCallback,
GPUStatsCallback,
LossWatchDogCallback,
SaveAxolotlConfigtoWandBCallback,
Expand Down Expand Up @@ -1452,6 +1453,8 @@ def get_callbacks(self):
if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg))

if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
callbacks.append(SaveModelCallback())

return callbacks
Expand Down
15 changes: 15 additions & 0 deletions src/axolotl/utils/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import gc
import logging
import math
import os
Expand Down Expand Up @@ -842,3 +843,17 @@ def on_train_end( # pylint: disable=unused-argument
):
control.should_save = True
return control


class GCCallback(TrainerCallback):
"""Callback to garbage collect torch cache"""

def __init__(self, gc_steps=None):
self.gc_steps = gc_steps

def on_step_end(
self, args, state, control, **kwargs # pylint: disable=unused-argument
):
if state.global_step % self.gc_steps == 0:
torch.cuda.empty_cache()
gc.collect()
2 changes: 2 additions & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,8 @@ class Config:
loss_watchdog_threshold: Optional[float] = None
loss_watchdog_patience: Optional[int] = None

gc_steps: Optional[int] = None

bf16: Optional[Union[Literal["auto"], bool]] = "auto"
fp16: Optional[bool] = None
bfloat16: Optional[bool] = None # for non-AMP cases
Expand Down

0 comments on commit 2312caa

Please sign in to comment.