Skip to content

Commit

Permalink
reset LlamaForCausalLM too in fixtures for cce patch
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Dec 9, 2024
1 parent 94f7751 commit 385c906
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,13 @@ def temp_dir():
@pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches():
from transformers import Trainer
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
from transformers.models.llama.modeling_llama import (
LlamaFlashAttention2,
LlamaForCausalLM,
)

original_fa2_forward = LlamaFlashAttention2.forward
original_llama_forward = LlamaForCausalLM.forward
original_trainer_inner_training_loop = (
Trainer._inner_training_loop # pylint: disable=protected-access
)
Expand All @@ -131,13 +135,15 @@ def cleanup_monkeypatches():
yield
# Reset LlamaFlashAttention2 forward
LlamaFlashAttention2.forward = original_fa2_forward
LlamaForCausalLM.forward = original_llama_forward
Trainer._inner_training_loop = ( # pylint: disable=protected-access
original_trainer_inner_training_loop
)
Trainer.training_step = original_trainer_training_step

# Reset other known monkeypatches
modules_to_reset: list[tuple[str, list[str]]] = [
("transformers.models.llama",),
("transformers.models.llama.modeling_llama", ["LlamaFlashAttention2"]),
("transformers.trainer",),
("transformers", ["Trainer"]),
Expand Down

0 comments on commit 385c906

Please sign in to comment.