From d42f2020466067c77ccbb325411f291caf8fb696 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 15 Nov 2024 19:11:04 -0500 Subject: [PATCH] Fsdp grad accum monkeypatch (#2064) --- requirements.txt | 2 +- .../monkeypatch/trainer_fsdp_grad_accum.py | 83 +++++++++++++++++++ src/axolotl/utils/trainer.py | 8 ++ tests/e2e/patched/test_trainer_fsdp.py | 15 ++++ 4 files changed, 107 insertions(+), 1 deletion(-) create mode 100644 src/axolotl/monkeypatch/trainer_fsdp_grad_accum.py create mode 100644 tests/e2e/patched/test_trainer_fsdp.py diff --git a/requirements.txt b/requirements.txt index f352fecda9..0997930f4c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ packaging==23.2 peft==0.13.2 -transformers==4.46.1 +transformers==4.46.2 tokenizers>=0.20.1 bitsandbytes==0.44.1 accelerate==1.1.0 diff --git a/src/axolotl/monkeypatch/trainer_fsdp_grad_accum.py b/src/axolotl/monkeypatch/trainer_fsdp_grad_accum.py new file mode 100644 index 0000000000..6819fde111 --- /dev/null +++ b/src/axolotl/monkeypatch/trainer_fsdp_grad_accum.py @@ -0,0 +1,83 @@ +""" +fix for FSDP gradient accumulation +see https://github.com/huggingface/transformers/pull/34645 +""" +import inspect + +from accelerate.logging import get_logger +from transformers.trainer import Trainer + +from axolotl.monkeypatch.unsloth_ import detab_code + +LOG = get_logger("axolotl.monkeypatch.trainer_fsdp_grad_accumulation") + +ORIGINAL_CONTEXT_CODE = """ + context = ( + functools.partial(self.accelerator.no_sync, model=model) + if i == len(batch_samples) - 1 + else contextlib.nullcontext + ) +""" + +PATCHED_CONTEXT_CODE = """ + context = ( + functools.partial(self.accelerator.no_sync, model=model) + if i != len(batch_samples) - 1 + else contextlib.nullcontext + ) +""" + + +def get_training_loop_code() -> str: + training_loop = inspect.getsource( + Trainer._inner_training_loop # pylint: disable=protected-access + ) + return training_loop + + +def check_training_loop_is_patchable() -> bool: + train_loop = get_training_loop_code() + train_loop, _ = detab_code(train_loop) + return ORIGINAL_CONTEXT_CODE in train_loop + + +def patch_training_loop_for_fsdp_grad_accum(): + """ + monkeypatch for fixing the training loop for FSDP gradient accumulation + """ + + train_loop = get_training_loop_code() + Trainer._original_inner_training_loop = ( # pylint: disable=protected-access + train_loop + ) + train_loop, _ = detab_code(train_loop) + assert ( + ORIGINAL_CONTEXT_CODE in train_loop + ), "Original _inner_training_loop code not found" + + train_loop = train_loop.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE) + train_loop = train_loop.replace( + "def _inner_training_loop(", + "def _fixed_inner_training_loop(", + 1, + ) + + # load imports necessary + import transformers.trainer + + items_to_import = [] + for item in dir(transformers.trainer): + if item in train_loop: + items_to_import.append(item) + + exec( # pylint: disable=exec-used # nosec B102 + "from transformers.trainer import (" + + ", ".join(x for x in items_to_import) + + ")", + globals(), + ) + exec(train_loop, globals()) # pylint: disable=exec-used # nosec B102 + LOG.info("patching _inner_training_loop", main_process_only=True) + Trainer._inner_training_loop = ( # pylint: disable=protected-access + _fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821 + ) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 2d3a6944f7..5c9bfd6635 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -16,6 +16,9 @@ from transformers.utils import is_torch_bf16_gpu_available from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder +from axolotl.monkeypatch.trainer_fsdp_grad_accum import ( + patch_training_loop_for_fsdp_grad_accum, +) from axolotl.utils.distributed import reduce_and_broadcast from axolotl.utils.environment import check_cuda_p2p_ib_support from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths @@ -493,6 +496,11 @@ def prepare_opinionated_env(cfg): def setup_trainer( cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps ): + if cfg.fsdp: + try: + patch_training_loop_for_fsdp_grad_accum() + except AssertionError: + pass if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]: trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor) trainer_builder.model_ref = model[1] diff --git a/tests/e2e/patched/test_trainer_fsdp.py b/tests/e2e/patched/test_trainer_fsdp.py new file mode 100644 index 0000000000..1095cff3c0 --- /dev/null +++ b/tests/e2e/patched/test_trainer_fsdp.py @@ -0,0 +1,15 @@ +"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected.""" +import unittest + +from axolotl.monkeypatch.trainer_fsdp_grad_accum import check_training_loop_is_patchable + + +class TestTrainerFSDPIntegration(unittest.TestCase): + """Unsloth monkeypatch integration tests.""" + + def test_train_loop_patchable(self): + # ensures the current version of transformers has loss code that matches our patching code + self.assertTrue( + check_training_loop_is_patchable(), + "HF transformers _inner_training_loop has changed and isn't patchable", + )