Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU]Improve sdpa_opt kernel by skipping computes of causal mask #28260

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,11 @@ KERNEL(sdpa_opt)(
#define b0_idx (batch_idx / NUM_HEADS)
#define b1_idx (batch_idx % NUM_HEADS)
#define target_seq_dim ((uint)get_global_id(1))
#if IS_PAGED_ATTENTION
#define target_seq_idx ((uint)block_start_pos - subsequence_begins[gws_seq_indexes_correspondence[target_seq_dim]])
#else
#define target_seq_idx ((uint)get_global_id(1) * TARGET_SEQ_LEN_BLOCK_SIZE)
#endif
#define head_size_idx ((uint)get_local_id(2) % HEAD_SIZE)
#define sglid (uint)get_sub_group_local_id()
#define sgid (uint)get_sub_group_id()
Expand Down Expand Up @@ -994,8 +998,15 @@ KERNEL(sdpa_opt)(
__attribute__((opencl_unroll_hint(1)))
for (uint start_partition_idx = 0; start_partition_idx < SOURCE_SEQ_LEN; start_partition_idx += SEQ_LEN_PARTITION_SIZE) {
const uint seq_len = start_partition_idx + sgid * SUBGROUP_SIZE;
#if IS_CAUSAL
const uint partition_seq_len = min((uint)SEQ_LEN_PARTITION_SIZE, (uint)max(0, (int)(target_seq_idx + seq_idx_end) - (int)start_partition_idx));
#else
const uint partition_seq_len = min((uint)SOURCE_SEQ_LEN - start_partition_idx, (uint)SEQ_LEN_PARTITION_SIZE);
#endif

#if IS_CAUSAL
if (seq_len <= target_seq_idx) { // keep tril i.e. m >= n
#endif
#if IS_PAGED_ATTENTION
#ifdef BROADCAST_GROUP_SIZE
const uint heads_dim = num_heads_dim / BROADCAST_GROUP_SIZE;
Expand Down Expand Up @@ -1026,6 +1037,9 @@ KERNEL(sdpa_opt)(
#endif

int seq_len_calc_size = min((int)(SOURCE_SEQ_LEN) - (int)seq_len, (int)SUBGROUP_SIZE);
#if IS_CAUSAL
MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = OUTPUT_VAL_ZERO;
Copy link
Contributor

@zaixing-wang zaixing-wang Jan 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use qk_acc=1e-9 here, so that we don't have to assign qk_acc ​​again in line 1196?

#else // !IS_CAUSAL
MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc;

qk_acc = FUNC_CALL(load_attn_mask)(OPTIONAL_SHAPE_INFO_TENSOR
Expand All @@ -1041,6 +1055,7 @@ KERNEL(sdpa_opt)(
ATTN_MASK_BUFFER
ATTN_SCALE_BUFFER
PA_BUFFERS);
#endif // !IS_CAUSAL

if (seq_len_calc_size >= SUBGROUP_SIZE) {
#if IS_KV_COMPRESSED
Expand Down Expand Up @@ -1157,6 +1172,10 @@ KERNEL(sdpa_opt)(
{
SOFTMAX_ACCUMULATOR_TYPE qk_max = SOFTMAX_ACCUMULATOR_VAL_MIN;
unroll_for (uint i = 0; i < TARGET_SEQ_LEN_BLOCK_SIZE; i++) {
#if IS_CAUSAL
// casual mask: valid only if query <= kv_len
if (seq_len + i <= target_seq_idx + sglid) {
#endif // IS_CAUSAL
#if !APPLY_SCALES_TO_QUERY
#if HAS_SCALE_INPUT
const OUTPUT_TYPE scale_val = *scale;
Expand All @@ -1172,12 +1191,21 @@ KERNEL(sdpa_opt)(
#endif

qk_acc[i] = INPUT0_MIN_FUNC(INPUT0_MAX_FUNC(qk_acc[i], INPUT0_VAL_MIN), INPUT0_VAL_MAX);

#if IS_CAUSAL
} else {
qk_acc[i] = -1e9f;
}
#endif // IS_CAUSAL
qk_max = SOFTMAX_ACCUMULATOR_MAX_FUNC(qk_max, TO_SOFTMAX_ACCUMULATOR_TYPE(qk_acc[i]));
slm_qk_vals[sglid][sgid * TARGET_SEQ_LEN_BLOCK_SIZE + i] = qk_acc[i];
}
slm_qk_max_vals[sglid][sgid] = qk_max;
}
#if IS_CAUSAL
} else { // skip triu
slm_qk_max_vals[sglid][sgid] = -1e9f;
}
#endif

barrier(CLK_LOCAL_MEM_FENCE);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,22 @@ std::vector<Params> get_test_params() {
p.push_back({with_rearrange, with_mask, !with_scale, !causal, compressed, 1, ov::element::Type_t::f16, 10, 1, 1, {0, 1, 2, 3}});
p.push_back({with_rearrange, with_mask, !with_scale, !causal, compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 2, 1, 3}});
p.push_back({with_rearrange, with_mask, !with_scale, !causal, compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 1, 2, 3}});

/* -- causal mask -- */

p.push_back({with_rearrange, !with_mask, !with_scale, causal, !compressed, 1, ov::element::Type_t::f16, 10, 1, 1, {0, 1, 2, 3}});
p.push_back({with_rearrange, with_mask, !with_scale, causal, !compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 1, 2, 3}});
p.push_back({with_rearrange, with_mask, !with_scale, causal, !compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 2, 1, 3}});
p.push_back({!with_rearrange, with_mask, !with_scale, causal, !compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 2, 1, 3}});

// Beam search
p.push_back({with_rearrange, !with_mask, !with_scale, causal, !compressed, 2, ov::element::Type_t::f16, 10, 4, 1, {0, 1, 2, 3}});
p.push_back({with_rearrange, !with_mask, !with_scale, causal, !compressed, 4, ov::element::Type_t::f16, 5, 16, 1, {0, 2, 1, 3}});

// Compressed
p.push_back({with_rearrange, with_mask, !with_scale, causal, compressed, 1, ov::element::Type_t::f16, 10, 1, 1, {0, 1, 2, 3}});
p.push_back({with_rearrange, with_mask, !with_scale, causal, compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 2, 1, 3}});
p.push_back({with_rearrange, with_mask, !with_scale, causal, compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 1, 2, 3}});
return p;
}

Expand Down
Loading