Skip to content

Commit

Permalink
Fsdp grad accum monkeypatch (#2064)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian authored and bursteratom committed Nov 18, 2024
1 parent b8a63b8 commit 8dc9e72
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 1 deletion.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.13.2
transformers==4.46.1
transformers==4.46.2
tokenizers>=0.20.1
bitsandbytes==0.44.1
accelerate==1.1.0
Expand Down
83 changes: 83 additions & 0 deletions src/axolotl/monkeypatch/trainer_fsdp_grad_accum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""
fix for FSDP gradient accumulation
see https://github.com/huggingface/transformers/pull/34645
"""
import inspect

from accelerate.logging import get_logger
from transformers.trainer import Trainer

from axolotl.monkeypatch.unsloth_ import detab_code

LOG = get_logger("axolotl.monkeypatch.trainer_fsdp_grad_accumulation")

ORIGINAL_CONTEXT_CODE = """
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i == len(batch_samples) - 1
else contextlib.nullcontext
)
"""

PATCHED_CONTEXT_CODE = """
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i != len(batch_samples) - 1
else contextlib.nullcontext
)
"""


def get_training_loop_code() -> str:
training_loop = inspect.getsource(
Trainer._inner_training_loop # pylint: disable=protected-access
)
return training_loop


def check_training_loop_is_patchable() -> bool:
train_loop = get_training_loop_code()
train_loop, _ = detab_code(train_loop)
return ORIGINAL_CONTEXT_CODE in train_loop


def patch_training_loop_for_fsdp_grad_accum():
"""
monkeypatch for fixing the training loop for FSDP gradient accumulation
"""

train_loop = get_training_loop_code()
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
train_loop
)
train_loop, _ = detab_code(train_loop)
assert (
ORIGINAL_CONTEXT_CODE in train_loop
), "Original _inner_training_loop code not found"

train_loop = train_loop.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE)
train_loop = train_loop.replace(
"def _inner_training_loop(",
"def _fixed_inner_training_loop(",
1,
)

# load imports necessary
import transformers.trainer

items_to_import = []
for item in dir(transformers.trainer):
if item in train_loop:
items_to_import.append(item)

exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(train_loop, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _inner_training_loop", main_process_only=True)
Trainer._inner_training_loop = ( # pylint: disable=protected-access
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
)
8 changes: 8 additions & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from transformers.utils import is_torch_bf16_gpu_available

from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.monkeypatch.trainer_fsdp_grad_accum import (
patch_training_loop_for_fsdp_grad_accum,
)
from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
Expand Down Expand Up @@ -493,6 +496,11 @@ def prepare_opinionated_env(cfg):
def setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
):
if cfg.fsdp:
try:
patch_training_loop_for_fsdp_grad_accum()
except AssertionError:
pass
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
trainer_builder.model_ref = model[1]
Expand Down
15 changes: 15 additions & 0 deletions tests/e2e/patched/test_trainer_fsdp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
import unittest

from axolotl.monkeypatch.trainer_fsdp_grad_accum import check_training_loop_is_patchable


class TestTrainerFSDPIntegration(unittest.TestCase):
"""Unsloth monkeypatch integration tests."""

def test_train_loop_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code
self.assertTrue(
check_training_loop_is_patchable(),
"HF transformers _inner_training_loop has changed and isn't patchable",
)

0 comments on commit 8dc9e72

Please sign in to comment.