diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 7d41126390..1abbfab12d 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -617,7 +617,7 @@ def pre_attn_forward( else: past_key_value = None - if use_flash_attention and FusedSDPA: + if use_flash_attention and FusedSDPA is not None: import habana_frameworks.torch.hpu as ht softmax_mode = "fast" if flash_attention_fast_softmax else "None"