Skip to content

Commit

Permalink
correctly set the _attn_implementation when adding other functions to it
Browse files Browse the repository at this point in the history
  • Loading branch information
Cyrilvallez committed Dec 19, 2024
1 parent f74a08e commit 6e5aac8
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,12 +1473,7 @@ def _autoset_attn_implementation(
' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.'
)

if not isinstance(config._attn_implementation, dict) and config._attn_implementation not in [
"eager",
"sdpa",
"flash_attention_2",
"flex_attention",
]:
if not isinstance(config._attn_implementation, dict) and config._attn_implementation not in ["eager"] + list(ALL_ATTENTION_FUNCTIONS.keys()):
message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
if cls._supports_flash_attn_2:
message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
Expand Down Expand Up @@ -1540,6 +1535,8 @@ def _autoset_attn_implementation(
"Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
)
torch.backends.cuda.enable_flash_sdp(False)
elif requested_attn_implementation in ALL_ATTENTION_FUNCTIONS.keys():
config._attn_implementation = requested_attn_implementation
elif isinstance(requested_attn_implementation, dict):
config._attn_implementation = None
else:
Expand Down

0 comments on commit 6e5aac8

Please sign in to comment.