From 6e5aac8613746d98566e24c8708ac2ef425f0a94 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 19 Dec 2024 16:57:05 +0100 Subject: [PATCH] correctly set the _attn_implementation when adding other functions to it --- src/transformers/modeling_utils.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 9dcd6d758ecbe7..2fb9e3c6ebe89d 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1473,12 +1473,7 @@ def _autoset_attn_implementation( ' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.' ) - if not isinstance(config._attn_implementation, dict) and config._attn_implementation not in [ - "eager", - "sdpa", - "flash_attention_2", - "flex_attention", - ]: + if not isinstance(config._attn_implementation, dict) and config._attn_implementation not in ["eager"] + list(ALL_ATTENTION_FUNCTIONS.keys()): message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)' if cls._supports_flash_attn_2: message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)' @@ -1540,6 +1535,8 @@ def _autoset_attn_implementation( "Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends." ) torch.backends.cuda.enable_flash_sdp(False) + elif requested_attn_implementation in ALL_ATTENTION_FUNCTIONS.keys(): + config._attn_implementation = requested_attn_implementation elif isinstance(requested_attn_implementation, dict): config._attn_implementation = None else: