diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 0376493c2b4646..49d086c76e8683 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1537,7 +1537,7 @@ def _autoset_attn_implementation( "Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends." ) torch.backends.cuda.enable_flash_sdp(False) - elif requested_attn_implementation in ALL_ATTENTION_FUNCTIONS.keys(): + elif requested_attn_implementation in list(ALL_ATTENTION_FUNCTIONS.keys()): config._attn_implementation = requested_attn_implementation elif isinstance(requested_attn_implementation, dict): config._attn_implementation = None