From 9bd6c9482e0d7ab488693d840ad7ed0596ffd99a Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 18:56:50 +0100 Subject: [PATCH] fix default sdpa --- src/transformers/modeling_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c502742429e789..ee576f69cfe423 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1531,10 +1531,10 @@ def _autoset_attn_implementation( config = cls._check_and_enable_flex_attn(config, hard_check_only=True) elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available(): # use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif. - # config = cls._check_and_enable_sdpa( - # config, - # hard_check_only=False if requested_attn_implementation is None else True, - # ) + config = cls._check_and_enable_sdpa( + config, + hard_check_only=False if requested_attn_implementation is None else True, + ) if ( torch.version.hip is not None