diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index a49df831b46ea..3671c2f91e3b7 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -912,9 +912,8 @@ def check_and_convert(t, scale): p_descale = 1.0 / p_scale o_descale = 1.0 / o_scale - if is_navi(): - max_seqlens_q = 0 - max_seqlens_k = 0 + arg_max_seqlens_q = 0 if is_navi() else max_seqlens_q + arg_max_seqlens_k = 0 if is_navi() else max_seqlens_k attn_fwd[grid]( q, @@ -944,8 +943,8 @@ def check_and_convert(t, scale): HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, - MAX_SEQLENS_Q=max_seqlens_q, - MAX_SEQLENS_K=max_seqlens_k, + MAX_SEQLENS_Q=arg_max_seqlens_q, + MAX_SEQLENS_K=arg_max_seqlens_k, IS_CAUSAL=causal, VARLEN=True, BLOCK_DMODEL=padded_d_model,