diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl index 140fb0b9263748..f6447ff12454f3 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl @@ -867,7 +867,11 @@ KERNEL(sdpa_opt)( #define b0_idx (batch_idx / NUM_HEADS) #define b1_idx (batch_idx % NUM_HEADS) #define target_seq_dim ((uint)get_global_id(1)) +#if IS_PAGED_ATTENTION + #define target_seq_idx ((uint)block_start_pos - subsequence_begins[gws_seq_indexes_correspondence[target_seq_dim]]) +#else #define target_seq_idx ((uint)get_global_id(1) * TARGET_SEQ_LEN_BLOCK_SIZE) +#endif #define head_size_idx ((uint)get_local_id(2) % HEAD_SIZE) #define sglid (uint)get_sub_group_local_id() #define sgid (uint)get_sub_group_id() @@ -994,8 +998,15 @@ KERNEL(sdpa_opt)( __attribute__((opencl_unroll_hint(1))) for (uint start_partition_idx = 0; start_partition_idx < SOURCE_SEQ_LEN; start_partition_idx += SEQ_LEN_PARTITION_SIZE) { const uint seq_len = start_partition_idx + sgid * SUBGROUP_SIZE; +#if IS_CAUSAL + const uint partition_seq_len = min((uint)SEQ_LEN_PARTITION_SIZE, (uint)max(0, (int)(target_seq_idx + seq_idx_end) - (int)start_partition_idx)); +#else const uint partition_seq_len = min((uint)SOURCE_SEQ_LEN - start_partition_idx, (uint)SEQ_LEN_PARTITION_SIZE); +#endif +#if IS_CAUSAL + if (seq_len <= target_seq_idx) { // keep tril i.e. m >= n +#endif #if IS_PAGED_ATTENTION #ifdef BROADCAST_GROUP_SIZE const uint heads_dim = num_heads_dim / BROADCAST_GROUP_SIZE; @@ -1026,21 +1037,21 @@ KERNEL(sdpa_opt)( #endif int seq_len_calc_size = min((int)(SOURCE_SEQ_LEN) - (int)seq_len, (int)SUBGROUP_SIZE); +#if IS_CAUSAL + MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZERO; +#else // !IS_CAUSAL MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc; qk_acc = FUNC_CALL(load_attn_mask)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, -#if IS_PAGED_ATTENTION - block_start_pos - subsequence_begins[gws_seq_indexes_correspondence[target_seq_dim]] + sglid, -#else target_seq_idx + sglid, -#endif // TODO: pass seq_len_calc_size here seq_len ATTN_MASK_BUFFER ATTN_SCALE_BUFFER PA_BUFFERS); +#endif // !IS_CAUSAL if (seq_len_calc_size >= SUBGROUP_SIZE) { #if IS_KV_COMPRESSED @@ -1157,6 +1168,10 @@ KERNEL(sdpa_opt)( { SOFTMAX_ACCUMULATOR_TYPE qk_max = SOFTMAX_ACCUMULATOR_VAL_MIN; unroll_for (uint i = 0; i < TARGET_SEQ_LEN_BLOCK_SIZE; i++) { +#if IS_CAUSAL + // casual mask: valid only if m >= n + if (seq_len + i <= target_seq_idx + sglid) { +#endif // IS_CAUSAL #if !APPLY_SCALES_TO_QUERY #if HAS_SCALE_INPUT const OUTPUT_TYPE scale_val = *scale; @@ -1172,12 +1187,21 @@ KERNEL(sdpa_opt)( #endif qk_acc[i] = INPUT0_MIN_FUNC(INPUT0_MAX_FUNC(qk_acc[i], INPUT0_VAL_MIN), INPUT0_VAL_MAX); - +#if IS_CAUSAL + } else { + qk_acc[i] = INPUT0_VAL_MIN; + } +#endif // IS_CAUSAL qk_max = SOFTMAX_ACCUMULATOR_MAX_FUNC(qk_max, TO_SOFTMAX_ACCUMULATOR_TYPE(qk_acc[i])); slm_qk_vals[sglid][sgid * TARGET_SEQ_LEN_BLOCK_SIZE + i] = qk_acc[i]; } slm_qk_max_vals[sglid][sgid] = qk_max; } +#if IS_CAUSAL + } else { // skip triu + slm_qk_max_vals[sglid][sgid] = SOFTMAX_ACCUMULATOR_VAL_MIN; + } +#endif barrier(CLK_LOCAL_MEM_FENCE); diff --git a/src/plugins/intel_gpu/tests/functional/subgraph_tests/dynamic/kv_cache_sdpa.cpp b/src/plugins/intel_gpu/tests/functional/subgraph_tests/dynamic/kv_cache_sdpa.cpp index fe923135550e5b..89612039fb788f 100644 --- a/src/plugins/intel_gpu/tests/functional/subgraph_tests/dynamic/kv_cache_sdpa.cpp +++ b/src/plugins/intel_gpu/tests/functional/subgraph_tests/dynamic/kv_cache_sdpa.cpp @@ -351,6 +351,22 @@ std::vector get_test_params() { p.push_back({with_rearrange, with_mask, !with_scale, !causal, compressed, 1, ov::element::Type_t::f16, 10, 1, 1, {0, 1, 2, 3}}); p.push_back({with_rearrange, with_mask, !with_scale, !causal, compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 2, 1, 3}}); p.push_back({with_rearrange, with_mask, !with_scale, !causal, compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 1, 2, 3}}); + + /* -- causal mask -- */ + + p.push_back({with_rearrange, !with_mask, !with_scale, causal, !compressed, 1, ov::element::Type_t::f16, 10, 1, 1, {0, 1, 2, 3}}); + p.push_back({with_rearrange, with_mask, !with_scale, causal, !compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 1, 2, 3}}); + p.push_back({with_rearrange, with_mask, !with_scale, causal, !compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 2, 1, 3}}); + p.push_back({!with_rearrange, with_mask, !with_scale, causal, !compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 2, 1, 3}}); + + // Beam search + p.push_back({with_rearrange, !with_mask, !with_scale, causal, !compressed, 2, ov::element::Type_t::f16, 10, 4, 1, {0, 1, 2, 3}}); + p.push_back({with_rearrange, !with_mask, !with_scale, causal, !compressed, 4, ov::element::Type_t::f16, 5, 16, 1, {0, 2, 1, 3}}); + + // Compressed + p.push_back({with_rearrange, with_mask, !with_scale, causal, compressed, 1, ov::element::Type_t::f16, 10, 1, 1, {0, 1, 2, 3}}); + p.push_back({with_rearrange, with_mask, !with_scale, causal, compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 2, 1, 3}}); + p.push_back({with_rearrange, with_mask, !with_scale, causal, compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 1, 2, 3}}); return p; }