From c8509718792f4c1c69120f37888a05a39997b788 Mon Sep 17 00:00:00 2001 From: Muhammad Tanvir Date: Thu, 14 Nov 2024 16:42:28 +0000 Subject: [PATCH] Add Flash Attention v2 example --- examples/sycl/pvc/CMakeLists.txt | 2 + .../pvc/flash_attention_v2/CMakeLists.txt | 33 ++ .../pvc/flash_attention_v2/online_softmax.hpp | 234 ++++++++ .../pvc/flash_attention_v2/pvc_flash_attn.cpp | 532 ++++++++++++++++++ .../pvc_flash_attn_epilogue.hpp | 307 ++++++++++ .../pvc_flash_attn_gemm_universal.hpp | 350 ++++++++++++ .../flash_attention_v2/pvc_flash_attn_mma.hpp | 414 ++++++++++++++ 7 files changed, 1872 insertions(+) create mode 100644 examples/sycl/pvc/flash_attention_v2/CMakeLists.txt create mode 100644 examples/sycl/pvc/flash_attention_v2/online_softmax.hpp create mode 100644 examples/sycl/pvc/flash_attention_v2/pvc_flash_attn.cpp create mode 100644 examples/sycl/pvc/flash_attention_v2/pvc_flash_attn_epilogue.hpp create mode 100644 examples/sycl/pvc/flash_attention_v2/pvc_flash_attn_gemm_universal.hpp create mode 100644 examples/sycl/pvc/flash_attention_v2/pvc_flash_attn_mma.hpp diff --git a/examples/sycl/pvc/CMakeLists.txt b/examples/sycl/pvc/CMakeLists.txt index 322896e20e..f930939167 100644 --- a/examples/sycl/pvc/CMakeLists.txt +++ b/examples/sycl/pvc/CMakeLists.txt @@ -41,3 +41,5 @@ cutlass_example_add_executable( pvc_collective_builder pvc_collective_builder.cpp ) + +add_subdirectory(flash_attention_v2) diff --git a/examples/sycl/pvc/flash_attention_v2/CMakeLists.txt b/examples/sycl/pvc/flash_attention_v2/CMakeLists.txt new file mode 100644 index 0000000000..9712dd220f --- /dev/null +++ b/examples/sycl/pvc/flash_attention_v2/CMakeLists.txt @@ -0,0 +1,33 @@ +# Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +cutlass_example_add_executable( + pvc_flash_attention + pvc_flash_attn.cpp +) diff --git a/examples/sycl/pvc/flash_attention_v2/online_softmax.hpp b/examples/sycl/pvc/flash_attention_v2/online_softmax.hpp new file mode 100644 index 0000000000..c2046164aa --- /dev/null +++ b/examples/sycl/pvc/flash_attention_v2/online_softmax.hpp @@ -0,0 +1,234 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing online softmax. +*/ + +#pragma once + +#include +#include "cutlass/cutlass.h" + +namespace flash { + +template +struct MaxOp { + CUTLASS_DEVICE T + operator()(T const & x, T const & y) { return x > y ? x : y; } +}; + +template +struct SumOp { + CUTLASS_DEVICE T + operator()(T const & x, T const & y) { return x + y; } +}; + +template +struct Softmax { + struct Arguments { + Element scale; + }; + + using Params = Arguments; + + static constexpr Params + to_underlying_arguments(Arguments const& args) { + Arguments x{static_cast(args.scale) * static_cast(M_LOG2E)}; + return x; + } + + template < + bool CheckInf, + int SizeA, + int SizeB, + int SizeC, + class FragAcc, + class FragMax + > + CUTLASS_DEVICE static constexpr + void scale_exp_log2(FragAcc &acc, FragMax const max, Element const scale) { + CUTLASS_PRAGMA_UNROLL + for(int x = 0; x < SizeA; x++) { + CUTLASS_PRAGMA_UNROLL + for(int y = 0; y < SizeB; y++) { + Element max_scale = !CheckInf ? max(x, y) : max(x, y) == -INFINITY ? Element{0} : max(x, y); + CUTLASS_PRAGMA_UNROLL + for(int z = 0; z < SizeC; z++) { + acc(x, y, z) = sycl::native::exp2((acc(x, y, z) - max_scale) * scale); + } + } + } + } + + template < + bool zero_init, + int SizeA, + int SizeB, + int SizeC, + class FragSrc, + class FragDst, + class Op + > + CUTLASS_DEVICE static void workitem_reduce(FragSrc const &src, FragDst &dst, Op op) { + // reduction per work item + CUTLASS_PRAGMA_UNROLL + for(int x = 0; x < SizeA; x++) { + CUTLASS_PRAGMA_UNROLL + for(int y = 0; y < SizeB; y++) { + dst(x, y) = zero_init ? src(x, y, 0) : op(dst(x, y), src(x, y, 0)); + CUTLASS_PRAGMA_UNROLL + for(int z = 1; z < SizeC; z++) { + dst(x, y) = op(dst(x, y), src(x, y, z)); + } + } + } + } + + template < + bool zero_init, + int SizeA, + int SizeB, + int SizeC, + class FragDst, + class Op + > + CUTLASS_DEVICE static void subgroup_allreduce(FragDst &dst, Op op) { + // reduce across the sub_group to get the final output + auto sg = syclcompat::get_nd_item<1>().get_sub_group(); + CUTLASS_PRAGMA_UNROLL + for(int x = 0; x < SizeA; x++) { + CUTLASS_PRAGMA_UNROLL + for(int y = 0; y < SizeB; y++) { + CUTLASS_PRAGMA_UNROLL + for(uint laneMask = 8; laneMask >= 1; laneMask /= 2) { + dst(x,y) = op(dst(x, y), syclcompat::permute_sub_group_by_xor(sg, dst(x, y), laneMask, 16)); + } + } + } + } + + template < + bool zero_init, + int SizeA, + int SizeB, + int SizeC, + class FragSrc, + class FragDst, + class Op + > + CUTLASS_DEVICE static void reduce(FragSrc const &src, FragDst &dst, Op op) { + // reduce across all the N tiles in shape + workitem_reduce(src, dst, op); + subgroup_allreduce(dst, op); + } + + template < + bool zero_init, + int SizeA, + int SizeB, + int SizeC, + class FragSrc, + class FragMax + > + CUTLASS_DEVICE static void reduce_max(FragSrc const &src, FragMax& max) { + MaxOp max_op; + reduce(src, max, max_op); + } + + template < + bool zero_init, + int SizeA, + int SizeB, + int SizeC, + class FragSrc, + class FragSum + > + CUTLASS_DEVICE static void reduce_sum(FragSrc const &src, FragSum& sum) { + SumOp sum_op; + workitem_reduce(src, sum, sum_op); + } + + template < + bool is_first, + bool CheckInf, + int SizeA, + int SizeB, + int SizeC, + class FragAcc, + class FragMax, + class FragSum, + class FragOut + > + CUTLASS_DEVICE static typename std::enable_if::type + run(FragAcc &frag_s, FragMax& max, FragSum& sum, FragOut&, Params const ¶ms) { + reduce_max(frag_s, max); + scale_exp_log2(frag_s, max, params.scale); + reduce_sum(frag_s, sum); + } + + template < + bool is_first, + bool CheckInf, + int SizeA, + int SizeB, + int SizeC, + class FragAcc, + class FragMax, + class FragSum, + class FragOut + > + CUTLASS_DEVICE static typename std::enable_if::type + run(FragAcc &frag_s, FragMax& max, FragSum& sum, FragOut &out, Params const ¶ms) { + cute::Tensor max_prev = cute::make_fragment_like(max); + cute::copy(max, max_prev); + reduce_max(frag_s, max); + + CUTLASS_PRAGMA_UNROLL + for(int x = 0; x < SizeA; x++) { + CUTLASS_PRAGMA_UNROLL + for(int y = 0; y < SizeB; y++) { + Element curr_max = !CheckInf ? max(x, y) : max(x, y) == -INFINITY ? 0.0f : max(x, y); + Element curr_scale = sycl::native::exp2((max_prev(x, y) - curr_max) * params.scale); + sum(x, y) *= curr_scale; + CUTLASS_PRAGMA_UNROLL + for(int z = 0; z < SizeC; z++) { + out(x, y, z) *= curr_scale; + } + } + } + + scale_exp_log2(frag_s, max, params.scale); + reduce_sum(frag_s, sum); + } + + Params params; +}; +} diff --git a/examples/sycl/pvc/flash_attention_v2/pvc_flash_attn.cpp b/examples/sycl/pvc/flash_attention_v2/pvc_flash_attn.cpp new file mode 100644 index 0000000000..0454161912 --- /dev/null +++ b/examples/sycl/pvc/flash_attention_v2/pvc_flash_attn.cpp @@ -0,0 +1,532 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "pvc_flash_attn_gemm_universal.hpp" +#include "pvc_flash_attn_epilogue.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/sycl_event_manager.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "../common.h" + +using namespace cute; + +// Command line options parsing +struct Options { + + bool help; + bool error; + bool is_causal; + + int batch, num_heads, seq_len, head_size, iterations; + float softmax_scale; + + Options(): + help(false), + error(false), + is_causal(false), + batch(4), num_heads(8), seq_len(4096), head_size(64), iterations(20), + softmax_scale(1.f) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + if (cmd.check_cmd_line_flag("is_causal")) { + is_causal = true; + } + + cmd.get_cmd_line_argument("batch", batch, 4); + cmd.get_cmd_line_argument("num_heads", num_heads, 8); + cmd.get_cmd_line_argument("seq_len", seq_len, 16384); + cmd.get_cmd_line_argument("head_size", head_size, 64); + cmd.get_cmd_line_argument("iterations", iterations, 100); + + softmax_scale = 1 / sqrt(static_cast(head_size)); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "PVC Flash Attention v2 Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --is_causal Apply Causal Mask to the output of first Matmul\n" + << " --batch= Sets the Batch Size of the Multi-Head Self Attention module\n" + << " --num_heads= Sets the Number of Attention Heads of the Multi-Head Self Attention module\n" + << " --seq_len= Sets the Sequence length of the Multi-Head Self Attention module\n" + << " --head_size= Sets the Attention Head dimension of the Multi-Head Self Attention module\n" + << " --iterations= Iterations\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class GemmKernel +> +struct ExampleRunner { + + using StrideQ = typename GemmKernel::StrideQ; + using StrideK = typename GemmKernel::StrideK; + using StrideV = typename GemmKernel::StrideV; + using StrideO = typename GemmKernel::StrideO; + using StrideLSE = typename GemmKernel::StrideLSE; + + using LayoutQ = cutlass::layout::RowMajor; + using LayoutK = cutlass::layout::RowMajor; + using LayoutV = cutlass::layout::RowMajor; + using LayoutO = cutlass::layout::RowMajor; + using LayoutLSE = cutlass::layout::RowMajor; + + using ElementQ = typename GemmKernel::ElementQ; + using ElementK = typename GemmKernel::ElementK; + using ElementV = typename GemmKernel::ElementV; + using ElementAcc = typename GemmKernel::ElementAccumulator; + + using CollectiveEpilogue = typename GemmKernel::CollectiveEpilogue; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename GemmKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideQ stride_Q; + StrideK stride_K; + StrideV stride_V; + StrideO stride_O; + StrideLSE stride_LSE; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_Q; + cutlass::DeviceAllocation block_K; + cutlass::DeviceAllocation block_V; + cutlass::DeviceAllocation block_O; + cutlass::DeviceAllocation block_lse; + cutlass::DeviceAllocation block_ref_O; + cutlass::DeviceAllocation block_ref_lse; + + // + // Methods + // + + bool verify(const ProblemShapeType& problem_size, bool is_causal) { + auto [batch, num_heads, seq_len, head_size] = problem_size; + + int mat_size = seq_len * head_size; + + for(int b = 0; b < batch; b++) { + + for(int h = 0; h < num_heads; h++) { + + int offset = (h + b * num_heads) * mat_size; + + cutlass::DeviceAllocation block_S; + block_S.reset(seq_len * seq_len); + + cutlass::TensorRef ref_Q(block_Q.get() + offset, LayoutQ::packed({seq_len, head_size})); + cutlass::TensorRef ref_K(block_K.get() + offset, LayoutK::packed({head_size, seq_len})); + cutlass::TensorRef ref_V(block_V.get() + offset, LayoutV::packed({seq_len, head_size})); + cutlass::TensorRef ref_S(block_S.get(), LayoutQ::packed({seq_len, seq_len})); + cutlass::TensorRef ref_O(block_ref_O.get() + offset, LayoutO::packed({seq_len, head_size})); + + cutlass::reference::device::GemmComplex( + {seq_len, seq_len, head_size}, + 1.f, + ref_Q, + cutlass::ComplexTransform::kNone, + ref_K, + cutlass::ComplexTransform::kNone, + 0.f, + ref_S, + ref_S, + ElementAccumulator(0), + 1, // batch_count + seq_len * head_size, // batch_stride_Q + seq_len * head_size, // batch_stride_K + seq_len * seq_len, // batch_stride_S + seq_len * seq_len // batch_stride_S + ); + + syclcompat::wait(); + + std::vector host_S(seq_len * seq_len); + syclcompat::memcpy(host_S.data(), block_S.get(), host_S.size()); + syclcompat::wait(); + + if(is_causal) { + // apply mask to S + for (int row = 0; row < seq_len; row++) { + for (int col = 0; col < seq_len; col++) { + if (col > row) + host_S[col + row * seq_len] = -INFINITY; + } + } + } + + // compute max element per row of S + std::vector max_vec(seq_len); + for (int row = 0; row < seq_len; row++) { + int idx = row * seq_len; + max_vec[row] = host_S[idx]; + for (int col = 1; col < seq_len; col++, idx++) { + if (max_vec[row] < host_S[idx]) + max_vec[row] = host_S[idx]; + } + } + + // compute exp of S + for (int row = 0; row < seq_len; row++) { + int idx = row * seq_len; + for (int col = 0; col < seq_len; col++, idx++) { + host_S[idx] = expf((host_S[idx] - max_vec[row]) / sqrt(static_cast((head_size)))); + } + } + + // compute sum per row of S + std::vector sum_vec(seq_len, ElementOutput{0}); + for (int row = 0; row < seq_len; row++) { + int idx = row * seq_len; + for (int col = 0; col < seq_len; col++, idx++) { + sum_vec[row] += host_S[idx]; + } + + //scale each row with the sum to compute softmax + idx = row * seq_len; + for (int col = 0; col < seq_len; col++, idx++) { + host_S[idx] /= sum_vec[row]; + } + } + + std::vector host_P(host_S.size()); + for(int p = 0; p < host_P.size(); p++) host_P[p] = static_cast(host_S[p]); + + cutlass::DeviceAllocation block_P; + block_P.reset(host_P.size()); + + syclcompat::memcpy(block_P.get(), host_P.data(), host_P.size()); + syclcompat::wait(); + + cutlass::TensorRef ref_P(block_P.get(), LayoutQ::packed({seq_len, seq_len})); + + cutlass::reference::device::GemmComplex( + {seq_len, head_size, seq_len}, + 1.f, + ref_P, + cutlass::ComplexTransform::kNone, + ref_V, + cutlass::ComplexTransform::kNone, + 0.f, + ref_O, + ref_O, + ElementAccumulator(0), + 1, // batch_count + seq_len * seq_len, // batch_stride_P + seq_len * head_size, // batch_stride_V + seq_len * head_size, // batch_stride_O + seq_len * head_size // batch_stride_O + ); + + syclcompat::wait(); + } + } + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareRelativelyEqual( + block_ref_O.get(), block_O.get(), block_O.size(), 0.5f, 0.5f); + + return passed; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + // auto problem_shape = cute::append<4>(problem_size, 1); + auto [batch, num_heads, seq_len, head_size] = problem_size; + + stride_Q = cutlass::make_cute_packed_stride(StrideQ{}, cute::make_shape(seq_len, head_size, batch * num_heads)); + stride_K = cutlass::make_cute_packed_stride(StrideK{}, cute::make_shape(seq_len, head_size, batch * num_heads)); + stride_V = cutlass::make_cute_packed_stride(StrideV{}, cute::make_shape(seq_len, head_size, batch * num_heads)); + stride_O = cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len, head_size, batch * num_heads)); + stride_LSE = cutlass::make_cute_packed_stride(StrideLSE{}, cute::make_shape(seq_len, 1, batch * num_heads)); + + auto count = batch * num_heads * seq_len * head_size; + block_Q.reset(count); + block_K.reset(count); + block_V.reset(count); + block_O.reset(count); + block_ref_O.reset(count); + block_lse.reset(count); + block_ref_lse.reset(count); + + initialize_block(block_Q, seed + 2023); + initialize_block(block_K, seed + 2022); //assume K is already transposed + initialize_block(block_V, seed + 2021); + + } + + static void run(typename GemmKernel::Params params) { + dim3 const block = GemmKernel::get_block_shape(); + dim3 const grid = GemmKernel::get_grid_shape(params); + + // configure smem size and carveout + int smem_size = GemmKernel::SharedStorageSize; + + const auto sycl_block = syclcompat::dim3(block.x, block.y, block.z); + const auto sycl_grid = syclcompat::dim3(grid.x, grid.y, grid.z); + + using namespace syclcompat::experimental; + auto event = launch>(launch_policy{ + sycl_grid, sycl_block, local_mem_size{static_cast(smem_size)}, + kernel_properties{sycl_exp::sub_group_size} + }, params); + + EventManager::getInstance().addEvent(event); + } + + void run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.batch, options.num_heads, options.seq_len, options.head_size}; + + initialize(problem_size); + + typename GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_Q.get(), stride_Q, block_K.get(), stride_K, block_V.get(), stride_V}, + {options.softmax_scale}, + {{1}, block_O.get(), stride_O, block_lse.get(), stride_LSE}, + hw_info + }; + + // GemmKernel gemm_op; + + size_t workspace_size = GemmKernel::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + GemmKernel::can_implement(arguments); + + // Initialize the workspace + auto status = GemmKernel::initialize_workspace(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + return; + } + + typename GemmKernel::Params params = GemmKernel::to_underlying_arguments(arguments, workspace.get()); + + // Run the GEMM + run(params); + + syclcompat::wait(); + + // Verify that the result is correct + bool passed = verify(problem_size, options.is_causal); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if (passed && options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + run(params); + } + syclcompat::wait(); + + float cute_time = timer.seconds() / options.iterations; + double flops_qk = 2.0 * options.batch * options.num_heads * options.seq_len * options.seq_len * options.head_size; + double flops_softmax = 5.0 * options.seq_len * options.seq_len; // seq_len^2 + seq_len^2 + 3 * seq_len^2 + double flops_pv = 2.0 * options.batch * options.num_heads * options.seq_len * options.head_size * options.seq_len; + double tflops = (flops_qk + flops_softmax + flops_pv) * 1e-12; + std::cout << "Problem Size: " << options.batch << 'x' << options.num_heads << 'x' << options.seq_len << 'x' << options.head_size << std::endl; + printf("Cutlass Flash Attention Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000); + } + + return; + } + +}; + +int main(int argc, const char** argv) +{ + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + bool passed; + + // The code section below describes datatype for input, output matrices and computation between + // elements in input matrices. + using ElementAccumulator = float; // <- data type of accumulator + using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementInputQ = bfloat16_t; // <- data type of elements in input matrix A + using ElementInputKV = bfloat16_t; // <- data type of elements in input matrix B + using ElementOutput = float; // <- data type of elements in output matrix D + + using GmemTiledCopyQ = XE_2D_U16x32x32_LD_N; + using GmemTiledCopyK = XE_2D_U16x32x32_LD_V; + using GmemTiledCopyV = XE_2D_U16x32x32_LD_V; + + using LayoutQ = cutlass::layout::RowMajor; + using LayoutK = cutlass::layout::RowMajor; + using LayoutV = cutlass::layout::RowMajor; + using LayoutO = cutlass::layout::RowMajor; + using LayoutLSE = cutlass::layout::RowMajor; + + // Workgroup-level tile + using TileShape = Shape<_128, _64, _32>; + + using TiledMma = TiledMMA, + Layout>, + Tile<_32, _32, _32>>; // Subgroup level-tile + + constexpr int PipelineStages = 3; + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVC; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue; + + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogueAttention< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_ST_N>; + + if(options.is_causal) { + // Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMmaAttention< + GEMMDispatchPolicy, + TileShape, + ElementInputQ, + cutlass::gemm::TagToStrideA_t, + ElementInputKV, + cutlass::gemm::TagToStrideB_t, + ElementInputKV, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyQ, // Q + GmemTiledCopyK, // K + GmemTiledCopyV, // V, + true + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversalAttention< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + ExampleRunner runner; + + runner.run(options, hw_info); + } else { + // Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMmaAttention< + GEMMDispatchPolicy, + TileShape, + ElementInputQ, + cutlass::gemm::TagToStrideA_t, + ElementInputKV, + cutlass::gemm::TagToStrideB_t, + ElementInputKV, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyQ, // Q + GmemTiledCopyK, // K + GmemTiledCopyV, // V, + false + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversalAttention< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + ExampleRunner runner; + + runner.run(options, hw_info); + } + + return 0; +} diff --git a/examples/sycl/pvc/flash_attention_v2/pvc_flash_attn_epilogue.hpp b/examples/sycl/pvc/flash_attention_v2/pvc_flash_attn_epilogue.hpp new file mode 100644 index 0000000000..9c23207d90 --- /dev/null +++ b/examples/sycl/pvc/flash_attention_v2/pvc_flash_attn_epilogue.hpp @@ -0,0 +1,307 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/detail/layout.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class DispatchPolicy, + class... Args +> +class CollectiveEpilogueAttention { + static_assert(cutlass::detail::dependent_false, "Could not find an epilogue specialization."); +}; + +template < + class CtaTileMNK_, + class ElementO_, + class StrideO_, + class ElementLSE_, + class StrideLSE_, + class FusionCallbacks_, + class CopyOpO_ +> +class CollectiveEpilogueAttention< + IntelPVCEpilogue, + CtaTileMNK_, + ElementO_, + StrideO_, + ElementLSE_, + StrideLSE_, + FusionCallbacks_, + CopyOpO_ +> { +public: + // + // Type Aliases + // + using DispatchPolicy = IntelPVCEpilogue; + using CtaTileMNK = CtaTileMNK_; + using FusionCallbacks = FusionCallbacks_; + using ElementO = ElementO_; + using ElementAccumulator = ElementO_; + using StrideO = StrideO_; + using ElementLSE = ElementLSE_; + using StrideLSE = StrideLSE_; + using CopyOpO = CopyOpO_; + + using ThreadEpilogueOp = typename fusion::FusionCallbacksTraits::Operation; + using GmemTiledCopyO = CopyOpO; + using ElementOutput = typename FusionCallbacks::ElementOutput; + using ElementCompute = typename FusionCallbacks::ElementCompute; + + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + + static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); + static_assert(cute::rank(StrideO{}) == 3, "StrideO must be rank-4: [batch, num_heads, seq_len, head_size]"); + static_assert(cute::rank(StrideLSE{}) == 3, "StrideLSE must be rank-3: [batch, num_heads, seq_len]"); + + using Trait_O = Copy_Traits; + using XE_Copy_O = decltype(make_tiled_copy(Copy_Atom{} + .with(static_cast(nullptr), int32_t(0), int32_t(0), int32_t(0)), + Layout>>{}, + make_layout(make_shape(get<0>(typename Trait_O::Shape_MN{}), + get<1>(typename Trait_O::Shape_MN{}) / Int{})))); +private: + constexpr static bool is_destination_supported = not cute::is_void_v; + +public: + + using EmptyType = cute::tuple<>; + + struct TensorStorageImpl: cute::tuple { + using FusionStorage = typename FusionCallbacks::SharedStorage; + FusionStorage thread; + }; + + struct SharedStorage { + using TensorStorage = TensorStorageImpl; + + TensorStorage tensors; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + + // Host side epilogue arguments + struct Arguments { + typename FusionCallbacks::Arguments thread{}; + ElementO const* ptr_O; + StrideO dO; + ElementLSE* ptr_LSE; + StrideLSE dLSE; + }; + + // Device side epilogue params + struct Params { + typename FusionCallbacks::Params thread{}; + XE_Copy_O xe_store_o; + ElementLSE* ptr_LSE; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + auto [batch, num_heads, seq_len, head_size] = problem_shape; + + XE_Copy_O xe_store_o = {}; + xe_store_o = make_tiled_copy(Copy_Atom, ElementO>{}.with( + args.ptr_O, head_size, seq_len, head_size), + Layout>>{}, + make_layout(make_shape(get<0>(typename Trait_O::Shape_MN{}), + get<1>(typename Trait_O::Shape_MN{}) / Int{}))); + + return { + FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), + xe_store_o, + args.ptr_LSE + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + template + CUTLASS_HOST_DEVICE static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + CUTLASS_HOST_DEVICE + CollectiveEpilogueAttention(Params const& params_, TensorStorage const& shared_storage_) + : params(params_), fusion_callbacks(params_.thread, shared_storage_.thread) {} + + template < + class ProblemShape, + class TileCoord, + class FragOut, + class FragMax, + class FragSum, + class TiledMma + > + CUTLASS_DEVICE void + operator() ( + ProblemShape problem_shape, + TileCoord tile_coord, + FragOut &out, + FragMax const &max, + FragSum const &sum, + TiledMma tiled_mma, + ElementLSE const& softmax_scale + ) { + + Tensor tLSEr = make_fragment_like(sum); + + using namespace cute; + + using MmaAtomShape = typename TiledMma::AtomShape_MNK; + static constexpr auto BLK_M = get<0>(CtaTileMNK{}); + static constexpr auto BLK_N = get<1>(CtaTileMNK{}); + static constexpr auto BLK_K = get<2>(CtaTileMNK{}); + // static_assert(is_same_v, "assertation fail"); + static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); + + static constexpr auto SG_M = ceil_div(BLK_M, ATOM_M); + static constexpr auto SG_N = ceil_div(BLK_N, ATOM_N); + static constexpr auto SG_K = ceil_div(BLK_K, ATOM_K); + using SubgroupTileShape = Shape; + + static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // A frags per sub_group + static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // B frags per sub_group + + static constexpr int FragmentSize = (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize; + + // Indexing variables + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord; + auto m_offset = m_coord * BLK_M + (get_sub_group_id() / ATOM_N) * SG_M; + auto n_offset = n_coord * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N; + auto l_offset = l_coord; + + CUTLASS_PRAGMA_UNROLL + for (int x = 0; x < FragmentSize; x++) { + CUTLASS_PRAGMA_UNROLL + for (int y = 0; y < FragsM; y++) { + + ElementLSE curr_sum = sum(x, y); + ElementO scale = (curr_sum == 0.f || curr_sum != curr_sum) ? 1.f : 1.f / curr_sum; + + tLSEr(x, y) = curr_sum == 0.f ? -INFINITY : max(x, y) * softmax_scale + logf(curr_sum); + + CUTLASS_PRAGMA_UNROLL + for (int z = 0; z < FragsN; z++) { + out(x, y, z) *= scale; + } + } + } + + // Indexing variables + auto [batch, num_heads, seq_len, head_size] = problem_shape; + + Tensor tOi = params.xe_store_o.get_pvc_tensor( + make_coord(m_offset, n_offset, 0), + make_shape(_, Int{}, Int{}, batch * num_heads), + make_stride(Int(MmaAtomShape{})>{}, Int(MmaAtomShape{})>{}, _1{})); + + copy(params.xe_store_o, out, tOi(_,_,_,l_coord)); + + /*const int lse_offset = seq_coord + (num_heads_coord + batch_coord * num_heads) * seq_len; + + auto lse_ptr = params.ptr_LSE + lse_offset; + + auto sg = syclcompat::get_nd_item<1>().get_sub_group(); + const int lane_id = static_cast(sg.get_local_linear_id()); + + // use only 1 work item per sub_group to write lse since all + // work items within subgroup have the same sum() data stored + // in registers + if(lane_id == 0) { + int count = 0; + CUTLASS_PRAGMA_UNROLL + for (int x = 0; x < FragsM; x++) { + CUTLASS_PRAGMA_UNROLL + for (int y = 0; y < FragmentSize; y++) { + *(lse_ptr + count++) = tLSEr(y, x); + } + } + }*/ + } + +private: + Params const& params; + FusionCallbacks fusion_callbacks; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/sycl/pvc/flash_attention_v2/pvc_flash_attn_gemm_universal.hpp b/examples/sycl/pvc/flash_attention_v2/pvc_flash_attn_gemm_universal.hpp new file mode 100644 index 0000000000..4f09828d53 --- /dev/null +++ b/examples/sycl/pvc/flash_attention_v2/pvc_flash_attn_gemm_universal.hpp @@ -0,0 +1,350 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "online_softmax.hpp" +#include "pvc_flash_attn_mma.hpp" + +namespace cutlass::gemm::kernel { + +template < + class ProblemShape, + class CollectiveMainloop, + class CollectiveEpilogue, + class TileScheduler_ = void +> +class GemmUniversalAttention; + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileScheduler_ +> +class GemmUniversalAttention +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + + static_assert(rank(ProblemShape{}) == 4, + "ProblemShape{} should be "); + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::WorkgroupTileShape; + using WorkgroupTileShape = TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementQ = typename CollectiveMainloop::ElementQ; + using StrideQ = typename CollectiveMainloop::StrideQ; + using ElementK = typename CollectiveMainloop::ElementK; + using StrideK = typename CollectiveMainloop::StrideK; + using ElementV = typename CollectiveMainloop::ElementV; + using StrideV = typename CollectiveMainloop::StrideV; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + + using SoftmaxArguments = typename flash::Softmax::Arguments; + using SoftmaxParams = typename flash::Softmax::Params; + + static_assert(cute::is_void_v or cute::is_same_v, + "Intel PVC does not support specializing the tile scheduler."); + using TileSchedulerTag = TileScheduler_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileScheduler_, ArchTag, WorkgroupTileShape, + cute::Shape, cute::Int<1>, cute::Int<1>>>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementO = typename CollectiveEpilogue::ElementO; + using StrideO = typename CollectiveEpilogue::StrideO; + using ElementLSE = typename CollectiveEpilogue::ElementLSE; + using StrideLSE = typename CollectiveEpilogue::StrideLSE; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + static_assert(cute::is_same_v, + "Mainloop and epilogue do not agree on accumulator value type."); + + // MSVC requires the cast to fix a warning-as-error. + static constexpr int SharedStorageSize = 0; + + static constexpr bool CausalMask = CollectiveMainloop::CausalMask; + static constexpr int SubgroupSize = CollectiveMainloop::SubgroupSize; // sub_group size + static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock; + using MmaAtomShape = typename CollectiveMainloop::MmaAtomShape; + using SubgroupTileShape = typename CollectiveMainloop::SubgroupTileShape; + + static constexpr int BLK_M = CollectiveMainloop::BLK_M; + static constexpr int BLK_N = CollectiveMainloop::BLK_N; + static constexpr int BLK_K = CollectiveMainloop::BLK_K; + + static constexpr int ATOM_M = CollectiveMainloop::ATOM_M; + static constexpr int ATOM_N = CollectiveMainloop::ATOM_N; + static constexpr int ATOM_K = CollectiveMainloop::ATOM_K; + + static constexpr int SG_M = CollectiveMainloop::SG_M; + static constexpr int SG_N = CollectiveMainloop::SG_N; + static constexpr int SG_K = CollectiveMainloop::SG_K; + + // Kernel level shared memory storage + struct SharedStorage { + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + EpilogueTensorStorage epilogue; + }; + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + SoftmaxArguments softmax{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode; + ProblemShape problem_shape; + MainloopParams mainloop; + SoftmaxArguments softmax; + EpilogueParams epilogue; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + (void) workspace; + return { + args.mode, + args.problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + flash::Softmax::to_underlying_arguments(args.softmax), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace) + }; + } + + static bool + can_implement(Arguments const& args) { + bool mode_implementable = args.mode == GemmUniversalMode::kGemm or + (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + return mode_implementable && TileScheduler::can_implement(args.scheduler); + } + + static int + get_workspace_size(Arguments const& args) { + return 0; + } + + static + cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + static dim3 + get_grid_shape(Params const& params) { + return dim3( + cute::size(cute::ceil_div(cute::shape<3>(params.problem_shape), cute::shape<1>(WorkgroupTileShape{}))), + cute::size(cute::ceil_div(cute::shape<2>(params.problem_shape), cute::shape<0>(WorkgroupTileShape{}))), + cute::size(cute::shape<0>(params.problem_shape) * cute::shape<1>(params.problem_shape)) + ); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + // Preconditions + CUTE_STATIC_ASSERT(is_static::value); + + // Separate out problem shape for convenience + auto batch = get<0>(params.problem_shape); + auto num_heads = get<1>(params.problem_shape); + auto seq_len = get<2>(params.problem_shape); + auto head_size = get<3>(params.problem_shape); + + // Preconditions + static_assert(cute::rank(StrideQ{}) == 3, "StrideQ must be rank-4: [batch, num_heads, seq_len, head_size]."); + static_assert(cute::rank(StrideK{}) == 3, "StrideK must be rank-4: [batch, num_heads, seq_len, head_size]."); + static_assert(cute::rank(StrideV{}) == 3, "StrideV must be rank-4: [batch, num_heads, seq_len, head_size]."); + + int thread_idx = int(ThreadIdxX()); + int sub_group_id = thread_idx / SubgroupSize; + constexpr auto workgroup_shape = WorkgroupTileShape{}; // (SUB_M,SUB_N,SUB_K) + constexpr auto subgroup_shape = SubgroupTileShape{}; // (SUB_M,SUB_N,SUB_K) + + auto blk_shape = TileShape{}; + auto blk_m_coord = BlockIdxY(); + auto blk_n_coord = BlockIdxX(); + auto blk_l_coord = BlockIdxZ(); + auto blk_coord_mnkl = make_coord(blk_m_coord, blk_n_coord, _, blk_l_coord); + + Tensor mQ_mkl = make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(seq_len, head_size, batch * num_heads), StrideQ{}); //(m,k,l) + Tensor mK_nkl = make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(seq_len, head_size, batch * num_heads), StrideK{}); //(n,k,l) + Tensor mV_nkl = make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(seq_len, head_size, batch * num_heads), StrideV{}); //(n,k,l) + + Tensor mQ_mk = mQ_mkl(_,_,blk_l_coord); // (m,k) + Tensor mK_nk = mK_nkl(_,_,blk_l_coord); // (n,k) + Tensor mV_nk = mV_nkl(_,_,blk_l_coord); // (n,k) + + auto gQ = local_tile(mQ_mk, blk_shape, make_coord(blk_m_coord, 0, _), Step<_1, X, _1>{}); + + const int seq_coord = BlockIdxY() * BLK_M + (sub_group_id / ATOM_N) * SG_M; + const int head_size_coord = BlockIdxX() * BLK_N + (sub_group_id % ATOM_N) * SG_N; + const int l_coord = BlockIdxZ(); + + // Compute tile residues for predication + auto m_max_coord = seq_len - get<0>(subgroup_shape) * seq_coord; // M - SUB_M * m_coord + auto n_max_coord = seq_len - get<1>(subgroup_shape) * seq_coord; // N - SUB_N * n_coord + auto k_residue = head_size - get<2>(subgroup_shape) * (head_size / get<2>(subgroup_shape)); // K - SUB_K * k_coord_max + auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue); + + // Allocate the tiled_mma and the accumulators for the (M,N) subgroup_shape + TiledMma tiled_mma; + + Tensor out_reg = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); + constexpr int VecA = 8; + constexpr int FragsM1 = 4; + constexpr int FragsN2 = 2; + + Tensor max_reg = make_tensor(Shape, Int>{}); + Tensor sum_reg = make_tensor(Shape, Int>{}); + + fill(max_reg, -INFINITY); + clear(sum_reg); + clear(out_reg); + + // Perform the collective scoped MMA + CollectiveMainloop collective_mma; + + const int causal_seq_len = seq_coord + get<0>(subgroup_shape); + const int non_causal_seq_len = seq_len; + + const int nblock_limit = CausalMask ? cute::ceil_div(causal_seq_len, get<1>(subgroup_shape)) + : cute::ceil_div(non_causal_seq_len, get<1>(subgroup_shape)); + + const int item_id = thread_idx % SubgroupSize; + + // loop over K and V, perform fused attention + online softmax + for (int nblock = 0, load_idx = 0; nblock < nblock_limit; nblock++, + load_idx += get<1>(subgroup_shape)) { + // 1) Load K (performed inside mmaQK) + // 2) Create Tensor S + auto gK = local_tile(mK_nk, blk_shape, take<0, 3>(make_coord(0, load_idx, _, blk_l_coord)), Step< X, _1, _1>{}); + Tensor tSr = make_tensor(Shape, Int, Int>{}); + clear(tSr); + // 3) Perform GEMM S = Q*K + auto tile_coord_QK = make_coord(seq_coord, load_idx, _, blk_l_coord); + collective_mma.mmaQK(tile_coord_QK, tSr, gQ, gK, tSr, head_size / get<1>(subgroup_shape), params.mainloop); + + // Apply causal mask + if constexpr (CausalMask) { + // mask the elements of each tile where j > i + int col_idx = item_id + load_idx; + CUTLASS_PRAGMA_UNROLL + for(int n = 0; n < FragsN2; n++, col_idx += get<1>(MmaAtomShape())) { + CUTLASS_PRAGMA_UNROLL + for(int m = 0; m < FragsM1; m++) { + int row_idx = m * VecA + seq_coord; + CUTLASS_PRAGMA_UNROLL + for(int row = 0; row < VecA; row++, row_idx++) { + if(col_idx > row_idx) + tSr(row, m, n) = -INFINITY; + } + } + } + } + + if (nblock == 0) + flash::Softmax::template run(tSr, + max_reg, sum_reg, out_reg, params.softmax); + else + flash::Softmax::template run(tSr, + max_reg, sum_reg, out_reg, params.softmax); + // 7) Convert S to P (FP32 -> BF16) + Tensor tPr = make_tensor(shape(tSr)); + CUTLASS_PRAGMA_UNROLL + for (int p_idx = 0; p_idx < size(tPr); p_idx++) { + tPr(p_idx) = static_cast(tSr(p_idx)); + } + + // 8) Scale out_reg with l + // 10) Perform GEMM O = + auto gV = local_tile(mV_nk, blk_shape, take<0, 3>(make_coord(0, load_idx, _, blk_l_coord)), Step< X, _1, _1>{}); + collective_mma.mmaPV(out_reg, tPr, gV, out_reg, 1, head_size_coord, params.mainloop); + } + + // Reduce the sum of exponents across the subgroup before scaling/normalizing output + flash::SumOp op; + flash::Softmax::template subgroup_allreduce(sum_reg, op); + + CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue}; + + epilogue( + params.problem_shape, + blk_coord_mnkl, + out_reg, + max_reg, + sum_reg, + tiled_mma, + params.softmax.scale); + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/examples/sycl/pvc/flash_attention_v2/pvc_flash_attn_mma.hpp b/examples/sycl/pvc/flash_attention_v2/pvc_flash_attn_mma.hpp new file mode 100644 index 0000000000..d8f913973c --- /dev/null +++ b/examples/sycl/pvc/flash_attention_v2/pvc_flash_attn_mma.hpp @@ -0,0 +1,414 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class DispatchPolicy, + class TileShape_, + class ElementQ_, + class StrideQ_, + class ElementK_, + class StrideK_, + class ElementV_, + class StrideV_, + class TiledMma_, + class GmemTiledCopyQ_, + class GmemTiledCopyK_, + class GmemTiledCopyV_, + bool CausalMask_> +struct CollectiveMmaAttention { + static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Stages, + class TileShape_, + class ElementQ_, + class StrideQ_, + class ElementK_, + class StrideK_, + class ElementV_, + class StrideV_, + class TiledMma_, + class GmemTiledCopyQ_, + class GmemTiledCopyK_, + class GmemTiledCopyV_, + bool CausalMask_> +struct CollectiveMmaAttention< + MainloopIntelPVC, + TileShape_, + ElementQ_, + StrideQ_, + ElementK_, + StrideK_, + ElementV_, + StrideV_, + TiledMma_, + GmemTiledCopyQ_, + GmemTiledCopyK_, + GmemTiledCopyV_, + CausalMask_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopIntelPVC; + using WorkgroupTileShape = TileShape_; + using ElementQ = ElementQ_; + using StrideQ = StrideQ_; + using ElementK = ElementK_; + using StrideK = StrideK_; + using ElementV = ElementV_; + using StrideV = StrideV_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyQ = GmemTiledCopyQ_; + using GmemTiledCopyK = GmemTiledCopyK_; + using GmemTiledCopyV = GmemTiledCopyV_; + using ArchTag = typename DispatchPolicy::ArchTag; + + static constexpr bool CausalMask = CausalMask_; + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + + using MmaAtomShape = typename TiledMma::AtomShape_MNK; + + static constexpr auto BLK_M = get<0>(WorkgroupTileShape{}); + static constexpr auto BLK_N = get<1>(WorkgroupTileShape{}); + static constexpr auto BLK_K = get<2>(WorkgroupTileShape{}); + + static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); + + static constexpr auto SG_M = ceil_div(BLK_M, ATOM_M); + static constexpr auto SG_N = ceil_div(BLK_N, ATOM_N); + static constexpr auto SG_K = ceil_div(BLK_K, ATOM_K); + using SubgroupTileShape = Shape; + + static constexpr size_t cacheline_bytes = 64; + static constexpr auto block_size_w_a = cute::min(SG_K, cacheline_bytes / sizeof(ElementQ)); + static constexpr auto block_size_w_b = cute::min(SG_N, cacheline_bytes / sizeof(ElementK)); + static constexpr auto nums_block_w_a = ceil_div(SG_K, block_size_w_a); + static constexpr auto nums_block_w_b = ceil_div(SG_N, block_size_w_b); + using PrefetchQThrShape = Shape, Int>; + using PrefetchKThrShape = Shape, Int>; + using PrefetchVThrShape = Shape, Int>; + using PrefetchQTileSize = decltype(ceil_div(Shape, Int>{},PrefetchQThrShape{})); + using PrefetchKTileSize = decltype(ceil_div(Shape, Int>{},PrefetchKThrShape{})); + using PrefetchVTileSize = decltype(ceil_div(Shape, Int>{},PrefetchVThrShape{})); + + static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}); + using traits_load_Q = Copy_Traits; + using atom_load_Q = Copy_Atom; + using XE_Copy_Q = decltype(make_tiled_copy(atom_load_Q{} + .with(static_cast(nullptr), int32_t(0), int32_t(0), int32_t(0)), + Layout>>{}, + make_layout(make_shape(get<0>(typename traits_load_Q::Shape_MN{}), + get<1>(typename traits_load_Q::Shape_MN{}) / Int{})))); + using traits_load_K = Copy_Traits; + using atom_load_K = Copy_Atom; + using XE_Copy_K = decltype(make_tiled_copy(atom_load_K{} + .with(static_cast(nullptr), int32_t(0), int32_t(0), int32_t(0)), + Layout>>{}, + make_layout(make_shape(get<0>(typename traits_load_K::Shape_MN{}), + get<1>(typename traits_load_K::Shape_MN{}) / Int{})))); + + using traits_load_V = Copy_Traits; + using atom_load_V = Copy_Atom; + using XE_Copy_V = decltype(make_tiled_copy(atom_load_V{} + .with(static_cast(nullptr), int32_t(0), int32_t(0), int32_t(0)), + Layout>>{}, + make_layout(make_shape(get<0>(typename traits_load_K::Shape_MN{}), + get<1>(typename traits_load_K::Shape_MN{}) / Int{})))); + + using XE_Prefetch_Q = decltype(cute::detail::prefetch_selector()); + using XE_Prefetch_K = decltype(cute::detail::prefetch_selector()); + using XE_Prefetch_V = decltype(cute::detail::prefetch_selector()); + + // Host side kernel arguments + struct Arguments { + ElementQ const* ptr_Q; + StrideQ dQ; + ElementK const* ptr_K; + StrideK dK; + ElementV const* ptr_V; + StrideV dV; + }; + + struct Params { + XE_Copy_Q gmem_tiled_copy_q; + XE_Copy_K gmem_tiled_copy_k; + XE_Copy_V gmem_tiled_copy_v; + + XE_Prefetch_Q gmem_prefetch_q; + XE_Prefetch_K gmem_prefetch_k; + XE_Prefetch_V gmem_prefetch_v; + }; + + // + // Methods + // + + CollectiveMmaAttention() = default; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + auto [batch, num_heads, seq_len, head_size] = problem_shape; + + XE_Copy_Q copyQ = make_tiled_copy(Copy_Atom, ElementQ>{}.with(args.ptr_Q, head_size, seq_len, head_size), + Layout>>{}, + make_layout(make_shape(get<0>(typename traits_load_Q::Shape_MN{}), + get<1>(typename traits_load_Q::Shape_MN{}) / Int{}))); + XE_Copy_K copyK = make_tiled_copy(Copy_Atom, ElementK>{}.with(args.ptr_K, seq_len, head_size, seq_len), + Layout>>{}, + make_layout(make_shape(get<0>(typename traits_load_K::Shape_MN{}), + get<1>(typename traits_load_K::Shape_MN{}) / Int{}))); + + XE_Copy_V copyV = make_tiled_copy(Copy_Atom, ElementV>{}.with(args.ptr_V, head_size, seq_len, head_size), + Layout>>{}, + make_layout(make_shape(get<0>(typename traits_load_V::Shape_MN{}), + get<1>(typename traits_load_V::Shape_MN{}) / Int{}))); + + XE_Prefetch_Q prefetchQ = cute::detail::prefetch_selector((void *)args.ptr_Q, head_size, seq_len, head_size); + XE_Prefetch_K prefetchK = cute::detail::prefetch_selector((void *)args.ptr_K, seq_len, head_size, seq_len); + XE_Prefetch_V prefetchV = cute::detail::prefetch_selector((void *)args.ptr_V, head_size, seq_len, head_size); + return Params{copyQ, copyK, copyV, prefetchQ, prefetchK, prefetchV}; + } + + template < + class TileCoord, + class FragAccum, + class TensorQ, + class TensorK, + class FragSrc + > + CUTLASS_DEVICE void + mmaQK( + TileCoord tile_coord, + FragAccum& accum, + TensorQ gA, + TensorK gB, + FragSrc const &frag_src, + int const &k_tile_count, + Params const ¶ms) { + + int thread_idx = static_cast(ThreadIdxX()); + // Instantiate the MMA object + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_slice(thread_idx); + Tensor tCrA_partition = thread_mma.partition_fragment_A(gA(_, _, 0)); + Tensor tCrA = make_tensor(static_cast(tCrA_partition).data(), + tCrA_partition.shape()); + Tensor tCrB_partition = thread_mma.partition_fragment_B(gB(_, _, 0)); + Tensor tCrB = make_tensor(static_cast(tCrB_partition).data(), + make_shape(size<0>(tCrB_partition.shape()), + size<2>(tCrB_partition.shape()), + size<1>(tCrB_partition.shape()))); + // Partition the copying of A and B tiles across the threads + auto gmem_thr_copy_A = params.gmem_tiled_copy_q.get_slice(thread_idx); + auto gmem_thr_copy_B = params.gmem_tiled_copy_k.get_slice(thread_idx); + + auto tCrA_copy_view = gmem_thr_copy_A.retile_D(tCrA); + auto tCrB_copy_view = gmem_thr_copy_B.retile_D(tCrB); + + #if CUTLASS_ENABLE_DEBUG_PRINTS + if (thread(LOG_THREAD, LOG_GROUP)) { + print("======================= A: \n"); + print(" gA : "); print(gA); print("\n"); + print("tCrA_copy_view : "); print(tCrA_copy_view); print("\n"); + print(" tCrA : "); print(tCrA); print("\n"); + + print("===================== B :\n"); + print(" gB : "); print(gB); print("\n"); + print("tCrB_copy_view : "); print(tCrB_copy_view); print("\n"); + print(" tCrB : "); print(tCrB); print("\n"); + + print("===================== Config: \n"); + print(" threads per workgroup : "); print(MaxThreadsPerBlock); print("\n"); + print(" SubgroupTileShape : "); print(SubgroupTileShape{}); print("\n"); + + print(" PrefetchQThrShape : ");print(PrefetchQThrShape{});print("\n"); + print(" PrefetchKThrShape : ");print(PrefetchKThrShape{});print("\n"); + print(" PrefetchQTileSize : ");print(PrefetchQTileSize{});print("\n"); + print(" PrefetchKTileSize : ");print(PrefetchKTileSize{});print("\n"); + } + #endif + + // + // Mainloop + // + int sub_group_id = get_sub_group_id(); + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord; + Tensor iter_a = params.gmem_tiled_copy_q.get_pvc_tensor( + make_coord(m_coord, 0, l_coord), append<4>(tCrA_copy_view.shape(), k_tile_count), + append<3>(typename XE_Copy_Q::Shape_MN{}, BLK_K), seq<0,1,1>{}); + Tensor iter_b = params.gmem_tiled_copy_k.get_pvc_tensor( + make_coord(0, n_coord, l_coord), append<4>(tCrB_copy_view.shape(), k_tile_count), + append<3>(typename XE_Copy_K::Shape_MN{}, BLK_K), seq<0,1,0>{}); + + Tensor prefetch_iter_a = params.gmem_prefetch_q.get_pvc_tensor( + make_coord(m_coord + (sub_group_id % ATOM_N) / get<1>(PrefetchQThrShape{}) * get<0>(PrefetchQTileSize{}), + (sub_group_id % ATOM_N) % get<1>(PrefetchQThrShape{}) * get<1>(PrefetchQTileSize{}), l_coord), + append<4>(make_shape(_1{}, _1{}, _1{}), k_tile_count), + append<3>(make_shape(SG_M, SG_K), BLK_K), seq<0, 1, 1>{}); + Tensor prefetch_iter_b = params.gmem_prefetch_k.get_pvc_tensor( + make_coord((sub_group_id / ATOM_N) / get<1>(PrefetchKThrShape{}) * get<0>(PrefetchKTileSize{}), + n_coord + (sub_group_id / ATOM_N) % get<1>(PrefetchKThrShape{}) * get<1>(PrefetchKTileSize{}), l_coord), + append<4>(make_shape(_1{}, _1{}, _1{}), k_tile_count), + append<3>(make_shape(SG_K, SG_N), BLK_K), seq<0,1,0>{}); + +CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < DispatchPolicy::Stages; i++) { + if constexpr(cute::detail::has_prefetch) { + prefetch(params.gmem_tiled_copy_q, prefetch_iter_a(_,_,_,i)); + } + if constexpr(cute::detail::has_prefetch) { + prefetch(params.gmem_tiled_copy_k, prefetch_iter_b(_,_,_,i)); + } + } +CUTLASS_PRAGMA_UNROLL + for (int k_tile = 0; k_tile < k_tile_count; ++k_tile) { + // Copy gmem to rmem for the first k_tile + copy(params.gmem_tiled_copy_q, iter_a(_,_,_,k_tile), tCrA_copy_view); + copy(params.gmem_tiled_copy_k, iter_b(_,_,_,k_tile), tCrB_copy_view); + + if(k_tile + DispatchPolicy::Stages < k_tile_count) { + if constexpr(cute::detail::has_prefetch) { + prefetch(params.gmem_tiled_copy_q, prefetch_iter_a(_,_,_,k_tile + DispatchPolicy::Stages)); + } + if constexpr(cute::detail::has_prefetch) { + prefetch(params.gmem_tiled_copy_k, prefetch_iter_b(_,_,_,k_tile + DispatchPolicy::Stages)); + } + } + for (int i = 0; i < SG_K / SubgroupSize; i++) { + cute::gemm(tiled_mma, accum, tCrA(_, _, i), tCrB(_, i, _), frag_src); + } + } + } + + template < + class FragAccum, + class FragP, + class TensorV, + class FragSrc + > + CUTLASS_DEVICE void + mmaPV( + FragAccum& accum, + FragP const &tPr, + TensorV gB, + FragSrc const &frag_src, + int const &k_tile_count, + int const &head_size_coord, + Params const ¶ms) { + + int thread_idx = static_cast(ThreadIdxX()); + // Instantiate the MMA object + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_slice(thread_idx); + + Tensor tCrB_partition = thread_mma.partition_fragment_B(gB(_, _, head_size_coord)); + Tensor tCrB = make_tensor(static_cast(tCrB_partition).data(), + make_shape(size<0>(tCrB_partition.shape()), + size<2>(tCrB_partition.shape()), + size<1>(tCrB_partition.shape()))); + // Partition the copying of A and B tiles across the threads + auto gmem_thr_copy_B = params.gmem_tiled_copy_k.get_slice(thread_idx); + + auto tCrB_copy_view = gmem_thr_copy_B.retile_D(tCrB); + + #if CUTLASS_ENABLE_DEBUG_PRINTS + if (thread(LOG_THREAD, LOG_GROUP)) { + print("===================== B :\n"); + print(" gB : "); print(gB); print("\n"); + print("tCrB_copy_view : "); print(tCrB_copy_view); print("\n"); + print(" tCrB : "); print(tCrB); print("\n"); + + print("===================== Config: \n"); + print(" threads per workgroup : "); print(MaxThreadsPerBlock); print("\n"); + print(" SubgroupTileShape : "); print(SubgroupTileShape{}); print("\n"); + + print(" PrefetchVThrShape : ");print(PrefetchVThrShape{});print("\n"); + print(" PrefetchVTileSize : ");print(PrefetchVTileSize{});print("\n"); + } + #endif + + // + // Mainloop + // + int sub_group_id = get_sub_group_id(); + const int n_coord = BlockIdxX() * BLK_N + (sub_group_id % ATOM_N) * SG_N; + const int l_coord = BlockIdxZ(); + + Tensor iter_b = params.gmem_tiled_copy_v.get_pvc_tensor( + make_coord(0, n_coord, l_coord), append<4>(tCrB_copy_view.shape(), k_tile_count), + append<3>(typename XE_Copy_K::Shape_MN{}, BLK_K), seq<0,1,0>{}); + + Tensor prefetch_iter_b = params.gmem_prefetch_v.get_pvc_tensor( + make_coord((sub_group_id / ATOM_N) / get<1>(PrefetchVThrShape{}) * get<0>(PrefetchVTileSize{}), + n_coord + (sub_group_id / ATOM_N) % get<1>(PrefetchVThrShape{}) * get<1>(PrefetchVTileSize{}), l_coord), + append<4>(make_shape(_1{}, _1{}, _1{}), k_tile_count), + append<3>(make_shape(SG_K, SG_N), BLK_K), seq<0,1,0>{}); + + prefetch(params.gmem_tiled_copy_v, prefetch_iter_b(_,_,_,0)); + + copy(params.gmem_tiled_copy_v, iter_b(_,_,_, 0), tCrB_copy_view); + + for (int i = 0; i < SG_K / SubgroupSize; i++) { + cute::gemm(tiled_mma, accum, tPr(_, _, i), tCrB(_, i, _), frag_src); + } + + } + +}; + +} // namespace cutlass::gemm::collective + +/////////////////////////////////////////////////////////////////////////////////////////////////