From 1e2093176515ddfd7a7dc5f77b2bb4d6a1bc3445 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri, 15 Dec 2023 11:08:27 +0100 Subject: [PATCH] [`FA-2`] Fix fa-2 issue when passing `config` to `from_pretrained` (#28043) * fix fa-2 issue * fix test * Update src/transformers/modeling_utils.py Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> * clenaer fix * up * add more robust tests * Update src/transformers/modeling_utils.py Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> * fixup * Update src/transformers/modeling_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * pop * add test --------- Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/modeling_utils.py | 12 ++++++++++++ tests/test_modeling_utils.py | 25 +++++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 3247c323685815..7e5d3e54e619e8 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2955,6 +2955,18 @@ def from_pretrained( **kwargs, ) else: + # In case one passes a config to `from_pretrained` + "attn_implementation" + # override the `_attn_implementation` attribute to `attn_implementation` of the kwargs + # Please see: https://github.com/huggingface/transformers/issues/28038 + + # Overwrite `config._attn_implementation` by the one from the kwargs --> in auto-factory + # we pop attn_implementation from the kwargs but this handles the case where users + # passes manually the config to `from_pretrained`. + config = copy.deepcopy(config) + + kwarg_attn_imp = kwargs.pop("attn_implementation", None) + if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp: + config._attn_implementation = kwarg_attn_imp model_kwargs = kwargs quantizer = None diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index ddfaad5214dc50..a8a483b4017c84 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -1823,6 +1823,16 @@ def test_error_no_flash_available(self): self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception)) + def test_error_no_flash_available_with_config(self): + with self.assertRaises(ValueError) as cm: + config = AutoConfig.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel") + + _ = AutoModel.from_pretrained( + "hf-tiny-model-private/tiny-random-MCTCTModel", config=config, attn_implementation="flash_attention_2" + ) + + self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception)) + def test_error_wrong_attn_implementation(self): with self.assertRaises(ValueError) as cm: _ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="foo") @@ -1840,6 +1850,21 @@ def test_not_available_flash(self): self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception)) + def test_not_available_flash_with_config(self): + if is_flash_attn_2_available(): + self.skipTest("Please uninstall flash-attn package to run test_not_available_flash") + + config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-GPTBigCodeModel") + + with self.assertRaises(ImportError) as cm: + _ = AutoModel.from_pretrained( + "hf-internal-testing/tiny-random-GPTBigCodeModel", + config=config, + attn_implementation="flash_attention_2", + ) + + self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception)) + def test_not_available_sdpa(self): if is_torch_sdpa_available(): self.skipTest("This test requires torch<=2.0")