Skip to content

Commit

Permalink
Prevent Graph break in Llama when using flash attention (huggingface#…
Browse files Browse the repository at this point in the history
  • Loading branch information
pramodkumar-habanalabs authored Aug 30, 2024
1 parent 35e0145 commit a1a92c9
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ def pre_attn_forward(
else:
past_key_value = None

if use_flash_attention and FusedSDPA:
if use_flash_attention and FusedSDPA is not None:
import habana_frameworks.torch.hpu as ht

softmax_mode = "fast" if flash_attention_fast_softmax else "None"
Expand Down

0 comments on commit a1a92c9

Please sign in to comment.