diff --git a/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle.cpp b/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle.cpp index 93c662a288..e5ded0ef3b 100644 --- a/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle.cpp +++ b/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle.cpp @@ -40,7 +40,7 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t, else if(t.permute.compare("0,1,3,4,2,5") == 0) { constexpr matrix_core_permute_style pstyle = - matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv; + matrix_core_permute_style::b_nr_kr_kw_nw_kv; using Kernel = matrix_core_swizzle_kernel; @@ -83,7 +83,7 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t, else if(t.permute.compare("0,1,3,4,2,5") == 0) { constexpr matrix_core_permute_style pstyle = - matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv; + matrix_core_permute_style::b_nr_kr_kw_nw_kv; using Kernel = matrix_core_swizzle_kernel; diff --git a/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp b/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp index 60ac103ec3..28f4c452bc 100644 --- a/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp +++ b/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp @@ -42,8 +42,8 @@ enum class matrix_core_permute_style { permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6 permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6 - permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5 - permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv, + b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5 + b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv, }; // assume this is B matrix, originally we have batch*n*k @@ -203,7 +203,7 @@ struct matrix_core_swizzle_kernel else { // clang-format off - // permute_b_nr_kr_kw_nw_kv or permute_b_nr_kr_waveflatten + // b_nr_kr_kw_nw_kv or b_nr_kr_waveflatten constexpr index_t Kv = Alignment; constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; @@ -332,7 +332,7 @@ struct matrix_core_swizzle_kernel make_tuple(sequence<0>{}, sequence<1>{})); return tmp_1; #else - // permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv, + // b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv, constexpr index_t kv = Alignment; constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; @@ -376,13 +376,13 @@ struct matrix_core_swizzle_kernel else { #if MERGE_2D_013425 - // permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv + // b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv return make_tile_window(dst_view, make_tuple(number{}, number{}), {i_n * NPerBlock, i_k * KPerBlock}, get_dst_dist()); #else - // permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv + // b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv constexpr index_t kv = Alignment; constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; diff --git a/example/ck_tile/06_permute/permute.cpp b/example/ck_tile/06_permute/permute.cpp index af95b64e69..477ae370b9 100644 --- a/example/ck_tile/06_permute/permute.cpp +++ b/example/ck_tile/06_permute/permute.cpp @@ -264,7 +264,7 @@ bool run(const ck_tile::ArgParser& arg_parser) { if(arg_parser.get_str("perm") == std::string("0,1,3,4,2,5")) { - // permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5 + // b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5 matrix_core_swizzle_traits t; t.data_type = data_type; t.permute = arg_parser.get_str("perm"); diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp index 91b54932ce..0cb393f7de 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp @@ -5,7 +5,7 @@ #include #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" -#include "ck_tile/ops/moe_sorting.hpp" +#include "ck_tile/ops/fused_moe.hpp" struct moe_sorting_trait { diff --git a/example/ck_tile/15_fused_moe/CMakeLists.txt b/example/ck_tile/15_fused_moe/CMakeLists.txt new file mode 100644 index 0000000000..a716eef19e --- /dev/null +++ b/example/ck_tile/15_fused_moe/CMakeLists.txt @@ -0,0 +1,19 @@ +set(TILE_EXAPMLE_FUSED_MOE "tile_example_fused_moe") +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +message("adding ${TILE_EXAPMLE_FUSED_MOE}") +file(GLOB INSTANCE_SRCS instances/*.cpp) +add_executable(${TILE_EXAPMLE_FUSED_MOE} EXCLUDE_FROM_ALL main.cpp) +target_include_directories(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${INSTANCE_SRCS}) + +set(TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -DCK_TILE_BUFFER_LOAD_AGPR=1) # TODO: enable load to a +list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=4) # rta +# list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -mllvm -greedy-reverse-local-assignment=1) +# list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) + +target_compile_options(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS}) diff --git a/example/ck_tile/15_fused_moe/README.md b/example/ck_tile/15_fused_moe/README.md new file mode 100644 index 0000000000..dd566c1667 --- /dev/null +++ b/example/ck_tile/15_fused_moe/README.md @@ -0,0 +1,69 @@ +# fused-moe +Implementing the fused-moe block operator using ck-tile. This is a scatter/gather-group-gemm based solution, similiar to that of [vllm moe](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py), but we introduce more kernel fusion to boost performance +![](misc/moe-0.png) + +The benifit of this fused-moe: +* 1.5~2x perf boost compared with current vllm solution +* zero workspace to reduce memory footprint +* much less kernel instance, easy to maintain + +# Implementation and feature support +## moe-sorting +this is a common pre-process step before the actual moe-gemm. The purpose is to transform the moe loop over from token-by-token to expert-by-expert, make sure very workgroup is working for a single expert (B matrix). Besides, we extend this op to do the zeroing of the output buffer(to be used for reduce buffer with atomic) + +## moe-gemm +`moe-gemm` is a group-gemm based back-to-back gemm, where the row-id of input token comes from another buffer. Naive understanding of fused-moe is from token-by-token view as below picture: +![](misc/moe-1.png) +After `moe-sorting`, we can view this algorithm as expert-by-expert, as below: +![](misc/moe-2.png) + +## optimization +summary of the key design of this fused-moe operator: +* fuse 2 group-gemm + activation + `topk-weight` multiply into single kernel, using atomic for 2nd gemm accumualation +* fuse buffer-zeroing in `moe-sorgin`, user no longer need call extra torch.zero() for the out buffer +* fused scatter-gather for row index(same as vllm) +* pre-shuffle B matric(weight) to maximize memory throughput. input(activation) keep original layout `[batch, hidden]`. +* extrem optimized pipeline using block-inline-asm(we call it `micro-kernel` or `uk`), while not breaking the *composable* design of ck + +## +``` +// [indexing implementation-1] +// using M_a as constexpr block_size to partition all tokens into different slices +// each slice map to one expert, and one expert can have multiple slices +// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5 +// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]] +// tok-0 tok-1 tok-2 tok-3 tok-4 +// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number) +// +// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]] +// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 +// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] +// +// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1) +// * this could be larger than actual, since actual tokens are on GPU +// +// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] +// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -| +// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o] +// +// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr +// +// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5] +// * length is (max_num_tokens_padded + block_size - 1) / block_size +// +// num_tokens_post_padded_ptr : [28] +// num_sorted_tiles_ptr : [7] +// +// * different from vLLM +// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id +// 2)need sorted_weight_ptr +// 3) use num_sorted_tiles_ptr, already divided by M_a +// +// * below used for indexing +// 1) sorted_token_ids_ptr [max_num_tokens_padded] +// 2) sorted_weight_ptr +// 3) sorted_expert_ids_ptr +// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one) +// +// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1) +``` \ No newline at end of file diff --git a/example/ck_tile/15_fused_moe/fused_moe.hpp b/example/ck_tile/15_fused_moe/fused_moe.hpp new file mode 100644 index 0000000000..6bd7688d8a --- /dev/null +++ b/example/ck_tile/15_fused_moe/fused_moe.hpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "fused_moesorting.hpp" +#include "fused_moegemm.hpp" + +struct fused_moe_args +{ + const void* a_ptr; // [m, k], input token + const void* a_scale_ptr; // [m, 1], token scale + const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) + const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w]) + const void* g_scale_ptr; // [e, 1, n], gate(up) scale + const void* d_scale_ptr; // [e, 1, k], down scale + const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input + void* o_ptr; // [m, k], output token (no need to do zeroing) + + const void* topk_ids_ptr; // [tokens, topk] + const void* topk_weight_ptr; // [tokens, topk] + void* sorted_token_ids_ptr; // [max_num_tokens_padded] + void* sorted_weight_ptr; // [max_num_tokens_padded] + void* sorted_expert_ids_ptr; // [(max_num_tokens_padded + block_size - 1) / block_size] + void* num_sorted_tiles_ptr; // [1] + + ck_tile::index_t block_m; // block_m, used to devide the input + ck_tile::index_t hidden_size; // k + ck_tile::index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2 + ck_tile::index_t num_tokens; // input number of tokens for current iteration + ck_tile::index_t num_experts; // number of groups + ck_tile::index_t topk; // need this? + + ck_tile::index_t stride_token; // for input/output, stride for each row, should >= hidden_size +}; + +// This is the public API, will be generated by script +struct fused_moe_traits +{ + std::string prec_i; // input precision + std::string prec_w; // weight precision + std::string prec_o; // output precision + std::string prec_st; // token scale data type + std::string prec_sw; // weight scale data type + std::string prec_sq; // smooth quant scale + std::string prec_kw; // topk-weight data type + int block_m; + int gate_only; + int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant +}; + +float fused_moe(fused_moe_traits, fused_moe_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/15_fused_moe/fused_moegemm.hpp b/example/ck_tile/15_fused_moe/fused_moegemm.hpp new file mode 100644 index 0000000000..b8e51475ad --- /dev/null +++ b/example/ck_tile/15_fused_moe/fused_moegemm.hpp @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/fused_moe.hpp" +#include + +// this is only a convenient structure for creating an example +// this is not part of the host API +template +struct FusedMoeGemmTypeConfig; + +template +struct FusedMoeGemmTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using GDataType = ck_tile::bf16_t; + using DDataType = ck_tile::bf16_t; + using AccDataType = float; + using ODataType = ck_tile::bf16_t; + using AScaleDataType = ck_tile::remove_cvref_t; + using GScaleDataType = ck_tile::remove_cvref_t; + using DScaleDataType = ck_tile::remove_cvref_t; + using YSmoothScaleDataType = ck_tile::remove_cvref_t; + using TopkWeightDataType = ck_tile::remove_cvref_t; + using IndexDataType = ck_tile::index_t; +}; + +template +struct FusedMoeGemmTypeConfig +{ + using ADataType = ck_tile::fp16_t; + using GDataType = ck_tile::fp16_t; + using DDataType = ck_tile::fp16_t; + using AccDataType = float; + using ODataType = ck_tile::fp16_t; + using AScaleDataType = ck_tile::remove_cvref_t; + using GScaleDataType = ck_tile::remove_cvref_t; + using DScaleDataType = ck_tile::remove_cvref_t; + using YSmoothScaleDataType = ck_tile::remove_cvref_t; + using TopkWeightDataType = ck_tile::remove_cvref_t; + using IndexDataType = ck_tile::index_t; +}; + +template +struct FusedMoeGemmTypeConfig +{ + using ADataType = ck_tile::int8_t; + using GDataType = ck_tile::int8_t; + using DDataType = ck_tile::int8_t; + using AccDataType = int32_t; + using ODataType = ck_tile::bf16_t; + using AScaleDataType = ck_tile::remove_cvref_t; + using GScaleDataType = ck_tile::remove_cvref_t; + using DScaleDataType = ck_tile::remove_cvref_t; + using YSmoothScaleDataType = ck_tile::remove_cvref_t; + using TopkWeightDataType = ck_tile::remove_cvref_t; + using IndexDataType = ck_tile::index_t; +}; + +// runtime args +struct fused_moegemm_args : public ck_tile::FusedMoeGemmHostArgs +{ +}; + +// This is the public API, will be generated by script +struct fused_moegemm_traits +{ + std::string prec_i; // input precision + std::string prec_w; // weight precision + std::string prec_o; // output precision + std::string prec_st; // token scale data type + std::string prec_sw; // weight scale data type + std::string prec_sq; // smooth quant scale + std::string prec_kw; // topk-weight data type + int block_m; + int gate_only; + int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant +}; + +float fused_moegemm(fused_moegemm_traits, fused_moegemm_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/15_fused_moe/fused_moesorting.hpp b/example/ck_tile/15_fused_moe/fused_moesorting.hpp new file mode 100644 index 0000000000..57dace9b41 --- /dev/null +++ b/example/ck_tile/15_fused_moe/fused_moesorting.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/fused_moe.hpp" + +struct fused_moesorting_trait +{ + std::string index_type; + std::string weight_type; // currently always float +}; + +struct fused_moesorting_args : public ck_tile::MoeSortingHostArgs +{ +}; + +float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s); diff --git a/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp new file mode 100644 index 0000000000..bfc0ce4096 --- /dev/null +++ b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "fused_moe.hpp" + +float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_config& s) +{ + auto s_sub = ck_tile::stream_config{s.stream_id_, false, s.log_level_, 0, 1}; + + auto o_data_bytes = [&]() { + if(t.prec_o == "fp32") + return 4; + else if(t.prec_o == "fp16" || t.prec_o == "bf16") + return 2; + else if(t.prec_o == "int8" || t.prec_o == "fp8") + return 1; + return 1; + }(); + + auto t0 = fused_moesorting_trait{"int32", "fp32"}; + auto a0 = fused_moesorting_args{ + a.topk_ids_ptr, // const void* p_topk_ids; + a.topk_weight_ptr, // const void* p_weights; + a.sorted_token_ids_ptr, // void* p_sorted_token_ids; + a.sorted_weight_ptr, // void* p_sorted_weights; + a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids; + a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad; + a.o_ptr, // void* p_moe_buf; + a.num_tokens, // index_t tokens; + a.block_m, // index_t unit_size; + a.num_experts, // index_t num_experts; + a.topk, // index_t topk; + a.num_tokens * a.stride_token * o_data_bytes // index_t moe_buf_bytes; + }; + + auto t1 = fused_moegemm_traits{t.prec_i, + t.prec_w, + t.prec_o, + t.prec_st, + t.prec_sw, + t.prec_sq, + t.prec_kw, + t.block_m, + t.gate_only, + t.fused_quant}; + auto a1 = fused_moegemm_args{ + a.a_ptr, // const void* a_ptr; + a.a_scale_ptr, // const void* a_scale_ptr; + a.g_ptr, // const void* g_ptr; + a.d_ptr, // const void* d_ptr; + a.g_scale_ptr, // const void* g_scale_ptr; + a.d_scale_ptr, // const void* d_scale_ptr; + a.y_smooth_scale_ptr, // const void* y_smooth_scale_ptr; + a.o_ptr, // void* o_ptr; + a.sorted_token_ids_ptr, // const void* sorted_token_ids_ptr; + a.sorted_weight_ptr, // const void* sorted_weight_ptr; + a.sorted_expert_ids_ptr, // const void* sorted_expert_ids_ptr; + a.num_sorted_tiles_ptr, // const void* num_sorted_tiles_ptr; + a.hidden_size, // index_t hidden_size; + a.intermediate_size, // index_t intermediate_size; + a.num_tokens, // index_t num_tokens; + a.num_experts, // index_t num_experts; + a.topk, // index_t topk; + a.stride_token // index_t stride_token; + }; + + float r0 = -1; + float r1 = -1; + + float r = ck_tile::launch_kernel( + s, + [=, &r0](const ck_tile::stream_config&) { r0 = fused_moesorting(t0, a0, s_sub); }, + [=, &r1](const ck_tile::stream_config&) { r1 = fused_moegemm(t1, a1, s_sub); }); + + // keep unsupported case return negative + if(r0 < 0 || r1 < 0) + return -1; + + return r; +} diff --git a/example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp new file mode 100644 index 0000000000..c1a4c495c3 --- /dev/null +++ b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "fused_moegemm.hpp" +#include "fused_moegemm_api_traits.hpp" + +// Note: this internal API only declare, not define here, otherwise will block `make -j` +template +float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a); + +template +using S = ck_tile::sequence; + +float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile::stream_config& s) +{ + // clang-format off + float r = -1; + if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" && + t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1) + { + using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, 1, 0>; + r = fused_moegemm_(s, a); + } + else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" && + t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1) + { + using t_ = fmoe_, S<1, 4, 1>, S<16, 16, 32>, 1, 0>; + r = fused_moegemm_(s, a); + } + // clang-format on + return r; +} diff --git a/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp new file mode 100644 index 0000000000..5872179ef7 --- /dev/null +++ b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "fused_moegemm_api_traits.hpp" +#include "ck_tile/ops/fused_moe.hpp" +#include + +template +using S = ck_tile::sequence; + +// do not the define of this tepmlate function inside the _api.cpp, otherwise will block make -j +template +float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a) +{ + using f_traits = ck_tile::FusedMoeGemmTraits; + using f_shape = ck_tile::FusedMoeGemmShape; + using f_problem = + ck_tile::FusedMoeGemmPipelineProblem; + + // using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx; + using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmUk; + using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear; + using f_kernel = ck_tile::FusedMoeGemmKernel; + + const dim3 grids = f_kernel::GridSize(a); + constexpr dim3 blocks = f_kernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + + static int printed = 0; + + auto kargs = f_kernel::MakeKargs(a); + if(s.log_level_ > 0 && printed == 0) + { + std::cout << ", " << f_kernel::GetName() << std::flush; + printed = 1; + } + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(f_kernel{}, grids, blocks, 0, kargs)); +} diff --git a/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp new file mode 100644 index 0000000000..cc476685de --- /dev/null +++ b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template + typename WarpPerBlock_, + typename WarpTile_, // seq<*,*,*>, used to select mfma + ck_tile::index_t GateOnly_ = 0, + ck_tile::index_t FusedQuant_ = 0> +struct fmoe_ // traits, ugly name, only used for internal +{ + using TypeConfig = FusedMoeGemmTypeConfig; + + using ADataType = ck_tile::remove_cvref_t; + using GDataType = ck_tile::remove_cvref_t; + using DDataType = ck_tile::remove_cvref_t; + using AccDataType = ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; + using AScaleDataType = ck_tile::remove_cvref_t; + using GScaleDataType = ck_tile::remove_cvref_t; + using DScaleDataType = ck_tile::remove_cvref_t; + using YSmoothScaleDataType = ck_tile::remove_cvref_t; + using TopkWeightDataType = ck_tile::remove_cvref_t; + using IndexDataType = ck_tile::remove_cvref_t; + + static constexpr ck_tile::index_t BT_ = BlockTIle_::at(ck_tile::number<0>{}); // block token + static constexpr ck_tile::index_t BI_ = + BlockTIle_::at(ck_tile::number<1>{}); // block intermediate + static constexpr ck_tile::index_t BH_ = BlockTIle_::at(ck_tile::number<2>{}); // block hidden + static constexpr ck_tile::index_t BD_ = BlockTIle_::at(ck_tile::number<3>{}); // block down + + using BlockTile_0 = ck_tile::sequence; + using WarpPerBlock_0 = ck_tile::remove_cvref_t; + using WarpTile_0 = ck_tile::remove_cvref_t; + + using BlockTile_1 = ck_tile::sequence; + using WarpPerBlock_1 = ck_tile::remove_cvref_t; + using WarpTile_1 = ck_tile::remove_cvref_t; + + static constexpr ck_tile::index_t GateOnly = GateOnly_; + static constexpr ck_tile::index_t FusedQuant = FusedQuant_; +}; diff --git a/example/ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp b/example/ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp new file mode 100644 index 0000000000..93f9c77869 --- /dev/null +++ b/example/ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "fused_moegemm.hpp" +#include "fused_moegemm_api_traits.hpp" +#include "fused_moegemm_api_internal.hpp" + +// clang-format off +template float fused_moegemm_< + fmoe_, S<1, 4, 1>, S<16, 16, 32>, 1, 0> +>(const ck_tile::stream_config& s, fused_moegemm_args a); + +// clang-format on diff --git a/example/ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp b/example/ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp new file mode 100644 index 0000000000..b8a823e8ed --- /dev/null +++ b/example/ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "fused_moegemm.hpp" +#include "fused_moegemm_api_traits.hpp" +#include "fused_moegemm_api_internal.hpp" + +// clang-format off +template float fused_moegemm_< + fmoe_, S<1, 4, 1>, S<16, 16, 32>, 1, 0> +>(const ck_tile::stream_config& s, fused_moegemm_args a); + +// clang-format on diff --git a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp new file mode 100644 index 0000000000..75aaf86b74 --- /dev/null +++ b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "fused_moesorting.hpp" + +#define MOE_SORTING_DISPATCH(unroll_num_) \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + using ms_problem = ck_tile::MoeSortingProblem; \ + using kernel = ck_tile::MoeSortingKernel; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + const auto lds_bytes = kernel::GetSmemSize(a); \ + float ave_time = ck_tile::launch_kernel( \ + s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ + return ave_time; + +float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s) +{ + if(t.weight_type == "fp32" && t.index_type == "int32") + { + if(a.num_experts > 127) + { + printf("lds size exceed, only support experts <127 \n"); + return -1; + } + if(a.moe_buf_bytes % 16) + { + printf("buf set size %d unaligned, must be multiple of 16\n", a.moe_buf_bytes); + return -1; + } + using index_t = ck_tile::index_t; + using ms_weight_type = float; + index_t smem_io_unroll_num = ck_tile::integer_divide_ceil(a.tokens * a.topk, 64); + switch(smem_io_unroll_num) + { + case(1): { + MOE_SORTING_DISPATCH(1); + } + case(2): { + MOE_SORTING_DISPATCH(2); + } + case(3): { + MOE_SORTING_DISPATCH(3); + } + case(5): { + MOE_SORTING_DISPATCH(5); + } + case(6): { + MOE_SORTING_DISPATCH(6); + } + case(7): { + MOE_SORTING_DISPATCH(7); + } + case(8): { + MOE_SORTING_DISPATCH(8); + } + case(9): { + MOE_SORTING_DISPATCH(9); + } + case(10): { + MOE_SORTING_DISPATCH(10); + } + case(11): { + MOE_SORTING_DISPATCH(11); + } + default: { + MOE_SORTING_DISPATCH(4); + } + } + } + return -1; +} diff --git a/example/ck_tile/15_fused_moe/main.cpp b/example/ck_tile/15_fused_moe/main.cpp new file mode 100644 index 0000000000..2f44f903e9 --- /dev/null +++ b/example/ck_tile/15_fused_moe/main.cpp @@ -0,0 +1,603 @@ +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "fused_moe.hpp" + +// different threshold for different dtype +template +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +// mfma_type, 0:32x32, 1:16x16 +// TODO: padding? +template +auto shuffle_moe_weight(const ck_tile::HostTensor& t, std::string mfma_dtype, int mfma_type = 0) +{ + assert(t.get_lengths().size() == 3); + int b_ = t.get_lengths()[0]; + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[2]; + if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0) + { + ck_tile::HostTensor t_view({b_, n_ / 32, 32, k_ / 16, 2, 8}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5}); + } + else if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 1) + { + ck_tile::HostTensor t_view({b_, n_ / 16, 16, k_ / 32, 4, 8}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5}); + } + else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 0) + { + ck_tile::HostTensor t_view({b_, n_ / 32, 32, k_ / 32, 2, 16}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5}); + } + else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 1) + { + ck_tile::HostTensor t_view({b_, n_ / 16, 16, k_ / 64, 4, 16}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5}); + } + return t; +} + +template +void topid_unique_gen( + std::vector& host_tensor, int tokens, int topk, int num_expert, int seed) +{ + size_t total_size = topk * tokens; + std::srand(seed); + std::set unique_set; + IndexType current_v; + for(size_t i = 0; i < total_size; i++) + { + if(i % topk == 0) + { + unique_set.clear(); + } + current_v = std::rand() % num_expert; + while(unique_set.find(current_v) != unique_set.end()) + { + current_v = std::rand() % num_expert; + } + unique_set.insert(current_v); + host_tensor[i] = current_v; + } +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("t", "128", "num input tokens") + .insert("e", "32", "num of experts") + .insert("k", "5", "topk") + .insert("h", "8192", "hidden_size of this model") + .insert("i", "8192", "intermediate_size between 2 gemms of FFN") + .insert("stride", "-1", "stride per row, if -1 then equal to hidden_size") + .insert("bm", "32", "blocking factor for sorted tokens") + .insert("tp", "8", "tensor parallel size") + .insert("v", "1", "cpu validation or not") + .insert("kname", "1", "print kernel name or not") + .insert("prec_i", "bf16", "input precision") + .insert("prec_w", "bf16", "weight precision") + .insert("prec_o", "bf16", "output precision") + .insert("prec_st", "auto", "token scale data type. auto will set to fp32") + .insert("prec_sw", "auto", "weight scale data type. auto will set to fp32") + .insert("prec_sq", "auto", "(dynamic) smooth quant data type. auto will set to fp32") + .insert("prec_kw", "auto", "topk-weight data type. auto will set to fp32") + .insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant") + .insert( + "gate_only", "1", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate") + .insert("api", "0", "benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm") + .insert("balance", + "0", + "if set to 1, will try balance the expert in topk-ids(convenient for testing)") + .insert("init", + "2", + "init method. 0:random stepped float(fast). 1: random uniform, 2:rand normalized" + "normalized(slow)") + .insert("seed", "11939", "seed used to do random") + .insert("warmup", "5", "cold iter") + .insert("repeat", "20", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// I:input-type, W:weight-type, O:output-type, ST:toke-scale-tpye, SW:weight-scale-type, +// SQ:smooth-quant-type, KW:topk-weight-type +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + ck_tile::index_t tokens = arg_parser.get_int("t"); + ck_tile::index_t experts = arg_parser.get_int("e"); + ck_tile::index_t topk = arg_parser.get_int("k"); + ck_tile::index_t hidden_size = arg_parser.get_int("h"); + ck_tile::index_t intermediate_size = arg_parser.get_int("i"); + ck_tile::index_t stride = arg_parser.get_int("stride"); + ck_tile::index_t block_m = arg_parser.get_int("bm"); + if(stride < 0) + stride = hidden_size; + std::string prec_i = arg_parser.get_str("prec_i"); + std::string prec_w = arg_parser.get_str("prec_w"); + std::string prec_o = arg_parser.get_str("prec_o"); + std::string prec_st = arg_parser.get_str("prec_st"); + std::string prec_sw = arg_parser.get_str("prec_sw"); + std::string prec_sq = arg_parser.get_str("prec_sq"); + std::string prec_kw = arg_parser.get_str("prec_kw"); + prec_st = (prec_st == "auto") ? "fp32" : prec_st; + prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw; + prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq; + prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw; + int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + int fused_quant = arg_parser.get_int("fquant"); + int gate_only = arg_parser.get_int("gate_only"); + int api = arg_parser.get_int("api"); + int balance = arg_parser.get_int("balance"); + int tp = arg_parser.get_int("tp"); + int init = arg_parser.get_int("init"); + uint32_t seed = arg_parser.get_uint32("seed"); + + // w0 (Gate+Up or Gate only, N size) + ck_tile::index_t shared_intermediate_size_0 = intermediate_size * (gate_only ? 1 : 2) / tp; + // w1 (Down, N size) + ck_tile::index_t shared_intermediate_size_1 = intermediate_size / tp; + + auto prec_str = [&]() { + auto base_str = prec_i; + if(prec_i != prec_w) + base_str += "x" + prec_w; + if(prec_i != prec_o) + base_str += "=" + prec_o; + if(fused_quant != 0) + { + base_str += std::string("(") + prec_st + "|" + prec_sw + "|" + prec_sq + ")"; + } + return base_str; + }(); + auto api_str = [&]() { + if(api == 0) + return std::string("fmoe"); + else if(api == 1) + return std::string("moeg"); + else if(api == 2) + return std::string("moes"); + return std::string(""); + }(); + + auto stride_str = [&]() { + if(stride == hidden_size) + return std::string(""); + else + return std::string(", st:") + std::to_string(stride); + }(); + + std::cout << "[" << api_str << "|" << prec_str << "]" + << " t:" << tokens << ", e:" << experts << ", k:" << topk << stride_str + << ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp + << ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1 + << ", go:" << gate_only << ", q:" << fused_quant << std::flush; + + using TypeConfig = FusedMoeGemmTypeConfig; + using ADataType = typename TypeConfig::ADataType; + using GDataType = typename TypeConfig::GDataType; + using DDataType = typename TypeConfig::DDataType; + using AccDataType = typename TypeConfig::AccDataType; + using ODataType = typename TypeConfig::ODataType; + using AScaleDataType = typename TypeConfig::AScaleDataType; + using GScaleDataType = typename TypeConfig::GScaleDataType; + using DScaleDataType = typename TypeConfig::DScaleDataType; + using YSmoothScaleDataType = typename TypeConfig::YSmoothScaleDataType; + using TopkWeightDataType = typename TypeConfig::TopkWeightDataType; + using IndexDataType = typename TypeConfig::IndexDataType; + + // host verify + ck_tile::HostTensor a_host({tokens, hidden_size}, {stride, 1}); + ck_tile::HostTensor g_host({experts, shared_intermediate_size_0, hidden_size}); + ck_tile::HostTensor d_host({experts, hidden_size, shared_intermediate_size_1}); + ck_tile::HostTensor o_host({tokens, hidden_size}, {stride, 1}); + ck_tile::HostTensor sa_host({tokens}); + ck_tile::HostTensor sg_host({shared_intermediate_size_0}); + ck_tile::HostTensor sd_host({shared_intermediate_size_1}); + ck_tile::HostTensor sy_host({shared_intermediate_size_1}); // smooth-quant + ck_tile::HostTensor topk_ids_host({tokens, topk}); // to be sort + ck_tile::HostTensor topk_weight_host({tokens, topk}); // to be sort + + int max_num_tokens_padded = topk * tokens + experts * block_m - topk; + ck_tile::HostTensor sorted_token_ids_host({max_num_tokens_padded}); + ck_tile::HostTensor sorted_weight_host({max_num_tokens_padded}); + ck_tile::HostTensor sorted_expert_ids_host( + {(max_num_tokens_padded + block_m - 1) / block_m}); + ck_tile::HostTensor num_sorted_tiles_host({1}); + + if(init == 0) + { + ck_tile::FillStepRange{-.5f, .5f, 0.01f}(a_host); + ck_tile::FillStepRange{-.5f, .5f, 0.01f}(g_host); + ck_tile::FillStepRange{.5f, -.5f, -0.01f}(d_host); + ck_tile::FillStepRange{0.f, 1.f, 0.01f}(sa_host); + ck_tile::FillStepRange{0.f, 1.f, 0.01f}(sg_host); + ck_tile::FillStepRange{0.f, 1.f, 0.01f}(sd_host); + ck_tile::FillStepRange{0.f, 1.f, 0.01f}(sy_host); + ck_tile::FillStepRange{-.5f, .5f, 0.01f}(topk_weight_host); + } + else if(init == 1) + { + ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(a_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(g_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(d_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(sa_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(sg_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(sd_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(sy_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}( + topk_weight_host); + } + else if(init == 2) + { + ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(a_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(g_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(d_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(sa_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(sg_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(sd_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(sy_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(topk_weight_host); + } + + // permute weight + ck_tile::HostTensor g_perm_host = shuffle_moe_weight(g_host, prec_w, 1); + ck_tile::HostTensor d_perm_host = shuffle_moe_weight(d_host, prec_w, 1); + + // do moe sorting + if(balance) + { + int e_cnt = 0; + for(int i = 0; i < static_cast(topk_ids_host.mData.size()); i++) + { + topk_ids_host.mData[i] = e_cnt; + e_cnt++; + if(e_cnt >= experts) + e_cnt = 0; + } + } + else + { + topid_unique_gen(topk_ids_host.mData, tokens, topk, experts, 11913); + } + +// leave it here for future debug purpose +#if 0 + a_host.loadtxt("../../ater/input_torch.txt"); + + topk_ids_host.loadtxt("../../ater/topk_ids_torch.txt", "int"); + // topk_ids_host.savetxt("topk_ids_2.txt"); + topk_weight_host.loadtxt("../../ater/topk_weights_torch.txt", "float"); + std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl; + + g_host.loadtxt("../../ater/w1_torch.txt", "float"); + std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl; + d_host.loadtxt("../../ater/w2_torch.txt", "float"); + std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl; + + ck_tile::HostTensor g_perm_host = shuffle_moe_weight(g_host, prec_w, 1); + std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl; + ck_tile::HostTensor d_perm_host = shuffle_moe_weight(d_host, prec_w, 1); + std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl; +#endif + +#if 0 + std::cout << "sorted_token_ids_host:" << sorted_token_ids_host << std::endl; + std::cout << "num_sorted_tiles_host:" << num_sorted_tiles_host << std::endl; + std::cout << "sorted_expert_ids_host:" << sorted_expert_ids_host << std::endl; + std::cout << "topk_weight_host:" << topk_weight_host << std::endl; + std::cout << "sorted_weight_host:" << sorted_weight_host << std::endl; +#endif + auto cal_tflops = [&](auto ms) { + double flop_gemm_0 = + 2 * static_cast(tokens) * topk * shared_intermediate_size_0 * hidden_size; + double flop_gemm_1 = + 2 * static_cast(tokens) * topk * shared_intermediate_size_1 * hidden_size; + return (flop_gemm_0 + flop_gemm_1) / (static_cast(ms) * 1e-3) / 1e12; + }; + + // TODO: this method we use expert-by-expert view, just for reference + auto cal_tbps = [&](auto ms) { + double token_bytes = + static_cast(tokens) * topk / experts * hidden_size * sizeof(ADataType); + double w0_bytes = static_cast(shared_intermediate_size_0) * experts * hidden_size * + sizeof(GDataType); + double w1_bytes = static_cast(shared_intermediate_size_1) * experts * hidden_size * + sizeof(DDataType); + double o_bytes = + static_cast(tokens) * topk / experts * hidden_size * sizeof(ODataType); + double topk_weights_bytes = static_cast(tokens) * topk * sizeof(TopkWeightDataType); + // ignore index, they are too small + + return (token_bytes + w0_bytes + w1_bytes + o_bytes + topk_weights_bytes) / + (static_cast(ms) * 1e-3) / 1e12; + }; + + if(api == 0) + { + ck_tile::DeviceMem a_buf(a_host); + ck_tile::DeviceMem g_perm_buf(g_perm_host); + ck_tile::DeviceMem d_perm_buf(d_perm_host); + ck_tile::DeviceMem sa_buf(sa_host); + ck_tile::DeviceMem sg_buf(sg_host); + ck_tile::DeviceMem sd_buf(sd_host); + ck_tile::DeviceMem sy_buf(sy_host); + ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); + + ck_tile::DeviceMem topk_ids_buf(topk_ids_host); + ck_tile::DeviceMem topk_weight_buf(topk_weight_host); + + ck_tile::DeviceMem sorted_token_ids_buf( + sorted_token_ids_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem sorted_weight_buf(sorted_weight_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem sorted_expert_ids_buf( + sorted_expert_ids_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem num_sorted_tiles_buf( + num_sorted_tiles_host.get_element_space_size_in_bytes()); + + fused_moe_traits traits{prec_i, + prec_w, + prec_o, + prec_st, + prec_sw, + prec_sq, + prec_kw, + block_m, + gate_only, + fused_quant}; + + fused_moe_args args{a_buf.GetDeviceBuffer(), + fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr, + g_perm_buf.GetDeviceBuffer(), + d_perm_buf.GetDeviceBuffer(), + fused_quant != 0 ? sg_buf.GetDeviceBuffer() : nullptr, + fused_quant != 0 ? sd_buf.GetDeviceBuffer() : nullptr, + fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr, + o_buf.GetDeviceBuffer(), + topk_ids_buf.GetDeviceBuffer(), + topk_weight_buf.GetDeviceBuffer(), + sorted_token_ids_buf.GetDeviceBuffer(), + sorted_weight_buf.GetDeviceBuffer(), + sorted_expert_ids_buf.GetDeviceBuffer(), + num_sorted_tiles_buf.GetDeviceBuffer(), + block_m, + hidden_size, + shared_intermediate_size_0, + tokens, + experts, + topk, + stride}; + float ave_time = fused_moe( + traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); + + if(ave_time < 0) + { + std::cout << " not supported!" << std::endl << std::flush; + return false; + } + + // float gb_per_sec = num_byte / 1.E6 / ave_time; + std::cout << ", " << ave_time * 1.E3 << " us, " << cal_tflops(ave_time) << " tflops, " + << cal_tbps(ave_time) << " TB/s" << std::flush; + bool pass = true; + + if(do_validation) + { + ck_tile::reference_moe_sorting( + topk_ids_host, + topk_weight_host, + sorted_token_ids_host, + sorted_weight_host, + sorted_expert_ids_host, + num_sorted_tiles_host.mData[0], + experts, + block_m); + + ck_tile::reference_fused_moe( + a_host, + g_host, + d_host, + sa_host, + sg_host, + sd_host, + sy_host, + o_host, + sorted_token_ids_host, + sorted_weight_host, + sorted_expert_ids_host, + num_sorted_tiles_host, + topk_ids_host, + block_m, + tokens, + experts, + hidden_size, + shared_intermediate_size_0, + topk, + gate_only); + + auto o_dev = o_buf.ToHost(); + // o_dev.savetxt("gpu-out.txt", "float"); + auto [rtol, atol] = get_elimit(); + pass &= ck_tile::check_err( + o_dev, o_host, std::string("OUT Error: Incorrect results!"), rtol, atol); + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush; + } + std::cout << std::flush << std::endl; + return pass; + } + else if(api == 1) + { + ck_tile::reference_moe_sorting( + topk_ids_host, + topk_weight_host, + sorted_token_ids_host, + sorted_weight_host, + sorted_expert_ids_host, + num_sorted_tiles_host.mData[0], + experts, + block_m); + + // done, preparing GPU buffer + ck_tile::DeviceMem a_buf(a_host); + ck_tile::DeviceMem g_perm_buf(g_perm_host); + ck_tile::DeviceMem d_perm_buf(d_perm_host); + ck_tile::DeviceMem sa_buf(sa_host); + ck_tile::DeviceMem sg_buf(sg_host); + ck_tile::DeviceMem sd_buf(sd_host); + ck_tile::DeviceMem sy_buf(sy_host); + ck_tile::DeviceMem o_buf(o_host); + + // manually clear output buffer for atomic + o_buf.SetZero(); + // + + ck_tile::DeviceMem sorted_token_ids_buf(sorted_token_ids_host); + ck_tile::DeviceMem sorted_weight_buf(sorted_weight_host); + ck_tile::DeviceMem sorted_expert_ids_buf(sorted_expert_ids_host); + ck_tile::DeviceMem num_sorted_tiles_buf(num_sorted_tiles_host); + + fused_moegemm_traits traits{prec_i, + prec_w, + prec_o, + prec_st, + prec_sw, + prec_sq, + prec_kw, + block_m, + gate_only, + fused_quant}; + + fused_moegemm_args args{a_buf.GetDeviceBuffer(), + fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr, + g_perm_buf.GetDeviceBuffer(), + d_perm_buf.GetDeviceBuffer(), + fused_quant != 0 ? sg_buf.GetDeviceBuffer() : nullptr, + fused_quant != 0 ? sd_buf.GetDeviceBuffer() : nullptr, + fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr, + o_buf.GetDeviceBuffer(), + sorted_token_ids_buf.GetDeviceBuffer(), + sorted_weight_buf.GetDeviceBuffer(), + sorted_expert_ids_buf.GetDeviceBuffer(), + num_sorted_tiles_buf.GetDeviceBuffer(), + hidden_size, + shared_intermediate_size_0, + tokens, + experts, + topk, + stride}; + + float ave_time = fused_moegemm( + traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); + + if(ave_time < 0) + { + std::cout << " not supported!" << std::endl << std::flush; + return false; + } + + // float gb_per_sec = num_byte / 1.E6 / ave_time; + std::cout << ", " << ave_time * 1.E3 << " us, " << cal_tflops(ave_time) << " tflops, " + << cal_tbps(ave_time) << " TB/s" << std::flush; + bool pass = true; + + if(do_validation) + { + ck_tile::reference_fused_moe( + a_host, + g_host, + d_host, + sa_host, + sg_host, + sd_host, + sy_host, + o_host, + sorted_token_ids_host, + sorted_weight_host, + sorted_expert_ids_host, + num_sorted_tiles_host, + topk_ids_host, + block_m, + tokens, + experts, + hidden_size, + shared_intermediate_size_0, + topk, + gate_only); + + auto o_dev = o_buf.ToHost(); + // o_dev.savetxt("gpu-out.txt", "float"); + auto [rtol, atol] = get_elimit(); + pass &= ck_tile::check_err( + o_dev, o_host, std::string("OUT Error: Incorrect results!"), rtol, atol); + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush; + } + std::cout << std::flush << std::endl; + + return pass; + } + return false; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + std::string prec_i = arg_parser.get_str("prec_i"); + std::string prec_w = arg_parser.get_str("prec_w"); + std::string prec_o = arg_parser.get_str("prec_o"); + std::string prec_st = arg_parser.get_str("prec_st"); + std::string prec_sw = arg_parser.get_str("prec_sw"); + std::string prec_sq = arg_parser.get_str("prec_sq"); + std::string prec_kw = arg_parser.get_str("prec_kw"); + prec_st = (prec_st == "auto") ? "fp32" : prec_st; + prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw; + prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq; + prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw; + + // no dynamic quant case + if(prec_i == "bf16" && prec_w == "bf16" && prec_o == "bf16" && prec_kw == "fp32") + { + return run( + arg_parser) + ? 0 + : -2; + } + else if(prec_i == "fp16" && prec_w == "fp16" && prec_o == "fp16" && prec_kw == "fp32") + { + return run( + arg_parser) + ? 0 + : -2; + } + + return -3; +} diff --git a/example/ck_tile/15_fused_moe/misc/moe-0.png b/example/ck_tile/15_fused_moe/misc/moe-0.png new file mode 100644 index 0000000000..aed1964f28 Binary files /dev/null and b/example/ck_tile/15_fused_moe/misc/moe-0.png differ diff --git a/example/ck_tile/15_fused_moe/misc/moe-1.png b/example/ck_tile/15_fused_moe/misc/moe-1.png new file mode 100644 index 0000000000..91a1f2d9dd Binary files /dev/null and b/example/ck_tile/15_fused_moe/misc/moe-1.png differ diff --git a/example/ck_tile/15_fused_moe/misc/moe-2.png b/example/ck_tile/15_fused_moe/misc/moe-2.png new file mode 100644 index 0000000000..98d83866fa Binary files /dev/null and b/example/ck_tile/15_fused_moe/misc/moe-2.png differ diff --git a/example/ck_tile/15_fused_moe/misc/moe-3.png b/example/ck_tile/15_fused_moe/misc/moe-3.png new file mode 100644 index 0000000000..77c6d9b6e4 Binary files /dev/null and b/example/ck_tile/15_fused_moe/misc/moe-3.png differ diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index b6a44f76b7..29305405bc 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -14,3 +14,5 @@ add_subdirectory(11_add_rmsnorm2d_rdquant) add_subdirectory(12_smoothquant) add_subdirectory(13_moe_sorting) add_subdirectory(14_moe_smoothquant) +add_subdirectory(15_fused_moe) + diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 3b198502d0..3cf0c2595d 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -52,6 +52,7 @@ #include "ck_tile/core/tensor/tile_elementwise.hpp" #include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/tensor/tile_window_linear.hpp" +#include "ck_tile/core/tensor/tile_window_utils.hpp" #include "ck_tile/core/tensor/update_tile.hpp" #include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/functional.hpp" @@ -62,6 +63,7 @@ #include "ck_tile/core/utility/philox_rand.hpp" #include "ck_tile/core/utility/random.hpp" #include "ck_tile/core/utility/reduce_operator.hpp" +#include "ck_tile/core/utility/static_counter.hpp" #include "ck_tile/core/utility/to_sequence.hpp" #include "ck_tile/core/utility/transpose_vectors.hpp" #include "ck_tile/core/utility/type_traits.hpp" diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 3feede4d2e..bebf035e9c 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -621,6 +621,65 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0) asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); } +CK_TILE_DEVICE void lds_load_fence(index_t cnt = 0) +{ + asm volatile("s_waitcnt lgkmcnt(%0)" : : "n"(cnt) : "memory"); +} + +template +struct buffer_atomic_add_if; + +template +struct buffer_atomic_add_if +{ + template + CK_TILE_DEVICE void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t /*s_offset*/, + index_t i_offset /*max 0xFFF*/, + index_t flag = 1) + { + static_assert(sizeof(T) == 4); + auto save_exec = __builtin_amdgcn_read_exec(); + using mbuf_t = float; + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "global_atomic_pk_add_bf16 %0, %1, %2 offset:%3\n" + "s_mov_b64 exec %5" + : + : "v"(v_offset), + "v"(bit_cast(value)), + "s"(res.xy), + "n"(i_offset), + "v"(flag), + "s"(save_exec) + : "memory"); + } +}; + +template +struct buffer_atomic_add; + +template +struct buffer_atomic_add +{ + template + CK_TILE_DEVICE void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t /*s_offset*/, + index_t i_offset /*max 0xFFF*/, + index_t /*flag = 1*/) + { + static_assert(sizeof(T) == 4); + using mbuf_t = float; + asm volatile("global_atomic_pk_add_bf16 %0, %1, %2 offset:%3" + : + : "v"(v_offset), "v"(bit_cast(value)), "s"(res.xy), "n"(i_offset) + : "memory"); + } +}; + namespace impl { // below type indicate the data type used for buffer load inline asm // clang-format off @@ -810,6 +869,11 @@ CK_TILE_DEVICE void buffer_store_fence(index_t cnt = 0) asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); } +CK_TILE_DEVICE auto async_load_fence_raw(index_t cnt = 0) +{ + asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); +} + // buffer load i8 CK_TILE_DEVICE_EXTERN int8_t llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc, @@ -2378,6 +2442,45 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer& src_thread_ #endif } +template +CK_TILE_DEVICE void amd_buffer_atomic_add_raw(const thread_buffer& src_thread_data, + T* p_dst_wave, + const index_t dst_thread_element_offset, + const index_t dst_linear_element_offset, + const bool dst_thread_element_valid, + const index_t dst_element_space_size, + bool_constant = {}) +{ + const int32x4_t dst_wave_buffer_resource = + make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T)); + + index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); + index_t dst_linear_addr_offset = dst_linear_element_offset * sizeof(T); + + if constexpr(oob_conditional_check) + { + buffer_atomic_add_if{}(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + 0, + dst_linear_addr_offset, + dst_thread_element_valid); + } + else + { + buffer_atomic_add{}(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + 0, + dst_linear_addr_offset, + 1); + } +} + // buffer_atomic_max requires: // 1) p_dst_wave must point to global memory // 2) p_dst_wave must be a wavewise pointer. diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 65a3a4e2ff..afcf982a63 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -73,6 +73,24 @@ CK_TILE_DEVICE void block_sync_lds() #endif } +CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0) +{ +#ifdef __gfx12__ + asm volatile("s_wait_loadcnt %0 \n" + "s_barrier_signal -1 \n" + "s_barrier_wait -1" + : + : "n"(cnt) + : "memory"); +#else + asm volatile("s_waitcnt vmcnt(%0) \n" + "s_barrier" + : + : "n"(cnt) + : "memory"); +#endif +} + CK_TILE_DEVICE void block_sync_lds_direct_load() { asm volatile("\ diff --git a/include/ck_tile/core/arch/utility.hpp b/include/ck_tile/core/arch/utility.hpp index a88780459b..df0f54c5ed 100644 --- a/include/ck_tile/core/arch/utility.hpp +++ b/include/ck_tile/core/arch/utility.hpp @@ -102,4 +102,28 @@ CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane) #endif } +template +CK_TILE_DEVICE auto flag_to_exec(const T& v_flag) +{ + static_assert(sizeof(T) == 4); + // per-thread v_flag store into 2x sgpr + uint32x2_t exec_flag; + asm volatile("v_cmp_ge_u32 %[s_exec_flag], %[v_flag], 1" + : [s_exec_flag] "=s"(exec_flag) + : [v_flag] "v"(v_flag)); + return exec_flag; +} + +template +CK_TILE_DEVICE auto cmp_lt_to_exec(const X& x, const Y& y) +{ + static_assert(sizeof(X) == 4 && sizeof(Y) == 4); + // per-thread cmp store into 2x sgpr + uint32x2_t exec_flag; + asm volatile("v_cmp_lt_u32 %[s_exec_flag], %[v_x], %[v_y]" + : [s_exec_flag] "=s"(exec_flag) + : [v_x] "v"(x), [v_y] "v"(y)); + return exec_flag; +} + } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index 2cc788d422..7dffa0e555 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -437,34 +437,74 @@ struct buffer_view>::scalar_type, typename vector_traits>::scalar_type>::value, bool>::type = false> - CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x) + CK_TILE_DEVICE void update(index_t i, + index_t linear_offset, + bool is_valid_element, + const X& x, + bool_constant = {}) { if constexpr(Op == memory_operation_enum::set) { - this->template set(i, linear_offset, is_valid_element, x); + this->template set(i, linear_offset, is_valid_element, x); } else if constexpr(Op == memory_operation_enum::atomic_add) { - this->template atomic_add(i, linear_offset, is_valid_element, x); + this->template atomic_add( + i, linear_offset, is_valid_element, x); } else if constexpr(Op == memory_operation_enum::atomic_max) { - this->template atomic_max(i, linear_offset, is_valid_element, x); + this->template atomic_max( + i, linear_offset, is_valid_element, x); } // FIXME: remove memory_operation_enum::add else if constexpr(Op == memory_operation_enum::add) { - auto tmp = this->template get(i, linear_offset, is_valid_element); - this->template set(i, linear_offset, is_valid_element, x + tmp); + auto tmp = + this->template get(i, linear_offset, is_valid_element); + this->template set( + i, linear_offset, is_valid_element, x + tmp); // tmp += x; // this->template set(i, is_valid_element, tmp); } } + // i is offset of T, not X. i should be aligned to X + template >::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> + CK_TILE_DEVICE void update_raw(index_t i, + index_t linear_offset, + bool is_valid_element, + const X& x, + bool_constant = {}, + bool_constant = {}) + { + if constexpr(Op == memory_operation_enum::set) + { + this->template set_raw(i, linear_offset, is_valid_element, x); + } + else if constexpr(Op == memory_operation_enum::atomic_add) + { + this->template atomic_add_raw( + i, linear_offset, is_valid_element, x); + } + else if constexpr(Op == memory_operation_enum::atomic_max) + { + // this->template atomic_max_raw(i, linear_offset, is_valid_element, x); + } + } + // i is offset of T, not X. i should be aligned to X template >::scalar_type, typename vector_traits>::scalar_type>::value, @@ -585,6 +626,39 @@ struct buffer_view>::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> + CK_TILE_DEVICE void + atomic_add_raw(index_t i, index_t linear_offset, bool is_valid_element, const X& x) + { + // using scalar_t = typename vector_traits>::scalar_type; + + // X contains multiple T + constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; + + constexpr index_t scalar_per_x_vector = vector_traits>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + static_assert(get_address_space() == address_space_enum::global, "only support global mem"); + + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_buffer_atomic_add_raw, + t_per_x, + Coherence, + oob_conditional_check, + pre_nop>( + x, p_data_, i, linear_offset, is_valid_element, buffer_size_); + } + + template >::scalar_type, typename vector_traits>::scalar_type>::value, diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index f150fc54ca..b280a1725d 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -22,28 +22,32 @@ template CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution& tile_window, + number = {}, bool_constant = {}) { - return tile_window.load(number<-1>{}, bool_constant{}); + return tile_window.load(number{}, bool_constant{}); } template CK_TILE_DEVICE auto load_tile(const tile_window_linear& tile_window, + number = {}, bool_constant = {}) { - return tile_window.load(number<-1>{}, bool_constant{}); + return tile_window.load(number{}, bool_constant{}); } template CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile, const tile_window_with_static_distribution& tile_window, + number = {}, bool_constant = {}) { - return tile_window.load(dst_tile, bool_constant{}); + return tile_window.load(dst_tile, number{}, bool_constant{}); +} + +template +CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile, + const tile_window_linear& tile_window, + number = {}, + bool_constant = {}) +{ + return tile_window.load(dst_tile, number{}, bool_constant{}); } /** @@ -76,6 +100,7 @@ template CK_TILE_DEVICE auto load_tile_raw(T& tile, @@ -83,11 +108,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile, WindowLengths_, TileDistribution_, NumCoord>& tile_window, + number = {}, bool_constant = {}, bool_constant = {}) { tile_window.load_raw( - tile, number<-1>{}, bool_constant{}, bool_constant{}); + tile, number{}, bool_constant{}, bool_constant{}); } template CK_TILE_DEVICE auto load_tile_raw(T& tile, @@ -102,11 +129,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile, WindowLengths_, TileDistribution_, LinearBottomDims_>& tile_window, + number = {}, bool_constant = {}, bool_constant = {}) { tile_window.load_raw( - tile, number<-1>{}, bool_constant{}, bool_constant{}); + tile, number{}, bool_constant{}, bool_constant{}); } template CK_TILE_DEVICE auto @@ -122,11 +151,14 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile, WindowLengths_, TileDistribution_, NumCoord>& tile_window, + number = {}, bool_constant = {}, bool_constant = {}) { - return tile_window.async_load_raw( - lds_tile, number<-1>{}, bool_constant{}, bool_constant{}); + return tile_window.async_load_raw(lds_tile, + number{}, + bool_constant{}, + bool_constant{}); } template CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile, @@ -141,11 +174,14 @@ CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile, WindowLengths_, TileDistribution_, LinearBottomDims_>& tile_window, + number = {}, bool_constant = {}, bool_constant = {}) { - return tile_window.async_load_raw( - lds_tile, number<-1>{}, bool_constant{}, bool_constant{}); + return tile_window.async_load_raw(lds_tile, + number{}, + bool_constant{}, + bool_constant{}); } CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0) diff --git a/include/ck_tile/core/tensor/static_distributed_tensor.hpp b/include/ck_tile/core/tensor/static_distributed_tensor.hpp index 29c20bed00..568d618ec2 100644 --- a/include/ck_tile/core/tensor/static_distributed_tensor.hpp +++ b/include/ck_tile/core/tensor/static_distributed_tensor.hpp @@ -201,4 +201,30 @@ CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number return unpacks; } +namespace detail { + +// check if 2 static_distributed_tensor has same data type and size of element +// but only difference in distribution +template +struct is_similiar_distributed_tensor +{ + static constexpr bool value = false; +}; + +template +struct is_similiar_distributed_tensor, + static_distributed_tensor> +{ + using Tx = static_distributed_tensor; + using Ty = static_distributed_tensor; + static constexpr bool value = std::is_same_v && + Tx::get_thread_buffer_size() == Ty::get_thread_buffer_size(); +}; + +template +inline constexpr bool is_similiar_distributed_tensor_v = + is_similiar_distributed_tensor::value; + +} // namespace detail + } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index 698ce5378d..4c72ed0859 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -333,6 +333,48 @@ struct tensor_view coord.get_offset(), linear_offset, is_valid_element, x); } + // X is vector of DataType. + // "coord" is coordinate of DataType, not X. "coord" should be aligned to X + template >::scalar_type, + typename vector_traits>::scalar_type>, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr void + update_vectorized_elements_raw(const TensorCoord& coord, + index_t linear_offset, + const X& x, + bool_constant = {}, + bool_constant = {}) + { + buf_.template update_raw( + coord.get_offset(), + linear_offset, + coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), + x); + } + + template >::scalar_type, + typename vector_traits>::scalar_type>, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr void + update_vectorized_elements_raw(const TensorCoord& coord, + index_t linear_offset, + bool is_valid_element, + const X& x, + bool_constant = {}, + bool_constant = {}) + { + buf_.template update_raw( + coord.get_offset(), linear_offset, is_valid_element, x); + } + CK_TILE_HOST_DEVICE void print() const { printf("tensor_view{"); diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index e410246983..caeb038521 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -292,12 +292,15 @@ struct tile_window_with_static_distribution { constexpr auto tile_dstr = TileDstr{}; auto dst_tensor = make_static_distributed_tensor(tile_dstr); - load(dst_tensor, bool_constant{}); + load(dst_tensor, number{}, bool_constant{}); return dst_tensor; } - template + template CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor, + number = {}, bool_constant = {}) const { using Traits = load_store_traits; @@ -785,6 +788,73 @@ struct tile_window_with_static_distribution }); } + template + CK_TILE_DEVICE void update_raw(const static_distributed_tensor& dstr_tensor, + number = {}, + bool_constant = {}, + bool_constant = {}) const + { + using Traits = load_store_traits; + + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; + + constexpr auto tile_dstr = TileDstr{}; + + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + /// TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + + // read from distributed tensor + vector_t vec_value; + + static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { + constexpr auto idx_ys = generate_tuple( + [&](auto jj) { + return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; + }, + number{}); + + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + + vec_value.template get_as()(j) = + dstr_tensor.get_thread_buffer().template at(); + }); + + // write into bottom tensor + get_bottom_tensor_view().template update_vectorized_elements_raw( + bottom_tensor_thread_coord, + 0, + vec_value, + bool_constant{}, + bool_constant{}); + + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), + idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + } + }); + }); + } + // move thread's botom tensor coordiante // [x0', x1', ... ] ==> [offset] // also move window-origin diff --git a/include/ck_tile/core/tensor/tile_window_linear.hpp b/include/ck_tile/core/tensor/tile_window_linear.hpp index 4b921ec5b9..96a8352c04 100644 --- a/include/ck_tile/core/tensor/tile_window_linear.hpp +++ b/include/ck_tile/core/tensor/tile_window_linear.hpp @@ -432,23 +432,38 @@ struct tile_window_linear CK_TILE_DEVICE static constexpr index_t get_bottom_linear_offset(number) { constexpr auto linear_coord = get_bottom_linear_coordinate(number{}); - // since this is linear offset, we assum bottom X tensor is always linear - constexpr index_t linear_offset = [&]() { - constexpr auto x_idx_ = linear_coord; - constexpr auto x_len_ = TileDstr{}.get_lengths(); - static_assert(x_idx_.size() == x_len_.size()); - constexpr index_t x_dims_ = x_idx_.size(); - index_t cu_stride_ = 1; - index_t cu_offset_ = 0; - static_for<0, x_dims_, 1>{}([&](auto i_) { - auto r_i_ = number{}; - cu_offset_ += x_idx_[r_i_] * cu_stride_; - cu_stride_ *= x_len_[r_i_]; - }); - return cu_offset_; - }(); - - return linear_offset; + constexpr auto is_pure_linear_tensor = + reduce_on_sequence(LinearBottomDims{}, multiplies{}, number<1>{}); + if constexpr(is_pure_linear_tensor) + { + // this case usually is a LDS window, everything is known at compile tile. + // we directly use BottomTensorView transform to compute the offset, in case padding + auto bottom_tensor_coord = + make_tensor_coordinate(BottomTensorView{}.get_tensor_descriptor(), linear_coord); + return bottom_tensor_coord.get_offset(); + } + else + { + // this case usually is a global window, where last dim can be linear + // we hack here, that use the original TileDstr to compute the linear offset + // ... hoping that there is no extra padding between other dims, which make sense + // since that would introduce runtime length (so can't use linear offset) + constexpr index_t linear_offset = [&]() { + constexpr auto x_idx_ = linear_coord; + constexpr auto x_len_ = TileDstr{}.get_lengths(); + static_assert(x_idx_.size() == x_len_.size()); + constexpr index_t x_dims_ = x_idx_.size(); + index_t cu_stride_ = 1; + index_t cu_offset_ = 0; + static_for<0, x_dims_, 1>{}([&](auto i_) { + auto r_i_ = number{}; + cu_offset_ += x_idx_[r_i_] * cu_stride_; + cu_stride_ *= x_len_[r_i_]; + }); + return cu_offset_; + }(); + return linear_offset; + } } CK_TILE_DEVICE constexpr auto get_num_of_access() const { return traits::NumAccess; } @@ -509,6 +524,64 @@ struct tile_window_linear return dst_tensor; } + template + CK_TILE_DEVICE auto load(DstTile& dst_tensor, + number = {}, + bool_constant = {}) const + { + using vector_t = typename traits::vector_t; + using SFC_Ys = typename traits::SFC_Ys; + + constexpr auto tile_dstr = TileDstr{}; + + // auto dst_tensor = make_static_distributed_tensor(tile_dstr); + + auto issue = [&](auto i_access_) { + constexpr auto IAccess = number{}; + + constexpr auto non_linear_id = number{}; + auto bottom_tensor_thread_coord = cached_coords_[non_linear_id]; + auto bottom_tensor_flag = cached_flags_[IAccess]; + + constexpr auto linear_offset = get_bottom_linear_offset(IAccess); + + // read from bottom tensor + const vector_t vec_value = + get_bottom_tensor_view().template get_vectorized_elements( + bottom_tensor_thread_coord, + linear_offset, + bottom_tensor_flag, + bool_constant{}); +#if 1 + // data index [y0, y1, ...] + constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess); + // write into distributed tensor + static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) { + constexpr auto idx_ys = generate_tuple( + [&](auto jj) { + return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj]; + }, + number{}); + + constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + + dst_tensor.get_thread_buffer().template at() = + vec_value.template get_as()[j]; + }); +#else + constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start); + static_assert(d % traits::ScalarPerVector == 0); + + dst_tensor.get_thread_buffer().template get_as()( + number{}) = bit_cast(vec_value); +#endif + }; + + WINDOW_DISPATCH_ISSUE(); + + return dst_tensor; + } + template + CK_TILE_DEVICE void update_raw(const static_distributed_tensor& dstr_tensor, + number = {}, + bool_constant = {}, + bool_constant = {}) const + { + + using vector_t = typename traits::vector_t; + using SFC_Ys = typename traits::SFC_Ys; + + constexpr auto tile_dstr = TileDstr{}; + + // loop over thread tensor space [y0, y1, ...] + auto issue = [&](auto i_access_) { + constexpr auto IAccess = number{}; + constexpr auto non_linear_id = number{}; + auto bottom_tensor_thread_coord = cached_coords_[non_linear_id]; + constexpr auto linear_offset = get_bottom_linear_offset(IAccess); + auto bottom_tensor_flag = cached_flags_[IAccess]; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess); + + // read from distributed tensor + vector_t vec_value; + + static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) { + constexpr auto idx_ys = generate_tuple( + [&](auto jj) { + return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; + }, + number{}); + + constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + + vec_value.template get_as()(j) = + dstr_tensor.get_thread_buffer().template at(); + }); + + // write into bottom tensor + get_bottom_tensor_view().template update_vectorized_elements_raw( + bottom_tensor_thread_coord, + linear_offset, + bottom_tensor_flag, + vec_value, + bool_constant{}, + bool_constant{}); + }; + + WINDOW_DISPATCH_ISSUE(); + } + // move thread's botom tensor coordiante // [x0', x1', ... ] ==> [offset] // also move window-origin diff --git a/include/ck_tile/core/tensor/tile_window_utils.hpp b/include/ck_tile/core/tensor/tile_window_utils.hpp new file mode 100644 index 0000000000..71a72329f8 --- /dev/null +++ b/include/ck_tile/core/tensor/tile_window_utils.hpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/utility.hpp" +#include "ck_tile/core/algorithm/space_filling_curve.hpp" +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/container/sequence.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/tensor/static_distributed_tensor.hpp" +#include "ck_tile/core/tensor/tensor_adaptor.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +#pragma once +namespace ck_tile { + +// input a lds store tile, extract some information from it +// used to set m0 value for gfx9 serious +template +CK_TILE_DEVICE auto get_async_store_smem_info(LdsTileWindow_&& lds_tile) +{ + using LdsTileWindow = remove_cvref_t; + using LdsDataType = typename LdsTileWindow::DataType; + + // issues * warps * lanes + static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded + + const index_t size_per_buf = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<0>{}, number<0>{}, number<0>{})) * + sizeof(LdsDataType); + + const index_t size_per_wave = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<0>{}, number<1>{}, number<0>{})) * + sizeof(LdsDataType) - + size_per_buf; + + const index_t size_per_issue = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<1>{}, number<0>{}, number<0>{})) * + sizeof(LdsDataType) - + size_per_buf; + + const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); + + return make_tuple(m0_init_value, size_per_issue); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/update_tile.hpp b/include/ck_tile/core/tensor/update_tile.hpp index fbce7c4083..570abde189 100644 --- a/include/ck_tile/core/tensor/update_tile.hpp +++ b/include/ck_tile/core/tensor/update_tile.hpp @@ -41,15 +41,65 @@ template + typename DataType_, + index_t i_access = -1, + bool oob_conditional_check = true> CK_TILE_DEVICE void update_tile(tile_window_with_static_distribution& tile_window, - const static_distributed_tensor& dstr_tensor) + const static_distributed_tensor& dstr_tensor, + number = {}, + bool_constant = {}) { - tile_window.update(dstr_tensor); + tile_window.update(dstr_tensor, number{}, bool_constant{}); +} + +template +CK_TILE_DEVICE void +update_tile_raw(tile_window_with_static_distribution& tile_window, + const static_distributed_tensor& dstr_tensor, + number = {}, + bool_constant = {}, + bool_constant = {}) +{ + tile_window.update_raw(dstr_tensor, + number{}, + bool_constant{}, + bool_constant{}); +} + +template +CK_TILE_DEVICE auto update_tile_raw( + tile_window_linear& + tile_window, + const static_distributed_tensor& dstr_tensor, + number = {}, + bool_constant = {}, + bool_constant = {}) +{ + tile_window.update_raw(dstr_tensor, + number{}, + bool_constant{}, + bool_constant{}); } } // namespace ck_tile diff --git a/include/ck_tile/core/utility/static_counter.hpp b/include/ck_tile/core/utility/static_counter.hpp new file mode 100644 index 0000000000..84af3dd52f --- /dev/null +++ b/include/ck_tile/core/utility/static_counter.hpp @@ -0,0 +1,116 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" + +namespace ck_tile { + +template +struct static_counter +{ + public: + template + static constexpr index_t next() + { + return next(0) * Step + Start; + } + + template + static constexpr index_t next() + { + struct Unique + { + }; + return next(0) * Step + Start; + } + + template + static constexpr index_t current() + { + return current(0) * Step + Start; + } + + template + static constexpr index_t current() + { + struct Unique + { + }; + return current(0) * Step + Start; + } + + private: + template + struct slot + { + _Pragma("GCC diagnostic push"); + _Pragma("GCC diagnostic ignored \"-Wundefined-internal\""); + friend constexpr bool slot_allocated(slot); + _Pragma("GCC diagnostic pop"); + }; + + template + struct allocate_slot + { + friend constexpr bool slot_allocated(slot) { return true; } + enum + { + value = I + }; + }; + + // If slot_allocated(slot) has NOT been defined, then SFINAE will keep this function out of + // the overload set... + template ())> + static constexpr index_t next(index_t) + { + return next(0); + } + + // ...And this function will be used, instead, which will define slot_allocated(slot) via + // allocate_slot. + template + static constexpr index_t next(double) + { + return allocate_slot::value; + } + + // If slot_allocated(slot) has NOT been defined, then SFINAE will keep this function out of + // the overload set... + template ())> + static constexpr index_t current(index_t) + { + return current(0); + } + + // ...And this function will be used, instead, which will return the current counter, or assert + // in case next() hasn't been called yet. + template + static constexpr index_t current(double) + { + static_assert(I != 0, "You must invoke next() first"); + + return I - 1; + } +}; + +namespace impl { +template +struct static_counter_uniq_; +} + +#define MAKE_SC() \ + ck_tile::static_counter> {} +#define MAKE_SC_WITH(start_, step_) \ + ck_tile::static_counter, start_, step_> {} +#define NEXT_SC(c_) c_.next<__COUNTER__>() +#define NEXT_SCI(c_, static_i_) c_.next<__COUNTER__ + static_i_>() + +// Usage: +// constexpr auto c = MAKE_SC() +// NEXT_SC(c) // -> constexpr 0 +// NEXT_SC(c) // -> constexpr 1 +// NEXT_SC(c) // -> constexpr 2 +} // namespace ck_tile diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index 2e96009ace..2f3a302eea 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -11,6 +11,7 @@ #include "ck_tile/host/fill.hpp" #include "ck_tile/host/hip_check_error.hpp" #include "ck_tile/host/host_tensor.hpp" +#include "ck_tile/host/joinable_thread.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/ranges.hpp" #include "ck_tile/host/reference/reference_batched_dropout.hpp" @@ -20,6 +21,7 @@ #include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp" #include "ck_tile/host/reference/reference_batched_softmax.hpp" #include "ck_tile/host/reference/reference_elementwise.hpp" +#include "ck_tile/host/reference/reference_fused_moe.hpp" #include "ck_tile/host/reference/reference_gemm.hpp" #include "ck_tile/host/reference/reference_im2col.hpp" #include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp" diff --git a/include/ck_tile/host/device_memory.hpp b/include/ck_tile/host/device_memory.hpp index 7c8549f74f..13684c0e24 100644 --- a/include/ck_tile/host/device_memory.hpp +++ b/include/ck_tile/host/device_memory.hpp @@ -7,6 +7,7 @@ #include #include #include "ck_tile/host/hip_check_error.hpp" +#include "ck_tile/host/host_tensor.hpp" namespace ck_tile { template @@ -36,6 +37,19 @@ struct DeviceMem mpDeviceBuf = nullptr; } } + template + DeviceMem(const HostTensor& t) : mMemSize(t.get_element_space_size_in_bytes()) + { + if(mMemSize != 0) + { + HIP_CHECK_ERROR(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); + } + else + { + mpDeviceBuf = nullptr; + } + ToDevice(t.data()); + } void Realloc(std::size_t mem_size) { if(mpDeviceBuf) @@ -92,6 +106,27 @@ struct DeviceMem HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost)); } } + + // construct a host tensor with type T + template + HostTensor ToHost(std::size_t cpySize) + { + // TODO: host tensor could be slightly larger than the device tensor + // we just copy all data from GPU buffer + std::size_t host_elements = (cpySize + sizeof(T) - 1) / sizeof(T); + HostTensor h_({host_elements}); + if(mpDeviceBuf) + { + HIP_CHECK_ERROR(hipMemcpy(h_.data(), mpDeviceBuf, cpySize, hipMemcpyDeviceToHost)); + } + return h_; + } + template + HostTensor ToHost() + { + return ToHost(mMemSize); + } + void SetZero() const { if(mpDeviceBuf) diff --git a/include/ck_tile/host/fill.hpp b/include/ck_tile/host/fill.hpp index 335911860a..f24c338755 100644 --- a/include/ck_tile/host/fill.hpp +++ b/include/ck_tile/host/fill.hpp @@ -13,6 +13,7 @@ #include #include "ck_tile/core.hpp" +#include "ck_tile/host/joinable_thread.hpp" namespace ck_tile { @@ -22,13 +23,44 @@ struct FillUniformDistribution float a_{-5.f}; float b_{5.f}; std::optional seed_{11939}; + // ATTENTION: threaded does not guarantee the distribution between thread + bool threaded = false; template void operator()(ForwardIter first, ForwardIter last) const { - std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}()); - std::uniform_real_distribution dis(a_, b_); - std::generate(first, last, [&dis, &gen]() { return ck_tile::type_convert(dis(gen)); }); + if(threaded) + { + uint32_t num_thread = std::thread::hardware_concurrency(); + auto total = static_cast(std::distance(first, last)); + auto work_per_thread = static_cast((total + num_thread - 1) / num_thread); + + std::vector threads(num_thread); + for(std::size_t it = 0; it < num_thread; ++it) + { + std::size_t iw_begin = it * work_per_thread; + std::size_t iw_end = std::min((it + 1) * work_per_thread, total); + auto thread_f = [this, total, iw_begin, iw_end, &first] { + if(iw_begin > total || iw_end > total) + return; + // need to make each thread unique, add an offset to current seed + std::mt19937 gen(seed_.has_value() ? (*seed_ + iw_begin) + : std::random_device{}()); + std::uniform_real_distribution dis(a_, b_); + std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() { + return ck_tile::type_convert(dis(gen)); + }); + }; + threads[it] = joinable_thread(thread_f); + } + } + else + { + std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}()); + std::uniform_real_distribution dis(a_, b_); + std::generate( + first, last, [&dis, &gen]() { return ck_tile::type_convert(dis(gen)); }); + } } template @@ -115,13 +147,44 @@ struct FillNormalDistribution float mean_{0.f}; float variance_{1.f}; std::optional seed_{11939}; + // ATTENTION: threaded does not guarantee the distribution between thread + bool threaded = false; template void operator()(ForwardIter first, ForwardIter last) const { - std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}()); - std::normal_distribution dis(mean_, std::sqrt(variance_)); - std::generate(first, last, [&dis, &gen]() { return ck_tile::type_convert(dis(gen)); }); + if(threaded) + { + uint32_t num_thread = std::thread::hardware_concurrency(); + auto total = static_cast(std::distance(first, last)); + auto work_per_thread = static_cast((total + num_thread - 1) / num_thread); + + std::vector threads(num_thread); + for(std::size_t it = 0; it < num_thread; ++it) + { + std::size_t iw_begin = it * work_per_thread; + std::size_t iw_end = std::min((it + 1) * work_per_thread, total); + auto thread_f = [this, total, iw_begin, iw_end, &first] { + if(iw_begin > total || iw_end > total) + return; + // need to make each thread unique, add an offset to current seed + std::mt19937 gen(seed_.has_value() ? (*seed_ + iw_begin) + : std::random_device{}()); + std::normal_distribution dis(mean_, std::sqrt(variance_)); + std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() { + return ck_tile::type_convert(dis(gen)); + }); + }; + threads[it] = joinable_thread(thread_f); + } + } + else + { + std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}()); + std::normal_distribution dis(mean_, std::sqrt(variance_)); + std::generate( + first, last, [&dis, &gen]() { return ck_tile::type_convert(dis(gen)); }); + } } template @@ -235,6 +298,44 @@ struct FillMonotonicSeq } }; +template +struct FillStepRange +{ + float start_value_{0}; + float end_value_{3}; + float step_{1}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::generate(first, last, [=, n = start_value_]() mutable { + auto tmp = n; + n += step_; + if constexpr(IsAscending) + { + if(n > end_value_) + n = start_value_; + } + else + { + if(n < end_value_) + n = start_value_; + } + + return type_convert(tmp); + }); + } + + template + auto operator()(ForwardRange&& range) const -> std::void_t< + decltype(std::declval()(std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + template struct FillConstant { diff --git a/include/ck_tile/host/host_tensor.hpp b/include/ck_tile/host/host_tensor.hpp index 5610ba324d..3902cad178 100644 --- a/include/ck_tile/host/host_tensor.hpp +++ b/include/ck_tile/host/host_tensor.hpp @@ -8,12 +8,13 @@ #include #include #include -#include #include #include #include +#include #include "ck_tile/core.hpp" +#include "ck_tile/host/joinable_thread.hpp" #include "ck_tile/host/ranges.hpp" namespace ck_tile { @@ -213,23 +214,6 @@ CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old return HostTensorDescriptor(new_lengths, new_strides); } -struct joinable_thread : std::thread -{ - template - joinable_thread(Xs&&... xs) : std::thread(std::forward(xs)...) - { - } - - joinable_thread(joinable_thread&&) = default; - joinable_thread& operator=(joinable_thread&&) = default; - - ~joinable_thread() - { - if(this->joinable()) - this->join(); - } -}; - template struct ParallelTensorFunctor { @@ -590,6 +574,107 @@ struct HostTensor size() * FromSize / ToSize}; } + friend std::ostream& operator<<(std::ostream& os, const HostTensor& t) + { + os << t.mDesc; + os << "["; + for(typename Data::size_type idx = 0; idx < t.mData.size(); ++idx) + { + if(0 < idx) + { + os << ", "; + } + if constexpr(std::is_same_v || std::is_same_v) + { + os << type_convert(t.mData[idx]) << " #### "; + } + else + { + os << t.mData[idx]; + } + } + os << "]"; + return os; + } + + // read data from a file, as dtype + // the file could dumped from torch as (targeting tensor is t here) + // numpy.savetxt("f.txt", t.view(-1).numpy()) + // numpy.savetxt("f.txt", t.cpu().view(-1).numpy()) # from cuda to cpu to save + // numpy.savetxt("f.txt", t.cpu().view(-1).numpy(), fmt="%d") # save as int + // will output f.txt, each line is a value + // dtype=float or int, internally will cast to real type + void loadtxt(std::string file_name, std::string dtype = "float") + { + std::ifstream file(file_name); + + if(file.is_open()) + { + std::string line; + + index_t cnt = 0; + while(std::getline(file, line)) + { + if(cnt >= static_cast(mData.size())) + { + throw std::runtime_error(std::string("data read from file:") + file_name + + " is too big"); + } + + if(dtype == "float") + { + mData[cnt] = type_convert(std::stof(line)); + } + else if(dtype == "int" || dtype == "int32") + { + mData[cnt] = type_convert(std::stoi(line)); + } + cnt++; + } + file.close(); + if(cnt < static_cast(mData.size())) + { + std::cerr << "Warning! reading from file:" << file_name + << ", does not match the size of this tensor" << std::endl; + } + } + else + { + // Print an error message to the standard error + // stream if the file cannot be opened. + throw std::runtime_error(std::string("unable to open file:") + file_name); + } + } + + // can save to a txt file and read from torch as: + // torch.from_numpy(np.loadtxt('f.txt', dtype=np.int32/np.float32...)).view([...]).contiguous() + void savetxt(std::string file_name, std::string dtype = "float") + { + std::ofstream file(file_name); + + if(file.is_open()) + { + for(auto& itm : mData) + { + if(dtype == "float") + file << type_convert(itm) << std::endl; + else if(dtype == "int") + file << type_convert(itm) << std::endl; + else + // TODO: we didn't implement operator<< for all custom + // data types, here fall back to float in case compile error + file << type_convert(itm) << std::endl; + } + file.close(); + } + else + { + // Print an error message to the standard error + // stream if the file cannot be opened. + throw std::runtime_error(std::string("unable to open file:") + file_name); + } + } + Descriptor mDesc; Data mData; }; diff --git a/include/ck_tile/host/joinable_thread.hpp b/include/ck_tile/host/joinable_thread.hpp new file mode 100644 index 0000000000..a822f967dc --- /dev/null +++ b/include/ck_tile/host/joinable_thread.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +namespace ck_tile { + +struct joinable_thread : std::thread +{ + template + joinable_thread(Xs&&... xs) : std::thread(std::forward(xs)...) + { + } + + joinable_thread(joinable_thread&&) = default; + joinable_thread& operator=(joinable_thread&&) = default; + + ~joinable_thread() + { + if(this->joinable()) + this->join(); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_fused_moe.hpp b/include/ck_tile/host/reference/reference_fused_moe.hpp new file mode 100644 index 0000000000..bf89f92759 --- /dev/null +++ b/include/ck_tile/host/reference/reference_fused_moe.hpp @@ -0,0 +1,196 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +namespace ck_tile { +// [indexing implementation-1] +// using M_a as constexpr block_size to partition all tokens into different slices +// each slice map to one expert, and one expert can have multiple slices +// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5 +// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]] +// tok-0 tok-1 tok-2 tok-3 tok-4 +// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float +// number) +// +// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]] +// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 +// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] +// +// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1) +// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated) +// * this could be larger than actual, since actual tokens are on GPU +// +// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, +// 0, 1, 2, 5] +// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 +// -|- exp-5 -| +// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, +// c, f, i, o] +// +// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr +// +// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5] +// * length is (max_num_tokens_padded + block_size - 1) / block_size +/// +// num_tokens_post_padded_ptr : [28] +// num_sorted_tiles_ptr : [7] + +template +void reference_fused_moe( + const ck_tile::HostTensor& a_host, // [tokens, hidden_size] + const ck_tile::HostTensor& g_host, // [experts, interme_size_0, hidden_size] + const ck_tile::HostTensor& d_host, // [experts, hidden_size, interme_size_1] + const ck_tile::HostTensor& sa_host, // [tokens, 1], + const ck_tile::HostTensor& sg_host, // [experts, 1, interme_size_0] + const ck_tile::HostTensor& sd_host, // [experts, 1, hidden_size], + const ck_tile::HostTensor& sy_host, // [experts, 1, interme_size_0] + ck_tile::HostTensor& o_host, // [tokens, hidden_size] + const ck_tile::HostTensor& sorted_token_ids_host, // [max_num_tokens_padded] + const ck_tile::HostTensor& sorted_weight_host, // [max_num_tokens_padded] + const ck_tile::HostTensor& + sorted_expert_ids_host, // [(max_num_tokens_padded + block_size - 1) / block_size] + const ck_tile::HostTensor& num_sorted_tiles_host, // [1] + + const ck_tile::HostTensor& + token_ids_host, // [tokens, topk] --> ugly!!! remove in the future + + ck_tile::index_t block_m, + ck_tile::index_t tokens, + ck_tile::index_t experts, + ck_tile::index_t hidden_size, + ck_tile::index_t intermediate_size, // this size is for gate/up + ck_tile::index_t topk, + ck_tile::index_t gate_only) +{ + assert(sorted_token_ids_host.get_num_of_dimension() == 1); + assert(sorted_weight_host.get_num_of_dimension() == 1); + assert(sorted_expert_ids_host.get_num_of_dimension() == 1); + assert(num_sorted_tiles_host.get_element_size() == 1); + ck_tile::index_t num_sorted_tiles = num_sorted_tiles_host.mData[0] / block_m; + ck_tile::index_t intermediate_size_0 = intermediate_size; + ck_tile::index_t intermediate_size_1 = intermediate_size / (gate_only ? 1 : 2); + + // TODO: better remove this in the future, or modify the token_id value + auto get_topk_id = [&](ck_tile::index_t token_id_, ck_tile::index_t expert_id_) { + for(ck_tile::index_t i_ = 0; i_ < topk; i_++) + { + if(token_ids_host(token_id_, i_) == expert_id_) + return i_; + } + throw std::runtime_error("not correct token/expert pair\n"); + return -1; // TODO: not correct!! + }; + + ck_tile::HostTensor out_topk_tokens({tokens, topk, hidden_size}); + + int max_num_tokens_padded = topk * tokens + experts * block_m - topk; + // assert(); + auto f = [&](auto i_flatten) { + ck_tile::index_t i_tile = i_flatten / block_m; + if(i_tile >= num_sorted_tiles) + return; + ck_tile::index_t i_expert = sorted_expert_ids_host.mData[i_tile]; + ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten]; + if(i_token >= tokens) + return; + ck_tile::index_t i_topk = get_topk_id(i_token, i_expert); // TODO: ugly + auto weight = sorted_weight_host.mData[i_flatten]; + + ck_tile::HostTensor acc_0({1, intermediate_size_0}); + // first gemm + for(ck_tile::index_t i_n = 0; i_n < intermediate_size_0; i_n++) + { + AccDataType acc = static_cast(0); + for(ck_tile::index_t i_k = 0; i_k < hidden_size; i_k++) + { + acc += type_convert(a_host(i_token, i_k)) * + type_convert(g_host(i_expert, i_n, i_k)); + } + acc_0(0, i_n) = acc; + // printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, acc); + } + + ck_tile::HostTensor y({1, intermediate_size_1}); + if(gate_only) + { + if(intermediate_size_1 != intermediate_size_0) + throw std::runtime_error( + "intermediate_size not correct, 0:" + std::to_string(intermediate_size_0) + + ", 1:" + std::to_string(intermediate_size_1)); + for(ck_tile::index_t i_n = 0; i_n < intermediate_size_1; i_n++) + { + Activation{}(y(0, i_n), acc_0(0, i_n)); + // printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, y(0, i_n)); + } + } + else + { + if(intermediate_size_1 * 2 != intermediate_size_0) + throw std::runtime_error( + "intermediate_size not correct, 0:" + std::to_string(intermediate_size_0) + + ", 1:" + std::to_string(intermediate_size_1)); + for(ck_tile::index_t i_n = 0; i_n < intermediate_size_1; i_n++) + { + AccDataType tmp; + Activation{}(tmp, acc_0(0, i_n)); + y(0, i_n) = tmp * acc_0(0, i_n + intermediate_size_1); // TODO: elementwise mul + } + } + + // second gemm, loop along gemm-n + ck_tile::HostTensor acc_1({1, hidden_size}); + for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++) + { + AccDataType acc = static_cast(0); + for(ck_tile::index_t i_k = 0; i_k < intermediate_size_1; i_k++) + { + acc += y(0, i_k) * type_convert(d_host(i_expert, i_n, i_k)); + } + acc_1(0, i_n) = acc * weight; // multiple weight here + } + + for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++) + { + out_topk_tokens(i_token, i_topk, i_n) = acc_1(0, i_n); + } + }; + + // make_ParallelTensorFunctor(f, max_num_tokens_padded)(std::thread::hardware_concurrency()); + make_ParallelTensorFunctor(f, max_num_tokens_padded)(1); + + // reduce + auto r = [&](auto i_token) { + for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++) + { + AccDataType acc = type_convert(0); + for(ck_tile::index_t i_topk = 0; i_topk < topk; i_topk++) + { + acc += out_topk_tokens(i_token, i_topk, i_n); + } + o_host(i_token, i_n) = type_convert(acc); + } + }; + make_ParallelTensorFunctor(r, tokens)(std::thread::hardware_concurrency()); + + (void)num_sorted_tiles_host; + (void)sa_host; + (void)sg_host; + (void)sd_host; + (void)sy_host; +} +} // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_permute.hpp b/include/ck_tile/host/reference/reference_permute.hpp index 14ed4f815e..4e0f1a877e 100644 --- a/include/ck_tile/host/reference/reference_permute.hpp +++ b/include/ck_tile/host/reference/reference_permute.hpp @@ -16,7 +16,7 @@ namespace ck_tile { */ template CK_TILE_HOST void -reference_permute(const HostTensor& x, HostTensor& y, std::vector dims) +reference_permute(const HostTensor& x, HostTensor& y, std::vector perm) { const auto x_len = x.mDesc.get_lengths(); const auto y_len = y.mDesc.get_lengths(); @@ -43,7 +43,7 @@ reference_permute(const HostTensor& x, HostTensor& y, std::v std::vector tmp(rank, 0); for(index_t i = 0; i < rank; i++) { - tmp[dims[i]] = y_coord[i]; + tmp[perm[i]] = y_coord[i]; } return tmp; }(); @@ -54,4 +54,23 @@ reference_permute(const HostTensor& x, HostTensor& y, std::v make_ParallelTensorFunctor(f, x_elm)(std::thread::hardware_concurrency()); } + +template +CK_TILE_HOST auto reference_permute(const HostTensor& x, std::vector perm) +{ + auto x_shape = x.get_lengths(); + ck_tile::index_t rank = perm.size(); + std::vector y_shape = [&]() { + std::vector tmp(rank, 0); + for(int i = 0; i < static_cast(rank); i++) + { + tmp[i] = x_shape[perm[i]]; + } + return tmp; + }(); + + HostTensor y(y_shape); + reference_permute(x, y, perm); + return y; +} } // namespace ck_tile diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index 01217e16ce..e24b1ba767 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -572,6 +572,105 @@ struct FastGelu } }; +struct FastGeluAsm +{ + template + CK_TILE_HOST void operator()(Y& y, const X& x) const; + + template + CK_TILE_DEVICE void operator()(Y& y, const X& x) const; + + template <> + CK_TILE_HOST void operator()(float& y, const float& x) const + { + // const float u = -2.f * x * (0.035677f * x * x + 0.797885f); + const float c1 = -2.0 * 0.035677f; + const float c2 = -2.0 * 0.797885f; + const float u = x * (c1 * x * x + c2); + const float emu = exp(u); + y = x / (1.f + emu); + } + + // device code, use lower precision "__ocml_exp_f32" and "rcp" + template <> + CK_TILE_DEVICE void operator()(float& y, const float& x) const + { + const uint32_t c1 = 0xbd92220c; // -2.0 * 0.035677f; + const float c2 = -2.0 * 0.797885f; + const uint32_t log2e_ = 0x3fb8aa3b; // log2e_v; + float tmp; + + asm volatile("v_mul_f32 %[v_tmp], %[v_x], %[v_x] ; x*x\n" + "v_fma_f32 %[v_tmp], %[v_tmp], %[s_c1], %[v_c2] ; c1*x*x+c2\n" + "v_mul_f32 %[v_tmp], %[v_tmp], %[v_x] ; x*(c1*x*x+c2)\n" + "v_mul_f32 %[v_tmp], %[v_tmp], %[s_log2e] ; log2e*x*(c1*x*x+c2)\n" + "v_exp_f32 %[v_tmp], %[v_tmp] ; emu = exp2(log2e*x*(c1*x*x+c2))\n" + "s_nop 0 ; hazard for exp\n" + "v_add_f32 %[v_tmp], %[v_tmp], 1.0 ; emu+1.0f\n" + "v_rcp_f32 %[v_tmp], %[v_tmp] ; 1/(emu+1.0f)\n" + "s_nop 0 ; hazard for rcp \n" + "v_mul_f32 %[v_y], %[v_tmp], %[v_x] ; x * 1/(emu+1f)\n" + : [v_y] "=v"(y), [v_tmp] "+v"(tmp) + : [v_x] "v"(x), [s_c1] "s"(c1), [v_c2] "v"(c2), [s_log2e] "s"(log2e_) + :); + } + + template <> + CK_TILE_HOST void operator()(fp32x2_t& y, const fp32x2_t& x) const + { + const float c1 = -2.0 * 0.035677f; + const float c2 = -2.0 * 0.797885f; + const float u0 = x.x * (c1 * x.x * x.x + c2); + const float emu0 = exp(u0); + y.x = x.x / (1.f + emu0); + const float u1 = x.y * (c1 * x.y * x.y + c2); + const float emu1 = exp(u1); + y.y = x.y / (1.f + emu1); + } + + // this is packed verion to remove data hazard for trans + template <> + CK_TILE_DEVICE void operator()(fp32x2_t& y, const fp32x2_t& x) const + { + const uint32_t c1 = 0xbd92220c; // -2.0 * 0.035677f; + float c2 = -2.0 * 0.797885f; + const uint32_t log2e_ = 0x3fb8aa3b; // log2e_v; + float tmp0, tmp1; + float y0 = x.x, y1 = x.y; + + asm volatile( + "v_mul_f32 %[v_tmp0], %[v_y0], %[v_y0] ; x*x\n" + "v_mul_f32 %[v_tmp1], %[v_y1], %[v_y1] ; x*x\n" + "v_fma_f32 %[v_tmp0], %[v_tmp0], %[s_c1], %[v_c2] ; c1*x*x+c2\n" + "v_fma_f32 %[v_tmp1], %[v_tmp1], %[s_c1], %[v_c2] ; c1*x*x+c2\n" + "v_mul_f32 %[v_tmp0], %[v_tmp0], %[v_y0] ; x*(c1*x*x+c2)\n" + "v_mul_f32 %[v_tmp1], %[v_tmp1], %[v_y1] ; x*(c1*x*x+c2)\n" + "v_mul_f32 %[v_tmp0], %[v_tmp0], %[s_log2e] ; log2e*x*(c1*x*x+c2)\n" + "v_mul_f32 %[v_tmp1], %[v_tmp1], %[s_log2e] ; log2e*x*(c1*x*x+c2)\n" + "v_exp_f32 %[v_tmp0], %[v_tmp0] ; emu = exp2(log2e*x*(c1*x*x+c2))\n" + "v_exp_f32 %[v_tmp1], %[v_tmp1] ; emu = exp2(log2e*x*(c1*x*x+c2))\n" + "v_add_f32 %[v_tmp0], %[v_tmp0], 1.0 ; emu+1.0f\n" + "v_add_f32 %[v_tmp1], %[v_tmp1], 1.0 ; emu+1.0f\n" + "v_rcp_f32 %[v_tmp0], %[v_tmp0] ; 1/(emu+1.0f)\n" + "v_rcp_f32 %[v_tmp1], %[v_tmp1] ; 1/(emu+1.0f)\n" + "v_mul_f32 %[v_y0], %[v_tmp0], %[v_y0] ; x * 1/(emu+1f)\n" + "v_mul_f32 %[v_y1], %[v_tmp1], %[v_y1] ; x * 1/(emu+1f)\n" + : [v_y0] "+v"(y0), + [v_y1] "+v"(y1), + [v_c2] "+v"(c2), + // NOTE! it is totally possible that c2/y0/y1 share same register, they are all local + // tmp variables we need to expicitly hint compiler they may read+write, to allow + // allocate different register , the side effect is c2=** may issue for every such + // inline asm block + [v_tmp0] "+v"(tmp0), + [v_tmp1] "+v"(tmp1) + : [s_c1] "s"(c1), [s_log2e] "s"(log2e_) + :); + y.x = y0; + y.y = y1; + } +}; + // https://paperswithcode.com/method/gelu // y = 0.5*x*(1+erf(x/sqrt(2))) struct Gelu diff --git a/include/ck_tile/ops/flatmm.hpp b/include/ck_tile/ops/flatmm.hpp new file mode 100644 index 0000000000..eee80cda4a --- /dev/null +++ b/include/ck_tile/ops/flatmm.hpp @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp" +#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp" +#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp b/include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp new file mode 100644 index 0000000000..f5c7caf7df --- /dev/null +++ b/include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp @@ -0,0 +1,615 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp" + +namespace ck_tile { + +// A async load to LDS, B direct to AGPR +// B matrix preshuffled in br*kr*w +// require 4 wave, occupancy=1c +// agpr useage:256 +// vgpr usage:64(A local) + 64(acc) + 8(os_a) + 8(os_b) = 144 (rem:112) +// +// for this gemm, 4 16x16x16 transposed layout +// input A vpgpr layout +// v0-v15: [ 0:15](gemm_m)x128(gemm_k) +// v16-v31: [16:31](gemm_m)x128(gemm_k) + +// input B vpgpr layout +// v0-v15: [ 0: 15](gemm_n)x128(gemm_k) +// v16-v31: [ 64: 79](gemm_n)x128(gemm_k) +// ...................... +// v111-v127: [448:463](gemm_n)x128(gemm_k) + +// output C vpgpr layout +// v0-v3 : [ 0:15](gemm_m)x[ 0: 15](gemm_n) +// v4-v7 : [16:31](gemm_m)x[ 0: 15](gemm_n) +// v8-v11: [ 0:15](gemm_m)x[64: 79](gemm_n) +// v12-v15: [16:31](gemm_m)x[64: 79](gemm_n) +// ...................... +// v56-v59: [ 0:15](gemm_m)x[448:463](gemm_n) +// v60-v63: [16:31](gemm_m)x[448:463](gemm_n) +struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16 +{ + static constexpr index_t Block_M = 32; + static constexpr index_t Block_N = 512; + static constexpr index_t Block_K = 128; + + static constexpr index_t WarpPerBlock_M = 1; + static constexpr index_t WarpPerBlock_N = 4; + static constexpr index_t WarpPerBlock_K = 1; + + static constexpr index_t NumWarps = 4; + + static constexpr index_t Warp_M = 16; + static constexpr index_t Warp_N = 16; + static constexpr index_t Warp_K = 32; // 16 * SubKPacks + + static constexpr index_t BlockSize = 256; + + static constexpr index_t SubKPacks = 2; // this is used to gurantee every threads can do dwordx4 + + // TODO: note Nr/Kr/W need consider SubKPacks + static constexpr index_t Block_W = Warp_N * Warp_K; // 512 element + static constexpr index_t Block_Nr = Block_N / Warp_N; // 32 element, 4 per wave + static constexpr index_t Block_Kr = Block_K / Warp_K; // 4 + + static constexpr index_t Repeat_M = Block_M / (Warp_M * WarpPerBlock_M); // 2 + static constexpr index_t Repeat_N = Block_N / (Warp_N * WarpPerBlock_N); // 8 + static constexpr index_t Repeat_K = Block_K / (Warp_K * WarpPerBlock_K); // 8/2=4 + + static CK_TILE_DEVICE constexpr auto MakeCBlockDist() + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2, 1>, // !! note here is different + sequence<0, 0>>{}; + + using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + return c_block_dstr; + } + + static CK_TILE_DEVICE constexpr auto MakeCBlockTile() + { + using CDataType = float; + constexpr auto c_block_dstr = MakeCBlockDist(); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A() + { + // A async->LDS + // constexpr index_t Block_M = Problem::BlockShape::Block_M0; + // constexpr index_t Block_K = Problem::BlockShape::Block_K0; + // constexpr index_t BlockSize = Problem::BlockShape::BlockSize; + constexpr index_t warpSize = ck_tile::get_warp_size(); + // constexpr index_t NumWarps = Problem::BlockShape::NumWarps; + + constexpr index_t KPack_ = 8; // GetSmemKPack_A(); // LDS + constexpr index_t KVector = 2; // GetAlignment_A(); // async copy 1 dword + constexpr index_t KPad = KPack_; // pad between warps + + static_assert(Block_K % KVector == 0); + constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K + if constexpr(LanesPerK >= warpSize) + { + // need multiple waves to load K + static_assert(LanesPerK % warpSize == 0); + constexpr index_t wavesPerK = LanesPerK / warpSize; + if constexpr(wavesPerK > NumWarps) + { + // TODO: need multiple issues along K to load all data + } + else + { + constexpr index_t wavesPerM = NumWarps / wavesPerK; + constexpr index_t NumIssues = Block_M / wavesPerM; + constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // k0 + number{}, // k1 + number{}), // k2 + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // k0 + number{}, // k1 + number<1>{}), // k2 + number{}, // lds store vector(actually no explicit store) + number<1>{}); + + constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( + lds_block_desc_0, + make_tuple( + make_pass_through_transform(number{}), + make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + return lds_block_desc_issues_warps_lanes; + } + } + else + { + // lanes within a wave load different M but same K + static_assert(warpSize % LanesPerK == 0); + constexpr index_t LaneGroups = warpSize / LanesPerK; // along m + constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps); + + constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // m2 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // m2 + number{}, // k0 + number<1>{}), // k1 + number{}, // lds store vector(actually no explicit store) + number<1>{}); + + constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( + lds_block_desc_0, + make_tuple(make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_merge_transform(make_tuple( + number{}, number{}, number{}))), + make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + return lds_block_desc_issues_warps_lanes; + } + } + + // template + CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadDesc_A() + { + // load from LDS to register, every wave has same layout + constexpr index_t KPack_ = 8; // GetSmemKPack_A(); // LDS + constexpr index_t KPad = KPack_; // pad between warps + + constexpr index_t kAMLane = 16; + constexpr index_t kABKLane = 4; + constexpr index_t kABKPerLane = 4; + constexpr index_t kKIter = 2; + static_assert(KPack_ == (kABKPerLane * kKIter)); + + constexpr auto lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(number{}, // m0 y + number{}, // m1 p + number{}, // k0 y + number{}, // k1 p + number{}), // k2 y-vector + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // k0 + number{}, // k1 + number<1>{}), // k2 + number{}, // lds load vector + number<1>{}); + + constexpr auto lds_desc_m_k = transform_tensor_descriptor( + lds_block_desc_0, + make_tuple(make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform( + make_tuple(number{}, number{}, number{}))), + make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return lds_desc_m_k; + } + + static constexpr auto GetGemm_AWarpEnc() + { + constexpr index_t kAMLane = 16; + constexpr index_t kABKLane = 4; + constexpr index_t kABKPerLane = 4; + constexpr index_t kKIter = 2; + + using enc_ = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + return enc_{}; + } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return 32 * (128 + 8) * sizeof(bf16_t); + } +}; + +struct Flatmm_32x512x128_1x4x1_16x16x32_BF16 : public Flatmm_32x512x128_1x4x1_16x16x32_Base +{ + using ADataType = bf16_t; + using BDataType = bf16_t; + + // TODO: need paired with tile_window_linear! + // TODO: need call init_raw() before call this function! + template + CK_TILE_DEVICE auto + operator()(const ARes& res_a, + const ACoords& cached_coords_a, + const BRes& res_b, + const BCoords& cached_coords_b, + CK_TILE_LDS_ADDR void* smem, + index_t k, + index_t tile_offset_a, // for each tile, the offset to move for each unroll + index_t tile_offset_b) // for each tile, the offset to move for each unroll + { + static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 2 /*2x per dword*/); // 8 + static_assert(BCoords::size() == Repeat_N); + + auto a_sst = make_tile_window( + make_tensor_view( + reinterpret_cast(smem), MakeLdsStoreDesc_A()), + MakeLdsStoreDesc_A().get_lengths(), + {0, 0, 0}); + + auto a_sld = [&]() { + constexpr auto a_warp_enc_ = GetGemm_AWarpEnc(); + constexpr auto a_outer_dstr_enc = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto a_block_dstr_encode = + detail::make_embed_tile_distribution_encoding(a_outer_dstr_enc, a_warp_enc_); + return make_tile_window_linear( + make_tensor_view( + reinterpret_cast(smem), MakeLdsLoadDesc_A()), + MakeLdsLoadDesc_A().get_lengths(), + {0, 0}, + make_static_tile_distribution(a_block_dstr_encode)); + }(); + + const index_t tile_offset_a_bytes = tile_offset_a * sizeof(ADataType); + const index_t tile_offset_b_bytes = tile_offset_b * sizeof(BDataType); + + const auto [m0_init_value, size_per_issue] = get_async_store_smem_info(a_sst); + constexpr auto smem_buf_size = + MakeLdsLoadDesc_A().get_element_space_size() * sizeof(ADataType); + static_assert(a_sld.get_num_of_access() == 8); + constexpr auto sld_os = generate_tuple( + [&](auto i_access) { + return number{}; + }, + number{}); + + index_t loop_cnt = k / Block_K; + + // this is the acc thread buffer + fp32x4_t v_acc[16]{.0f}; + + // B nr->kr +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Winline-asm" + // clang-format off + asm volatile( +#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16 +#include "uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc" +#undef CK_TILE_FLATMM_UK_MFMA + : [s_loop_cnt]"+s"(loop_cnt), + [v_acc_0]"+v"(v_acc[0]), + [v_acc_1]"+v"(v_acc[1]), + [v_acc_2]"+v"(v_acc[2]), + [v_acc_3]"+v"(v_acc[3]), + [v_acc_4]"+v"(v_acc[4]), + [v_acc_5]"+v"(v_acc[5]), + [v_acc_6]"+v"(v_acc[6]), + [v_acc_7]"+v"(v_acc[7]), + [v_acc_8]"+v"(v_acc[8]), + [v_acc_9]"+v"(v_acc[9]), + [v_acc_10]"+v"(v_acc[10]), + [v_acc_11]"+v"(v_acc[11]), + [v_acc_12]"+v"(v_acc[12]), + [v_acc_13]"+v"(v_acc[13]), + [v_acc_14]"+v"(v_acc[14]), + [v_acc_15]"+v"(v_acc[15]), + [s_mem_]"+r"(smem) + : [s_res_a0]"s"(res_a[0]), + [s_res_a1]"s"(res_a[1]), + [s_res_a2]"s"(res_a[2]), + [s_res_a3]"s"(res_a[3]), + [s_res_b0]"s"(res_b[0]), + [s_res_b1]"s"(res_b[1]), + [s_res_b2]"s"(res_b[2]), + [s_res_b3]"s"(res_b[3]), + [v_os_a0]"v"(static_cast(cached_coords_a[number<0>{}] * sizeof(ADataType))), + [v_os_a1]"v"(static_cast(cached_coords_a[number<1>{}] * sizeof(ADataType))), + [v_os_a2]"v"(static_cast(cached_coords_a[number<2>{}] * sizeof(ADataType))), + [v_os_a3]"v"(static_cast(cached_coords_a[number<3>{}] * sizeof(ADataType))), + [v_os_a4]"v"(static_cast(cached_coords_a[number<4>{}] * sizeof(ADataType))), + [v_os_a5]"v"(static_cast(cached_coords_a[number<5>{}] * sizeof(ADataType))), + [v_os_a6]"v"(static_cast(cached_coords_a[number<6>{}] * sizeof(ADataType))), + [v_os_a7]"v"(static_cast(cached_coords_a[number<7>{}] * sizeof(ADataType))), + + [v_os_b0]"v"(static_cast(cached_coords_b[number<0>{}] * sizeof(BDataType))), + [v_os_b1]"v"(static_cast(cached_coords_b[number<1>{}] * sizeof(BDataType))), + [v_os_b2]"v"(static_cast(cached_coords_b[number<2>{}] * sizeof(BDataType))), + [v_os_b3]"v"(static_cast(cached_coords_b[number<3>{}] * sizeof(BDataType))), + [v_os_b4]"v"(static_cast(cached_coords_b[number<4>{}] * sizeof(BDataType))), + [v_os_b5]"v"(static_cast(cached_coords_b[number<5>{}] * sizeof(BDataType))), + [v_os_b6]"v"(static_cast(cached_coords_b[number<6>{}] * sizeof(BDataType))), + [v_os_b7]"v"(static_cast(cached_coords_b[number<7>{}] * sizeof(BDataType))), + + [v_os_slda]"v"(static_cast(a_sld.cached_coords_[number<0>{}].get_offset() * sizeof(ADataType))), + [s_m0_init]"s"(m0_init_value), + [s_size_per_issue]"s"(size_per_issue), + [smem_sz]"n"(smem_buf_size), //(smem_buf_size), + [sld_os_0]"n"(sld_os[number<0>{}].value), + [sld_os_1]"n"(sld_os[number<1>{}].value), + [sld_os_2]"n"(sld_os[number<2>{}].value), + [sld_os_3]"n"(sld_os[number<3>{}].value), + [sld_os_4]"n"(sld_os[number<4>{}].value), + [sld_os_5]"n"(sld_os[number<5>{}].value), + [sld_os_6]"n"(sld_os[number<6>{}].value), + [sld_os_7]"n"(sld_os[number<7>{}].value), + [s_tile_os_a]"s"(tile_offset_a_bytes), + [s_tile_os_b]"s"(tile_offset_b_bytes) + : "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9", + "a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19", + "a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29", + "a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39", + "a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49", + "a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59", + "a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69", + "a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79", + "a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89", + "a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99", + "a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107", + "a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115", + "a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123", + "a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131", + "a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139", + "a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147", + "a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155", + "a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163", + "a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171", + "a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179", + "a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187", + "a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195", + "a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203", + "a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211", + "a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219", + "a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227", + "a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235", + "a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243", + "a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251", + "a252", "a253", "a254", "a255", + "s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23", + "s86", // s86 as tmp + "v64", "v65", "v66", "v67", "v68", "v69", + "v70", "v71", "v72", "v73", "v74", "v75", "v76", "v77", "v78", "v79", + "v80", "v81", "v82", "v83", "v84", "v85", "v86", "v87", "v88", "v89", + "v90", "v91", "v92", "v93", "v94", "v95", "v96", "v97", "v98", "v99", + "v100", "v101", "v102", "v103", "v104", "v105", "v106", "v107", + "v108", "v109", "v110", "v111", "v112", "v113", "v114", "v115", + "v116", "v117", "v118", "v119", "v120", "v121", "v122", "v123", + "v124", "v125", "v126", "v127" + ); + // clang-format on +#pragma clang diagnostic pop + + // return local scratch + auto c = MakeCBlockTile(); + for(auto i = 0; i < 16; i++) + { + c.get_thread_buffer()[4 * i + 0] = v_acc[i].x; + c.get_thread_buffer()[4 * i + 1] = v_acc[i].y; + c.get_thread_buffer()[4 * i + 2] = v_acc[i].z; + c.get_thread_buffer()[4 * i + 3] = v_acc[i].w; + } + return c; + } +}; + +struct Flatmm_32x512x128_1x4x1_16x16x32_FP16 : public Flatmm_32x512x128_1x4x1_16x16x32_Base +{ + using ADataType = fp16_t; + using BDataType = fp16_t; + + // TODO: need paired with tile_window_linear! + // TODO: need call init_raw() before call this function! + template + CK_TILE_DEVICE auto + operator()(const ARes& res_a, + const ACoords& cached_coords_a, + const BRes& res_b, + const BCoords& cached_coords_b, + CK_TILE_LDS_ADDR void* smem, + index_t k, + index_t tile_offset_a, // for each tile, the offset to move for each unroll + index_t tile_offset_b) // for each tile, the offset to move for each unroll + { + static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 2 /*2x per dword*/); // 8 + static_assert(BCoords::size() == Repeat_N); + + auto a_sst = make_tile_window( + make_tensor_view( + reinterpret_cast(smem), MakeLdsStoreDesc_A()), + MakeLdsStoreDesc_A().get_lengths(), + {0, 0, 0}); + + auto a_sld = [&]() { + constexpr auto a_warp_enc_ = GetGemm_AWarpEnc(); + constexpr auto a_outer_dstr_enc = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto a_block_dstr_encode = + detail::make_embed_tile_distribution_encoding(a_outer_dstr_enc, a_warp_enc_); + return make_tile_window_linear( + make_tensor_view( + reinterpret_cast(smem), MakeLdsLoadDesc_A()), + MakeLdsLoadDesc_A().get_lengths(), + {0, 0}, + make_static_tile_distribution(a_block_dstr_encode)); + }(); + + const index_t tile_offset_a_bytes = tile_offset_a * sizeof(ADataType); + const index_t tile_offset_b_bytes = tile_offset_b * sizeof(BDataType); + + const auto [m0_init_value, size_per_issue] = get_async_store_smem_info(a_sst); + constexpr auto smem_buf_size = + MakeLdsLoadDesc_A().get_element_space_size() * sizeof(ADataType); + static_assert(a_sld.get_num_of_access() == 8); + constexpr auto sld_os = generate_tuple( + [&](auto i_access) { + return number{}; + }, + number{}); + + index_t loop_cnt = k / Block_K; + + // this is the acc thread buffer + fp32x4_t v_acc[16]{.0f}; + + // B nr->kr +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Winline-asm" + // clang-format off + asm volatile( +#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16 +#include "uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc" +#undef CK_TILE_FLATMM_UK_MFMA + : [s_loop_cnt]"+s"(loop_cnt), + [v_acc_0]"+v"(v_acc[0]), + [v_acc_1]"+v"(v_acc[1]), + [v_acc_2]"+v"(v_acc[2]), + [v_acc_3]"+v"(v_acc[3]), + [v_acc_4]"+v"(v_acc[4]), + [v_acc_5]"+v"(v_acc[5]), + [v_acc_6]"+v"(v_acc[6]), + [v_acc_7]"+v"(v_acc[7]), + [v_acc_8]"+v"(v_acc[8]), + [v_acc_9]"+v"(v_acc[9]), + [v_acc_10]"+v"(v_acc[10]), + [v_acc_11]"+v"(v_acc[11]), + [v_acc_12]"+v"(v_acc[12]), + [v_acc_13]"+v"(v_acc[13]), + [v_acc_14]"+v"(v_acc[14]), + [v_acc_15]"+v"(v_acc[15]), + [s_mem_]"+r"(smem) + : [s_res_a0]"s"(res_a[0]), + [s_res_a1]"s"(res_a[1]), + [s_res_a2]"s"(res_a[2]), + [s_res_a3]"s"(res_a[3]), + [s_res_b0]"s"(res_b[0]), + [s_res_b1]"s"(res_b[1]), + [s_res_b2]"s"(res_b[2]), + [s_res_b3]"s"(res_b[3]), + [v_os_a0]"v"(static_cast(cached_coords_a[number<0>{}] * sizeof(ADataType))), + [v_os_a1]"v"(static_cast(cached_coords_a[number<1>{}] * sizeof(ADataType))), + [v_os_a2]"v"(static_cast(cached_coords_a[number<2>{}] * sizeof(ADataType))), + [v_os_a3]"v"(static_cast(cached_coords_a[number<3>{}] * sizeof(ADataType))), + [v_os_a4]"v"(static_cast(cached_coords_a[number<4>{}] * sizeof(ADataType))), + [v_os_a5]"v"(static_cast(cached_coords_a[number<5>{}] * sizeof(ADataType))), + [v_os_a6]"v"(static_cast(cached_coords_a[number<6>{}] * sizeof(ADataType))), + [v_os_a7]"v"(static_cast(cached_coords_a[number<7>{}] * sizeof(ADataType))), + + [v_os_b0]"v"(static_cast(cached_coords_b[number<0>{}] * sizeof(BDataType))), + [v_os_b1]"v"(static_cast(cached_coords_b[number<1>{}] * sizeof(BDataType))), + [v_os_b2]"v"(static_cast(cached_coords_b[number<2>{}] * sizeof(BDataType))), + [v_os_b3]"v"(static_cast(cached_coords_b[number<3>{}] * sizeof(BDataType))), + [v_os_b4]"v"(static_cast(cached_coords_b[number<4>{}] * sizeof(BDataType))), + [v_os_b5]"v"(static_cast(cached_coords_b[number<5>{}] * sizeof(BDataType))), + [v_os_b6]"v"(static_cast(cached_coords_b[number<6>{}] * sizeof(BDataType))), + [v_os_b7]"v"(static_cast(cached_coords_b[number<7>{}] * sizeof(BDataType))), + + [v_os_slda]"v"(static_cast(a_sld.cached_coords_[number<0>{}].get_offset() * sizeof(ADataType))), + [s_m0_init]"s"(m0_init_value), + [s_size_per_issue]"s"(size_per_issue), + [smem_sz]"n"(smem_buf_size), //(smem_buf_size), + [sld_os_0]"n"(sld_os[number<0>{}].value), + [sld_os_1]"n"(sld_os[number<1>{}].value), + [sld_os_2]"n"(sld_os[number<2>{}].value), + [sld_os_3]"n"(sld_os[number<3>{}].value), + [sld_os_4]"n"(sld_os[number<4>{}].value), + [sld_os_5]"n"(sld_os[number<5>{}].value), + [sld_os_6]"n"(sld_os[number<6>{}].value), + [sld_os_7]"n"(sld_os[number<7>{}].value), + [s_tile_os_a]"s"(tile_offset_a_bytes), + [s_tile_os_b]"s"(tile_offset_b_bytes) + : "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9", + "a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19", + "a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29", + "a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39", + "a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49", + "a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59", + "a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69", + "a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79", + "a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89", + "a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99", + "a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107", + "a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115", + "a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123", + "a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131", + "a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139", + "a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147", + "a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155", + "a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163", + "a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171", + "a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179", + "a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187", + "a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195", + "a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203", + "a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211", + "a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219", + "a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227", + "a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235", + "a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243", + "a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251", + "a252", "a253", "a254", "a255", + "s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23", + "s86", // s86 as tmp + "v64", "v65", "v66", "v67", "v68", "v69", + "v70", "v71", "v72", "v73", "v74", "v75", "v76", "v77", "v78", "v79", + "v80", "v81", "v82", "v83", "v84", "v85", "v86", "v87", "v88", "v89", + "v90", "v91", "v92", "v93", "v94", "v95", "v96", "v97", "v98", "v99", + "v100", "v101", "v102", "v103", "v104", "v105", "v106", "v107", + "v108", "v109", "v110", "v111", "v112", "v113", "v114", "v115", + "v116", "v117", "v118", "v119", "v120", "v121", "v122", "v123", + "v124", "v125", "v126", "v127" + ); + // clang-format on +#pragma clang diagnostic pop + + // return local scratch + auto c = MakeCBlockTile(); + for(auto i = 0; i < 16; i++) + { + c.get_thread_buffer()[4 * i + 0] = v_acc[i].x; + c.get_thread_buffer()[4 * i + 1] = v_acc[i].y; + c.get_thread_buffer()[4 * i + 2] = v_acc[i].z; + c.get_thread_buffer()[4 * i + 3] = v_acc[i].w; + } + return c; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp b/include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp new file mode 100644 index 0000000000..203c87b9c6 --- /dev/null +++ b/include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp @@ -0,0 +1,562 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp" + +namespace ck_tile { + +// "S"tream update output along "N" +// A in smem, B load from global +// require 4 wave, occupancy=1c +struct FlatmmSn_32x128x512_1x4x1_16x16x32_Base +{ + static constexpr index_t Block_M = 32; + static constexpr index_t Block_N = 128; + static constexpr index_t Block_K = 512; + + static constexpr index_t WarpPerBlock_M = 1; + static constexpr index_t WarpPerBlock_N = 4; + static constexpr index_t WarpPerBlock_K = 1; + + static constexpr index_t Warp_M = 16; + static constexpr index_t Warp_N = 16; + static constexpr index_t Warp_K = 32; + + static constexpr index_t BlockSize = 256; + + // static constexpr index_t KPack = 2; // this is used to gurantee every threads can do dwordx4 + + // TODO: note Nr/Kr/W need consider KPack + static constexpr index_t Block_W = Warp_N * Warp_K; // 512 element + static constexpr index_t Block_Nr = Block_N / Warp_N; // 32 element, 4 per wave + static constexpr index_t Block_Kr = Block_K / Warp_K; // 4 + + static constexpr index_t Repeat_M = Block_M / (Warp_M * WarpPerBlock_M); // 2 + static constexpr index_t Repeat_N = Block_N / (Warp_N * WarpPerBlock_N); // 2 + static constexpr index_t Repeat_K = Block_K / (Warp_K * WarpPerBlock_K); // 16 + + static CK_TILE_DEVICE constexpr auto MakeCBlockDist() + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2, 1>, // !! note here is different + sequence<0, 0>>{}; + + using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + return c_block_dstr; + } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + // y y p p p y + // reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4) + // but order is N0*M0*Nv + // in LDS we need store as + // M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4) + // y y wave-id lid/16 lid%16 v + return 2 * 2 * 4 * 4 * (16 * 4 + 4) * sizeof(bf16_t); + } +}; + +struct FlatmmSn_32x128x512_1x4x1_16x16x32_BF16 : public FlatmmSn_32x128x512_1x4x1_16x16x32_Base +{ + using BDataType = bf16_t; + using ODataType = bf16_t; + + // TODO: need paired with tile_window_linear! + // TODO: need call init_raw() before call this function! + // template + template + CK_TILE_DEVICE auto + operator()(const BRes& res_b, + const BCoords& cached_coords_b, + const ORes& res_o, + const OCoords& cached_coords_o, + const OFlags& o_flags, // this should be in sgpr + CK_TILE_LDS_ADDR void* smem, + index_t n, // loop along n dim + const ScaleTensor& scale_, + index_t tile_offset_b, // stride b is fixed to blockKr * blockW, but still can adjust + index_t tile_offset_o) + { + static_assert(BCoords::size() == 8); // 8 + static_assert(OCoords::size() == 8); + + const index_t tile_stride_b_bytes = tile_offset_b * sizeof(BDataType); + const index_t tile_stride_o_bytes = tile_offset_o * sizeof(ODataType); + + static_assert(ScaleTensor::size() == 2); + float s0 = scale_[number<0>{}]; + float s1 = scale_[number<1>{}]; + + index_t loop_cnt = n / Block_N; + + register float v_c0 asm("v64"); + register float v_c1 asm("v65"); + register float v_c2 asm("v66"); + register float v_c3 asm("v67"); + register float v_c4 asm("v68"); + register float v_c5 asm("v69"); + register float v_c6 asm("v70"); + register float v_c7 asm("v71"); + register float v_c8 asm("v72"); + register float v_c9 asm("v73"); + register float v_c10 asm("v74"); + register float v_c11 asm("v75"); + register float v_c12 asm("v76"); + register float v_c13 asm("v77"); + register float v_c14 asm("v78"); + register float v_c15 asm("v79"); + register float v_c16 asm("v80"); + register float v_c17 asm("v81"); + register float v_c18 asm("v82"); + register float v_c19 asm("v83"); + register float v_c20 asm("v84"); + register float v_c21 asm("v85"); + register float v_c22 asm("v86"); + register float v_c23 asm("v87"); + register float v_c24 asm("v88"); + register float v_c25 asm("v89"); + register float v_c26 asm("v90"); + register float v_c27 asm("v91"); + register float v_c28 asm("v92"); + register float v_c29 asm("v93"); + register float v_c30 asm("v94"); + register float v_c31 asm("v95"); + int32_t nan_hi = 0x7fff0000; + int32_t nan_lo = 0x00007fff; + + // in smem, the layout is M0(2)*K0(128)*M1(16)*K1(4) + // every threads need 8xK in contiguous register + // ... and every wave need the same data + int lane_id = threadIdx.x % 64; + int sld_y_os = (lane_id % 16) * 4 + (lane_id / 16) * 128; + sld_y_os *= 2; + + // y y p p p y + // reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4) + // but order is N0*M0*Nv + // in LDS we need store as + // M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4) + // y y wave-id lid/16 lid%16 v + // sst(v3) = (v0/16*34 + v0%16 * 2 + wid*136) * 4 + int sfl_sst = (threadIdx.x % 16 * 4) + (threadIdx.x / 16) * (64 + 4); + sfl_sst *= 2; + + // from LDS we need load as + // M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16) * Nv(4) + 4) + // ( 2 issue) (rem 32-lane) (4 wave*4issue) 2lane*1ussue(pk2) + // sld(v4) = v0/2 *34*4 + v0 % 2 *4 + wid*2 *4 + int sfl_sld = (lane_id % 2) * 2 + (lane_id / 2) * (64 + 4) + (threadIdx.x / 64) * 4; + sfl_sld *= 2; + + // B nr->kr + // clang-format off +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Winline-asm" + asm volatile( +#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16 +#include "uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc" +#undef CK_TILE_FLATMM_UK_MFMA + :[smem_]"+r"(smem), + [s_loop_cnt]"+s"(loop_cnt), + [c0]"+v" (v_c0), + [c1]"+v" (v_c1), + [c2]"+v" (v_c2), + [c3]"+v" (v_c3), + [c4]"+v" (v_c4), + [c5]"+v" (v_c5), + [c6]"+v" (v_c6), + [c7]"+v" (v_c7), + [c8]"+v" (v_c8), + [c9]"+v" (v_c9), + [c10]"+v"(v_c10), + [c11]"+v"(v_c11), + [c12]"+v"(v_c12), + [c13]"+v"(v_c13), + [c14]"+v"(v_c14), + [c15]"+v"(v_c15), + [c16]"+v"(v_c16), + [c17]"+v"(v_c17), + [c18]"+v"(v_c18), + [c19]"+v"(v_c19), + [c20]"+v"(v_c20), + [c21]"+v"(v_c21), + [c22]"+v"(v_c22), + [c23]"+v"(v_c23), + [c24]"+v"(v_c24), + [c25]"+v"(v_c25), + [c26]"+v"(v_c26), + [c27]"+v"(v_c27), + [c28]"+v"(v_c28), + [c29]"+v"(v_c29), + [c30]"+v"(v_c30), + [c31]"+v"(v_c31) + : + [sld_a_base]"n"(0), + [shfl_base]"n"(0), + [v_sld_y_os]"v"(sld_y_os), + [v_sfl_sld]"v"(sfl_sld), + [v_sfl_sst]"v"(sfl_sst), + [s_res_o0]"s"(res_o[0]), + [s_res_o1]"s"(res_o[1]), + //[s_res_o2]"s"(res_o[2]), + //[s_res_o3]"s"(res_o[3]), + [s_res_b0]"s"(res_b[0]), + [s_res_b1]"s"(res_b[1]), + [s_res_b2]"s"(res_b[2]), + [s_res_b3]"s"(res_b[3]), + [v_os_o0]"v"(static_cast(cached_coords_o[number<0>{}] * sizeof(ODataType))), + [v_os_o1]"v"(static_cast(cached_coords_o[number<1>{}] * sizeof(ODataType))), + [v_os_o2]"v"(static_cast(cached_coords_o[number<2>{}] * sizeof(ODataType))), + [v_os_o3]"v"(static_cast(cached_coords_o[number<3>{}] * sizeof(ODataType))), + [v_os_o4]"v"(static_cast(cached_coords_o[number<4>{}] * sizeof(ODataType))), + [v_os_o5]"v"(static_cast(cached_coords_o[number<5>{}] * sizeof(ODataType))), + [v_os_o6]"v"(static_cast(cached_coords_o[number<6>{}] * sizeof(ODataType))), + [v_os_o7]"v"(static_cast(cached_coords_o[number<7>{}] * sizeof(ODataType))), + [v_os_b0]"v"(static_cast(cached_coords_b[number<0>{}] * sizeof(BDataType))), + [v_os_b1]"v"(static_cast(cached_coords_b[number<1>{}] * sizeof(BDataType))), + [v_os_b2]"v"(static_cast(cached_coords_b[number<2>{}] * sizeof(BDataType))), + [v_os_b3]"v"(static_cast(cached_coords_b[number<3>{}] * sizeof(BDataType))), + [v_os_b4]"v"(static_cast(cached_coords_b[number<4>{}] * sizeof(BDataType))), + [v_os_b5]"v"(static_cast(cached_coords_b[number<5>{}] * sizeof(BDataType))), + [v_os_b6]"v"(static_cast(cached_coords_b[number<6>{}] * sizeof(BDataType))), + [v_os_b7]"v"(static_cast(cached_coords_b[number<7>{}] * sizeof(BDataType))), + + [s_tile_os_o]"s"(tile_stride_o_bytes), + [s_tile_os_b]"s"(tile_stride_b_bytes), + [scale_0]"v"(s0), + [scale_1]"v"(s1), + [v_nan_lo]"v"(nan_lo), + [v_nan_hi]"v"(nan_hi), + [s_execflag_0]"s"(o_flags[number<0>{}]), + [s_execflag_1]"s"(o_flags[number<1>{}]), + [s_execflag_2]"s"(o_flags[number<2>{}]), + [s_execflag_3]"s"(o_flags[number<3>{}]), + [s_execflag_4]"s"(o_flags[number<4>{}]), + [s_execflag_5]"s"(o_flags[number<5>{}]), + [s_execflag_6]"s"(o_flags[number<6>{}]), + [s_execflag_7]"s"(o_flags[number<7>{}]) + : + "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9", + "a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19", + "a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29", + "a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39", + "a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49", + "a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59", + "a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69", + "a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79", + "a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89", + "a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99", + "a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107", + "a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115", + "a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123", + "a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131", + "a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139", + "a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147", + "a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155", + "a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163", + "a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171", + "a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179", + "a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187", + "a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195", + "a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203", + "a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211", + "a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219", + "a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227", + "a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235", + "a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243", + "a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251", + "a252", "a253", "a254", "a255", + "s8", "s9", "s12", "s13", "s14", "s15", "s38", "s39", "s52", "s86", + "s36", "s37", + "v50", "v54", "v55", + "v64","v65","v66","v67","v68","v69","v70","v71", + "v72","v73","v74","v75","v76","v77","v78","v79", + "v80","v81","v82","v83","v84","v85","v86","v87", + "v88","v89","v90","v91","v92","v93","v94","v95", + "v128", "v129", "v130", "v131", + "v132", "v133", "v134", "v135", "v136", "v137", "v138", "v139", + "v140", "v141", "v142", "v143", "v144", "v145", "v146", "v147", + "v148", "v149", "v150", "v151", "v152", "v153", "v154", "v155", + "v156", "v157", "v158", "v159", "v160", "v161", "v162", "v163", + "v164", "v165", "v166", "v167", "v168", "v169", "v170", "v171", + "v172", "v173", "v174", "v175", "v176", "v177", "v178", "v179", + "v180", "v181", "v182", "v183", "v184", "v185", "v186", "v187", + "v188", "v189", "v190", "v191", "v192", "v193", "v194", "v195", + "v196", "v197", "v198", "v199", "v200", "v201", "v202", "v203", + "v204", "v205", "v206", "v207", "v208", "v209", "v210", "v211", + "v212", "v213", "v214", "v215", "v216", "v217", "v218", "v219", + "v220", "v221", "v222", "v223", "v224", "v225", "v226", "v227", + "v228", "v229", "v230", "v231", "v232", "v233", "v234", "v235", + "v236", "v237", "v238", "v239", "v240", "v241", "v242", "v243", + "v244", "v245", "v246", "v247", "v248", "v249", "v250", "v251", + "v252", "v253", "v254", "v255" + ); +#pragma clang diagnostic pop + // clang-format on + } +}; + +struct FlatmmSn_32x128x512_1x4x1_16x16x32_FP16 : public FlatmmSn_32x128x512_1x4x1_16x16x32_Base +{ + using BDataType = bf16_t; + using ODataType = bf16_t; + + // TODO: need paired with tile_window_linear! + // TODO: need call init_raw() before call this function! + // template + template + CK_TILE_DEVICE auto + operator()(const BRes& res_b, + const BCoords& cached_coords_b, + const ORes& res_o, + const OCoords& cached_coords_o, + const OFlags& o_flags, // this should be in sgpr + CK_TILE_LDS_ADDR void* smem, + index_t n, // loop along n dim + const ScaleTensor& scale_, + index_t tile_offset_b, // stride b is fixed to blockKr * blockW, but still can adjust + index_t tile_offset_o) + { + static_assert(BCoords::size() == 8); // 8 + static_assert(OCoords::size() == 8); + + const index_t tile_stride_b_bytes = tile_offset_b * sizeof(BDataType); + const index_t tile_stride_o_bytes = tile_offset_o * sizeof(ODataType); + + static_assert(ScaleTensor::size() == 2); + float s0 = scale_[number<0>{}]; + float s1 = scale_[number<1>{}]; + + index_t loop_cnt = n / Block_N; + + register float v_c0 asm("v64"); + register float v_c1 asm("v65"); + register float v_c2 asm("v66"); + register float v_c3 asm("v67"); + register float v_c4 asm("v68"); + register float v_c5 asm("v69"); + register float v_c6 asm("v70"); + register float v_c7 asm("v71"); + register float v_c8 asm("v72"); + register float v_c9 asm("v73"); + register float v_c10 asm("v74"); + register float v_c11 asm("v75"); + register float v_c12 asm("v76"); + register float v_c13 asm("v77"); + register float v_c14 asm("v78"); + register float v_c15 asm("v79"); + register float v_c16 asm("v80"); + register float v_c17 asm("v81"); + register float v_c18 asm("v82"); + register float v_c19 asm("v83"); + register float v_c20 asm("v84"); + register float v_c21 asm("v85"); + register float v_c22 asm("v86"); + register float v_c23 asm("v87"); + register float v_c24 asm("v88"); + register float v_c25 asm("v89"); + register float v_c26 asm("v90"); + register float v_c27 asm("v91"); + register float v_c28 asm("v92"); + register float v_c29 asm("v93"); + register float v_c30 asm("v94"); + register float v_c31 asm("v95"); + int32_t nan_hi = 0x7fff0000; + int32_t nan_lo = 0x00007fff; + + // in smem, the layout is M0(2)*K0(128)*M1(16)*K1(4) + // every threads need 8xK in contiguous register + // ... and every wave need the same data + int lane_id = threadIdx.x % 64; + int sld_y_os = (lane_id % 16) * 4 + (lane_id / 16) * 128; + sld_y_os *= 2; + + // y y p p p y + // reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4) + // but order is N0*M0*Nv + // in LDS we need store as + // M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4) + // y y wave-id lid/16 lid%16 v + // sst(v3) = (v0/16*34 + v0%16 * 2 + wid*136) * 4 + int sfl_sst = (threadIdx.x % 16 * 4) + (threadIdx.x / 16) * (64 + 4); + sfl_sst *= 2; + + // from LDS we need load as + // M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16) * Nv(4) + 4) + // ( 2 issue) (rem 32-lane) (4 wave*4issue) 2lane*1ussue(pk2) + // sld(v4) = v0/2 *34*4 + v0 % 2 *4 + wid*2 *4 + int sfl_sld = (lane_id % 2) * 2 + (lane_id / 2) * (64 + 4) + (threadIdx.x / 64) * 4; + sfl_sld *= 2; + + // B nr->kr + // clang-format off +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Winline-asm" + asm volatile( +#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16 +#include "uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc" +#undef CK_TILE_FLATMM_UK_MFMA + :[smem_]"+r"(smem), + [s_loop_cnt]"+s"(loop_cnt), + [c0]"+v" (v_c0), + [c1]"+v" (v_c1), + [c2]"+v" (v_c2), + [c3]"+v" (v_c3), + [c4]"+v" (v_c4), + [c5]"+v" (v_c5), + [c6]"+v" (v_c6), + [c7]"+v" (v_c7), + [c8]"+v" (v_c8), + [c9]"+v" (v_c9), + [c10]"+v"(v_c10), + [c11]"+v"(v_c11), + [c12]"+v"(v_c12), + [c13]"+v"(v_c13), + [c14]"+v"(v_c14), + [c15]"+v"(v_c15), + [c16]"+v"(v_c16), + [c17]"+v"(v_c17), + [c18]"+v"(v_c18), + [c19]"+v"(v_c19), + [c20]"+v"(v_c20), + [c21]"+v"(v_c21), + [c22]"+v"(v_c22), + [c23]"+v"(v_c23), + [c24]"+v"(v_c24), + [c25]"+v"(v_c25), + [c26]"+v"(v_c26), + [c27]"+v"(v_c27), + [c28]"+v"(v_c28), + [c29]"+v"(v_c29), + [c30]"+v"(v_c30), + [c31]"+v"(v_c31) + : + [sld_a_base]"n"(0), + [shfl_base]"n"(0), + [v_sld_y_os]"v"(sld_y_os), + [v_sfl_sld]"v"(sfl_sld), + [v_sfl_sst]"v"(sfl_sst), + [s_res_o0]"s"(res_o[0]), + [s_res_o1]"s"(res_o[1]), + //[s_res_o2]"s"(res_o[2]), + //[s_res_o3]"s"(res_o[3]), + [s_res_b0]"s"(res_b[0]), + [s_res_b1]"s"(res_b[1]), + [s_res_b2]"s"(res_b[2]), + [s_res_b3]"s"(res_b[3]), + [v_os_o0]"v"(static_cast(cached_coords_o[number<0>{}] * sizeof(ODataType))), + [v_os_o1]"v"(static_cast(cached_coords_o[number<1>{}] * sizeof(ODataType))), + [v_os_o2]"v"(static_cast(cached_coords_o[number<2>{}] * sizeof(ODataType))), + [v_os_o3]"v"(static_cast(cached_coords_o[number<3>{}] * sizeof(ODataType))), + [v_os_o4]"v"(static_cast(cached_coords_o[number<4>{}] * sizeof(ODataType))), + [v_os_o5]"v"(static_cast(cached_coords_o[number<5>{}] * sizeof(ODataType))), + [v_os_o6]"v"(static_cast(cached_coords_o[number<6>{}] * sizeof(ODataType))), + [v_os_o7]"v"(static_cast(cached_coords_o[number<7>{}] * sizeof(ODataType))), + [v_os_b0]"v"(static_cast(cached_coords_b[number<0>{}] * sizeof(BDataType))), + [v_os_b1]"v"(static_cast(cached_coords_b[number<1>{}] * sizeof(BDataType))), + [v_os_b2]"v"(static_cast(cached_coords_b[number<2>{}] * sizeof(BDataType))), + [v_os_b3]"v"(static_cast(cached_coords_b[number<3>{}] * sizeof(BDataType))), + [v_os_b4]"v"(static_cast(cached_coords_b[number<4>{}] * sizeof(BDataType))), + [v_os_b5]"v"(static_cast(cached_coords_b[number<5>{}] * sizeof(BDataType))), + [v_os_b6]"v"(static_cast(cached_coords_b[number<6>{}] * sizeof(BDataType))), + [v_os_b7]"v"(static_cast(cached_coords_b[number<7>{}] * sizeof(BDataType))), + + [s_tile_os_o]"s"(tile_stride_o_bytes), + [s_tile_os_b]"s"(tile_stride_b_bytes), + [scale_0]"v"(s0), + [scale_1]"v"(s1), + [v_nan_lo]"v"(nan_lo), + [v_nan_hi]"v"(nan_hi), + [s_execflag_0]"s"(o_flags[number<0>{}]), + [s_execflag_1]"s"(o_flags[number<1>{}]), + [s_execflag_2]"s"(o_flags[number<2>{}]), + [s_execflag_3]"s"(o_flags[number<3>{}]), + [s_execflag_4]"s"(o_flags[number<4>{}]), + [s_execflag_5]"s"(o_flags[number<5>{}]), + [s_execflag_6]"s"(o_flags[number<6>{}]), + [s_execflag_7]"s"(o_flags[number<7>{}]) + : + "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9", + "a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19", + "a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29", + "a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39", + "a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49", + "a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59", + "a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69", + "a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79", + "a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89", + "a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99", + "a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107", + "a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115", + "a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123", + "a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131", + "a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139", + "a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147", + "a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155", + "a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163", + "a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171", + "a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179", + "a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187", + "a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195", + "a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203", + "a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211", + "a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219", + "a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227", + "a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235", + "a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243", + "a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251", + "a252", "a253", "a254", "a255", + "s8", "s9", "s12", "s13", "s14", "s15", "s38", "s39", "s52", "s86", + "s36", "s37", + "v50", "v54", "v55", + "v64","v65","v66","v67","v68","v69","v70","v71", + "v72","v73","v74","v75","v76","v77","v78","v79", + "v80","v81","v82","v83","v84","v85","v86","v87", + "v88","v89","v90","v91","v92","v93","v94","v95", + "v128", "v129", "v130", "v131", + "v132", "v133", "v134", "v135", "v136", "v137", "v138", "v139", + "v140", "v141", "v142", "v143", "v144", "v145", "v146", "v147", + "v148", "v149", "v150", "v151", "v152", "v153", "v154", "v155", + "v156", "v157", "v158", "v159", "v160", "v161", "v162", "v163", + "v164", "v165", "v166", "v167", "v168", "v169", "v170", "v171", + "v172", "v173", "v174", "v175", "v176", "v177", "v178", "v179", + "v180", "v181", "v182", "v183", "v184", "v185", "v186", "v187", + "v188", "v189", "v190", "v191", "v192", "v193", "v194", "v195", + "v196", "v197", "v198", "v199", "v200", "v201", "v202", "v203", + "v204", "v205", "v206", "v207", "v208", "v209", "v210", "v211", + "v212", "v213", "v214", "v215", "v216", "v217", "v218", "v219", + "v220", "v221", "v222", "v223", "v224", "v225", "v226", "v227", + "v228", "v229", "v230", "v231", "v232", "v233", "v234", "v235", + "v236", "v237", "v238", "v239", "v240", "v241", "v242", "v243", + "v244", "v245", "v246", "v247", "v248", "v249", "v250", "v251", + "v252", "v253", "v254", "v255" + ); +#pragma clang diagnostic pop + // clang-format on + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/block/flatmm_uk_config.hpp b/include/ck_tile/ops/flatmm/block/flatmm_uk_config.hpp new file mode 100644 index 0000000000..003335c0e7 --- /dev/null +++ b/include/ck_tile/ops/flatmm/block/flatmm_uk_config.hpp @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#define CK_TILE_FLATMM_UK_MFMA_FP16 0 +#define CK_TILE_FLATMM_UK_MFMA_BF16 1 +#define CK_TILE_FLATMM_UK_MFMA_INT8 2 +#define CK_TILE_FLATMM_UK_MFMA_FP8 3 +#define CK_TILE_FLATMM_UK_MFMA_BF8 4 diff --git a/include/ck_tile/ops/flatmm/block/uk/README.md b/include/ck_tile/ops/flatmm/block/uk/README.md new file mode 100644 index 0000000000..84fa132296 --- /dev/null +++ b/include/ck_tile/ops/flatmm/block/uk/README.md @@ -0,0 +1 @@ +the files under this folder should not be included directly! \ No newline at end of file diff --git a/include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc b/include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc new file mode 100644 index 0000000000..8b57611f06 --- /dev/null +++ b/include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc @@ -0,0 +1,613 @@ +#ifndef CK_TILE_FLATMM_UK_MFMA +#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16 +#endif + +#if CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_BF16 +# define _UK_MFMA_ "v_mfma_f32_16x16x16_bf16" + +# define _UK_PK_CVT_(x0_, x1_, y_) \ + " v_cmp_u_f32 s[36:37], " x0_ ", " x0_ " \n" \ + " v_add3_u32 v50, " x0_ ", %[v_nan_lo], 1 \n" \ + " v_cndmask_b32 v54, v50, %[v_nan_hi], s[36:37] \n" \ + " v_cmp_u_f32 s[36:37], " x1_ ", " x1_ " \n" \ + " v_add3_u32 v50, " x1_ ", %[v_nan_lo], 1 \n" \ + " v_cndmask_b32 v55, v50, %[v_nan_hi], s[36:37] \n" \ + " v_perm_b32 " y_ ", v55, v54, s52 \n" + +# define _UK_ATOMIC_ADD_ "global_atomic_pk_add_bf16" + +#elif CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_FP16 +#define _UK_MFMA_ "v_mfma_f32_16x16x16_f16" + +# define _UK_PK_CVT_(x0_, x1_, y_) \ + " v_cvt_f16_f32 v54, " x0_ " \n" \ + " v_cvt_f16_f32 v55, " x1_ " \n" \ + " v_pack_b32_f16 " y_ ", v54, v55 \n" + +# define _UK_ATOMIC_ADD_ "global_atomic_pk_add_f16" + +#endif + + +";-------------------------------------------------------------\n" +" s_mov_b32 s52, 0x07060302 ; v_perm\n" +" s_mov_b64 s[38:39], exec ; save current exec\n" +" s_mov_b32 s8, %[s_res_o0] \n" +" s_mov_b32 s9, %[s_res_o1] \n" +" s_mov_b32 s12, %[s_res_b0] \n" +" s_mov_b32 s13, %[s_res_b1] \n" +" s_mov_b32 s14, %[s_res_b2] \n" +" s_mov_b32 s15, %[s_res_b3] \n" +" ds_read_b64 v[128:129], %[v_sld_y_os] offset:0 + %[sld_a_base] \n" +" ds_read_b64 v[130:131], %[v_sld_y_os] offset:128 + %[sld_a_base] \n" +" ds_read_b64 v[132:133], %[v_sld_y_os] offset:1024 + %[sld_a_base] \n" +" ds_read_b64 v[134:135], %[v_sld_y_os] offset:1152 + %[sld_a_base] \n" +" ds_read_b64 v[136:137], %[v_sld_y_os] offset:2048 + %[sld_a_base] \n" +" ds_read_b64 v[138:139], %[v_sld_y_os] offset:2176 + %[sld_a_base] \n" +" ds_read_b64 v[140:141], %[v_sld_y_os] offset:3072 + %[sld_a_base] \n" +" ds_read_b64 v[142:143], %[v_sld_y_os] offset:3200 + %[sld_a_base] \n" +" ds_read_b64 v[144:145], %[v_sld_y_os] offset:4096 + %[sld_a_base] \n" +" ds_read_b64 v[146:147], %[v_sld_y_os] offset:4224 + %[sld_a_base] \n" +" ds_read_b64 v[148:149], %[v_sld_y_os] offset:5120 + %[sld_a_base] \n" +" ds_read_b64 v[150:151], %[v_sld_y_os] offset:5248 + %[sld_a_base] \n" +" ds_read_b64 v[152:153], %[v_sld_y_os] offset:6144 + %[sld_a_base] \n" +" ds_read_b64 v[154:155], %[v_sld_y_os] offset:6272 + %[sld_a_base] \n" +" ds_read_b64 v[156:157], %[v_sld_y_os] offset:7168 + %[sld_a_base] \n" +" ds_read_b64 v[158:159], %[v_sld_y_os] offset:7296 + %[sld_a_base] \n" +" ds_read_b64 v[160:161], %[v_sld_y_os] offset:8192 + %[sld_a_base] \n" +" ds_read_b64 v[162:163], %[v_sld_y_os] offset:8320 + %[sld_a_base] \n" +" ds_read_b64 v[164:165], %[v_sld_y_os] offset:9216 + %[sld_a_base] \n" +" ds_read_b64 v[166:167], %[v_sld_y_os] offset:9344 + %[sld_a_base] \n" +" ds_read_b64 v[168:169], %[v_sld_y_os] offset:10240 + %[sld_a_base] \n" +" ds_read_b64 v[170:171], %[v_sld_y_os] offset:10368 + %[sld_a_base] \n" +" ds_read_b64 v[172:173], %[v_sld_y_os] offset:11264 + %[sld_a_base] \n" +" ds_read_b64 v[174:175], %[v_sld_y_os] offset:11392 + %[sld_a_base] \n" +" ds_read_b64 v[176:177], %[v_sld_y_os] offset:12288 + %[sld_a_base] \n" +" ds_read_b64 v[178:179], %[v_sld_y_os] offset:12416 + %[sld_a_base] \n" +" ds_read_b64 v[180:181], %[v_sld_y_os] offset:13312 + %[sld_a_base] \n" +" ds_read_b64 v[182:183], %[v_sld_y_os] offset:13440 + %[sld_a_base] \n" +" ds_read_b64 v[184:185], %[v_sld_y_os] offset:14336 + %[sld_a_base] \n" +" ds_read_b64 v[186:187], %[v_sld_y_os] offset:14464 + %[sld_a_base] \n" +" ds_read_b64 v[188:189], %[v_sld_y_os] offset:15360 + %[sld_a_base] \n" +" ds_read_b64 v[190:191], %[v_sld_y_os] offset:15488 + %[sld_a_base] \n" +" ds_read_b64 v[192:193], %[v_sld_y_os] offset:16384 + %[sld_a_base] \n" +" ds_read_b64 v[194:195], %[v_sld_y_os] offset:16512 + %[sld_a_base] \n" +" ds_read_b64 v[196:197], %[v_sld_y_os] offset:17408 + %[sld_a_base] \n" +" ds_read_b64 v[198:199], %[v_sld_y_os] offset:17536 + %[sld_a_base] \n" +" ds_read_b64 v[200:201], %[v_sld_y_os] offset:18432 + %[sld_a_base] \n" +" ds_read_b64 v[202:203], %[v_sld_y_os] offset:18560 + %[sld_a_base] \n" +" ds_read_b64 v[204:205], %[v_sld_y_os] offset:19456 + %[sld_a_base] \n" +" ds_read_b64 v[206:207], %[v_sld_y_os] offset:19584 + %[sld_a_base] \n" +" ds_read_b64 v[208:209], %[v_sld_y_os] offset:20480 + %[sld_a_base] \n" +" ds_read_b64 v[210:211], %[v_sld_y_os] offset:20608 + %[sld_a_base] \n" +" ds_read_b64 v[212:213], %[v_sld_y_os] offset:21504 + %[sld_a_base] \n" +" ds_read_b64 v[214:215], %[v_sld_y_os] offset:21632 + %[sld_a_base] \n" +" ds_read_b64 v[216:217], %[v_sld_y_os] offset:22528 + %[sld_a_base] \n" +" ds_read_b64 v[218:219], %[v_sld_y_os] offset:22656 + %[sld_a_base] \n" +" ds_read_b64 v[220:221], %[v_sld_y_os] offset:23552 + %[sld_a_base] \n" +" ds_read_b64 v[222:223], %[v_sld_y_os] offset:23680 + %[sld_a_base] \n" +" ds_read_b64 v[224:225], %[v_sld_y_os] offset:24576 + %[sld_a_base] \n" +" ds_read_b64 v[226:227], %[v_sld_y_os] offset:24704 + %[sld_a_base] \n" +" ds_read_b64 v[228:229], %[v_sld_y_os] offset:25600 + %[sld_a_base] \n" +" ds_read_b64 v[230:231], %[v_sld_y_os] offset:25728 + %[sld_a_base] \n" +" ds_read_b64 v[232:233], %[v_sld_y_os] offset:26624 + %[sld_a_base] \n" +" ds_read_b64 v[234:235], %[v_sld_y_os] offset:26752 + %[sld_a_base] \n" +" ds_read_b64 v[236:237], %[v_sld_y_os] offset:27648 + %[sld_a_base] \n" +" ds_read_b64 v[238:239], %[v_sld_y_os] offset:27776 + %[sld_a_base] \n" +" ds_read_b64 v[240:241], %[v_sld_y_os] offset:28672 + %[sld_a_base] \n" +" ds_read_b64 v[242:243], %[v_sld_y_os] offset:28800 + %[sld_a_base] \n" +" ds_read_b64 v[244:245], %[v_sld_y_os] offset:29696 + %[sld_a_base] \n" +" ds_read_b64 v[246:247], %[v_sld_y_os] offset:29824 + %[sld_a_base] \n" +" ds_read_b64 v[248:249], %[v_sld_y_os] offset:30720 + %[sld_a_base] \n" +" ds_read_b64 v[250:251], %[v_sld_y_os] offset:30848 + %[sld_a_base] \n" +" ds_read_b64 v[252:253], %[v_sld_y_os] offset:31744 + %[sld_a_base] \n" +" ds_read_b64 v[254:255], %[v_sld_y_os] offset:31872 + %[sld_a_base] \n" +" s_waitcnt 0 \n" +" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen \n" +" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024 \n" +" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen offset:2048 \n" +" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen offset:3072 \n" +" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen \n" +" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen offset:1024 \n" +" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen offset:2048 \n" +" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen offset:3072 \n" +" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen \n" +" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen offset:1024 \n" +" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen offset:2048 \n" +" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen offset:3072 \n" +" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen \n" +" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024 \n" +" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048 \n" +" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072 \n" +" buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[12:15], 0 offen \n" +" buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[12:15], 0 offen offset:1024 \n" +" buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[12:15], 0 offen offset:2048 \n" +" buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[12:15], 0 offen offset:3072 \n" +" buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[12:15], 0 offen \n" +" buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[12:15], 0 offen offset:1024 \n" +" buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[12:15], 0 offen offset:2048 \n" +" buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[12:15], 0 offen offset:3072 \n" +" buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[12:15], 0 offen \n" +" buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[12:15], 0 offen offset:1024 \n" +" buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[12:15], 0 offen offset:2048 \n" +" buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[12:15], 0 offen offset:3072 \n" +" buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[12:15], 0 offen \n" +" buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[12:15], 0 offen offset:1024 \n" +" buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[12:15], 0 offen offset:2048 \n" +" buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[12:15], 0 offen offset:3072 \n" +" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n" +" s_cselect_b32 s86, %[s_tile_os_b], 0 \n" +" s_add_u32 s12, s86, s12 \n" +" s_addc_u32 s13, 0, s13 \n" +" s_waitcnt 0 \n" +"L_start%=: \n" +" s_waitcnt vmcnt(32) \n" +" s_barrier \n" +_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[0:1], v[128:129], 0 \n" +" buffer_load_dwordx4 acc[128:131], %[v_os_b0], s[12:15], 0 offen \n" +_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[2:3], v[130:131], [%[c0], %[c1], %[c2], %[c3]] \n" +_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[4:5], v[132:133], [%[c0], %[c1], %[c2], %[c3]] \n" +_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[6:7], v[134:135], [%[c0], %[c1], %[c2], %[c3]] \n" +_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[8:9], v[136:137], [%[c0], %[c1], %[c2], %[c3]] \n" +" buffer_load_dwordx4 acc[132:135], %[v_os_b0], s[12:15], 0 offen offset:1024 \n" +_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[10:11], v[138:139], [%[c0], %[c1], %[c2], %[c3]] \n" +_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[12:13], v[140:141], [%[c0], %[c1], %[c2], %[c3]] \n" +_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[14:15], v[142:143], [%[c0], %[c1], %[c2], %[c3]] \n" +_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[0:1], v[192:193], 0 \n" +" buffer_load_dwordx4 acc[136:139], %[v_os_b0], s[12:15], 0 offen offset:2048 \n" +_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[2:3], v[194:195], [%[c4], %[c5], %[c6], %[c7]] \n" +_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[4:5], v[196:197], [%[c4], %[c5], %[c6], %[c7]] \n" +_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[6:7], v[198:199], [%[c4], %[c5], %[c6], %[c7]] \n" +_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[8:9], v[200:201], [%[c4], %[c5], %[c6], %[c7]] \n" +" buffer_load_dwordx4 acc[140:143], %[v_os_b0], s[12:15], 0 offen offset:3072 \n" +_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[10:11], v[202:203], [%[c4], %[c5], %[c6], %[c7]] \n" +_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[12:13], v[204:205], [%[c4], %[c5], %[c6], %[c7]] \n" +_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[14:15], v[206:207], [%[c4], %[c5], %[c6], %[c7]] \n" +_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[16:17], v[128:129], 0 \n" +" buffer_load_dwordx4 acc[144:147], %[v_os_b1], s[12:15], 0 offen \n" +_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[18:19], v[130:131], [%[c8], %[c9], %[c10], %[c11]] \n" +_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[20:21], v[132:133], [%[c8], %[c9], %[c10], %[c11]] \n" +_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[22:23], v[134:135], [%[c8], %[c9], %[c10], %[c11]] \n" +_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[24:25], v[136:137], [%[c8], %[c9], %[c10], %[c11]] \n" +" buffer_load_dwordx4 acc[148:151], %[v_os_b1], s[12:15], 0 offen offset:1024 \n" +_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[26:27], v[138:139], [%[c8], %[c9], %[c10], %[c11]] \n" +_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[28:29], v[140:141], [%[c8], %[c9], %[c10], %[c11]] \n" +_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[30:31], v[142:143], [%[c8], %[c9], %[c10], %[c11]] \n" +_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[16:17], v[192:193], 0 \n" +" buffer_load_dwordx4 acc[152:155], %[v_os_b1], s[12:15], 0 offen offset:2048 \n" +_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[18:19], v[194:195], [%[c12], %[c13], %[c14], %[c15]] \n" +_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[20:21], v[196:197], [%[c12], %[c13], %[c14], %[c15]] \n" +_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[22:23], v[198:199], [%[c12], %[c13], %[c14], %[c15]] \n" +_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[24:25], v[200:201], [%[c12], %[c13], %[c14], %[c15]] \n" +" buffer_load_dwordx4 acc[156:159], %[v_os_b1], s[12:15], 0 offen offset:3072 \n" +_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[26:27], v[202:203], [%[c12], %[c13], %[c14], %[c15]] \n" +_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[28:29], v[204:205], [%[c12], %[c13], %[c14], %[c15]] \n" +_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[30:31], v[206:207], [%[c12], %[c13], %[c14], %[c15]] \n" +" s_waitcnt vmcnt(32) \n" +_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[32:33], v[144:145], [%[c0], %[c1], %[c2], %[c3]] \n" +" buffer_load_dwordx4 acc[160:163], %[v_os_b2], s[12:15], 0 offen \n" +_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[34:35], v[146:147], [%[c0], %[c1], %[c2], %[c3]] \n" +_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[36:37], v[148:149], [%[c0], %[c1], %[c2], %[c3]] \n" +_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[38:39], v[150:151], [%[c0], %[c1], %[c2], %[c3]] \n" +_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[40:41], v[152:153], [%[c0], %[c1], %[c2], %[c3]] \n" +" buffer_load_dwordx4 acc[164:167], %[v_os_b2], s[12:15], 0 offen offset:1024 \n" +_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[42:43], v[154:155], [%[c0], %[c1], %[c2], %[c3]] \n" +_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[44:45], v[156:157], [%[c0], %[c1], %[c2], %[c3]] \n" +_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[46:47], v[158:159], [%[c0], %[c1], %[c2], %[c3]] \n" +_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[32:33], v[208:209], [%[c4], %[c5], %[c6], %[c7]] \n" +" buffer_load_dwordx4 acc[168:171], %[v_os_b2], s[12:15], 0 offen offset:2048 \n" +_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[34:35], v[210:211], [%[c4], %[c5], %[c6], %[c7]] \n" +_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[36:37], v[212:213], [%[c4], %[c5], %[c6], %[c7]] \n" +_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[38:39], v[214:215], [%[c4], %[c5], %[c6], %[c7]] \n" +_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[40:41], v[216:217], [%[c4], %[c5], %[c6], %[c7]] \n" +" buffer_load_dwordx4 acc[172:175], %[v_os_b2], s[12:15], 0 offen offset:3072 \n" +_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[42:43], v[218:219], [%[c4], %[c5], %[c6], %[c7]] \n" +_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[44:45], v[220:221], [%[c4], %[c5], %[c6], %[c7]] \n" +_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[46:47], v[222:223], [%[c4], %[c5], %[c6], %[c7]] \n" +_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[48:49], v[144:145], [%[c8], %[c9], %[c10], %[c11]] \n" +" buffer_load_dwordx4 acc[176:179], %[v_os_b3], s[12:15], 0 offen \n" +_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[50:51], v[146:147], [%[c8], %[c9], %[c10], %[c11]] \n" +_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[52:53], v[148:149], [%[c8], %[c9], %[c10], %[c11]] \n" +_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[54:55], v[150:151], [%[c8], %[c9], %[c10], %[c11]] \n" +_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[56:57], v[152:153], [%[c8], %[c9], %[c10], %[c11]] \n" +" buffer_load_dwordx4 acc[180:183], %[v_os_b3], s[12:15], 0 offen offset:1024 \n" +_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[58:59], v[154:155], [%[c8], %[c9], %[c10], %[c11]] \n" +_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[60:61], v[156:157], [%[c8], %[c9], %[c10], %[c11]] \n" +_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[62:63], v[158:159], [%[c8], %[c9], %[c10], %[c11]] \n" +_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[48:49], v[208:209], [%[c12], %[c13], %[c14], %[c15]] \n" +" buffer_load_dwordx4 acc[184:187], %[v_os_b3], s[12:15], 0 offen offset:2048 \n" +_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[50:51], v[210:211], [%[c12], %[c13], %[c14], %[c15]] \n" +_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[52:53], v[212:213], [%[c12], %[c13], %[c14], %[c15]] \n" +_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[54:55], v[214:215], [%[c12], %[c13], %[c14], %[c15]] \n" +_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[56:57], v[216:217], [%[c12], %[c13], %[c14], %[c15]] \n" +" buffer_load_dwordx4 acc[188:191], %[v_os_b3], s[12:15], 0 offen offset:3072 \n" +_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[58:59], v[218:219], [%[c12], %[c13], %[c14], %[c15]] \n" +_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[60:61], v[220:221], [%[c12], %[c13], %[c14], %[c15]] \n" +_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[62:63], v[222:223], [%[c12], %[c13], %[c14], %[c15]] \n" +" s_waitcnt vmcnt(32) \n" +_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[64:65], v[160:161], [%[c0], %[c1], %[c2], %[c3]] \n" +" buffer_load_dwordx4 acc[192:195], %[v_os_b4], s[12:15], 0 offen \n" +_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[66:67], v[162:163], [%[c0], %[c1], %[c2], %[c3]] \n" +_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[68:69], v[164:165], [%[c0], %[c1], %[c2], %[c3]] \n" +_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[70:71], v[166:167], [%[c0], %[c1], %[c2], %[c3]] \n" +_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[72:73], v[168:169], [%[c0], %[c1], %[c2], %[c3]] \n" +" buffer_load_dwordx4 acc[196:199], %[v_os_b4], s[12:15], 0 offen offset:1024 \n" +_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[74:75], v[170:171], [%[c0], %[c1], %[c2], %[c3]] \n" +_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[76:77], v[172:173], [%[c0], %[c1], %[c2], %[c3]] \n" +_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[78:79], v[174:175], [%[c0], %[c1], %[c2], %[c3]] \n" +_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[64:65], v[224:225], [%[c4], %[c5], %[c6], %[c7]] \n" +" buffer_load_dwordx4 acc[200:203], %[v_os_b4], s[12:15], 0 offen offset:2048 \n" +_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[66:67], v[226:227], [%[c4], %[c5], %[c6], %[c7]] \n" +_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[68:69], v[228:229], [%[c4], %[c5], %[c6], %[c7]] \n" +_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[70:71], v[230:231], [%[c4], %[c5], %[c6], %[c7]] \n" +_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[72:73], v[232:233], [%[c4], %[c5], %[c6], %[c7]] \n" +" buffer_load_dwordx4 acc[204:207], %[v_os_b4], s[12:15], 0 offen offset:3072 \n" +_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[74:75], v[234:235], [%[c4], %[c5], %[c6], %[c7]] \n" +_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[76:77], v[236:237], [%[c4], %[c5], %[c6], %[c7]] \n" +_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[78:79], v[238:239], [%[c4], %[c5], %[c6], %[c7]] \n" +_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[80:81], v[160:161], [%[c8], %[c9], %[c10], %[c11]] \n" +" buffer_load_dwordx4 acc[208:211], %[v_os_b5], s[12:15], 0 offen \n" +_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[82:83], v[162:163], [%[c8], %[c9], %[c10], %[c11]] \n" +_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[84:85], v[164:165], [%[c8], %[c9], %[c10], %[c11]] \n" +_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[86:87], v[166:167], [%[c8], %[c9], %[c10], %[c11]] \n" +_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[88:89], v[168:169], [%[c8], %[c9], %[c10], %[c11]] \n" +" buffer_load_dwordx4 acc[212:215], %[v_os_b5], s[12:15], 0 offen offset:1024 \n" +_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[90:91], v[170:171], [%[c8], %[c9], %[c10], %[c11]] \n" +_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[92:93], v[172:173], [%[c8], %[c9], %[c10], %[c11]] \n" +_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[94:95], v[174:175], [%[c8], %[c9], %[c10], %[c11]] \n" +_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[80:81], v[224:225], [%[c12], %[c13], %[c14], %[c15]] \n" +" buffer_load_dwordx4 acc[216:219], %[v_os_b5], s[12:15], 0 offen offset:2048 \n" +_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[82:83], v[226:227], [%[c12], %[c13], %[c14], %[c15]] \n" +_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[84:85], v[228:229], [%[c12], %[c13], %[c14], %[c15]] \n" +_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[86:87], v[230:231], [%[c12], %[c13], %[c14], %[c15]] \n" +_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[88:89], v[232:233], [%[c12], %[c13], %[c14], %[c15]] \n" +" buffer_load_dwordx4 acc[220:223], %[v_os_b5], s[12:15], 0 offen offset:3072 \n" +_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[90:91], v[234:235], [%[c12], %[c13], %[c14], %[c15]] \n" +_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[92:93], v[236:237], [%[c12], %[c13], %[c14], %[c15]] \n" +_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[94:95], v[238:239], [%[c12], %[c13], %[c14], %[c15]] \n" +" s_waitcnt vmcnt(32) \n" +_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[96:97], v[176:177], [%[c0], %[c1], %[c2], %[c3]] \n" +" buffer_load_dwordx4 acc[224:227], %[v_os_b6], s[12:15], 0 offen \n" +_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[98:99], v[178:179], [%[c0], %[c1], %[c2], %[c3]] \n" +_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[100:101], v[180:181], [%[c0], %[c1], %[c2], %[c3]] \n" +_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[102:103], v[182:183], [%[c0], %[c1], %[c2], %[c3]] \n" +_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[104:105], v[184:185], [%[c0], %[c1], %[c2], %[c3]] \n" +" buffer_load_dwordx4 acc[228:231], %[v_os_b6], s[12:15], 0 offen offset:1024 \n" +_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[106:107], v[186:187], [%[c0], %[c1], %[c2], %[c3]] \n" +_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[108:109], v[188:189], [%[c0], %[c1], %[c2], %[c3]] \n" +_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[110:111], v[190:191], [%[c0], %[c1], %[c2], %[c3]] \n" +_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[96:97], v[240:241], [%[c4], %[c5], %[c6], %[c7]] \n" +" buffer_load_dwordx4 acc[232:235], %[v_os_b6], s[12:15], 0 offen offset:2048 \n" +_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[98:99], v[242:243], [%[c4], %[c5], %[c6], %[c7]] \n" +_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[100:101], v[244:245], [%[c4], %[c5], %[c6], %[c7]] \n" +_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[102:103], v[246:247], [%[c4], %[c5], %[c6], %[c7]] \n" +_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[104:105], v[248:249], [%[c4], %[c5], %[c6], %[c7]] \n" +" buffer_load_dwordx4 acc[236:239], %[v_os_b6], s[12:15], 0 offen offset:3072 \n" +_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[106:107], v[250:251], [%[c4], %[c5], %[c6], %[c7]] \n" +_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[108:109], v[252:253], [%[c4], %[c5], %[c6], %[c7]] \n" +_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[110:111], v[254:255], [%[c4], %[c5], %[c6], %[c7]] \n" +_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[112:113], v[176:177], [%[c8], %[c9], %[c10], %[c11]] \n" +" buffer_load_dwordx4 acc[240:243], %[v_os_b7], s[12:15], 0 offen \n" +_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[114:115], v[178:179], [%[c8], %[c9], %[c10], %[c11]] \n" +_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[116:117], v[180:181], [%[c8], %[c9], %[c10], %[c11]] \n" +_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[118:119], v[182:183], [%[c8], %[c9], %[c10], %[c11]] \n" +_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[120:121], v[184:185], [%[c8], %[c9], %[c10], %[c11]] \n" +" buffer_load_dwordx4 acc[244:247], %[v_os_b7], s[12:15], 0 offen offset:1024 \n" +_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[122:123], v[186:187], [%[c8], %[c9], %[c10], %[c11]] \n" +_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[124:125], v[188:189], [%[c8], %[c9], %[c10], %[c11]] \n" +_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[126:127], v[190:191], [%[c8], %[c9], %[c10], %[c11]] \n" +_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[112:113], v[240:241], [%[c12], %[c13], %[c14], %[c15]] \n" +" buffer_load_dwordx4 acc[248:251], %[v_os_b7], s[12:15], 0 offen offset:2048 \n" +_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[114:115], v[242:243], [%[c12], %[c13], %[c14], %[c15]] \n" +_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[116:117], v[244:245], [%[c12], %[c13], %[c14], %[c15]] \n" +_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[118:119], v[246:247], [%[c12], %[c13], %[c14], %[c15]] \n" +_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[120:121], v[248:249], [%[c12], %[c13], %[c14], %[c15]] \n" +" buffer_load_dwordx4 acc[252:255], %[v_os_b7], s[12:15], 0 offen offset:3072 \n" +_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[122:123], v[250:251], [%[c12], %[c13], %[c14], %[c15]] \n" +_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[124:125], v[252:253], [%[c12], %[c13], %[c14], %[c15]] \n" +_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[126:127], v[254:255], [%[c12], %[c13], %[c14], %[c15]]\n" +" v_mul_f32 %[c0], %[scale_0], %[c0] \n" +" v_mul_f32 %[c1], %[scale_0], %[c1] \n" +" v_mul_f32 %[c2], %[scale_0], %[c2] \n" +" v_mul_f32 %[c3], %[scale_0], %[c3] \n" +" v_mul_f32 %[c4], %[scale_1], %[c4] \n" +" v_mul_f32 %[c5], %[scale_1], %[c5] \n" +" v_mul_f32 %[c6], %[scale_1], %[c6] \n" +" v_mul_f32 %[c7], %[scale_1], %[c7] \n" +" v_mul_f32 %[c8], %[scale_0], %[c8] \n" +" v_mul_f32 %[c9], %[scale_0], %[c9] \n" +" v_mul_f32 %[c10], %[scale_0], %[c10] \n" +" v_mul_f32 %[c11], %[scale_0], %[c11] \n" +" v_mul_f32 %[c12], %[scale_1], %[c12] \n" +" v_mul_f32 %[c13], %[scale_1], %[c13] \n" +" v_mul_f32 %[c14], %[scale_1], %[c14] \n" +" v_mul_f32 %[c15], %[scale_1], %[c15] \n" +_UK_PK_CVT_("%[c0]", "%[c1]", "%[c0]") +_UK_PK_CVT_("%[c2]", "%[c3]", "%[c1]") +_UK_PK_CVT_("%[c4]", "%[c5]", "%[c2]") +_UK_PK_CVT_("%[c6]", "%[c7]", "%[c3]") +_UK_PK_CVT_("%[c8]", "%[c9]", "%[c4]") +_UK_PK_CVT_("%[c10]", "%[c11]", "%[c5]") +_UK_PK_CVT_("%[c12]", "%[c13]", "%[c6]") +_UK_PK_CVT_("%[c14]", "%[c15]", "%[c7]") +" ;------------------------------ \n" +" ds_write_b64 %[v_sfl_sst], [%[c0],%[c1]] offset:0 + %[shfl_base] \n" +" ds_write_b64 %[v_sfl_sst], [%[c2],%[c3]] offset:4352 + %[shfl_base] \n" +" ds_write_b64 %[v_sfl_sst], [%[c4],%[c5]] offset:2176 + %[shfl_base] \n" +" ds_write_b64 %[v_sfl_sst], [%[c6],%[c7]] offset:6528 + %[shfl_base] \n" +" s_waitcnt lgkmcnt(0) \n" +" s_barrier \n" +" ds_read_b32 %[c0], %[v_sfl_sld] offset:0 + %[shfl_base] \n" +" ds_read_b32 %[c1], %[v_sfl_sld] offset:32 + %[shfl_base] \n" +" ds_read_b32 %[c2], %[v_sfl_sld] offset:64 + %[shfl_base] \n" +" ds_read_b32 %[c3], %[v_sfl_sld] offset:96 + %[shfl_base] \n" +" ds_read_b32 %[c4], %[v_sfl_sld] offset:4352 + %[shfl_base] \n" +" ds_read_b32 %[c5], %[v_sfl_sld] offset:4384 + %[shfl_base] \n" +" ds_read_b32 %[c6], %[v_sfl_sld] offset:4416 + %[shfl_base] \n" +" ds_read_b32 %[c7], %[v_sfl_sld] offset:4448 + %[shfl_base] \n" +" s_waitcnt lgkmcnt(0) \n" +" s_mov_b64 exec, %[s_execflag_0] \n" +_UK_ATOMIC_ADD_ " %[v_os_o0], %[c0], s[8:9] \n" +" s_mov_b64 exec, %[s_execflag_1] \n" +_UK_ATOMIC_ADD_ " %[v_os_o1], %[c1], s[8:9] \n" +" s_mov_b64 exec, %[s_execflag_2] \n" +_UK_ATOMIC_ADD_ " %[v_os_o2], %[c2], s[8:9] \n" +" s_mov_b64 exec, %[s_execflag_3] \n" +_UK_ATOMIC_ADD_ " %[v_os_o3], %[c3], s[8:9] \n" +" s_mov_b64 exec, %[s_execflag_4] \n" +_UK_ATOMIC_ADD_ " %[v_os_o4], %[c4], s[8:9] \n" +" s_mov_b64 exec, %[s_execflag_5] \n" +_UK_ATOMIC_ADD_ " %[v_os_o5], %[c5], s[8:9] \n" +" s_mov_b64 exec, %[s_execflag_6] \n" +_UK_ATOMIC_ADD_ " %[v_os_o6], %[c6], s[8:9] \n" +" s_mov_b64 exec, %[s_execflag_7] \n" +_UK_ATOMIC_ADD_ " %[v_os_o7], %[c7], s[8:9] \n" +" s_mov_b64 exec, s[38:39] \n" +" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 ; k-- \n" +" s_cmp_gt_i32 %[s_loop_cnt] 0 \n" +" s_cbranch_scc0 L_end%= \n" +" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n" +" s_cselect_b32 s86, %[s_tile_os_b], 0 \n" +" s_add_u32 s12, s86, s12 \n" +" s_addc_u32 s13, 0, s13 \n" +" s_add_u32 s8, %[s_tile_os_o], s8 \n" +" s_addc_u32 s9, 0, s9 \n" +" s_waitcnt vmcnt(32) \n" +" s_barrier \n" +_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[128:129], v[128:129], 0 \n" +" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen \n" +_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[130:131], v[130:131], [%[c16],%[c17],%[c18],%[c19]] \n" +_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[132:133], v[132:133], [%[c16],%[c17],%[c18],%[c19]] \n" +_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[134:135], v[134:135], [%[c16],%[c17],%[c18],%[c19]] \n" +_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[136:137], v[136:137], [%[c16],%[c17],%[c18],%[c19]] \n" +" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024 \n" +_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[138:139], v[138:139], [%[c16],%[c17],%[c18],%[c19]] \n" +_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[140:141], v[140:141], [%[c16],%[c17],%[c18],%[c19]] \n" +_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[142:143], v[142:143], [%[c16],%[c17],%[c18],%[c19]] \n" +_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[128:129], v[192:193], 0 \n" +" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen offset:2048 \n" +_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[130:131], v[194:195], [%[c20],%[c21],%[c22],%[c23]] \n" +_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[132:133], v[196:197], [%[c20],%[c21],%[c22],%[c23]] \n" +_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[134:135], v[198:199], [%[c20],%[c21],%[c22],%[c23]] \n" +_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[136:137], v[200:201], [%[c20],%[c21],%[c22],%[c23]] \n" +" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen offset:3072 \n" +_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[138:139], v[202:203], [%[c20],%[c21],%[c22],%[c23]] \n" +_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[140:141], v[204:205], [%[c20],%[c21],%[c22],%[c23]] \n" +_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[142:143], v[206:207], [%[c20],%[c21],%[c22],%[c23]] \n" +_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[144:145], v[128:129], 0 \n" +" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen \n" +_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[146:147], v[130:131], [%[c24],%[c25],%[c26],%[c27]] \n" +_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[148:149], v[132:133], [%[c24],%[c25],%[c26],%[c27]] \n" +_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[150:151], v[134:135], [%[c24],%[c25],%[c26],%[c27]] \n" +_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[152:153], v[136:137], [%[c24],%[c25],%[c26],%[c27]] \n" +" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen offset:1024 \n" +_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[154:155], v[138:139], [%[c24],%[c25],%[c26],%[c27]] \n" +_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[156:157], v[140:141], [%[c24],%[c25],%[c26],%[c27]] \n" +_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[158:159], v[142:143], [%[c24],%[c25],%[c26],%[c27]] \n" +_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[144:145], v[192:193], 0 \n" +" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen offset:2048 \n" +_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[146:147], v[194:195], [%[c28],%[c29],%[c30],%[c31]] \n" +_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[148:149], v[196:197], [%[c28],%[c29],%[c30],%[c31]] \n" +_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[150:151], v[198:199], [%[c28],%[c29],%[c30],%[c31]] \n" +_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[152:153], v[200:201], [%[c28],%[c29],%[c30],%[c31]] \n" +" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen offset:3072 \n" +_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[154:155], v[202:203], [%[c28],%[c29],%[c30],%[c31]] \n" +_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[156:157], v[204:205], [%[c28],%[c29],%[c30],%[c31]] \n" +_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[158:159], v[206:207], [%[c28],%[c29],%[c30],%[c31]] \n" +" s_waitcnt vmcnt(32) \n" +_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[160:161], v[144:145], [%[c16],%[c17],%[c18],%[c19]] \n" +" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen \n" +_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[162:163], v[146:147], [%[c16],%[c17],%[c18],%[c19]] \n" +_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[164:165], v[148:149], [%[c16],%[c17],%[c18],%[c19]] \n" +_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[166:167], v[150:151], [%[c16],%[c17],%[c18],%[c19]] \n" +_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[168:169], v[152:153], [%[c16],%[c17],%[c18],%[c19]] \n" +" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen offset:1024 \n" +_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[170:171], v[154:155], [%[c16],%[c17],%[c18],%[c19]] \n" +_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[172:173], v[156:157], [%[c16],%[c17],%[c18],%[c19]] \n" +_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[174:175], v[158:159], [%[c16],%[c17],%[c18],%[c19]] \n" +_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[160:161], v[208:209], [%[c20],%[c21],%[c22],%[c23]] \n" +" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen offset:2048 \n" +_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[162:163], v[210:211], [%[c20],%[c21],%[c22],%[c23]] \n" +_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[164:165], v[212:213], [%[c20],%[c21],%[c22],%[c23]] \n" +_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[166:167], v[214:215], [%[c20],%[c21],%[c22],%[c23]] \n" +_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[168:169], v[216:217], [%[c20],%[c21],%[c22],%[c23]] \n" +" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen offset:3072 \n" +_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[170:171], v[218:219], [%[c20],%[c21],%[c22],%[c23]] \n" +_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[172:173], v[220:221], [%[c20],%[c21],%[c22],%[c23]] \n" +_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[174:175], v[222:223], [%[c20],%[c21],%[c22],%[c23]] \n" +_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[176:177], v[144:145], [%[c24],%[c25],%[c26],%[c27]] \n" +" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen \n" +_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[178:179], v[146:147], [%[c24],%[c25],%[c26],%[c27]] \n" +_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[180:181], v[148:149], [%[c24],%[c25],%[c26],%[c27]] \n" +_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[182:183], v[150:151], [%[c24],%[c25],%[c26],%[c27]] \n" +_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[184:185], v[152:153], [%[c24],%[c25],%[c26],%[c27]] \n" +" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024 \n" +_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[186:187], v[154:155], [%[c24],%[c25],%[c26],%[c27]] \n" +_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[188:189], v[156:157], [%[c24],%[c25],%[c26],%[c27]] \n" +_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[190:191], v[158:159], [%[c24],%[c25],%[c26],%[c27]] \n" +_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[176:177], v[208:209], [%[c28],%[c29],%[c30],%[c31]] \n" +" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048 \n" +_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[178:179], v[210:211], [%[c28],%[c29],%[c30],%[c31]] \n" +_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[180:181], v[212:213], [%[c28],%[c29],%[c30],%[c31]] \n" +_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[182:183], v[214:215], [%[c28],%[c29],%[c30],%[c31]] \n" +_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[184:185], v[216:217], [%[c28],%[c29],%[c30],%[c31]] \n" +" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072 \n" +_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[186:187], v[218:219], [%[c28],%[c29],%[c30],%[c31]] \n" +_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[188:189], v[220:221], [%[c28],%[c29],%[c30],%[c31]] \n" +_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[190:191], v[222:223], [%[c28],%[c29],%[c30],%[c31]] \n" +" s_waitcnt vmcnt(32) \n" +_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[192:193], v[160:161], [%[c16],%[c17],%[c18],%[c19]] \n" +" buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[12:15], 0 offen \n" +_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[194:195], v[162:163], [%[c16],%[c17],%[c18],%[c19]] \n" +_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[196:197], v[164:165], [%[c16],%[c17],%[c18],%[c19]] \n" +_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[198:199], v[166:167], [%[c16],%[c17],%[c18],%[c19]] \n" +_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[200:201], v[168:169], [%[c16],%[c17],%[c18],%[c19]] \n" +" buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[12:15], 0 offen offset:1024 \n" +_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[202:203], v[170:171], [%[c16],%[c17],%[c18],%[c19]] \n" +_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[204:205], v[172:173], [%[c16],%[c17],%[c18],%[c19]] \n" +_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[206:207], v[174:175], [%[c16],%[c17],%[c18],%[c19]] \n" +_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[192:193], v[224:225], [%[c20],%[c21],%[c22],%[c23]] \n" +" buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[12:15], 0 offen offset:2048 \n" +_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[194:195], v[226:227], [%[c20],%[c21],%[c22],%[c23]] \n" +_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[196:197], v[228:229], [%[c20],%[c21],%[c22],%[c23]] \n" +_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[198:199], v[230:231], [%[c20],%[c21],%[c22],%[c23]] \n" +_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[200:201], v[232:233], [%[c20],%[c21],%[c22],%[c23]] \n" +" buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[12:15], 0 offen offset:3072 \n" +_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[202:203], v[234:235], [%[c20],%[c21],%[c22],%[c23]] \n" +_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[204:205], v[236:237], [%[c20],%[c21],%[c22],%[c23]] \n" +_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[206:207], v[238:239], [%[c20],%[c21],%[c22],%[c23]] \n" +_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[208:209], v[160:161], [%[c24],%[c25],%[c26],%[c27]] \n" +" buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[12:15], 0 offen \n" +_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[210:211], v[162:163], [%[c24],%[c25],%[c26],%[c27]] \n" +_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[212:213], v[164:165], [%[c24],%[c25],%[c26],%[c27]] \n" +_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[214:215], v[166:167], [%[c24],%[c25],%[c26],%[c27]] \n" +_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[216:217], v[168:169], [%[c24],%[c25],%[c26],%[c27]] \n" +" buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[12:15], 0 offen offset:1024 \n" +_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[218:219], v[170:171], [%[c24],%[c25],%[c26],%[c27]] \n" +_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[220:221], v[172:173], [%[c24],%[c25],%[c26],%[c27]] \n" +_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[222:223], v[174:175], [%[c24],%[c25],%[c26],%[c27]] \n" +_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[208:209], v[224:225], [%[c28],%[c29],%[c30],%[c31]] \n" +" buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[12:15], 0 offen offset:2048 \n" +_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[210:211], v[226:227], [%[c28],%[c29],%[c30],%[c31]] \n" +_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[212:213], v[228:229], [%[c28],%[c29],%[c30],%[c31]] \n" +_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[214:215], v[230:231], [%[c28],%[c29],%[c30],%[c31]] \n" +_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[216:217], v[232:233], [%[c28],%[c29],%[c30],%[c31]] \n" +" buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[12:15], 0 offen offset:3072 \n" +_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[218:219], v[234:235], [%[c28],%[c29],%[c30],%[c31]] \n" +_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[220:221], v[236:237], [%[c28],%[c29],%[c30],%[c31]] \n" +_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[222:223], v[238:239], [%[c28],%[c29],%[c30],%[c31]] \n" +" s_waitcnt vmcnt(32) \n" +_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[224:225], v[176:177], [%[c16],%[c17],%[c18],%[c19]] \n" +" buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[12:15], 0 offen \n" +_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[226:227], v[178:179], [%[c16],%[c17],%[c18],%[c19]] \n" +_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[228:229], v[180:181], [%[c16],%[c17],%[c18],%[c19]] \n" +_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[230:231], v[182:183], [%[c16],%[c17],%[c18],%[c19]] \n" +_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[232:233], v[184:185], [%[c16],%[c17],%[c18],%[c19]] \n" +" buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[12:15], 0 offen offset:1024 \n" +_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[234:235], v[186:187], [%[c16],%[c17],%[c18],%[c19]] \n" +_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[236:237], v[188:189], [%[c16],%[c17],%[c18],%[c19]] \n" +_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[238:239], v[190:191], [%[c16],%[c17],%[c18],%[c19]] \n" +_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[224:225], v[240:241], [%[c20],%[c21],%[c22],%[c23]] \n" +" buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[12:15], 0 offen offset:2048 \n" +_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[226:227], v[242:243], [%[c20],%[c21],%[c22],%[c23]] \n" +_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[228:229], v[244:245], [%[c20],%[c21],%[c22],%[c23]] \n" +_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[230:231], v[246:247], [%[c20],%[c21],%[c22],%[c23]] \n" +_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[232:233], v[248:249], [%[c20],%[c21],%[c22],%[c23]] \n" +" buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[12:15], 0 offen offset:3072 \n" +_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[234:235], v[250:251], [%[c20],%[c21],%[c22],%[c23]] \n" +_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[236:237], v[252:253], [%[c20],%[c21],%[c22],%[c23]] \n" +_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[238:239], v[254:255], [%[c20],%[c21],%[c22],%[c23]] \n" +_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[240:241], v[176:177], [%[c24],%[c25],%[c26],%[c27]] \n" +" buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[12:15], 0 offen \n" +_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[242:243], v[178:179], [%[c24],%[c25],%[c26],%[c27]] \n" +_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[244:245], v[180:181], [%[c24],%[c25],%[c26],%[c27]] \n" +_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[246:247], v[182:183], [%[c24],%[c25],%[c26],%[c27]] \n" +_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[248:249], v[184:185], [%[c24],%[c25],%[c26],%[c27]] \n" +" buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[12:15], 0 offen offset:1024 \n" +_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[250:251], v[186:187], [%[c24],%[c25],%[c26],%[c27]] \n" +_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[252:253], v[188:189], [%[c24],%[c25],%[c26],%[c27]] \n" +_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[254:255], v[190:191], [%[c24],%[c25],%[c26],%[c27]] \n" +_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[240:241], v[240:241], [%[c28],%[c29],%[c30],%[c31]] \n" +" buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[12:15], 0 offen offset:2048 \n" +_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[242:243], v[242:243], [%[c28],%[c29],%[c30],%[c31]] \n" +_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[244:245], v[244:245], [%[c28],%[c29],%[c30],%[c31]] \n" +_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[246:247], v[246:247], [%[c28],%[c29],%[c30],%[c31]] \n" +_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[248:249], v[248:249], [%[c28],%[c29],%[c30],%[c31]] \n" +" buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[12:15], 0 offen offset:3072 \n" +_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[250:251], v[250:251], [%[c28],%[c29],%[c30],%[c31]] \n" +_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[252:253], v[252:253], [%[c28],%[c29],%[c30],%[c31]] \n" +_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[254:255], v[254:255], [%[c28],%[c29],%[c30],%[c31]]\n" +" v_mul_f32 %[c16], %[scale_0], %[c16] \n" +" v_mul_f32 %[c17], %[scale_0], %[c17] \n" +" v_mul_f32 %[c18], %[scale_0], %[c18] \n" +" v_mul_f32 %[c19], %[scale_0], %[c19] \n" +" v_mul_f32 %[c20], %[scale_1], %[c20] \n" +" v_mul_f32 %[c21], %[scale_1], %[c21] \n" +" v_mul_f32 %[c22], %[scale_1], %[c22] \n" +" v_mul_f32 %[c23], %[scale_1], %[c23] \n" +" v_mul_f32 %[c24], %[scale_0], %[c24] \n" +" v_mul_f32 %[c25], %[scale_0], %[c25] \n" +" v_mul_f32 %[c26], %[scale_0], %[c26] \n" +" v_mul_f32 %[c27], %[scale_0], %[c27] \n" +" v_mul_f32 %[c28], %[scale_1], %[c28] \n" +" v_mul_f32 %[c29], %[scale_1], %[c29] \n" +" v_mul_f32 %[c30], %[scale_1], %[c30] \n" +" v_mul_f32 %[c31], %[scale_1], %[c31] \n" + +_UK_PK_CVT_("%[c16]", "%[c17]", "%[c16]") +_UK_PK_CVT_("%[c18]", "%[c19]", "%[c17]") +_UK_PK_CVT_("%[c20]", "%[c21]", "%[c18]") +_UK_PK_CVT_("%[c22]", "%[c23]", "%[c19]") +_UK_PK_CVT_("%[c24]", "%[c25]", "%[c20]") +_UK_PK_CVT_("%[c26]", "%[c27]", "%[c21]") +_UK_PK_CVT_("%[c28]", "%[c29]", "%[c22]") +_UK_PK_CVT_("%[c30]", "%[c31]", "%[c23]") + +" ;------------------------------ \n" +" ds_write_b64 %[v_sfl_sst], [%[c16],%[c17]] offset:0 + %[shfl_base] \n" +" ds_write_b64 %[v_sfl_sst], [%[c18],%[c19]] offset:4352 + %[shfl_base] \n" +" ds_write_b64 %[v_sfl_sst], [%[c20],%[c21]] offset:2176 + %[shfl_base] \n" +" ds_write_b64 %[v_sfl_sst], [%[c22],%[c23]] offset:6528 + %[shfl_base] \n" +" s_waitcnt lgkmcnt(0) \n" +" s_barrier \n" +" ds_read_b32 %[c16], %[v_sfl_sld] offset:0 + %[shfl_base] \n" +" ds_read_b32 %[c17], %[v_sfl_sld] offset:32 + %[shfl_base] \n" +" ds_read_b32 %[c18], %[v_sfl_sld] offset:64 + %[shfl_base] \n" +" ds_read_b32 %[c19], %[v_sfl_sld] offset:96 + %[shfl_base] \n" +" ds_read_b32 %[c20], %[v_sfl_sld] offset:4352 + %[shfl_base] \n" +" ds_read_b32 %[c21], %[v_sfl_sld] offset:4384 + %[shfl_base] \n" +" ds_read_b32 %[c22], %[v_sfl_sld] offset:4416 + %[shfl_base] \n" +" ds_read_b32 %[c23], %[v_sfl_sld] offset:4448 + %[shfl_base] \n" +" s_waitcnt lgkmcnt(0) \n" +" s_mov_b64 exec, %[s_execflag_0] \n" +_UK_ATOMIC_ADD_ " %[v_os_o0], %[c16], s[8:9] \n" +" s_mov_b64 exec, %[s_execflag_1] \n" +_UK_ATOMIC_ADD_ " %[v_os_o1], %[c17], s[8:9] \n" +" s_mov_b64 exec, %[s_execflag_2] \n" +_UK_ATOMIC_ADD_ " %[v_os_o2], %[c18], s[8:9] \n" +" s_mov_b64 exec, %[s_execflag_3] \n" +_UK_ATOMIC_ADD_ " %[v_os_o3], %[c19], s[8:9] \n" +" s_mov_b64 exec, %[s_execflag_4] \n" +_UK_ATOMIC_ADD_ " %[v_os_o4], %[c20], s[8:9] \n" +" s_mov_b64 exec, %[s_execflag_5] \n" +_UK_ATOMIC_ADD_ " %[v_os_o5], %[c21], s[8:9] \n" +" s_mov_b64 exec, %[s_execflag_6] \n" +_UK_ATOMIC_ADD_ " %[v_os_o6], %[c22], s[8:9] \n" +" s_mov_b64 exec, %[s_execflag_7] \n" +_UK_ATOMIC_ADD_ " %[v_os_o7], %[c23], s[8:9] \n" +" s_mov_b64 exec, s[38:39] \n" +" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 ; k-- \n" +" s_cmp_gt_i32 %[s_loop_cnt] 0 \n" +" s_cbranch_scc0 L_end%= \n" +" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n" +" s_cselect_b32 s86, %[s_tile_os_b], 0 \n" +" s_add_u32 s12, s86, s12 \n" +" s_addc_u32 s13, 0, s13 \n" +" s_add_u32 s8, %[s_tile_os_o], s8 \n" +" s_addc_u32 s9, 0, s9 \n" +" s_branch L_start%= \n" +"L_end%=: \n" + +#undef _UK_MFMA_ +#undef _UK_PK_CVT_ +#undef _UK_ATOMIC_ADD_ diff --git a/include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc b/include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc new file mode 100644 index 0000000000..a34a21d39f --- /dev/null +++ b/include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc @@ -0,0 +1,516 @@ +#ifndef CK_TILE_FLATMM_UK_MFMA +#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16 +#endif + +#if CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_BF16 +#define _UK_MFMA_ "v_mfma_f32_16x16x16_bf16" +#elif CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_FP16 +#define _UK_MFMA_ "v_mfma_f32_16x16x16_f16" +#endif + +"s_mov_b32 s16, %[s_res_a0] \n" +"s_mov_b32 s17, %[s_res_a1] \n" +"s_mov_b32 s18, %[s_res_a2] \n" +"s_mov_b32 s19, %[s_res_a3] \n" +"s_mov_b32 s20, %[s_res_b0] \n" +"s_mov_b32 s21, %[s_res_b1] \n" +"s_mov_b32 s22, %[s_res_b2] \n" +"s_mov_b32 s23, %[s_res_b3] \n" +// "s_nop 4\n" +"; -- prefetch A0\n" +"s_add_u32 m0, 0, %[s_m0_init] \n" +"buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n" +"s_add_u32 m0, %[s_size_per_issue], m0 \n" +"buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n" +"s_add_u32 m0, %[s_size_per_issue], m0 \n" +"buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n" +"s_add_u32 m0, %[s_size_per_issue], m0 \n" +"buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n" +"s_add_u32 m0, %[s_size_per_issue], m0 \n" +"buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n" +"s_add_u32 m0, %[s_size_per_issue], m0 \n" +"buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n" +"s_add_u32 m0, %[s_size_per_issue], m0 \n" +"buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n" +"s_add_u32 m0, %[s_size_per_issue], m0 \n" +"buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n" +"s_add_u32 m0, %[smem_sz], %[s_m0_init] \n" +"s_cmp_gt_i32 %[s_loop_cnt] 1 ; move a with cond \n" +"s_cselect_b32 s86, %[s_tile_os_a], 0 ; move a with cond \n" +"s_add_u32 s16, s86, s16 ; move a with cond \n" +"s_addc_u32 s17, 0, s17 ; move a with cond \n" +"; -- prefetch A1\n" +"buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n" +"s_add_u32 m0, %[s_size_per_issue], m0 \n" +"buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n" +"s_add_u32 m0, %[s_size_per_issue], m0 \n" +"buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n" +"s_add_u32 m0, %[s_size_per_issue], m0 \n" +"buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n" +"s_add_u32 m0, %[s_size_per_issue], m0 \n" +"buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n" +"s_add_u32 m0, %[s_size_per_issue], m0 \n" +"buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n" +"s_add_u32 m0, %[s_size_per_issue], m0 \n" +"buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n" +"s_add_u32 m0, %[s_size_per_issue], m0 \n" +"buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n" +"s_add_u32 m0, 0, %[s_m0_init] \n" +"s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n" +"s_cselect_b32 s86, %[s_tile_os_a], 0 ; move a with cond \n" +"s_add_u32 s16, s86, s16 ; move a with cond \n" +"s_addc_u32 s17, 0, s17 ; move a with cond \n" +"; -- prefetch B0\n" +"buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[20:23], 0 offen \n" +"buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[20:23], 0 offen offset:1024 \n" +"buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[20:23], 0 offen offset:2048 \n" +"buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[20:23], 0 offen offset:3072 \n" +"buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[20:23], 0 offen \n" +"buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[20:23], 0 offen offset:1024 \n" +"buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[20:23], 0 offen offset:2048 \n" +"buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[20:23], 0 offen offset:3072 \n" +"buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[20:23], 0 offen \n" +"buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[20:23], 0 offen offset:1024 \n" +"buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[20:23], 0 offen offset:2048 \n" +"buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[20:23], 0 offen offset:3072 \n" +"buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[20:23], 0 offen \n" +"buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[20:23], 0 offen offset:1024 \n" +"buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[20:23], 0 offen offset:2048 \n" +"buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[20:23], 0 offen offset:3072 \n" +"buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[20:23], 0 offen \n" +"buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[20:23], 0 offen offset:1024 \n" +"buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[20:23], 0 offen offset:2048 \n" +"buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[20:23], 0 offen offset:3072 \n" +"buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[20:23], 0 offen \n" +"buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[20:23], 0 offen offset:1024 \n" +"buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[20:23], 0 offen offset:2048 \n" +"buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[20:23], 0 offen offset:3072 \n" +"buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[20:23], 0 offen \n" +"buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[20:23], 0 offen offset:1024 \n" +"buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[20:23], 0 offen offset:2048 \n" +"buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[20:23], 0 offen offset:3072 \n" +"buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[20:23], 0 offen \n" +"buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[20:23], 0 offen offset:1024 \n" +"buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[20:23], 0 offen offset:2048 \n" +"buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[20:23], 0 offen offset:3072 \n" +"s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n" +"s_cselect_b32 s86, %[s_tile_os_b], 0 ; move b with cond \n" +"s_add_u32 s20, s86, s20 ; move b with cond \n" +"s_addc_u32 s21, 0, s21 ; move b with cond \n" +"s_waitcnt vmcnt(40) \n" +"s_barrier \n" +"ds_read_b128 v[64:67], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_0]\n" // 1024: N stride, 64 K stride +"ds_read_b128 v[68:71], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_1]\n" +"ds_read_b128 v[72:75], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_2]\n" +"ds_read_b128 v[76:79], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_3]\n" +"ds_read_b128 v[80:83], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_4]\n" +"ds_read_b128 v[84:87], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_5]\n" +"ds_read_b128 v[88:91], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_6]\n" +"ds_read_b128 v[92:95], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_7]\n" +"L_start%=: \n" +" s_waitcnt vmcnt(24) & lgkmcnt(0) \n" +" s_barrier \n" +_UK_MFMA_ " %[v_acc_0], acc[0:1], v[64:65], %[v_acc_0] \n" +_UK_MFMA_ " %[v_acc_0], acc[2:3], v[66:67], %[v_acc_0] \n" +" buffer_load_dwordx4 acc[128:131], %[v_os_b0], s[20:23], 0 offen \n" +_UK_MFMA_ " %[v_acc_0], acc[4:5], v[68:69], %[v_acc_0] \n" +_UK_MFMA_ " %[v_acc_0], acc[6:7], v[70:71], %[v_acc_0] \n" +" buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n" +" s_add_u32 m0, %[s_size_per_issue], m0 \n" +_UK_MFMA_ " %[v_acc_0], acc[8:9], v[72:73], %[v_acc_0] \n" +_UK_MFMA_ " %[v_acc_0], acc[10:11], v[74:75], %[v_acc_0] \n" +" buffer_load_dwordx4 acc[132:135], %[v_os_b0], s[20:23], 0 offen offset:1024 \n" +_UK_MFMA_ " %[v_acc_0], acc[12:13], v[76:77], %[v_acc_0] \n" +_UK_MFMA_ " %[v_acc_0], acc[14:15], v[78:79], %[v_acc_0] \n" +" buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n" +" s_add_u32 m0, %[s_size_per_issue], m0 \n" +_UK_MFMA_ " %[v_acc_1], acc[0:1], v[80:81], %[v_acc_1] \n" +_UK_MFMA_ " %[v_acc_1], acc[2:3], v[82:83], %[v_acc_1] \n" +" buffer_load_dwordx4 acc[136:139], %[v_os_b0], s[20:23], 0 offen offset:2048 \n" +_UK_MFMA_ " %[v_acc_1], acc[4:5], v[84:85], %[v_acc_1] \n" +_UK_MFMA_ " %[v_acc_1], acc[6:7], v[86:87], %[v_acc_1] \n" +" buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n" +" s_add_u32 m0, %[s_size_per_issue], m0 \n" +_UK_MFMA_ " %[v_acc_1], acc[8:9], v[88:89], %[v_acc_1] \n" +_UK_MFMA_ " %[v_acc_1], acc[10:11], v[90:91], %[v_acc_1] \n" +" buffer_load_dwordx4 acc[140:143], %[v_os_b0], s[20:23], 0 offen offset:3072 \n" +_UK_MFMA_ " %[v_acc_1], acc[12:13], v[92:93], %[v_acc_1] \n" +_UK_MFMA_ " %[v_acc_1], acc[14:15], v[94:95], %[v_acc_1] \n" +" buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n" +" s_add_u32 m0, %[s_size_per_issue], m0 \n" +_UK_MFMA_ " %[v_acc_2], acc[16:17], v[64:65], %[v_acc_2] \n" +_UK_MFMA_ " %[v_acc_2], acc[18:19], v[66:67], %[v_acc_2] \n" +" buffer_load_dwordx4 acc[144:147], %[v_os_b1], s[20:23], 0 offen \n" +_UK_MFMA_ " %[v_acc_2], acc[20:21], v[68:69], %[v_acc_2] \n" +_UK_MFMA_ " %[v_acc_2], acc[22:23], v[70:71], %[v_acc_2] \n" +" buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n" +" s_add_u32 m0, %[s_size_per_issue], m0 \n" +_UK_MFMA_ " %[v_acc_2], acc[24:25], v[72:73], %[v_acc_2] \n" +_UK_MFMA_ " %[v_acc_2], acc[26:27], v[74:75], %[v_acc_2] \n" +" buffer_load_dwordx4 acc[148:151], %[v_os_b1], s[20:23], 0 offen offset:1024 \n" +_UK_MFMA_ " %[v_acc_2], acc[28:29], v[76:77], %[v_acc_2] \n" +_UK_MFMA_ " %[v_acc_2], acc[30:31], v[78:79], %[v_acc_2] \n" +" buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n" +" s_add_u32 m0, %[s_size_per_issue], m0 \n" +_UK_MFMA_ " %[v_acc_3], acc[16:17], v[80:81], %[v_acc_3] \n" +_UK_MFMA_ " %[v_acc_3], acc[18:19], v[82:83], %[v_acc_3] \n" +" buffer_load_dwordx4 acc[152:155], %[v_os_b1], s[20:23], 0 offen offset:2048 \n" +_UK_MFMA_ " %[v_acc_3], acc[20:21], v[84:85], %[v_acc_3] \n" +_UK_MFMA_ " %[v_acc_3], acc[22:23], v[86:87], %[v_acc_3] \n" +" buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n" +" s_add_u32 m0, %[s_size_per_issue], m0 \n" +_UK_MFMA_ " %[v_acc_3], acc[24:25], v[88:89], %[v_acc_3] \n" +_UK_MFMA_ " %[v_acc_3], acc[26:27], v[90:91], %[v_acc_3] \n" +" buffer_load_dwordx4 acc[156:159], %[v_os_b1], s[20:23], 0 offen offset:3072 \n" +_UK_MFMA_ " %[v_acc_3], acc[28:29], v[92:93], %[v_acc_3] \n" +_UK_MFMA_ " %[v_acc_3], acc[30:31], v[94:95], %[v_acc_3] \n" +" buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n" +" s_add_u32 m0, %[smem_sz], %[s_m0_init] \n" +" s_waitcnt vmcnt(32) \n" +_UK_MFMA_ " %[v_acc_4], acc[32:33], v[64:65], %[v_acc_4] \n" +_UK_MFMA_ " %[v_acc_4], acc[34:35], v[66:67], %[v_acc_4] \n" +" buffer_load_dwordx4 acc[160:163], %[v_os_b2], s[20:23], 0 offen \n" +_UK_MFMA_ " %[v_acc_4], acc[36:37], v[68:69], %[v_acc_4] \n" +_UK_MFMA_ " %[v_acc_4], acc[38:39], v[70:71], %[v_acc_4] \n" +" ds_read_b128 v[96:99], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_0] \n" +_UK_MFMA_ " %[v_acc_4], acc[40:41], v[72:73], %[v_acc_4] \n" +_UK_MFMA_ " %[v_acc_4], acc[42:43], v[74:75], %[v_acc_4] \n" +" buffer_load_dwordx4 acc[164:167], %[v_os_b2], s[20:23], 0 offen offset:1024 \n" +_UK_MFMA_ " %[v_acc_4], acc[44:45], v[76:77], %[v_acc_4] \n" +_UK_MFMA_ " %[v_acc_4], acc[46:47], v[78:79], %[v_acc_4] \n" +" ds_read_b128 v[100:103], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_1] \n" +_UK_MFMA_ " %[v_acc_5], acc[32:33], v[80:81], %[v_acc_5] \n" +_UK_MFMA_ " %[v_acc_5], acc[34:35], v[82:83], %[v_acc_5] \n" +" buffer_load_dwordx4 acc[168:171], %[v_os_b2], s[20:23], 0 offen offset:2048 \n" +_UK_MFMA_ " %[v_acc_5], acc[36:37], v[84:85], %[v_acc_5] \n" +_UK_MFMA_ " %[v_acc_5], acc[38:39], v[86:87], %[v_acc_5] \n" +" ds_read_b128 v[104:107], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_2] \n" +_UK_MFMA_ " %[v_acc_5], acc[40:41], v[88:89], %[v_acc_5] \n" +_UK_MFMA_ " %[v_acc_5], acc[42:43], v[90:91], %[v_acc_5] \n" +" buffer_load_dwordx4 acc[172:175], %[v_os_b2], s[20:23], 0 offen offset:3072 \n" +_UK_MFMA_ " %[v_acc_5], acc[44:45], v[92:93], %[v_acc_5] \n" +_UK_MFMA_ " %[v_acc_5], acc[46:47], v[94:95], %[v_acc_5] \n" +" ds_read_b128 v[108:111], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_3] \n" +_UK_MFMA_ " %[v_acc_6], acc[48:49], v[64:65], %[v_acc_6] \n" +_UK_MFMA_ " %[v_acc_6], acc[50:51], v[66:67], %[v_acc_6] \n" +" buffer_load_dwordx4 acc[176:179], %[v_os_b3], s[20:23], 0 offen \n" +_UK_MFMA_ " %[v_acc_6], acc[52:53], v[68:69], %[v_acc_6] \n" +_UK_MFMA_ " %[v_acc_6], acc[54:55], v[70:71], %[v_acc_6] \n" +" ds_read_b128 v[112:115], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_4] \n" +_UK_MFMA_ " %[v_acc_6], acc[56:57], v[72:73], %[v_acc_6] \n" +_UK_MFMA_ " %[v_acc_6], acc[58:59], v[74:75], %[v_acc_6] \n" +" buffer_load_dwordx4 acc[180:183], %[v_os_b3], s[20:23], 0 offen offset:1024 \n" +_UK_MFMA_ " %[v_acc_6], acc[60:61], v[76:77], %[v_acc_6] \n" +_UK_MFMA_ " %[v_acc_6], acc[62:63], v[78:79], %[v_acc_6] \n" +" ds_read_b128 v[116:119], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_5] \n" +_UK_MFMA_ " %[v_acc_7], acc[48:49], v[80:81], %[v_acc_7] \n" +_UK_MFMA_ " %[v_acc_7], acc[50:51], v[82:83], %[v_acc_7] \n" +" buffer_load_dwordx4 acc[184:187], %[v_os_b3], s[20:23], 0 offen offset:2048 \n" +_UK_MFMA_ " %[v_acc_7], acc[52:53], v[84:85], %[v_acc_7] \n" +_UK_MFMA_ " %[v_acc_7], acc[54:55], v[86:87], %[v_acc_7] \n" +" ds_read_b128 v[120:123], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_6] \n" +_UK_MFMA_ " %[v_acc_7], acc[56:57], v[88:89], %[v_acc_7] \n" +_UK_MFMA_ " %[v_acc_7], acc[58:59], v[90:91], %[v_acc_7] \n" +" buffer_load_dwordx4 acc[188:191], %[v_os_b3], s[20:23], 0 offen offset:3072 \n" +_UK_MFMA_ " %[v_acc_7], acc[60:61], v[92:93], %[v_acc_7] \n" +_UK_MFMA_ " %[v_acc_7], acc[62:63], v[94:95], %[v_acc_7] \n" +" ds_read_b128 v[124:127], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_7] \n" +" s_waitcnt vmcnt(32) \n" +_UK_MFMA_ " %[v_acc_8], acc[64:65], v[64:65], %[v_acc_8] \n" +_UK_MFMA_ " %[v_acc_8], acc[66:67], v[66:67], %[v_acc_8] \n" +" buffer_load_dwordx4 acc[192:195], %[v_os_b4], s[20:23], 0 offen \n" +_UK_MFMA_ " %[v_acc_8], acc[68:69], v[68:69], %[v_acc_8] \n" +_UK_MFMA_ " %[v_acc_8], acc[70:71], v[70:71], %[v_acc_8] \n" +_UK_MFMA_ " %[v_acc_8], acc[72:73], v[72:73], %[v_acc_8] \n" +_UK_MFMA_ " %[v_acc_8], acc[74:75], v[74:75], %[v_acc_8] \n" +" buffer_load_dwordx4 acc[196:199], %[v_os_b4], s[20:23], 0 offen offset:1024 \n" +_UK_MFMA_ " %[v_acc_8], acc[76:77], v[76:77], %[v_acc_8] \n" +_UK_MFMA_ " %[v_acc_8], acc[78:79], v[78:79], %[v_acc_8] \n" +_UK_MFMA_ " %[v_acc_9], acc[64:65], v[80:81], %[v_acc_9] \n" +_UK_MFMA_ " %[v_acc_9], acc[66:67], v[82:83], %[v_acc_9] \n" +" buffer_load_dwordx4 acc[200:203], %[v_os_b4], s[20:23], 0 offen offset:2048 \n" +_UK_MFMA_ " %[v_acc_9], acc[68:69], v[84:85], %[v_acc_9] \n" +_UK_MFMA_ " %[v_acc_9], acc[70:71], v[86:87], %[v_acc_9] \n" +_UK_MFMA_ " %[v_acc_9], acc[72:73], v[88:89], %[v_acc_9] \n" +_UK_MFMA_ " %[v_acc_9], acc[74:75], v[90:91], %[v_acc_9] \n" +" buffer_load_dwordx4 acc[204:207], %[v_os_b4], s[20:23], 0 offen offset:3072 \n" +_UK_MFMA_ " %[v_acc_9], acc[76:77], v[92:93], %[v_acc_9] \n" +_UK_MFMA_ " %[v_acc_9], acc[78:79], v[94:95], %[v_acc_9] \n" +_UK_MFMA_ " %[v_acc_10], acc[80:81], v[64:65], %[v_acc_10] \n" +_UK_MFMA_ " %[v_acc_10], acc[82:83], v[66:67], %[v_acc_10] \n" +" buffer_load_dwordx4 acc[208:211], %[v_os_b5], s[20:23], 0 offen \n" +_UK_MFMA_ " %[v_acc_10], acc[84:85], v[68:69], %[v_acc_10] \n" +_UK_MFMA_ " %[v_acc_10], acc[86:87], v[70:71], %[v_acc_10] \n" +_UK_MFMA_ " %[v_acc_10], acc[88:89], v[72:73], %[v_acc_10] \n" +_UK_MFMA_ " %[v_acc_10], acc[90:91], v[74:75], %[v_acc_10] \n" +" buffer_load_dwordx4 acc[212:215], %[v_os_b5], s[20:23], 0 offen offset:1024 \n" +_UK_MFMA_ " %[v_acc_10], acc[92:93], v[76:77], %[v_acc_10] \n" +_UK_MFMA_ " %[v_acc_10], acc[94:95], v[78:79], %[v_acc_10] \n" +_UK_MFMA_ " %[v_acc_11], acc[80:81], v[80:81], %[v_acc_11] \n" +_UK_MFMA_ " %[v_acc_11], acc[82:83], v[82:83], %[v_acc_11] \n" +" buffer_load_dwordx4 acc[216:219], %[v_os_b5], s[20:23], 0 offen offset:2048 \n" +_UK_MFMA_ " %[v_acc_11], acc[84:85], v[84:85], %[v_acc_11] \n" +_UK_MFMA_ " %[v_acc_11], acc[86:87], v[86:87], %[v_acc_11] \n" +_UK_MFMA_ " %[v_acc_11], acc[88:89], v[88:89], %[v_acc_11] \n" +_UK_MFMA_ " %[v_acc_11], acc[90:91], v[90:91], %[v_acc_11] \n" +" buffer_load_dwordx4 acc[220:223], %[v_os_b5], s[20:23], 0 offen offset:3072 \n" +_UK_MFMA_ " %[v_acc_11], acc[92:93], v[92:93], %[v_acc_11] \n" +_UK_MFMA_ " %[v_acc_11], acc[94:95], v[94:95], %[v_acc_11] \n" +" s_waitcnt vmcnt(32) \n" +_UK_MFMA_ " %[v_acc_12], acc[96:97], v[64:65], %[v_acc_12] \n" +_UK_MFMA_ " %[v_acc_12], acc[98:99], v[66:67], %[v_acc_12] \n" +" buffer_load_dwordx4 acc[224:227], %[v_os_b6], s[20:23], 0 offen \n" +_UK_MFMA_ " %[v_acc_12], acc[100:101], v[68:69], %[v_acc_12] \n" +_UK_MFMA_ " %[v_acc_12], acc[102:103], v[70:71], %[v_acc_12] \n" +_UK_MFMA_ " %[v_acc_12], acc[104:105], v[72:73], %[v_acc_12] \n" +_UK_MFMA_ " %[v_acc_12], acc[106:107], v[74:75], %[v_acc_12] \n" +" buffer_load_dwordx4 acc[228:231], %[v_os_b6], s[20:23], 0 offen offset:1024 \n" +_UK_MFMA_ " %[v_acc_12], acc[108:109], v[76:77], %[v_acc_12] \n" +_UK_MFMA_ " %[v_acc_12], acc[110:111], v[78:79], %[v_acc_12] \n" +_UK_MFMA_ " %[v_acc_13], acc[96:97], v[80:81], %[v_acc_13] \n" +_UK_MFMA_ " %[v_acc_13], acc[98:99], v[82:83], %[v_acc_13] \n" +" buffer_load_dwordx4 acc[232:235], %[v_os_b6], s[20:23], 0 offen offset:2048 \n" +_UK_MFMA_ " %[v_acc_13], acc[100:101], v[84:85], %[v_acc_13] \n" +_UK_MFMA_ " %[v_acc_13], acc[102:103], v[86:87], %[v_acc_13] \n" +_UK_MFMA_ " %[v_acc_13], acc[104:105], v[88:89], %[v_acc_13] \n" +_UK_MFMA_ " %[v_acc_13], acc[106:107], v[90:91], %[v_acc_13] \n" +" buffer_load_dwordx4 acc[236:239], %[v_os_b6], s[20:23], 0 offen offset:3072 \n" +_UK_MFMA_ " %[v_acc_13], acc[108:109], v[92:93], %[v_acc_13] \n" +_UK_MFMA_ " %[v_acc_13], acc[110:111], v[94:95], %[v_acc_13] \n" +_UK_MFMA_ " %[v_acc_14], acc[112:113], v[64:65], %[v_acc_14] \n" +_UK_MFMA_ " %[v_acc_14], acc[114:115], v[66:67], %[v_acc_14] \n" +" buffer_load_dwordx4 acc[240:243], %[v_os_b7], s[20:23], 0 offen \n" +_UK_MFMA_ " %[v_acc_14], acc[116:117], v[68:69], %[v_acc_14] \n" +_UK_MFMA_ " %[v_acc_14], acc[118:119], v[70:71], %[v_acc_14] \n" +_UK_MFMA_ " %[v_acc_14], acc[120:121], v[72:73], %[v_acc_14] \n" +_UK_MFMA_ " %[v_acc_14], acc[122:123], v[74:75], %[v_acc_14] \n" +" buffer_load_dwordx4 acc[244:247], %[v_os_b7], s[20:23], 0 offen offset:1024 \n" +_UK_MFMA_ " %[v_acc_14], acc[124:125], v[76:77], %[v_acc_14] \n" +_UK_MFMA_ " %[v_acc_14], acc[126:127], v[78:79], %[v_acc_14] \n" +_UK_MFMA_ " %[v_acc_15], acc[112:113], v[80:81], %[v_acc_15] \n" +_UK_MFMA_ " %[v_acc_15], acc[114:115], v[82:83], %[v_acc_15] \n" +" buffer_load_dwordx4 acc[248:251], %[v_os_b7], s[20:23], 0 offen offset:2048 \n" +_UK_MFMA_ " %[v_acc_15], acc[116:117], v[84:85], %[v_acc_15] \n" +_UK_MFMA_ " %[v_acc_15], acc[118:119], v[86:87], %[v_acc_15] \n" +_UK_MFMA_ " %[v_acc_15], acc[120:121], v[88:89], %[v_acc_15] \n" +_UK_MFMA_ " %[v_acc_15], acc[122:123], v[90:91], %[v_acc_15] \n" +" buffer_load_dwordx4 acc[252:255], %[v_os_b7], s[20:23], 0 offen offset:3072\n" +_UK_MFMA_ " %[v_acc_15], acc[124:125], v[92:93], %[v_acc_15] \n" +_UK_MFMA_ " %[v_acc_15], acc[126:127], v[94:95], %[v_acc_15] \n" +" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 \n" +" s_cmp_gt_i32 %[s_loop_cnt] 0 \n" +" s_cbranch_scc0 L_end%= \n" +" s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n" +" s_cselect_b32 s86, %[s_tile_os_a], 0 \n" +" s_add_u32 s16, s86, s16 \n" +" s_addc_u32 s17, 0, s17 \n" +" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n" +" s_cselect_b32 s86, %[s_tile_os_b], 0 \n" +" s_add_u32 s20, s86, s20 \n" +" s_addc_u32 s21, 0, s21 \n" +" ;------------------------------------------ \n" +" s_waitcnt vmcnt(24) & lgkmcnt(0) \n" +" s_barrier \n" +_UK_MFMA_ " %[v_acc_0], acc[128:129], v[96:97], %[v_acc_0] \n" +_UK_MFMA_ " %[v_acc_0], acc[130:131], v[98:99], %[v_acc_0] \n" +" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[20:23], 0 offen \n" +_UK_MFMA_ " %[v_acc_0], acc[132:133], v[100:101], %[v_acc_0] \n" +_UK_MFMA_ " %[v_acc_0], acc[134:135], v[102:103], %[v_acc_0] \n" +" buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n" +" s_add_u32 m0, %[s_size_per_issue], m0 \n" +_UK_MFMA_ " %[v_acc_0], acc[136:137], v[104:105], %[v_acc_0] \n" +_UK_MFMA_ " %[v_acc_0], acc[138:139], v[106:107], %[v_acc_0] \n" +" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[20:23], 0 offen offset:1024 \n" +_UK_MFMA_ " %[v_acc_0], acc[140:141], v[108:109], %[v_acc_0] \n" +_UK_MFMA_ " %[v_acc_0], acc[142:143], v[110:111], %[v_acc_0] \n" +" buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n" +" s_add_u32 m0, %[s_size_per_issue], m0 \n" +_UK_MFMA_ " %[v_acc_1], acc[128:129], v[112:113], %[v_acc_1] \n" +_UK_MFMA_ " %[v_acc_1], acc[130:131], v[114:115], %[v_acc_1] \n" +" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[20:23], 0 offen offset:2048 \n" +_UK_MFMA_ " %[v_acc_1], acc[132:133], v[116:117], %[v_acc_1] \n" +_UK_MFMA_ " %[v_acc_1], acc[134:135], v[118:119], %[v_acc_1] \n" +" buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n" +" s_add_u32 m0, %[s_size_per_issue], m0 \n" +_UK_MFMA_ " %[v_acc_1], acc[136:137], v[120:121], %[v_acc_1] \n" +_UK_MFMA_ " %[v_acc_1], acc[138:139], v[122:123], %[v_acc_1] \n" +" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[20:23], 0 offen offset:3072 \n" +_UK_MFMA_ " %[v_acc_1], acc[140:141], v[124:125], %[v_acc_1] \n" +_UK_MFMA_ " %[v_acc_1], acc[142:143], v[126:127], %[v_acc_1] \n" +" buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n" +" s_add_u32 m0, %[s_size_per_issue], m0 \n" +_UK_MFMA_ " %[v_acc_2], acc[144:145], v[96:97], %[v_acc_2] \n" +_UK_MFMA_ " %[v_acc_2], acc[146:147], v[98:99], %[v_acc_2] \n" +" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[20:23], 0 offen \n" +_UK_MFMA_ " %[v_acc_2], acc[148:149], v[100:101], %[v_acc_2] \n" +_UK_MFMA_ " %[v_acc_2], acc[150:151], v[102:103], %[v_acc_2] \n" +" buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n" +" s_add_u32 m0, %[s_size_per_issue], m0 \n" +_UK_MFMA_ " %[v_acc_2], acc[152:153], v[104:105], %[v_acc_2] \n" +_UK_MFMA_ " %[v_acc_2], acc[154:155], v[106:107], %[v_acc_2] \n" +" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[20:23], 0 offen offset:1024 \n" +_UK_MFMA_ " %[v_acc_2], acc[156:157], v[108:109], %[v_acc_2] \n" +_UK_MFMA_ " %[v_acc_2], acc[158:159], v[110:111], %[v_acc_2] \n" +" buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n" +" s_add_u32 m0, %[s_size_per_issue], m0 \n" +_UK_MFMA_ " %[v_acc_3], acc[144:145], v[112:113], %[v_acc_3] \n" +_UK_MFMA_ " %[v_acc_3], acc[146:147], v[114:115], %[v_acc_3] \n" +" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[20:23], 0 offen offset:2048 \n" +_UK_MFMA_ " %[v_acc_3], acc[148:149], v[116:117], %[v_acc_3] \n" +_UK_MFMA_ " %[v_acc_3], acc[150:151], v[118:119], %[v_acc_3] \n" +" buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n" +" s_add_u32 m0, %[s_size_per_issue], m0 \n" +_UK_MFMA_ " %[v_acc_3], acc[152:153], v[120:121], %[v_acc_3] \n" +_UK_MFMA_ " %[v_acc_3], acc[154:155], v[122:123], %[v_acc_3] \n" +" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[20:23], 0 offen offset:3072 \n" +_UK_MFMA_ " %[v_acc_3], acc[156:157], v[124:125], %[v_acc_3] \n" +_UK_MFMA_ " %[v_acc_3], acc[158:159], v[126:127], %[v_acc_3] \n" +" buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n" +" s_add_u32 m0, 0, %[s_m0_init] \n" +" s_waitcnt vmcnt(32) \n" +_UK_MFMA_ " %[v_acc_4], acc[160:161], v[96:97], %[v_acc_4] \n" +_UK_MFMA_ " %[v_acc_4], acc[162:163], v[98:99], %[v_acc_4] \n" +" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[20:23], 0 offen \n" +_UK_MFMA_ " %[v_acc_4], acc[164:165], v[100:101], %[v_acc_4] \n" +_UK_MFMA_ " %[v_acc_4], acc[166:167], v[102:103], %[v_acc_4] \n" +" ds_read_b128 v[64:67], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_0] \n" +_UK_MFMA_ " %[v_acc_4], acc[168:169], v[104:105], %[v_acc_4] \n" +_UK_MFMA_ " %[v_acc_4], acc[170:171], v[106:107], %[v_acc_4] \n" +" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[20:23], 0 offen offset:1024 \n" +_UK_MFMA_ " %[v_acc_4], acc[172:173], v[108:109], %[v_acc_4] \n" +_UK_MFMA_ " %[v_acc_4], acc[174:175], v[110:111], %[v_acc_4] \n" +" ds_read_b128 v[68:71], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_1] \n" +_UK_MFMA_ " %[v_acc_5], acc[160:161], v[112:113], %[v_acc_5] \n" +_UK_MFMA_ " %[v_acc_5], acc[162:163], v[114:115], %[v_acc_5] \n" +" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[20:23], 0 offen offset:2048 \n" +_UK_MFMA_ " %[v_acc_5], acc[164:165], v[116:117], %[v_acc_5] \n" +_UK_MFMA_ " %[v_acc_5], acc[166:167], v[118:119], %[v_acc_5] \n" +" ds_read_b128 v[72:75], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_2] \n" +_UK_MFMA_ " %[v_acc_5], acc[168:169], v[120:121], %[v_acc_5] \n" +_UK_MFMA_ " %[v_acc_5], acc[170:171], v[122:123], %[v_acc_5] \n" +" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[20:23], 0 offen offset:3072 \n" +_UK_MFMA_ " %[v_acc_5], acc[172:173], v[124:125], %[v_acc_5] \n" +_UK_MFMA_ " %[v_acc_5], acc[174:175], v[126:127], %[v_acc_5] \n" +" ds_read_b128 v[76:79], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_3] \n" +_UK_MFMA_ " %[v_acc_6], acc[176:177], v[96:97], %[v_acc_6] \n" +_UK_MFMA_ " %[v_acc_6], acc[178:179], v[98:99], %[v_acc_6] \n" +" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[20:23], 0 offen \n" +_UK_MFMA_ " %[v_acc_6], acc[180:181], v[100:101], %[v_acc_6] \n" +_UK_MFMA_ " %[v_acc_6], acc[182:183], v[102:103], %[v_acc_6] \n" +" ds_read_b128 v[80:83], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_4] \n" +_UK_MFMA_ " %[v_acc_6], acc[184:185], v[104:105], %[v_acc_6] \n" +_UK_MFMA_ " %[v_acc_6], acc[186:187], v[106:107], %[v_acc_6] \n" +" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[20:23], 0 offen offset:1024 \n" +_UK_MFMA_ " %[v_acc_6], acc[188:189], v[108:109], %[v_acc_6] \n" +_UK_MFMA_ " %[v_acc_6], acc[190:191], v[110:111], %[v_acc_6] \n" +" ds_read_b128 v[84:87], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_5] \n" +_UK_MFMA_ " %[v_acc_7], acc[176:177], v[112:113], %[v_acc_7] \n" +_UK_MFMA_ " %[v_acc_7], acc[178:179], v[114:115], %[v_acc_7] \n" +" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[20:23], 0 offen offset:2048 \n" +_UK_MFMA_ " %[v_acc_7], acc[180:181], v[116:117], %[v_acc_7] \n" +_UK_MFMA_ " %[v_acc_7], acc[182:183], v[118:119], %[v_acc_7] \n" +" ds_read_b128 v[88:91], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_6] \n" +_UK_MFMA_ " %[v_acc_7], acc[184:185], v[120:121], %[v_acc_7] \n" +_UK_MFMA_ " %[v_acc_7], acc[186:187], v[122:123], %[v_acc_7] \n" +" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[20:23], 0 offen offset:3072 \n" +_UK_MFMA_ " %[v_acc_7], acc[188:189], v[124:125], %[v_acc_7] \n" +_UK_MFMA_ " %[v_acc_7], acc[190:191], v[126:127], %[v_acc_7] \n" +" ds_read_b128 v[92:95], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_7] \n" +" s_waitcnt vmcnt(32) \n" +_UK_MFMA_ " %[v_acc_8], acc[192:193], v[96:97], %[v_acc_8] \n" +_UK_MFMA_ " %[v_acc_8], acc[194:195], v[98:99], %[v_acc_8] \n" +" buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[20:23], 0 offen \n" +_UK_MFMA_ " %[v_acc_8], acc[196:197], v[100:101], %[v_acc_8] \n" +_UK_MFMA_ " %[v_acc_8], acc[198:199], v[102:103], %[v_acc_8] \n" +_UK_MFMA_ " %[v_acc_8], acc[200:201], v[104:105], %[v_acc_8] \n" +_UK_MFMA_ " %[v_acc_8], acc[202:203], v[106:107], %[v_acc_8] \n" +" buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[20:23], 0 offen offset:1024 \n" +_UK_MFMA_ " %[v_acc_8], acc[204:205], v[108:109], %[v_acc_8] \n" +_UK_MFMA_ " %[v_acc_8], acc[206:207], v[110:111], %[v_acc_8] \n" +_UK_MFMA_ " %[v_acc_9], acc[192:193], v[112:113], %[v_acc_9] \n" +_UK_MFMA_ " %[v_acc_9], acc[194:195], v[114:115], %[v_acc_9] \n" +" buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[20:23], 0 offen offset:2048 \n" +_UK_MFMA_ " %[v_acc_9], acc[196:197], v[116:117], %[v_acc_9] \n" +_UK_MFMA_ " %[v_acc_9], acc[198:199], v[118:119], %[v_acc_9] \n" +_UK_MFMA_ " %[v_acc_9], acc[200:201], v[120:121], %[v_acc_9] \n" +_UK_MFMA_ " %[v_acc_9], acc[202:203], v[122:123], %[v_acc_9] \n" +" buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[20:23], 0 offen offset:3072 \n" +_UK_MFMA_ " %[v_acc_9], acc[204:205], v[124:125], %[v_acc_9] \n" +_UK_MFMA_ " %[v_acc_9], acc[206:207], v[126:127], %[v_acc_9] \n" +_UK_MFMA_ " %[v_acc_10], acc[208:209], v[96:97], %[v_acc_10] \n" +_UK_MFMA_ " %[v_acc_10], acc[210:211], v[98:99], %[v_acc_10] \n" +" buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[20:23], 0 offen \n" +_UK_MFMA_ " %[v_acc_10], acc[212:213], v[100:101], %[v_acc_10] \n" +_UK_MFMA_ " %[v_acc_10], acc[214:215], v[102:103], %[v_acc_10] \n" +_UK_MFMA_ " %[v_acc_10], acc[216:217], v[104:105], %[v_acc_10] \n" +_UK_MFMA_ " %[v_acc_10], acc[218:219], v[106:107], %[v_acc_10] \n" +" buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[20:23], 0 offen offset:1024 \n" +_UK_MFMA_ " %[v_acc_10], acc[220:221], v[108:109], %[v_acc_10] \n" +_UK_MFMA_ " %[v_acc_10], acc[222:223], v[110:111], %[v_acc_10] \n" +_UK_MFMA_ " %[v_acc_11], acc[208:209], v[112:113], %[v_acc_11] \n" +_UK_MFMA_ " %[v_acc_11], acc[210:211], v[114:115], %[v_acc_11] \n" +" buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[20:23], 0 offen offset:2048 \n" +_UK_MFMA_ " %[v_acc_11], acc[212:213], v[116:117], %[v_acc_11] \n" +_UK_MFMA_ " %[v_acc_11], acc[214:215], v[118:119], %[v_acc_11] \n" +_UK_MFMA_ " %[v_acc_11], acc[216:217], v[120:121], %[v_acc_11] \n" +_UK_MFMA_ " %[v_acc_11], acc[218:219], v[122:123], %[v_acc_11] \n" +" buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[20:23], 0 offen offset:3072 \n" +_UK_MFMA_ " %[v_acc_11], acc[220:221], v[124:125], %[v_acc_11] \n" +_UK_MFMA_ " %[v_acc_11], acc[222:223], v[126:127], %[v_acc_11] \n" +" s_waitcnt vmcnt(32) \n" +_UK_MFMA_ " %[v_acc_12], acc[224:225], v[96:97], %[v_acc_12] \n" +_UK_MFMA_ " %[v_acc_12], acc[226:227], v[98:99], %[v_acc_12] \n" +" buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[20:23], 0 offen \n" +_UK_MFMA_ " %[v_acc_12], acc[228:229], v[100:101], %[v_acc_12] \n" +_UK_MFMA_ " %[v_acc_12], acc[230:231], v[102:103], %[v_acc_12] \n" +_UK_MFMA_ " %[v_acc_12], acc[232:233], v[104:105], %[v_acc_12] \n" +_UK_MFMA_ " %[v_acc_12], acc[234:235], v[106:107], %[v_acc_12] \n" +" buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[20:23], 0 offen offset:1024 \n" +_UK_MFMA_ " %[v_acc_12], acc[236:237], v[108:109], %[v_acc_12] \n" +_UK_MFMA_ " %[v_acc_12], acc[238:239], v[110:111], %[v_acc_12] \n" +_UK_MFMA_ " %[v_acc_13], acc[224:225], v[112:113], %[v_acc_13] \n" +_UK_MFMA_ " %[v_acc_13], acc[226:227], v[114:115], %[v_acc_13] \n" +" buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[20:23], 0 offen offset:2048 \n" +_UK_MFMA_ " %[v_acc_13], acc[228:229], v[116:117], %[v_acc_13] \n" +_UK_MFMA_ " %[v_acc_13], acc[230:231], v[118:119], %[v_acc_13] \n" +_UK_MFMA_ " %[v_acc_13], acc[232:233], v[120:121], %[v_acc_13] \n" +_UK_MFMA_ " %[v_acc_13], acc[234:235], v[122:123], %[v_acc_13] \n" +" buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[20:23], 0 offen offset:3072 \n" +_UK_MFMA_ " %[v_acc_13], acc[236:237], v[124:125], %[v_acc_13] \n" +_UK_MFMA_ " %[v_acc_13], acc[238:239], v[126:127], %[v_acc_13] \n" +_UK_MFMA_ " %[v_acc_14], acc[240:241], v[96:97], %[v_acc_14] \n" +_UK_MFMA_ " %[v_acc_14], acc[242:243], v[98:99], %[v_acc_14] \n" +" buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[20:23], 0 offen \n" +_UK_MFMA_ " %[v_acc_14], acc[244:245], v[100:101], %[v_acc_14] \n" +_UK_MFMA_ " %[v_acc_14], acc[246:247], v[102:103], %[v_acc_14] \n" +_UK_MFMA_ " %[v_acc_14], acc[248:249], v[104:105], %[v_acc_14] \n" +_UK_MFMA_ " %[v_acc_14], acc[250:251], v[106:107], %[v_acc_14] \n" +" buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[20:23], 0 offen offset:1024 \n" +_UK_MFMA_ " %[v_acc_14], acc[252:253], v[108:109], %[v_acc_14] \n" +_UK_MFMA_ " %[v_acc_14], acc[254:255], v[110:111], %[v_acc_14] \n" +_UK_MFMA_ " %[v_acc_15], acc[240:241], v[112:113], %[v_acc_15] \n" +_UK_MFMA_ " %[v_acc_15], acc[242:243], v[114:115], %[v_acc_15] \n" +" buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[20:23], 0 offen offset:2048 \n" +_UK_MFMA_ " %[v_acc_15], acc[244:245], v[116:117], %[v_acc_15] \n" +_UK_MFMA_ " %[v_acc_15], acc[246:247], v[118:119], %[v_acc_15] \n" +_UK_MFMA_ " %[v_acc_15], acc[248:249], v[120:121], %[v_acc_15] \n" +_UK_MFMA_ " %[v_acc_15], acc[250:251], v[122:123], %[v_acc_15] \n" +" buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[20:23], 0 offen offset:3072 \n" +_UK_MFMA_ " %[v_acc_15], acc[252:253], v[124:125], %[v_acc_15] \n" +_UK_MFMA_ " %[v_acc_15], acc[254:255], v[126:127], %[v_acc_15] \n" +" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 \n" +" s_cmp_gt_i32 %[s_loop_cnt] 0 \n" +" s_cbranch_scc0 L_end%= \n" +" s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n" +" s_cselect_b32 s86, %[s_tile_os_a], 0 \n" +" s_add_u32 s16, s86, s16 \n" +" s_addc_u32 s17, 0, s17 \n" +" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n" +" s_cselect_b32 s86, %[s_tile_os_b], 0 \n" +" s_add_u32 s20, s86, s20 \n" +" s_addc_u32 s21, 0, s21 \n" +" s_branch L_start%= \n" +"L_end%=: \n" +" s_nop 2 \n" + +#undef _UK_MFMA_ diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 10bb01168f..173887513e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -331,7 +331,8 @@ struct BlockFmhaPipelineQRKSVSAsync Policy::template MakeVDramTileDistribution()); // prefetch K tile - async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, k_oob_ck, k_pre_np); + async_load_tile_raw( + k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np); move_tile_window(k_dram_window, {0, kK0}); __builtin_amdgcn_sched_barrier(0); @@ -355,6 +356,7 @@ struct BlockFmhaPipelineQRKSVSAsync static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { async_load_tile_raw(k_lds_store(number{})>{}), k_dram_window, + number<-1>{}, k_oob_ck, k_pre_np); if constexpr(i_k0 < k0_loops - 1) @@ -386,7 +388,7 @@ struct BlockFmhaPipelineQRKSVSAsync __builtin_amdgcn_s_barrier(); const auto bias_tile = load_tile(bias_dram_window); // load bias tile - auto v_buf = load_tile(v_dram_window, bool_constant{}); + auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); __builtin_amdgcn_sched_barrier(0); { // tail gemm_0(s_acc, @@ -514,7 +516,8 @@ struct BlockFmhaPipelineQRKSVSAsync move_tile_window( v_dram_window, {0, kK1}); // will have scratch if move this right after load_tile(v_dram)... - v_buf = load_tile(v_dram_window, bool_constant{}); // load next v_buf + v_buf = load_tile( + v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf } __builtin_amdgcn_sched_barrier(0); @@ -618,7 +621,8 @@ struct BlockFmhaPipelineQRKSVSAsync static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1) { - v_buf = load_tile(v_dram_window, bool_constant{}); // load next v_buf + v_buf = load_tile( + v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf } block_sync_lds(); gemm_1(o_acc, @@ -665,8 +669,11 @@ struct BlockFmhaPipelineQRKSVSAsync if constexpr(k1_loops >= 2 && LdsSeq.at(number<0>{}) == LdsSeq.at(number{})) __builtin_amdgcn_s_barrier(); - async_load_tile_raw( - k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, k_oob_ck, k_pre_np); + async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), + k_dram_window, + number<-1>{}, + k_oob_ck, + k_pre_np); move_tile_window(k_dram_window, {0, kK0}); } // tail diff --git a/include/ck_tile/ops/fused_moe.hpp b/include/ck_tile/ops/fused_moe.hpp index b74607f061..d23af0af8d 100644 --- a/include/ck_tile/ops/fused_moe.hpp +++ b/include/ck_tile/ops/fused_moe.hpp @@ -3,7 +3,15 @@ #pragma once +#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp" +#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp" +#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp" #include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp" +#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp" +#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp" +#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp" +#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp" +#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp" diff --git a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp new file mode 100644 index 0000000000..2d25d44f3c --- /dev/null +++ b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp @@ -0,0 +1,421 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/elementwise.hpp" +#include +#include + +// clang-format off +// [indexing implementation-1] +// using M_a as constexpr block_size to partition all tokens into different slices +// each slice map to one expert, and one expert can have multiple slices +// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5 +// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]] +// tok-0 tok-1 tok-2 tok-3 tok-4 +// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number) +// +// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]] +// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 +// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] +// +// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1) +// * this could be larger than actual, since actual tokens are on GPU +// +// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] +// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -| +// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o] +// +// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr +// +// * Note on token_id_per_expert/sorted_token_ids_ptr data: +// currently we do not have topk information from the data of token_id_per_expert/sorted_token_ids_ptr. +// In some cases(like smooth-quant), we need topk information to indexing into tokens quant from +// different expert smooth quant. So we modify the number stored inside token_id_per_expert/sorted_token_ids_ptr +// +// 32bit 0........23 24.....31 bit +// (data) -> (token_id | topk_id) +// low 24 bit is for token id, top 8 bit is for topk id +// +// the input after smooth-quant is [token, topk, hidden_dim], originally it is [token, hidden_dim] +// the input scale for token is [topk, token, 1], the smooth-quant scale for first gemm is [expert, interm_dim] +// +// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5] +// * length is (max_num_tokens_padded + block_size - 1) / block_size +// +// num_tokens_post_padded_ptr : [28] +// num_sorted_tiles_ptr : [7] +// +// * different from vLLM +// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id +// 2)need sorted_weight_ptr +// 3) use num_sorted_tiles_ptr, already divided by M_a +// +// * below used for indexing +// 1) sorted_token_ids_ptr [max_num_tokens_padded] +// 2) sorted_weight_ptr +// 3) sorted_expert_ids_ptr +// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one) +// +// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1) +// +// [indexing implementation-2] +// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]] +// tok-0 tok-1 tok-2 tok-3 tok-4 +// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number) +// +// we generate original rol/col id as +// topk_rc_ids : [[0, 5, A], [1, 6, B], [2, 7, C], [3, 8, D], [4, 9, E]] +// let x be one element of above, we can get: +// tpok_row_id(token_id) = x % num_tokens(5) +// tpok_col_id(expert_Id) = x / num_tokens +// topk_row_id/col_id can be used to access original topk_ids/topk_weight +// +// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 5, 5]] +// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 +// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] +// +// we can get permuted_rc_ids: +// [[0], [2, 3, 4], [1, 8], [5, 6, 7, D, 9], [], [A, B, C, E]] +// +// +// clang-format on +// +namespace ck_tile { + +// m: num_tokens (or token*input-batch) +// k: intermediate_size +// n: intermediate_size used between 2 FC (TP slice this) +// e: num expert +// if doing pre-shuffle +// nr : n / Block_Nr +// kr : k / Block_Kr +// w : fattened 1d wave buffer +struct FusedMoeGemmHostArgs +{ + const void* a_ptr; // [m, k], input token + const void* a_scale_ptr; // [m, 1], token scale + const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) + const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w]) + const void* g_scale_ptr; // [e, 1, n], gate(up) scale + const void* d_scale_ptr; // [e, 1, k], down scale + const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input + void* o_ptr; // [m, k], output token + + const void* sorted_token_ids_ptr; // [max_num_tokens_padded] + const void* sorted_weight_ptr; // [max_num_tokens_padded] + const void* sorted_expert_ids_ptr; // [(max_num_tokens_padded + block_size - 1) / block_size] + const void* num_sorted_tiles_ptr; // [1] + + index_t hidden_size; // k + index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2 + index_t num_tokens; // input number of tokens for current iteration + index_t num_experts; // number of groups + index_t topk; // need this? + + index_t stride_token; // for input/output, stride for each row, should >= hidden_size +}; + +// This is scatter/gather b2b group-gemm +template +struct FusedMoeGemmKernel +{ + using Partitioner = remove_cvref_t; + using Pipeline = remove_cvref_t; + using Epilogue = remove_cvref_t; // TODO: not used + // static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu; + // static_assert(kBlockPerCu > 0); + + using BlockShape = typename Pipeline::BlockShape; // this is FusedMoeGemmShape + static constexpr index_t BlockSize_ = BlockShape::BlockSize; + + using ADataType = typename Pipeline::Problem::ADataType; + using GDataType = typename Pipeline::Problem::GDataType; + using DDataType = typename Pipeline::Problem::DDataType; + using AccDataType = typename Pipeline::Problem::AccDataType; + using ODataType = typename Pipeline::Problem::ODataType; + using AScaleDataType = typename Pipeline::Problem::AScaleDataType; + using GScaleDataType = typename Pipeline::Problem::GScaleDataType; + using DScaleDataType = typename Pipeline::Problem::DScaleDataType; + using YSmoothScaleDataType = typename Pipeline::Problem::YSmoothScaleDataType; + using TopkWeightDataType = typename Pipeline::Problem::TopkWeightDataType; + using IndexDataType = typename Pipeline::Problem::IndexDataType; + using YDataType = typename Pipeline::Problem::YDataType; + + using Traits = typename Pipeline::Problem::Traits; + static constexpr bool UseUK = true; + + static constexpr bool IsGateOnly = Traits::IsGateOnly; + static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant; + static constexpr bool PadHiddenSize = Traits::PadHiddenSize; + static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize; + + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp32"; }; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + template <> struct t2s { static constexpr const char * name = "fp8"; }; + template <> struct t2s { static constexpr const char * name = "bf8"; }; + template <> struct t2s { static constexpr const char * name = "int8"; }; + // clang-format on + + CK_TILE_HOST static std::string GetName() + { +#define _SS_ std::string +#define _TS_ std::to_string + // clang-format off + using S_ = BlockShape; + + auto prec_str = [&] () { + std::string base_str = _SS_(t2s::name); + if (!std::is_same_v) { + base_str += _SS_("_") + _SS_(t2s::name); + } + return base_str; + }(); + + return _SS_("fused_moe_") + _SS_(prec_str) + "_" + + _TS_(S_::Block_M0) + "x" + _TS_(S_::Block_N0) + "x" + _TS_(S_::Block_K0) + "x" + _TS_(S_::Block_N1) + "_" + + _TS_(S_::WarpPerBlock_M0) + "x" + _TS_(S_::WarpPerBlock_N0) + "x" + _TS_(S_::WarpPerBlock_K0) + "_" + + _TS_(S_::Warp_M0) + "x" + _TS_(S_::Warp_N0) + "x" + _TS_(S_::Warp_K0) + "_" + _SS_(Pipeline::name); +#undef _SS_ +#undef _TS_ + // clang-format on + } + + struct FusedMoeGemmKargs + { + const void* a_ptr; // [m, k], input token + const void* a_scale_ptr; // [m, 1], token scale + const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) + const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w]) + const void* g_scale_ptr; // [e, 1, n], gate(up) scale + const void* d_scale_ptr; // [e, 1, k], down scale + const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input + void* o_ptr; // [m, k], output token + + const void* sorted_token_ids_ptr; + const void* sorted_weight_ptr; + const void* sorted_expert_ids_ptr; + const void* num_sorted_tiles_ptr; + + index_t hidden_size; // k + index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2 + index_t num_tokens; // input number of tokens for current iteration + index_t num_experts; // number of groups + index_t topk; // need this? + + index_t stride_token; // for input/output, stride for each row, should >= hidden_size + }; + + // TODO: switch karg based on + using Kargs = FusedMoeGemmKargs; + using Hargs = FusedMoeGemmHostArgs; + + CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs) + { + // TODO: hargs/kargs not guranteed to be the same + return bit_cast(hargs); + } + + CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs) + { + constexpr index_t block_m = BlockShape::Block_M0; + int max_num_tokens_padded = + hargs.topk * hargs.num_tokens + hargs.num_experts * block_m - hargs.topk; + // printf("xxx max_num_tokens_padded:%d\n", max_num_tokens_padded); + return Partitioner::GridSize(max_num_tokens_padded, hargs.intermediate_size); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(BlockSize_); } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + if constexpr(UseUK) + { + __shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()]; + IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane( + *reinterpret_cast(kargs.num_sorted_tiles_ptr)); + + num_sorted_tiles = num_sorted_tiles / BlockShape::Block_M0; + + const auto [sorted_tile_id, intermediate_tile_id] = + Partitioner{}(num_sorted_tiles, kargs.intermediate_size); + // if(threadIdx.x == 0) + // printf("bid:%d,%d, num_sorted_tiles:%d, sorted_tile_id:%d(%d), + // intermediate_tile_id:%d\n", static_cast(blockIdx.x), + // static_cast(blockIdx.y), num_sorted_tiles, sorted_tile_id, sorted_tile_id >= + // num_sorted_tiles? 1 : 0, intermediate_tile_id); + if(sorted_tile_id >= num_sorted_tiles) + return; + + Pipeline{}(kargs, smem, sorted_tile_id, intermediate_tile_id); + } + else + { + // allocate LDS + // __shared__ char smem_ptr[GetSmemSize()]; + IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane( + *reinterpret_cast(kargs.num_sorted_tiles_ptr)); + constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2; + + index_t nr_0 = kargs.intermediate_size / BlockShape::Block_Nr0; + index_t kr_0 = kargs.hidden_size / BlockShape::Block_Kr0; + index_t nr_1 = kargs.hidden_size / BlockShape::Block_Nr1; // should be same as kr_0 + index_t kr_1 = + kargs.intermediate_size / BlockShape::Block_Kr1; // should be same as nr_0 + + index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size; + index_t expert_stride_1 = kargs.intermediate_size * kargs.hidden_size; + + __shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()]; + + // note this is in unit of tile, need multiple tile size to get the index + const auto [sorted_tile_id, intermediate_tile_id] = + Partitioner{}(num_sorted_tiles, kargs.intermediate_size); + if(sorted_tile_id >= num_sorted_tiles) + return; + + const IndexDataType expert_id = + __builtin_amdgcn_readfirstlane(reinterpret_cast( + kargs.sorted_expert_ids_ptr)[sorted_tile_id]); + + // index along intermediate_size + // index_t hidden_idx = __builtin_amdgcn_readfirstlane(intermediate_tile_id * + // BlockShape::Block_N0); + index_t interm_idx_nr = + __builtin_amdgcn_readfirstlane(intermediate_tile_id * BlockShape::Block_Nr0); + + const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col] + const auto sorted_token_id = + a_coord[number<0>{}] + sorted_tile_id * BlockShape::Block_M0; + + index_t token_id = + reinterpret_cast(kargs.sorted_token_ids_ptr)[sorted_token_id]; + auto topk_weight = reinterpret_cast( + kargs.sorted_weight_ptr)[sorted_token_id]; + + const auto a_window = [&]() { + // A is already pre-padded in previous kernel + const ADataType* a_ptr = reinterpret_cast(kargs.a_ptr); + const auto a_view_ = make_naive_tensor_view( + a_ptr, + make_tuple(kargs.num_tokens, kargs.hidden_size), + make_tuple(kargs.stride_token, 1), + number{}, + number<1>{}); + + // gather is here use indexing transform + const auto a_gather_view_ = transform_tensor_view( + a_view_, + make_tuple(make_indexing_transform(kargs.num_tokens, token_id), + make_pass_through_transform(kargs.hidden_size)), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + const auto a_window_ = make_tile_window( + a_gather_view_, + make_tuple(number{}, number{}), + {0, 0}); + return a_window_; + }(); + + // TODO: gtile using NSub to have less register pressure + const auto g_window = [&]() { + const GDataType* g_ptr = reinterpret_cast(kargs.g_ptr) + + static_cast(expert_id) * expert_stride_0 + + interm_idx_nr * kr_0 * BlockShape::Block_W0; + const auto g_view_ = make_naive_tensor_view( + g_ptr, + make_tuple(nr_0, kr_0, number{}), + make_tuple(kr_0 * BlockShape::Block_W0, number{}, 1), + number{}, + number<1>{}); + const auto g_view_1_ = + pad_tensor_view(g_view_, + make_tuple(number{}, + number{}, + number{}), + sequence{}); + + const auto g_window_ = make_tile_window(g_view_1_, + make_tuple(number{}, + number{}, + number{}), + {0, 0, 0}); + return g_window_; + }(); + + const auto d_window = [&]() { + const DDataType* d_ptr = reinterpret_cast(kargs.d_ptr) + + static_cast(expert_id) * expert_stride_1 + + interm_idx_nr * BlockShape::Block_W1; + // note interm_idx_nr is along the gemm-k dim of 2nd gemm + + const auto d_view_ = make_naive_tensor_view( + d_ptr, + make_tuple(nr_1, kr_1, BlockShape::Block_W1), + make_tuple(kr_1 * BlockShape::Block_W1, BlockShape::Block_W1, 1), + number{}, + number<1>{}); + const auto d_view_1_ = + pad_tensor_view(d_view_, + make_tuple(number{}, + number{}, + number{}), + sequence{}); + + const auto d_window_ = make_tile_window(d_view_1_, + make_tuple(number{}, + number{}, + number{}), + {0, 0, 0}); + return d_window_; + }(); + + auto o_window = [&]() { + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr); + auto o_view_ = make_naive_tensor_view( + o_ptr, + make_tuple(kargs.num_tokens, kargs.hidden_size), + make_tuple(kargs.stride_token, 1), + number{}, + number<1>{}); + + // gather is here + auto o_scatter_view_ = transform_tensor_view( + o_view_, + make_tuple(make_indexing_transform(kargs.num_tokens, token_id), + make_pass_through_transform(kargs.hidden_size)), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + auto o_window_ = make_tile_window( + o_scatter_view_, + make_tuple(number{}, number{}), + {0, 0}); + return o_window_; + }(); + + // do compute yeah + Pipeline{}(a_window, + g_window, + d_window, + o_window, + topk_weight, + smem, + kargs.hidden_size, + kargs.intermediate_size, + kargs.stride_token); + } + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp new file mode 100644 index 0000000000..4f3f8bb7d3 --- /dev/null +++ b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +/* +tensors: +1. act (A): input feature map +2. gate (G): B matrix for first gemm, output will do activation(Silu) +3. up (U): B matrix for first gemm +4. down (D): B matrix for second gemm + N1 + / \ + +----------+ | + | Down | | + x----------x | + hidden hidden K1 | | | + N0 N0 x----------x | + | +------x-----x------+------x-----x------+ | | | + dim | | Gate | | | Up | | | | | | + contiguous | | | | | | | | | | | + | | | | | | | | | | | + v +------x-----x------+------x-----x------+ +----------+ V + K0 | | | | | contiguous + / \ v v v v | + +---------+ +------x-----x------+------x-----x------+ | +M0 | A | | | | | | | | | + +---------+ +------x-----x------+------x-----x------+ | + ----------> | | | + contiguous | V V + | x-----x +----------+ + +------------> M1 | Y | ---------> | Out(O) | + ACT x-----x +----------+ + K1 = N0 dim + +* Note: Act could be Gelu/Silu/... +* Note: some model does not have Up +*/ +template +struct FusedMoeGemmShape +{ + using BlockTile_0 = remove_cvref_t; + using WarpPerBlock_0 = remove_cvref_t; + using WarpTile_0 = remove_cvref_t; + using BlockTile_1 = remove_cvref_t; + using WarpPerBlock_1 = remove_cvref_t; + using WarpTile_1 = remove_cvref_t; + + static constexpr index_t NumWarps = + reduce_on_sequence(WarpPerBlock_0{}, multiplies{}, number<1>{}); + + // TODO: we don't support half warps aound to 1 warp here + static_assert(NumWarps == reduce_on_sequence(WarpPerBlock_1{}, multiplies{}, number<1>{})); + + static constexpr index_t Block_M0 = BlockTile_0::at(number<0>{}); + static constexpr index_t Block_N0 = BlockTile_0::at(number<1>{}); + static constexpr index_t Block_K0 = BlockTile_0::at(number<2>{}); + static constexpr index_t WarpPerBlock_M0 = WarpPerBlock_0::at(number<0>{}); + static constexpr index_t WarpPerBlock_N0 = WarpPerBlock_0::at(number<1>{}); + static constexpr index_t WarpPerBlock_K0 = WarpPerBlock_0::at(number<2>{}); + static constexpr index_t Warp_M0 = WarpTile_0::at(number<0>{}); + static constexpr index_t Warp_N0 = WarpTile_0::at(number<1>{}); + static constexpr index_t Warp_K0 = WarpTile_0::at(number<2>{}); + + static constexpr index_t ThreadPerBlock_M0 = Warp_M0 * WarpPerBlock_M0; + static constexpr index_t ThreadPerBlock_N0 = Warp_N0 * WarpPerBlock_N0; + static constexpr index_t ThreadPerBlock_K0 = Warp_K0 * WarpPerBlock_K0; + static_assert(Block_M0 % ThreadPerBlock_M0 == 0); + static_assert(Block_N0 % ThreadPerBlock_N0 == 0); + static_assert(Block_K0 % ThreadPerBlock_K0 == 0); + static constexpr index_t Repeat_M0 = Block_M0 / ThreadPerBlock_M0; + static constexpr index_t Repeat_N0 = Block_N0 / ThreadPerBlock_N0; + static constexpr index_t Repeat_K0 = Block_K0 / ThreadPerBlock_K0; + + static constexpr index_t Block_M1 = BlockTile_1::at(number<0>{}); + static constexpr index_t Block_N1 = BlockTile_1::at(number<1>{}); + static constexpr index_t Block_K1 = BlockTile_1::at(number<2>{}); + static constexpr index_t WarpPerBlock_M1 = WarpPerBlock_1::at(number<0>{}); + static constexpr index_t WarpPerBlock_N1 = WarpPerBlock_1::at(number<1>{}); + static constexpr index_t WarpPerBlock_K1 = WarpPerBlock_1::at(number<2>{}); + static constexpr index_t Warp_M1 = WarpTile_1::at(number<0>{}); + static constexpr index_t Warp_N1 = WarpTile_1::at(number<1>{}); + static constexpr index_t Warp_K1 = WarpTile_1::at(number<2>{}); + + static constexpr index_t ThreadPerBlock_M1 = Warp_M1 * WarpPerBlock_M1; + static constexpr index_t ThreadPerBlock_N1 = Warp_N1 * WarpPerBlock_N1; + static constexpr index_t ThreadPerBlock_K1 = Warp_K1 * WarpPerBlock_K1; + static_assert(Block_M1 % ThreadPerBlock_M1 == 0); + static_assert(Block_N1 % ThreadPerBlock_N1 == 0); + static_assert(Block_K1 % ThreadPerBlock_K1 == 0); + static constexpr index_t Repeat_M1 = Block_M1 / ThreadPerBlock_M1; + static constexpr index_t Repeat_N1 = Block_N1 / ThreadPerBlock_N1; + static constexpr index_t Repeat_K1 = Block_K1 / ThreadPerBlock_K1; + + static constexpr index_t BlockSize = warpSize * NumWarps; + + // some assert + static_assert(Block_M0 == Block_M1); + static_assert(Block_N0 == Block_K1 || (Block_N0 / 2) == Block_K1); // Gate Only or Gate+Up + + // pre-shuffle tile size compute (assume only for B matrix) + // we flatten the each wave tile to a 1d linear tensor(at model loading time) + // e.g. originally we have Block_N*Block_K tile size, after pre-shuffle + // we can have Block_Nr*Block_Kr*Block_W, where Block_W is Warp_N*Warp_K, + // and Block_Nr=Block_N/Warp_N, Block_Kr=Block_K/Warp_K + static constexpr index_t Block_W0 = Warp_N0 * Warp_K0; + static constexpr index_t Block_Nr0 = Block_N0 / Warp_N0; + static constexpr index_t Block_Kr0 = Block_K0 / Warp_K0; + static constexpr index_t Block_W1 = Warp_N1 * Warp_K1; + static constexpr index_t Block_Nr1 = Block_N1 / Warp_N1; + static constexpr index_t Block_Kr1 = Block_K1 / Warp_K1; + + static_assert(Block_W0 == Block_W1); + // static_assert(Block_Nr0 == Block_Kr1); +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp new file mode 100644 index 0000000000..381edb650d --- /dev/null +++ b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck_tile { + +template +struct FusedMoeGemmTilePartitioner_Linear +{ + // FusedMoeGemmShape + using BlockShape = ck_tile::remove_cvref_t; + + static constexpr const char* name = "lin"; + + CK_TILE_DEVICE auto operator()(ck_tile::index_t /*num_sorted_tiles*/, + ck_tile::index_t /*intermediate_size*/) + { + index_t i_n = blockIdx.x; + index_t i_m = blockIdx.y; + + return ck_tile::make_tuple(i_m, i_n); + } + + CK_TILE_HOST static constexpr auto GridSize(index_t max_tokens, index_t intermediate_size) + { + // TODO: this may need tuning + index_t ms = ck_tile::integer_divide_ceil(max_tokens, BlockShape::Block_M0); + index_t ns = ck_tile::integer_divide_ceil(intermediate_size, BlockShape::Block_N0); + return dim3(ns, ms, 1); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp new file mode 100644 index 0000000000..e9577e2304 --- /dev/null +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp @@ -0,0 +1,651 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp" + +namespace ck_tile { + +/* +This pipeline deal with a gemm(actually 2 gemm) with one very small(token), one very big(weight) +we need to design the pipeline such that all waves along gemm-N dim (gemm-m only 1 wave) + + <----- gemm-N ------> + +----+----+----+----+ + | w0 | w1 | w2 | w3 | gemm-m + +----+----+----+----+ +*/ +template +struct FusedMoeGemmPipeline_FlatmmEx +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + + using BlockShape = typename Problem::BlockShape; // this is FusedMoeGemmShape + + using ADataType = typename Problem::ADataType; + using GDataType = typename Problem::GDataType; + using DDataType = typename Problem::DDataType; + using AccDataType = typename Problem::AccDataType; + using ODataType = typename Problem::ODataType; + using AScaleDataType = typename Problem::AScaleDataType; + using GScaleDataType = typename Problem::GScaleDataType; + using DScaleDataType = typename Problem::DScaleDataType; + using YSmoothScaleDataType = typename Problem::YSmoothScaleDataType; + using TopkWeightDataType = typename Problem::TopkWeightDataType; + using IndexDataType = typename Problem::IndexDataType; + using YDataType = typename Problem::YDataType; + + using Traits = typename Problem::Traits; + + static constexpr bool IsGateOnly = Traits::IsGateOnly; + static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant; + static constexpr bool PadHiddenSize = Traits::PadHiddenSize; + static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize; + + static constexpr index_t kAlignmentA = Policy::template GetAlignment_A(); + static constexpr index_t kAlignmentG = Policy::template GetAlignment_G(); + static constexpr index_t kAlignmentD = Policy::template GetAlignment_D(); + static constexpr index_t kAlignmentO = Policy::template GetAlignment_O(); + + static constexpr index_t SLD_A = static_cast(FusedMoeGemmPipelineSequencerEnum::SLD_A); + static constexpr index_t GLD_A = static_cast(FusedMoeGemmPipelineSequencerEnum::GLD_A); + static constexpr index_t GLD_B = static_cast(FusedMoeGemmPipelineSequencerEnum::GLD_B); + static constexpr index_t GST_O = static_cast(FusedMoeGemmPipelineSequencerEnum::GST_O); + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + // minimize occupancy + return 2; + } + }(); + + static constexpr const char* name = "fused_moe_flatmm"; + + // TODO: there are multiple buffers + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_A() + { + return Policy::template GetSmemSize_A(); + } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + // this is the thread-offset along row/col + CK_TILE_HOST_DEVICE static auto GetACoord() + { + constexpr auto a_dist = Policy::template MakeGlobalTileDistribution_A(); + const auto a_coord = a_dist.calculate_index(); + return a_coord; + } + + // this is the thread-offset along row/col + CK_TILE_HOST_DEVICE static auto GetOCoord() + { + constexpr auto o_dist = Policy::template MakeOGlobalTileDistribution(); + const auto o_coord = o_dist.calculate_index(); + return o_coord; + } + + template + CK_TILE_DEVICE auto operator()(const AWindow& a_window_, + const GWindow& g_window_, + const DWindow& d_window_, + OWindow& o_window_, + TopkWeightDataType /*topk_weight*/, + CK_TILE_LDS_ADDR void* smem, + index_t hidden_size, + index_t intermediate_size) + { + _Pragma("clang diagnostic push") _Pragma("clang diagnostic ignored \"-Wc++20-extensions\""); + constexpr auto NEG1 = number<-1>{}; + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + constexpr auto TRUE = bool_constant{}; + constexpr auto FALSE = bool_constant{}; + + CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast(smem); + CK_TILE_LDS_ADDR ADataType* smem_1 = reinterpret_cast( + reinterpret_cast(smem) + + Policy::template GetSmemSize_A()); + + auto g_view = g_window_.get_bottom_tensor_view(); + + auto u_view = [&]() { + if constexpr(IsGateOnly) + { + return g_view; + } + else + { + index_t nr_0 = intermediate_size / BlockShape::Block_Nr0; + index_t kr_0 = hidden_size / BlockShape::Block_Kr0; + + const GDataType* g_ptr = + g_window_.get_bottom_tensor_view().get_buffer_view().p_data_; + const GDataType* u_ptr = g_ptr + (nr_0 / 2) * kr_0 * number{}; + + const auto u_view_ = make_naive_tensor_view( + u_ptr, + make_tuple(nr_0, kr_0, number{}), + make_tuple(kr_0 * BlockShape::Block_W0, number{}, 1), + number{}, + number<1>{}); + const auto u_view_1_ = + pad_tensor_view(u_view_, + make_tuple(number{}, + number{}, + number{}), + sequence{}); + return u_view_1_; + } + }(); + + auto a_win = make_tile_window_linear( + a_window_, Policy::template MakeGlobalTileDistribution_A()); + auto g_win = + make_tile_window_linear(g_window_, + Policy::template MakeGlobalTileDistribution_G(), + sequence<0, 1, 1>{}); + auto d_win = + make_tile_window_linear(d_window_, + Policy::template MakeGlobalTileDistribution_D(), + sequence<0, 1, 1>{}); + auto o_win = make_tile_window_linear( + o_window_, Policy::template MakeGlobalTileDistribution_O()); + + using g_thread_type = decltype(load_tile(g_win)); + using d_thread_type = decltype(load_tile(d_win)); + + using WarpGemm0 = decltype(Policy::template GetWarpGemm0()); + using WarpGemm1 = decltype(Policy::template GetWarpGemm1()); + auto warp_gemm_0 = WarpGemm0{}; + auto warp_gemm_1 = WarpGemm1{}; + + // issues_warps_lanes + auto a_sst_win0 = + make_tile_window(make_tensor_view( + smem_0, Policy::template MakeLdsStoreDesc_A()), + Policy::template MakeLdsStoreDesc_A().get_lengths(), + {0, 0, 0}); + + auto a_sst_win1 = + make_tile_window(make_tensor_view( + smem_1, Policy::template MakeLdsStoreDesc_A()), + Policy::template MakeLdsStoreDesc_A().get_lengths(), + {0, 0, 0}); + // m*k + auto a_sld_win0 = [&]() { + using WG = WarpGemm0; + constexpr auto a_outer_dstr_enc = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_outer_dstr_enc, typename WG::AWarpDstrEncoding{}); + return make_tile_window_linear( + make_tensor_view( + smem_0, Policy::template MakeLdsLoadDesc_A()), + Policy::template MakeLdsLoadDesc_A().get_lengths(), + {0, 0}, + make_static_tile_distribution(a_block_dstr_encode)); + }(); + + // m*k + auto a_sld_win1 = [&]() { + using WG = WarpGemm0; + constexpr auto a_outer_dstr_enc = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_outer_dstr_enc, typename WG::AWarpDstrEncoding{}); + return make_tile_window_linear( + make_tensor_view( + smem_1, Policy::template MakeLdsLoadDesc_A()), + Policy::template MakeLdsLoadDesc_A().get_lengths(), + {0, 0}, + make_static_tile_distribution(a_block_dstr_encode)); + }(); + + auto bridge_sst_win = [&]() { + return make_tile_window( + make_tensor_view( + reinterpret_cast(smem), + Policy::template MakeBridgeLdsStoreDesc()), + Policy::template MakeBridgeLdsStoreDesc().get_lengths(), + {0, 0}); + }(); + + auto bridge_sld_win = [&]() { + return make_tile_window_linear( + make_tensor_view( + reinterpret_cast(smem), + Policy::template MakeBridgeLdsLoadDesc()), + Policy::template MakeBridgeLdsLoadDesc().get_lengths(), + {0, 0}, + Policy::template MakeYTileDistribution()); + }(); + + // also OK with C array, 2 register buffer + statically_indexed_array gs; + + constexpr auto issues_a = number{}; + constexpr auto issues_g = number{}; + // constexpr auto issues_d = number{}; + // constexpr auto issues_o = number{}; + constexpr auto issues_gemm0 = + number{}; + constexpr auto issues_gemm1 = + number{}; + // constexpr auto issues_sld_a = number{}; + + const index_t num_blocks_k0 = + (hidden_size + BlockShape::Block_K0 - 1) / BlockShape::Block_K0; + const index_t num_blocks_n1 = + (hidden_size + BlockShape::Block_N1 - 1) / BlockShape::Block_N1; + + using a_thread_type = decltype(load_tile(a_sld_win0)); + statically_indexed_array as; + + auto gld_a = [&]>( + auto& a_store_, auto i_access, PreNop = {}) + { + async_load_tile_raw(a_store_, a_win, i_access, PreNop{}); + }; + auto move_a = [&]() { + move_tile_window(a_win, {number<0>{}, number{}}); + }; + auto sld_a = [&](auto& a_, auto& win_, auto i_access) { + load_tile_raw(a_, win_, i_access); + }; + + auto gld_g = [&]>( + auto& g_, auto i_access, PreNop = {}) + { + if constexpr(IsGateOnly) + { + // TODO: hack! + if constexpr(i_access.value == 0) + { + g_win.bottom_tensor_view_ = g_view; + } + else if constexpr(i_access.value == issues_g / 2) + { + g_win.bottom_tensor_view_ = u_view; + } + } + load_tile_raw(g_, g_win, i_access, FALSE, PreNop{}); + }; + auto move_g = [&]() { + move_tile_window(g_win, {number<0>{}, number{}, number<0>{}}); + }; + statically_indexed_array ds; + + auto gld_d = [&]>( + auto& d_, auto i_access, PreNop = {}) + { + load_tile_raw(d_, d_win, i_access, FALSE, PreNop{}); + }; + auto move_d = [&]() { + // d move along gemm-n + move_tile_window(d_win, {number{}, number<0>{}}); + }; + + auto atomic_add_o = [&]>( + auto& o_, auto i_access, PreNop = {}) + { + update_tile_raw(o_win, o_, i_access, TRUE, PreNop{}); + }; + + auto acc_0 = Policy::template MakeCBlockTile_Gemm0(); + auto acc_1s = generate_tuple( + [&](auto) { return Policy::template MakeCBlockTile_Gemm1(); }, number<2>{}); + + // clang-format off + auto gemm_0 = [&]> + (auto& t_c, auto& t_a, auto& t_b, auto i_access, PostNop = {}) { + using WarpGemm = remove_cvref_t; + + constexpr auto repeat_sub = WarpGemm::get_num_of_access(); + constexpr auto repeat_m = BlockShape::Repeat_M0; + // constexpr auto repeat_n = BlockShape::Repeat_N0; + constexpr auto repeat_k = BlockShape::Repeat_K0; + // loop order n->m->k + constexpr auto i_sub = i_access % repeat_sub; + constexpr auto i_k = (i_access / repeat_sub) % repeat_k; + constexpr auto i_m = (i_access / (repeat_sub * repeat_k )) % repeat_m; + constexpr auto i_n = (i_access / (repeat_sub * repeat_k )) / repeat_m; + + using AWarpTensor = typename WarpGemm::AWarpTensor; + using BWarpTensor = typename WarpGemm::BWarpTensor; + using CWarpTensor = typename WarpGemm::CWarpTensor; + using AWarpDstr = typename WarpGemm::AWarpDstr; + using BWarpDstr = typename WarpGemm::BWarpDstr; + using CWarpDstr = typename WarpGemm::CWarpDstr; + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto b_warp_y_lengths = to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + AWarpTensor w_a; + w_a.get_thread_buffer() = t_a.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + BWarpTensor w_b; + w_b.get_thread_buffer() = t_b.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + CWarpTensor w_c; + w_c.get_thread_buffer() = t_c.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + warp_gemm_0(w_c, w_a, w_b, number{}, PostNop{}); + + t_c.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + w_c.get_thread_buffer()); + }; + // clang-format on + + // clang-format off + auto gemm_1 = [&]> + (auto& t_c, auto& t_a, auto& t_b, auto i_access, PostNop = {}) { + using WarpGemm = remove_cvref_t; + + constexpr auto repeat_sub = WarpGemm::get_num_of_access(); + constexpr auto repeat_m = BlockShape::Repeat_M0; + // constexpr auto repeat_n = BlockShape::Repeat_N0; + constexpr auto repeat_k = BlockShape::Repeat_K0; + // loop order n->m->k + constexpr auto i_sub = i_access % repeat_sub; + constexpr auto i_k = (i_access / repeat_sub) % repeat_k; + constexpr auto i_m = (i_access / (repeat_sub * repeat_k )) % repeat_m; + constexpr auto i_n = (i_access / (repeat_sub * repeat_k )) / repeat_m; + + using AWarpTensor = typename WarpGemm::AWarpTensor; + using BWarpTensor = typename WarpGemm::BWarpTensor; + using CWarpTensor = typename WarpGemm::CWarpTensor; + using AWarpDstr = typename WarpGemm::AWarpDstr; + using BWarpDstr = typename WarpGemm::BWarpDstr; + using CWarpDstr = typename WarpGemm::CWarpDstr; + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto b_warp_y_lengths = to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + AWarpTensor w_a; + w_a.get_thread_buffer() = t_a.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + BWarpTensor w_b; + w_b.get_thread_buffer() = t_b.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + CWarpTensor w_c; + w_c.get_thread_buffer() = t_c.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + warp_gemm_1(w_c, w_a, w_b, number{}, PostNop{}); + + t_c.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + w_c.get_thread_buffer()); + }; + // clang-format on + _Pragma("clang diagnostic pop"); + + // this gemm pipeline is designed with assumption that issues of buffer-load/ds_read can + // be hide under mfma. In other words, issues of mfma is >= memory this is true if we + // pre-shuffle B matrix, and A matrix is relatively small we prefer use multiple mfma + // paired with 1 buffer-load B matrix, to get max throughput of buffer_load. and by + // preshuffle, we always pack to dwordx4 load, and this will already extend to multiple + // mfma but that is already consumed inside warpgemm-impl. So indeed how many extra + // mfma(that can reuse the B matrix) only affected by M repeat. + auto pipeline_gemm0 = [&]() { + constexpr index_t total_loops = issues_gemm0; + constexpr auto sr = Policy::template GetSequencer_0(); + static_assert(sr.size() == total_loops); + + constexpr auto c_sld_a_0 = MAKE_SC(); + constexpr auto c_gld_a_0 = MAKE_SC(); + constexpr auto c_gld_b_0 = MAKE_SC(); + // compute buffer 1 + static_for<0, total_loops, 1>{}([&](auto i_issue) { + gemm_0(acc_0, as[I0], gs[I0], i_issue); + constexpr index_t slot = sr.at(i_issue); + + if constexpr(slot & SLD_A) + sld_a(as[I1], a_sld_win1, number{}); + if constexpr(slot & GLD_A) + gld_a(a_sst_win0, number{}); + if constexpr(slot & GLD_B) + gld_g(gs[I0], number{}); + }); + move_g(); + move_a(); + block_sync_load_raw(issues_a + issues_g); + lds_load_fence(); + + constexpr auto c_sld_a_1 = MAKE_SC(); + constexpr auto c_gld_a_1 = MAKE_SC(); + constexpr auto c_gld_b_1 = MAKE_SC(); + + // compute buffer 1 + static_for<0, total_loops, 1>{}([&](auto i_issue) { + gemm_0(acc_0, as[I1], gs[I1], i_issue); + constexpr index_t slot = sr.at(i_issue); + + if constexpr(slot & SLD_A) + sld_a(as[I0], a_sld_win0, number{}); + if constexpr(slot & GLD_A) + gld_a(a_sst_win1, number{}); + if constexpr(slot & GLD_B) + gld_g(gs[I1], number{}); + }); + move_g(); + move_a(); + block_sync_load_raw(issues_a + issues_g); + lds_load_fence(); + }; + + auto pipeline_gemm0_tail = [&]() { + constexpr index_t total_loops = issues_gemm0; + constexpr auto sr = Policy::template GetSequencer_0(); + static_assert(sr.size() == total_loops); + + constexpr auto c_gld_b_0 = MAKE_SC(); + + // compute buffer 0 + static_for<0, total_loops, 1>{}([&](auto i_issue) { + gemm_0(acc_0, as[I0], gs[I0], i_issue); + constexpr index_t slot = sr.at(i_issue); + + if constexpr(slot & GLD_B) + gld_g(gs[I1], number{}); + }); + + block_sync_load_raw(issues_g); + sld_a(as[I1], a_sld_win1, NEG1); + + // compute buffer 1 + static_for<0, total_loops, 1>{}([&](auto i_issue) { + constexpr auto last_nop = [&]() { + if constexpr(i_issue == (total_loops - 1)) + return TRUE; + else + return FALSE; + }(); + gemm_0(acc_0, as[I1], gs[I1], i_issue, last_nop); // last gemm has nop + }); + }; + + auto y = Policy::template MakeYBlockTile(); + + auto pipeline_bridge = [&]() { + // cast to Y data + auto y_pre = cast_tile(acc_0); + store_tile(bridge_sst_win, y_pre); + clear_tile(acc_1s(I0)); + // wave_barrier(); + load_tile(y, bridge_sld_win); + clear_tile(acc_1s(I1)); + }; + + // note, gemm-1 start from idx-1 to N-2 (0, 1, 2....N-1) + auto pipeline_gemm1 = [&]() { + constexpr index_t total_loops = issues_gemm1; + constexpr auto sr = Policy::template GetSequencer_1(); + static_assert(sr.size() == total_loops); + + constexpr auto c_gld_b_0 = MAKE_SC(); + constexpr auto c_gst_o_0 = MAKE_SC(); + constexpr auto c_gld_b_1 = MAKE_SC(); + constexpr auto c_gst_o_1 = MAKE_SC(); + + // compute buffer 0 + static_for<0, total_loops, 1>{}([&](auto i_issue) { + gemm_1(acc_1s[I1], y, ds[I1], i_issue); + constexpr index_t slot = sr.at(i_issue); + if constexpr(slot & GLD_B) + gld_d(ds[I0], number{}); + + if constexpr(slot & GST_O) + { + auto out = cast_tile(acc_1s[I0]); + atomic_add_o(out, number{}); + } + }); + move_d(); + // move_o(); + + // compute buffer 1 + static_for<0, total_loops, 1>{}([&](auto i_issue) { + gemm_1(acc_1s[I0], y, ds[I0], i_issue); + constexpr index_t slot = sr.at(i_issue); + if constexpr(slot & GLD_B) + gld_d(ds[I1], number{}); + + if constexpr(slot & GST_O) + { + auto out = cast_tile(acc_1s[I1]); + atomic_add_o(out, number{}); + } + }); + move_d(); + }; + + auto pipeline_gemm1_head = [&]() { + constexpr index_t total_loops = issues_gemm1; + constexpr auto sr = Policy::template GetSequencer_1(); + static_assert(sr.size() == total_loops); + + constexpr auto c_gld_b_0 = MAKE_SC(); + + // compute buffer 0 + static_for<0, total_loops, 1>{}([&](auto i_issue) { + gemm_1(acc_1s[I0], y, ds[I0], i_issue); + constexpr index_t slot = sr.at(i_issue); + if constexpr(slot & GLD_B) + gld_d(ds[I1], number{}); + }); + move_d(); + }; + auto pipeline_gemm1_tail = [&]() { + constexpr index_t total_loops = issues_gemm1; + constexpr auto sr = Policy::template GetSequencer_1(); + static_assert(sr.size() == total_loops); + + constexpr auto c_gst_o_0 = MAKE_SC(); + + // compute buffer 1 + static_for<0, total_loops, 1>{}([&](auto i_issue) { + gemm_1(acc_1s[I1], y, ds[I1], i_issue); + + constexpr index_t slot = sr.at(i_issue); + if constexpr(slot & GST_O) + { + auto out = cast_tile(acc_1s[I0]); + atomic_add_o(out, number{}); + } + }); + { + auto out = cast_tile(acc_1s[I1]); + atomic_add_o(out, NEG1); + } + }; + + // start of pipeline + // clang-format off + gld_a(a_sst_win0, NEG1, TRUE); + gld_g(gs[I0], NEG1, TRUE); + move_a(); + move_g(); + clear_tile(acc_0); + + // preload for next round + gld_a(a_sst_win1, NEG1); + gld_g(gs[I1], NEG1); + + // make sure a,g loaded + block_sync_load_raw(issues_a + issues_g); + lds_load_fence(); + + // we manually unroll double buffer inside hot loop + const index_t iters_0 = (num_blocks_k0 - 2) / 2; + index_t i_0 = 0; // (void)i_0; (void)iters_0; (void)pipeline_gemm0; + while(i_0++ < iters_0) + { + pipeline_gemm0(); + } + pipeline_gemm0_tail(); + + pipeline_bridge(); + + const index_t iters_1 = (num_blocks_n1 - 2) / 2; + index_t i_1 = 0; // (void) i_1; (void)iters_1; (void)pipeline_gemm1; + pipeline_gemm1_head(); + while(i_1++ < iters_1) + { + pipeline_gemm1(); + } + pipeline_gemm1_tail(); + // clang-format on + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp new file mode 100644 index 0000000000..fea30f0297 --- /dev/null +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp @@ -0,0 +1,831 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp" +#include "ck_tile/ops/flatmm.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" + +namespace ck_tile { + +struct FusedMoeGemmPipelineFlatmmPolicy +{ + CK_TILE_HOST_DEVICE static constexpr index_t GetAsyncCopyDwords() + { + // TODO: always 1 dword + return 1; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_A() + { + // using async + constexpr index_t copy_bytes = 4 * GetAsyncCopyDwords(); + constexpr index_t data_bytes = sizeof(typename Problem::ADataType); + static_assert(copy_bytes % data_bytes == 0); + return copy_bytes / data_bytes; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_G() + { + constexpr index_t copy_bytes = [&]() { return 16; }(); + constexpr index_t data_bytes = sizeof(typename Problem::GDataType); + static_assert(copy_bytes % data_bytes == 0); + return copy_bytes / data_bytes; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_D() + { + constexpr index_t copy_bytes = [&]() { return 16; }(); + constexpr index_t data_bytes = sizeof(typename Problem::DDataType); + static_assert(copy_bytes % data_bytes == 0); + return copy_bytes / data_bytes; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_O() + { + if constexpr(Problem::Traits::OAtomic == 1) + { + // pack fp16/bf16 atomic + static_assert(sizeof(typename Problem::ODataType) == 2); + return 2; + } + else if constexpr(Problem::Traits::OAtomic == 2) + { + // fp32 atomic + return 1; + } + else + { + return 16 / sizeof(typename Problem::ODataType); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack() + { + // TODO: this is for 3d layout + return 16 / sizeof(remove_cvref_t); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack_A() + { + return GetSmemKPack(); + } + + // used for bridge LDS shuffle + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack_Y() + { + // TODO: this should match mfma layout + return 16 / sizeof(typename Problem::YDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_A() + { + constexpr auto a_sld_desc = MakeLdsLoadDesc_A(); + constexpr auto a_sst_desc = MakeLdsStoreDesc_A(); + static_assert(a_sld_desc.get_element_space_size() == a_sst_desc.get_element_space_size()); + return a_sld_desc.get_element_space_size(); + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_Bridge() + { + constexpr auto bridge_sld_desc = MakeBridgeLdsLoadDesc(); + constexpr auto bridge_sst_desc = MakeBridgeLdsStoreDesc(); + static_assert(bridge_sld_desc.get_element_space_size() == + bridge_sst_desc.get_element_space_size()); + return bridge_sld_desc.get_element_space_size(); + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + constexpr index_t a_lds = GetSmemSize_A(); + constexpr index_t bridge_lds = GetSmemSize_Bridge(); + return max(a_lds, bridge_lds); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_SimpleMxK() + { + constexpr index_t K_vec = Alignment; + constexpr index_t K_rem = KPerBlock / K_vec; + + if constexpr(get_warp_size() < K_rem) + { + static_assert(K_rem % get_warp_size() == 0); + constexpr index_t K_lan = get_warp_size(); // lane within same wave is along gemm-k + constexpr index_t K_wav = K_rem / get_warp_size(); + static_assert(K_wav <= NumWarps, "not not support thread has repeat along K yet"); + constexpr index_t M_wav = NumWarps / K_wav; + static_assert(MPerBlock % M_wav == 0, "this tile size is too small please check"); + constexpr index_t M_rep = MPerBlock / M_wav; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>, + tuple, sequence>, + tuple, sequence<2>>, + tuple, sequence<1>>, + sequence<1, 2>, + sequence<0, 2>>{}); + } + else + { + constexpr index_t K_lan = K_rem; + constexpr index_t M_lan = get_warp_size() / K_lan; + constexpr index_t M_wav = NumWarps; + static_assert(MPerBlock % (M_lan * M_wav) == 0, + "this tile size is too small please check"); + constexpr index_t M_rep = MPerBlock / (M_lan * M_wav); + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + } + + // optimized version for async, not same as simple MXK dist(pay attention!!) + template + CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_SimpleMxK_Async() + { + constexpr index_t K_vec = Alignment; + constexpr index_t K_rem = KPerBlock / K_vec; + + if constexpr(get_warp_size() <= K_rem) + { + static_assert(K_rem % get_warp_size() == 0); + constexpr index_t K_lan = get_warp_size(); // lane within same wave is along gemm-k + constexpr index_t K_wav = K_rem / get_warp_size(); + static_assert(K_wav <= NumWarps, "do not support thread has repeat along K yet"); + constexpr index_t M_wav = NumWarps / K_wav; + static_assert(MPerBlock % M_wav == 0, "this tile size is too small please check"); + constexpr index_t M_rep = MPerBlock / M_wav; + // NOTE: no swap, but hard to avoid LDS bank conflict + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>, + tuple, sequence>, + tuple, sequence<2>>, + tuple, sequence<1>>, + sequence<1, 2>, + sequence<0, 2>>{}); + } + else + { + constexpr index_t K_lan = K_rem; + constexpr index_t M_lan = get_warp_size() / K_lan; + constexpr index_t M_wav = NumWarps; + static_assert(MPerBlock % (M_lan * M_wav) == 0, + "this tile size is too small please check"); + constexpr index_t M_rep = MPerBlock / (M_lan * M_wav); + // NOTE: swapped for LDS load bank conflict free + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>, + // Note M_wave(num waves) is the fastest dim, different from sipmle 2d + // distribution + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_Nr_Kr_W() + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence, + sequence>, + tuple, sequence<3>>, + tuple, sequence<0>>, + sequence<1, 2, 3>, + sequence<0, 0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_A() + { + constexpr index_t Block_M_ = Problem::BlockShape::Block_M0; + constexpr index_t Block_K_ = Problem::BlockShape::Block_K0; + constexpr index_t NumWarps_ = Problem::BlockShape::NumWarps; + constexpr index_t Alignment_ = GetAlignment_A(); + return MakeGlobalTileDistribution_SimpleMxK_Async(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_G() + { + constexpr auto PermuteEnum = Problem::Traits::PermuteEnum; + // constexpr index_t hidden_radio_0 = Problem::Traits::IsGateOnly ? 1 : 2; + using S_ = typename Problem::BlockShape; + if constexpr(PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten) + { + // number{}.rrr(); + // number{}.eee(); + return MakeGlobalTileDistribution_Nr_Kr_W()>(); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_D() + { + constexpr auto PermuteEnum = Problem::Traits::PermuteEnum; + using S_ = typename Problem::BlockShape; + if constexpr(PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten) + { + return MakeGlobalTileDistribution_Nr_Kr_W()>(); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_O() + { + using S_ = remove_cvref_t; + using WarpGemm = remove_cvref_t())>; + // using CDataType = typename WarpGemm::CDataType; + + constexpr auto c_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + return c_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A() + { + // A async->LDS + constexpr index_t Block_M = Problem::BlockShape::Block_M0; + constexpr index_t Block_K = Problem::BlockShape::Block_K0; + // constexpr index_t BlockSize = Problem::BlockShape::BlockSize; + constexpr index_t warpSize = ck_tile::get_warp_size(); + constexpr index_t NumWarps = Problem::BlockShape::NumWarps; + + constexpr index_t KPack = GetSmemKPack_A(); // LDS + constexpr index_t KVector = GetAlignment_A(); // async copy 1 dword + constexpr index_t KPad = KPack; // pad between warps + + static_assert(Block_K % KVector == 0); + constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K + if constexpr(LanesPerK >= warpSize) + { + // need multiple waves to load K + static_assert(LanesPerK % warpSize == 0); + constexpr index_t wavesPerK = LanesPerK / warpSize; + if constexpr(wavesPerK > NumWarps) + { + // TODO: need multiple issues along K to load all data + } + else + { + constexpr index_t wavesPerM = NumWarps / wavesPerK; + constexpr index_t NumIssues = Block_M / wavesPerM; + constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // k0 + number{}, // k1 + number{}), // k2 + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // k0 + number{}, // k1 + number<1>{}), // k2 + number{}, // lds store vector(actually no explicit store) + number<1>{}); + + constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( + lds_block_desc_0, + make_tuple( + make_pass_through_transform(number{}), + make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + return lds_block_desc_issues_warps_lanes; + } + } + else + { + // lanes within a wave load different M but same K + static_assert(warpSize % LanesPerK == 0); + constexpr index_t LaneGroups = warpSize / LanesPerK; // along m + constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps); + + constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // m2 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // m2 + number{}, // k0 + number<1>{}), // k1 + number{}, // lds store vector(actually no explicit store) + number<1>{}); + + constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( + lds_block_desc_0, + make_tuple(make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_merge_transform(make_tuple( + number{}, number{}, number{}))), + make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + return lds_block_desc_issues_warps_lanes; + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadDesc_A() + { + // A async->LDS + // Note that, this descriptor is only to construct the layout inside LDS + // in real Gemm pipeline, ds_read may not follow this pattern + // (may follow that in tile_distribution) + // below code is almost the same as SmemStore dist, with difference: + // 1). modify the GuaranteedLastDimensionVectorLength of naive tensor desc + // 2). return discriptor is in NxK 2d layout + constexpr index_t Block_M = Problem::BlockShape::Block_M0; + constexpr index_t Block_K = Problem::BlockShape::Block_K0; + // constexpr index_t BlockSize = Problem::BlockShape::BlockSize; + constexpr index_t warpSize = ck_tile::get_warp_size(); + constexpr index_t NumWarps = Problem::BlockShape::NumWarps; + + constexpr index_t KPack = GetSmemKPack_A(); // LDS + constexpr index_t KVector = GetAlignment_A(); // async copy 1 dword + constexpr index_t KPad = KPack; // pad between warps + + static_assert(Block_K % KVector == 0); + constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K + if constexpr(LanesPerK >= warpSize) + { + // need multiple waves to load K + static_assert(LanesPerK % warpSize == 0); + constexpr index_t wavesPerK = LanesPerK / warpSize; + if constexpr(wavesPerK >= NumWarps) + { + // TODO: need multiple issues along K to load all data + } + else + { + constexpr index_t wavesPerM = NumWarps / wavesPerK; + constexpr index_t NumIssues = Block_M / wavesPerM; + constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // k0 + number{}, // k1 + number{}), // k2 + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // k0 + number{}, // k1 + number<1>{}), // k2 + number{}, // lds load vector + number<1>{}); + + constexpr auto lds_desc_m_k = transform_tensor_descriptor( + lds_block_desc_0, + make_tuple( + make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform(make_tuple( + number{}, number{}, number{}))), + make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return lds_desc_m_k; + } + } + else + { + // lanes within a wave load different M but same K + static_assert(warpSize % LanesPerK == 0); + constexpr index_t LaneGroups = warpSize / LanesPerK; // along m + constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps); + + constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // m2 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // m2 + number{}, // k0 + number<1>{}), // k1 + number{}, // lds load vector + number<1>{}); + + constexpr auto lds_desc_m_k = transform_tensor_descriptor( + lds_block_desc_0, + make_tuple( + make_merge_transform( + make_tuple(number{}, number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<0, 1, 2>{}, sequence<3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return lds_desc_m_k; + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsLoadDesc() + { + constexpr index_t Block_M = Problem::BlockShape::Block_M0; + constexpr index_t Block_N = Problem::BlockShape::Block_N0; + + constexpr index_t KVector = GetSmemKPack_Y(); // async copy 1 dword + constexpr index_t KPad = 0; // pad between warps + + constexpr auto desc = + make_naive_tensor_descriptor(make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number{}, + number<1>{}); + return desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsStoreDesc() + { + constexpr index_t Block_M = Problem::BlockShape::Block_M0; + constexpr index_t Block_N = Problem::BlockShape::Block_N0; + + constexpr index_t KVector = GetSmemKPack_Y(); // async copy 1 dword + constexpr index_t KPad = 0; // KVector; // pad between warps + + constexpr auto desc = + make_naive_tensor_descriptor(make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number{}, + number<1>{}); + return desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsStoreForUKDesc() + { + constexpr index_t WarpPerBlock_N = Problem::BlockShape::WarpPerBlock_N0; + constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N0; + constexpr index_t Repeat_M = Problem::BlockShape::Repeat_M0; + + constexpr index_t kAMLane = 16; + constexpr index_t kABKLane = 4; + constexpr index_t kABKPerLane = 4; + + constexpr index_t KPack = kABKPerLane; + + constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, // m + number{}, // n + number{}, // n + number{}, // n + number{}, // m + number{}), // n + make_tuple(number{}, // m + number{}, // n + number{}, // n + number{}, // n + number{}, // m + number<1>{}), // n + number{}, // lds store vector(actually no explicit store) + number<1>{}); + + constexpr auto desc = transform_tensor_descriptor( + lds_block_desc_0, + make_tuple(make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, + number{}, + number{}, + number{}))), + make_tuple(sequence<0, 4>{}, sequence<1, 2, 3, 5>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm0() + { + using S_ = typename Problem::BlockShape; + // A is vgpr, B is agpr. But since we transposed, so also need swap this + // TODO: this is ugly + constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_avv; + // TODO: ugly + if constexpr(std::is_same_v && + std::is_same_v && + S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16) + { + return WarpGemmImpl, + 2>>{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32) + { + return WarpGemmImpl, + 2>>{}; + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSequencer_0() + { + // this function return seq<...> used to identify gld/sld/valu... inside mfma sequence + // the purpose is to hide thoes instructions under mfma + // every value inside seq<...> is a mask, indicating a specific operation + using S_ = typename Problem::BlockShape; + constexpr index_t SLD_A = static_cast(FusedMoeGemmPipelineSequencerEnum::SLD_A); + constexpr index_t GLD_A = static_cast(FusedMoeGemmPipelineSequencerEnum::GLD_A); + constexpr index_t GLD_B = static_cast(FusedMoeGemmPipelineSequencerEnum::GLD_B); + if constexpr(std::is_same_v && + std::is_same_v && + S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16 && + S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 && + S_::Block_N1 == 128) + { + // Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async + // gld_a 8x ds_read_b128 sld_a total 64 slot :) + // clang-format off + constexpr auto seq_all = + // 0 1 2 3 4 5 6 7 + sequence{}; // 7 + return seq_all; + // clang-format on + } + else if constexpr(std::is_same_v && + std::is_same_v && + S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16 && + S_::Block_M0 == 32 && S_::Block_N0 == 256 && S_::Block_K0 == 128 && + S_::Block_N1 == 128) + { + // Total 32 instructions, 16 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async + // gld_a 8x ds_read_b128 sld_a total 64 slot :) + // clang-format off + constexpr auto seq_all = + // 0 1 2 3 4 5 6 7 + sequence{}; // 3 + return seq_all; + // clang-format on + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSequencer_1() + { + // this function return seq<...> used to identify gld/sld/valu... inside mfma sequence + // the purpose is to hide thoes instructions under mfma + // every value inside seq<...> is a mask, indicating a specific operation + using S_ = typename Problem::BlockShape; + constexpr index_t GLD_B = static_cast(FusedMoeGemmPipelineSequencerEnum::GLD_B); + constexpr index_t GST_O = static_cast(FusedMoeGemmPipelineSequencerEnum::GST_O); + if constexpr(std::is_same_v && + std::is_same_v && + S_::Warp_M1 == 32 && S_::Warp_N1 == 32 && S_::Warp_K1 == 16 && + S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 && + S_::Block_N1 == 128) + { + // Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async + // gld_a 8x ds_read_b128 sld_a total 64 slot :) + // clang-format off + constexpr auto seq_all = + // 0 1 2 3 4 5 6 7 + sequence{}; // 7 + return seq_all; + // clang-format on + } + else if constexpr(std::is_same_v && + std::is_same_v && + S_::Warp_M1 == 32 && S_::Warp_N1 == 32 && S_::Warp_K1 == 16 && + S_::Block_M0 == 32 && S_::Block_N0 == 256 && S_::Block_K0 == 128 && + S_::Block_N1 == 128) + { + // Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async + // gld_a 8x ds_read_b128 sld_a total 64 slot :) + // clang-format off + constexpr auto seq_all = + // 0 1 2 3 4 5 6 7 + sequence{}; // 3 + return seq_all; + // clang-format on + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm1() + { + using S_ = typename Problem::BlockShape; + constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_avv; + // TODO: ugly + if constexpr(std::is_same_v && + std::is_same_v && + S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16) + { + return WarpGemmImpl, + 2>>{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32) + { + return WarpGemmImpl, + 2>>{}; + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeCBlockTile_Gemm0() + { + using S_ = remove_cvref_t; + using WarpGemm = remove_cvref_t())>; + using CDataType = typename WarpGemm::CDataType; + + constexpr auto c_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeCBlockTile_Gemm1() + { + using S_ = remove_cvref_t; + using WarpGemm = remove_cvref_t())>; + using CDataType = typename WarpGemm::CDataType; + + constexpr auto c_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // this is used as A matrix for 2nd gemm + template + CK_TILE_HOST_DEVICE static constexpr auto MakeYTileDistribution() + { + using S_ = remove_cvref_t; + using WarpGemm = remove_cvref_t())>; + + // TODO: all waves a along different N, but same M + constexpr auto y_outer_dstr_enc = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto y_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + y_outer_dstr_enc, typename WarpGemm::AWarpDstrEncoding{}); + constexpr auto y_block_dstr = make_static_tile_distribution(y_block_dstr_encode); + return y_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeYBlockTile() + { + constexpr auto y_block_dstr = MakeYTileDistribution(); + auto y_block_tensor = + make_static_distributed_tensor(y_block_dstr); + return y_block_tensor; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetUK_0() + { + using S_ = typename Problem::BlockShape; + if constexpr(std::is_same_v && + std::is_same_v && + S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 && + S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32) + { + return Flatmm_32x512x128_1x4x1_16x16x32_BF16{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 && + S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32) + { + return Flatmm_32x512x128_1x4x1_16x16x32_FP16{}; + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetUK_1() + { + using S_ = typename Problem::BlockShape; + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v && + S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 && + S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32) + { + return FlatmmSn_32x128x512_1x4x1_16x16x32_BF16{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v && + S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 && + S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32) + { + return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{}; + } + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp new file mode 100644 index 0000000000..a6f71eafac --- /dev/null +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp @@ -0,0 +1,354 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp" + +namespace ck_tile { + +/* +This pipeline deal with a gemm(actually 2 gemm) with one very small(token), one very big(weight) +we need to design the pipeline such that all waves along gemm-N dim (gemm-m only 1 wave) + + <----- gemm-N ------> + +----+----+----+----+ + | w0 | w1 | w2 | w3 | gemm-m + +----+----+----+----+ +*/ +template +struct FusedMoeGemmPipeline_FlatmmUk +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + + using BlockShape = typename Problem::BlockShape; // this is FusedMoeGemmShape + + using ADataType = typename Problem::ADataType; + using GDataType = typename Problem::GDataType; + using DDataType = typename Problem::DDataType; + using AccDataType = typename Problem::AccDataType; + using ODataType = typename Problem::ODataType; + using AScaleDataType = typename Problem::AScaleDataType; + using GScaleDataType = typename Problem::GScaleDataType; + using DScaleDataType = typename Problem::DScaleDataType; + using YSmoothScaleDataType = typename Problem::YSmoothScaleDataType; + using TopkWeightDataType = typename Problem::TopkWeightDataType; + using IndexDataType = typename Problem::IndexDataType; + using YDataType = typename Problem::YDataType; + + using Traits = typename Problem::Traits; + + static constexpr bool IsGateOnly = Traits::IsGateOnly; + static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant; + static constexpr bool PadHiddenSize = Traits::PadHiddenSize; + static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize; + + static constexpr index_t kAlignmentA = Policy::template GetAlignment_A(); + static constexpr index_t kAlignmentG = Policy::template GetAlignment_G(); + static constexpr index_t kAlignmentD = Policy::template GetAlignment_D(); + static constexpr index_t kAlignmentO = Policy::template GetAlignment_O(); + + static constexpr index_t SLD_A = static_cast(FusedMoeGemmPipelineSequencerEnum::SLD_A); + static constexpr index_t GLD_A = static_cast(FusedMoeGemmPipelineSequencerEnum::GLD_A); + static constexpr index_t GLD_B = static_cast(FusedMoeGemmPipelineSequencerEnum::GLD_B); + static constexpr index_t GST_O = static_cast(FusedMoeGemmPipelineSequencerEnum::GST_O); + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + // minimize occupancy + return 2; + } + }(); + + static constexpr const char* name = "flatmm_uk"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + constexpr index_t smem_0 = Policy::template GetUK_0().GetSmemSize(); + constexpr index_t smem_1 = Policy::template GetUK_1().GetSmemSize(); + constexpr index_t smem_bridge = + BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType); + return max(smem_0, max(smem_1, smem_bridge)); + } + + // this is the thread-offset along row/col + CK_TILE_HOST_DEVICE static auto GetACoord() + { + constexpr auto a_dist = Policy::template MakeGlobalTileDistribution_A(); + const auto a_coord = a_dist.calculate_index(); + return a_coord; + } + + // this is the thread-offset along row/col + CK_TILE_HOST_DEVICE static auto GetOCoord() + { + constexpr auto o_dist = Policy::template MakeOGlobalTileDistribution(); + const auto o_coord = o_dist.calculate_index(); + return o_coord; + } + + CK_TILE_DEVICE constexpr auto GetNumRowCoords_A() + { + constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA; + constexpr index_t MLans = BlockShape::BlockSize / KLans; + constexpr index_t MRepeat = BlockShape::Block_M0 / MLans; + + return MRepeat; + } + + // TODO: properlly support scatter/gather + CK_TILE_DEVICE auto GetRowCoords_A(index_t base_offset) + { + constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA; + constexpr index_t MLans = BlockShape::BlockSize / KLans; + constexpr index_t MRepeat = BlockShape::Block_M0 / MLans; + + auto base_coord = threadIdx.x / KLans + base_offset; + + array coords; + static_for<0, MRepeat, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLans; }); + + return coords; + } + + template + CK_TILE_DEVICE auto GetRowID(const ROW_COORDS coords, const IndexDataType* sorted_token_ids_ptr) + { + constexpr index_t n_size = coords.size(); + + array row_ids; + static_for<0, n_size, 1>{}([&](auto i) { + row_ids.at(i) = sorted_token_ids_ptr[coords[i]]; // base_coord + i * MLans; + }); + + return row_ids; + } + + template + CK_TILE_DEVICE auto GetWeightScale(const ROW_COORDS coords, + const TopkWeightDataType* sorted_weight_ptr) + { + constexpr index_t n_size = coords.size(); + + array w; + static_for<0, n_size, 1>{}([&](auto i) { + w.at(i) = sorted_weight_ptr[coords[i]]; // base_coord + i * MLans; + }); + + return w; + } + + // TODO: this row id is before shuffle atomic, need use acc distribution + CK_TILE_DEVICE auto GetRowCoords_O(index_t base_offset) + { + constexpr index_t MLanes = BlockShape::Warp_M1; + constexpr index_t Repeat_M = BlockShape::Repeat_M1; + + auto base_coord = threadIdx.x % MLanes + base_offset; + + array coords; + static_for<0, Repeat_M, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLanes; }); + + return coords; + } + + template + CK_TILE_DEVICE auto operator()(const Karg& kargs, + CK_TILE_LDS_ADDR void* smem, + index_t sorted_tile_id, + index_t intermediate_tile_id) + { + constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2; + ck_tile::index_t shared_intermediate_size_0 = kargs.intermediate_size; + ck_tile::index_t shared_intermediate_size_1 = kargs.intermediate_size / hidden_radio_0; + + index_t nr_0 = shared_intermediate_size_0 / BlockShape::Warp_N0; // divide N in W + index_t kr_0 = kargs.hidden_size / BlockShape::Warp_K0; // divide K in W + index_t nr_1 = kargs.hidden_size / BlockShape::Warp_N1; + index_t kr_1 = shared_intermediate_size_1 / BlockShape::Warp_K1; + + const IndexDataType expert_id = __builtin_amdgcn_readfirstlane( + reinterpret_cast(kargs.sorted_expert_ids_ptr)[sorted_tile_id]); + index_t expert_stride_0 = shared_intermediate_size_0 * kargs.hidden_size; + index_t expert_stride_1 = shared_intermediate_size_1 * kargs.hidden_size; + + // nr*kr*w + index_t interm_idx_nr0 = __builtin_amdgcn_readfirstlane( + intermediate_tile_id * + BlockShape::Block_Nr0); // intermediate_tile_id * Block_N / (N in W) + + index_t interm_idx_kr1 = __builtin_amdgcn_readfirstlane( + intermediate_tile_id * + BlockShape::Block_Kr1); // intermediate_tile_id * Block_N / (N in W) + + auto row_coords_a = GetRowCoords_A(sorted_tile_id * BlockShape::Block_M0); + auto row_ids_a = GetRowID( + row_coords_a, reinterpret_cast(kargs.sorted_token_ids_ptr)); + auto a_coords = generate_tuple( + [&](auto i) { + return row_ids_a[i] * kargs.stride_token + + threadIdx.x % (BlockShape::Block_K0 / kAlignmentA) * kAlignmentA; + }, + number{}); + auto a_res = + make_wave_buffer_resource(reinterpret_cast(kargs.a_ptr), + kargs.num_tokens * kargs.stride_token * sizeof(ADataType)); + + auto g_win = [&]() { + const GDataType* g_ptr = reinterpret_cast(kargs.g_ptr) + + static_cast(expert_id) * expert_stride_0 + + interm_idx_nr0 * kr_0 * BlockShape::Block_W0; + auto g_view_ = make_naive_tensor_view( + g_ptr, + make_tuple(nr_0, kr_0, number{}), + make_tuple(kr_0 * BlockShape::Block_W0, number{}, 1), + number{}, + number<1>{}); + + auto g_window_ = make_tile_window_linear_raw( + g_view_, + make_tuple(number{}, + number{}, + number{}), + {0, 0, 0}, + Policy::template MakeGlobalTileDistribution_G(), + sequence<0, 1, 1>{}); + return g_window_; + }(); + + auto g_res = g_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_; + auto g_coords = generate_tuple([&](auto i) { return g_win.cached_coords_[i].get_offset(); }, + number{}); + + const auto d_win = [&]() { + const DDataType* d_ptr = reinterpret_cast(kargs.d_ptr) + + static_cast(expert_id) * expert_stride_1 + + interm_idx_kr1 * BlockShape::Block_W1; + // note interm_idx_nr0 is along the gemm-k dim of 2nd gemm + + const auto d_view_ = make_naive_tensor_view( + d_ptr, + make_tuple(nr_1, kr_1, BlockShape::Block_W1), + make_tuple(kr_1 * BlockShape::Block_W1, BlockShape::Block_W1, 1), + number{}, + number<1>{}); + + const auto d_window_ = make_tile_window_linear_raw( + d_view_, + make_tuple(number{}, + number{}, + number{}), + {0, 0, 0}, + Policy::template MakeGlobalTileDistribution_D(), + sequence<0, 1, 1>{}); + return d_window_; + }(); + auto d_res = d_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_; + + // TODO: load D order is N0.K0...127, N64.K0...127, N0.K128...255, N64.K128...255 + // block-k=512, block-n=128 + // wg |<----- W_ ----->| + // Nr(2)*Nw(4)* Kr *Kr0(4)*Kr1(4) * [Kl(4)*Nl(16)*Kv(8)]->one issue + // y p y y p p y + // 1 2 0(imm) + auto d_coords = [&]() { + constexpr index_t Nr_ = 2; + constexpr index_t Nw_ = 4; + constexpr index_t Kr0_ = 4; + constexpr index_t Kr1_ = 4; + constexpr index_t Kl_ = 4; + constexpr index_t Nl_ = 16; + constexpr index_t Kv_ = 8; + constexpr index_t W_ = Kl_ * Nl_ * Kv_; + constexpr index_t num_offsets_ = Nr_ * Kr0_; + index_t base_os_ = (threadIdx.x % 64) * Kv_ + (threadIdx.x / 64) * + shared_intermediate_size_1 * + Nl_; // Kr0_ * Kr1_ * W_; + return generate_tuple( + [&](auto i) { + constexpr auto i_nr_ = number{}; + constexpr auto i_kr0_ = number{}; + + return i_nr_ * shared_intermediate_size_1 * Nw_ * Nl_ + i_kr0_ * Kr1_ * W_ + + base_os_; + }, + number{}); + }(); + + auto o_coords = generate_tuple( + [&](auto i) { + return row_ids_a[i] * kargs.stride_token + + threadIdx.x % (BlockShape::Block_N1 / kAlignmentO) * kAlignmentO; + }, + number{}); + + auto o_flags = + generate_tuple([&](auto i) { return cmp_lt_to_exec(row_ids_a[i], kargs.num_tokens); }, + number{}); + + auto bridge_sst_win = [&]() { + constexpr auto desc_ = Policy::template MakeBridgeLdsStoreForUKDesc(); + constexpr auto dist_ = Policy::template GetUK_0().MakeCBlockDist(); + return make_tile_window_linear(make_tensor_view( + reinterpret_cast(smem), desc_), + desc_.get_lengths(), + {0, 0}, + dist_); + }(); + auto o_res = + make_wave_buffer_resource(reinterpret_cast(kargs.o_ptr), + kargs.num_tokens * kargs.stride_token * sizeof(ODataType)); + + auto row_coords_o = GetRowCoords_O(sorted_tile_id * BlockShape::Block_M0); + auto w_scale = GetWeightScale( + row_coords_o, reinterpret_cast(kargs.sorted_weight_ptr)); + + auto uk_0 = Policy::template GetUK_0(); + auto acc_0 = uk_0(a_res, + a_coords, + g_res, + g_coords, + smem, + kargs.hidden_size, + BlockShape::Block_K0, // tile offset for B matrix each unroll + BlockShape::Block_Kr0 * + BlockShape::Block_W0); // tile offset for B matrix each unroll + + sweep_tile( + acc_0, + [&](auto idx0, auto idx1) { + fp32x2_t v_{acc_0(idx0), acc_0(idx1)}; + typename Problem::GateActivation{}(v_, v_); + acc_0(idx0) = v_.x; + acc_0(idx1) = v_.y; + }, + sequence<1, 2>{}); + + auto y_pre = cast_tile(acc_0); + + block_sync_lds(); + + store_tile(bridge_sst_win, y_pre); + block_sync_lds(); + + auto uk_1 = Policy::template GetUK_1(); + uk_1(d_res, + d_coords, + o_res, + o_coords, + o_flags, + smem, + kargs.hidden_size, // total n number + w_scale, + BlockShape::Block_Nr1 * kr_1 * BlockShape::Block_W1, // along N + BlockShape::Block_N1); // along N + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp new file mode 100644 index 0000000000..6089c2558f --- /dev/null +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// TODO: alow 2 gemm have different type +template +struct FusedMoeGemmPipelineProblem +{ + using ADataType = remove_cvref_t; + using GDataType = remove_cvref_t; + using DDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using AScaleDataType = remove_cvref_t; + using GScaleDataType = remove_cvref_t; + using DScaleDataType = remove_cvref_t; + using YSmoothScaleDataType = remove_cvref_t; + using TopkWeightDataType = remove_cvref_t; + using IndexDataType = remove_cvref_t; + + // the input for next gemm should have same time as + using YDataType = ADataType; + + using GateActivation = remove_cvref_t; + using BlockShape = remove_cvref_t; + using Traits = remove_cvref_t; +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp new file mode 100644 index 0000000000..d7127b098c --- /dev/null +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +enum class FusedMoeGemmWeightPermuteEnum +{ + // permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6 + // permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6 + no_permute = 0, + b_nr_kr_kw_nw_kv = 1, // 0,1,3,4,2,5 + b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv, +}; + +template +struct FusedMoeGemmTraits +{ + // Gate+Up or Gate only + static constexpr bool IsGateOnly = IsGateOnly_; + static constexpr bool UseSmoothQuant = UseSmoothQuant_; + static constexpr index_t OAtomic = OAtomic_; + static constexpr FusedMoeGemmWeightPermuteEnum PermuteEnum = PermuteEnum_; + static constexpr bool PadHiddenSize = PadHiddenSize_; + static constexpr bool PadIntermediateSize = PadIntermediateSize_; +}; + +// Note: this need to be a bit mask +enum class FusedMoeGemmPipelineSequencerEnum +{ + SLD_A = 1 << 0, // shared load a + SLD_B = 1 << 1, + GLD_A = 1 << 2, // global load a + GLD_B = 1 << 3, + SST_A = 1 << 4, // shared store a + SST_B = 1 << 5, + GST_O = 1 << 6, // global store out +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 7ca4a697a7..89ea82c5bd 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -10,114 +10,134 @@ namespace ck_tile { // fp16 -using WarpGemmMfmaF16F16F32M32N32K8 = - WarpGemmImpl>; -using WarpGemmMfmaF16F16F32M16N16K16 = - WarpGemmImpl>; +using WarpGemmMfmaF16F16F32M32N32K8 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; -using WarpGemmMfmaF16F16F32M32N32K16 = - WarpGemmImpl>; +using WarpGemmMfmaF16F16F32M16N16K16 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; -using WarpGemmMfmaF16F16F32M16N16K32 = - WarpGemmImpl>; +using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl, + 2>>; -using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl< - WarpGemmAtrributeMfmaIterateK_SwizzleA>; +using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl, + 2>>; -using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl< - WarpGemmAtrributeMfmaIterateK_SwizzleA>; +using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl, + 1>>; -using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution = WarpGemmImpl< - WarpGemmAtrributeMfmaTransposedCDistribution>; +using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl, + 2>>; -using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution = WarpGemmImpl< - WarpGemmAtrributeMfmaTransposedCDistribution>; +using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution = + WarpGemmImpl>>; + +using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution = + WarpGemmImpl>>; using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution = WarpGemmImpl, 2>>; using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution = WarpGemmImpl, 2>>; using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl, 2>>; // bf16 -using WarpGemmMfmaBf16Bf16F32M32N32K8 = - WarpGemmImpl>; -using WarpGemmMfmaBf16Bf16F32M16N16K16 = - WarpGemmImpl>; +using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + +using WarpGemmMfmaBf16Bf16F32M16N16K16 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; -using WarpGemmMfmaBf16Bf16F32M32N32K16 = - WarpGemmImpl>; +using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl, + 2>>; -using WarpGemmMfmaBf16Bf16F32M16N16K32 = - WarpGemmImpl>; +using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl, + 2>>; -using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl< - WarpGemmAtrributeMfmaIterateK_SwizzleA>; +using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl, + 1>>; -using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA = WarpGemmImpl< - WarpGemmAtrributeMfmaIterateK_SwizzleA>; +using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA = + WarpGemmImpl, + 2>>; -using WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution = WarpGemmImpl< - WarpGemmAtrributeMfmaTransposedCDistribution>; +using WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution = + WarpGemmImpl>>; -using WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution = WarpGemmImpl< - WarpGemmAtrributeMfmaTransposedCDistribution>; +using WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution = + WarpGemmImpl>>; using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution = WarpGemmImpl, 2>>; using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution = WarpGemmImpl, 2>>; using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl, 2>>; // fp8 -using WarpGemmMfma_f32_32x32x16_fp8_fp8 = - WarpGemmImpl>; -using WarpGemmMfma_f32_32x32x16_fp8_bf8 = - WarpGemmImpl>; +using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + +using WarpGemmMfma_f32_32x32x16_fp8_bf8 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; -using WarpGemmMfma_f32_32x32x16_bf8_fp8 = - WarpGemmImpl>; +using WarpGemmMfma_f32_32x32x16_bf8_fp8 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; -using WarpGemmMfma_f32_32x32x16_bf8_bf8 = - WarpGemmImpl>; +using WarpGemmMfma_f32_32x32x16_bf8_bf8 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; -using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed = WarpGemmImpl< - WarpGemmAtrributeMfmaTransposedCDistribution>; +using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed = + WarpGemmImpl>>; -using WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed = WarpGemmImpl< - WarpGemmAtrributeMfmaTransposedCDistribution>; +using WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed = + WarpGemmImpl>>; -using WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed = WarpGemmImpl< - WarpGemmAtrributeMfmaTransposedCDistribution>; +using WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed = + WarpGemmImpl>>; -using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed = WarpGemmImpl< - WarpGemmAtrributeMfmaTransposedCDistribution>; +using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed = + WarpGemmImpl>>; template using WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl, + WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base, 2, swizzle_factor>>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index d80e5198e6..0a8d2dfbe3 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -25,6 +25,8 @@ struct WarpGemmAtrributeMfma static constexpr index_t kN = Impl::kN; static constexpr index_t kK = Impl::kK; + CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } + using AWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple, sequence>, @@ -51,10 +53,13 @@ struct WarpGemmAtrributeMfma sequence<0, 2>>; // c_vec += a_vec * b_vec - CK_TILE_DEVICE void - operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const { - Impl{}(c_vec, a_vec, b_vec); + Impl{}(c_vec, a_vec, b_vec, bool_constant{}); } // c_vec = a_vec * b_vec @@ -85,6 +90,8 @@ struct WarpGemmAtrributeMfmaIterateK static constexpr index_t kN = Impl::kN; static constexpr index_t kK = Impl::kK * kKIter; + CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; } + using AWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple, sequence>, @@ -111,8 +118,11 @@ struct WarpGemmAtrributeMfmaIterateK sequence<0, 2>>; // c_vec += a_vec * b_vec - CK_TILE_DEVICE void - operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const { using buf_a = thread_buffer; using buf_b = thread_buffer; @@ -122,10 +132,33 @@ struct WarpGemmAtrributeMfmaIterateK reinterpret_cast(a_vec) .template get_as()[iKIter], reinterpret_cast(b_vec) - .template get_as()[iKIter]); + .template get_as()[iKIter], + bool_constant{}); }); } + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + number, + bool_constant = {}) const + { + using buf_a = thread_buffer; + using buf_b = thread_buffer; + + static_assert(iKIter < kKIter); + + // static_for<0, kKIter, 1>{}([&](auto iKIter) { + Impl{}(c_vec, + reinterpret_cast(a_vec) + .template get_as()[iKIter], + reinterpret_cast(b_vec) + .template get_as()[iKIter], + bool_constant{}); + //}); + } + // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { @@ -168,6 +201,8 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution static constexpr index_t kN = Impl::kM; static constexpr index_t kK = Impl::kK; + CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } + using AWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple, sequence>, @@ -194,11 +229,14 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution sequence<0, 2>>; // c_vec += a_vec * b_vec - CK_TILE_DEVICE void - operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const { // swap A and B - Impl{}(c_vec, b_vec, a_vec); + Impl{}(c_vec, b_vec, a_vec, bool_constant{}); } // c_vec = a_vec * b_vec @@ -226,6 +264,8 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB static constexpr index_t kN = Impl::kM; static constexpr index_t kK = Impl::kK; + CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } + using AWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple, sequence>, @@ -255,12 +295,15 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB sequence<2, 2>, sequence<0, 2>>; + template // c_vec += a_vec * b_vec - CK_TILE_DEVICE void - operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const { // swap A and B - Impl{}(c_vec, b_vec, a_vec); + Impl{}(c_vec, b_vec, a_vec, bool_constant{}); } // c_vec = a_vec * b_vec @@ -291,6 +334,8 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution static constexpr index_t kN = Impl::kM; static constexpr index_t kK = Impl::kK * kKIter; + CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; } + using AWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple, sequence>, @@ -316,9 +361,12 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution sequence<2, 2>, sequence<0, 2>>; + template // c_vec += a_vec * b_vec - CK_TILE_DEVICE void - operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const { using buf_a = thread_buffer; using buf_b = thread_buffer; @@ -328,10 +376,34 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution reinterpret_cast(b_vec) .template get_as()[iKIter], reinterpret_cast(a_vec) - .template get_as()[iKIter]); + .template get_as()[iKIter], + bool_constant{}); }); } + template + // c_vec += a_vec * b_vec + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + number, + bool_constant = {}) const + { + using buf_a = thread_buffer; + using buf_b = thread_buffer; + + static_assert(iKIter < kKIter); + // swap A and B, value and type + // static_for<0, kKIter, 1>{}([&](auto iKIter) { + Impl{}(c_vec, + reinterpret_cast(b_vec) + .template get_as()[iKIter], + reinterpret_cast(a_vec) + .template get_as()[iKIter], + bool_constant{}); + //}); + } + // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { @@ -377,6 +449,8 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB static constexpr index_t kK = Impl::kK * kKIter; static constexpr index_t SFactor = SFactor_; // group how many CM1 together + CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; } + using AWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple, sequence>, @@ -429,8 +503,11 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB sequence<0, 2>>; #endif // c_vec += a_vec * b_vec - CK_TILE_DEVICE void - operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const { using buf_a = thread_buffer; using buf_b = thread_buffer; @@ -440,10 +517,33 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB reinterpret_cast(b_vec) .template get_as()[iKIter], reinterpret_cast(a_vec) - .template get_as()[iKIter]); + .template get_as()[iKIter], + bool_constant{}); }); } + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + number, + bool_constant = {}) const + { + using buf_a = thread_buffer; + using buf_b = thread_buffer; + + static_assert(iKIter < kKIter); + // swap A and B, value and type + // static_for<0, kKIter, 1>{}([&](auto iKIter) { + Impl{}(c_vec, + reinterpret_cast(b_vec) + .template get_as()[iKIter], + reinterpret_cast(a_vec) + .template get_as()[iKIter], + bool_constant{}); + //}); + } + // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { @@ -488,6 +588,8 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA static constexpr index_t kK = Impl::kK * kKIter; static constexpr index_t SFactor = SFactor_; // group how many CM1 together + CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; } + using AWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple>; // c_vec += a_vec * b_vec - CK_TILE_DEVICE void - operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const { using buf_a = thread_buffer; using buf_b = thread_buffer; @@ -529,10 +634,33 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA reinterpret_cast(a_vec) .template get_as()[iKIter], reinterpret_cast(b_vec) - .template get_as()[iKIter]); + .template get_as()[iKIter], + bool_constant{}); }); } + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + number, + bool_constant = {}) const + { + using buf_a = thread_buffer; + using buf_b = thread_buffer; + + static_assert(iKIter < kKIter); + + // static_for<0, kKIter, 1>{}([&](auto iKIter) { + Impl{}(c_vec, + reinterpret_cast(a_vec) + .template get_as()[iKIter], + reinterpret_cast(b_vec) + .template get_as()[iKIter], + bool_constant{}); + //}); + } + // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index bb59a72982..0aba1f5355 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,12 +7,68 @@ namespace ck_tile { +// TODO: refactor warp-gemm +// currently there is a discrepency for vav/vva if we need transpose C/D +// e.g. if we want A:agpr, B:vgpr, we have to use vva in WGAttrEnum +// because we swap the A/B pointer in _impl code (but not known this info here) +enum class WGAttrCtlEnum +{ + Default_ = 0, + Raw_vvv = 1, // c-vgpr, a-vgpr, b-vgpr + Raw_vaa = 2, // c-vgpr, a-agpr, b-agpr + Raw_vav = 3, // c-vgpr, a-agpr, b-vgpr + Raw_vva = 4, // c-vgpr, a-vgpr, b-agpr + Raw_avv = 5, // c-agpr, a-vgpr, b-vgpr + // raw_a_a_a = 3, // c-agpr, a-agpr, b-agpr +}; + +#define DISPATCH_MFMA_(mfma_, dmod_, amod_, bmod_, cmod_) \ + if constexpr(post_nop_) \ + { \ + asm volatile(mfma_ " %0, %1, %2, %3 ; yyy\n" \ + "s_nop 3" \ + : dmod_(c_vec) \ + : amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \ + :); \ + } \ + else \ + { \ + asm volatile(mfma_ " %0, %1, %2, %3\n" \ + : dmod_(c_vec) \ + : amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \ + :); \ + } + +#define DISPATCH_MFMA_CTRL_(mfma_, ctrl_) \ + if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vvv) \ + { \ + DISPATCH_MFMA_(mfma_, "+v", "v", "v", "v") \ + } \ + else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vaa) \ + { \ + DISPATCH_MFMA_(mfma_, "+v", "a", "a", "v") \ + } \ + else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vav) \ + { \ + DISPATCH_MFMA_(mfma_, "+v", "a", "v", "v") \ + } \ + else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vva) \ + { \ + DISPATCH_MFMA_(mfma_, "+v", "v", "a", "v") \ + } \ + else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_avv) \ + { \ + DISPATCH_MFMA_(mfma_, "+a", "v", "v", "a") \ + } + // FP16 +template struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 { - using ADataType = fp16_t; - using BDataType = fp16_t; - using CDataType = float; + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = fp16_t; + using BDataType = fp16_t; + using CDataType = float; using AVecType = ext_vector_t; using BVecType = ext_vector_t; @@ -33,16 +89,23 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 static constexpr index_t kCM1PerLane = 4; // c_vec += a_vec * b_vec - CK_TILE_DEVICE void - operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const { + DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x8f16", Ctrl) + else + { #if defined(__gfx9__) - c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0); + c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0); #else - ignore = c_vec; - ignore = a_vec; - ignore = b_vec; + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; #endif + } } // c_vec = a_vec * b_vec @@ -52,18 +115,20 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 return bit_cast( __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0)); #else - ignore = a_vec; - ignore = b_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; return CVecType{0.f}; #endif } }; +template struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 { - using ADataType = fp16_t; - using BDataType = fp16_t; - using CDataType = float; + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = fp16_t; + using BDataType = fp16_t; + using CDataType = float; using AVecType = ext_vector_t; using BVecType = ext_vector_t; @@ -84,16 +149,23 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 static constexpr index_t kCM1PerLane = 4; // c_vec += a_vec * b_vec - CK_TILE_DEVICE void - operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const { + DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16f16", Ctrl) + else + { #if defined(__gfx9__) - c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0); + c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0); #else - ignore = c_vec; - ignore = a_vec; - ignore = b_vec; + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; #endif + } } // c_vec = a_vec * b_vec @@ -103,19 +175,21 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 return bit_cast( __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); #else - ignore = a_vec; - ignore = b_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; return CVecType{0.f}; #endif } }; // Bf16 +template struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 { - using ADataType = bf16_t; - using BDataType = bf16_t; - using CDataType = float; + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = bf16_t; + using BDataType = bf16_t; + using CDataType = float; using AVecType = ext_vector_t; using BVecType = ext_vector_t; @@ -136,28 +210,35 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 static constexpr index_t kCM1PerLane = 4; // c_vec += a_vec * b_vec - CK_TILE_DEVICE void - operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const { + DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x8bf16_1k", Ctrl) + else + { #if defined(__gfx90a__) || defined(__gfx94__) - c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); + c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); #elif defined(__gfx908__) - static_for<0, 2, 1>{}([&](auto k) { - c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16( - reinterpret_cast&>(a_vec) - .template get_as>()[number{}], - reinterpret_cast&>(b_vec) - .template get_as>()[number{}], - c_vec, - 0, - 0, - 0); - }); + static_for<0, 2, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); #else - ignore = c_vec; - ignore = a_vec; - ignore = b_vec; + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; #endif + } } // c_vec = a_vec * b_vec @@ -181,18 +262,20 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 }); return c_vec; #else - ignore = a_vec; - ignore = b_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; return CVecType{0.f}; #endif } }; +template struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 { - using ADataType = bf16_t; - using BDataType = bf16_t; - using CDataType = float; + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = bf16_t; + using BDataType = bf16_t; + using CDataType = float; using AVecType = ext_vector_t; using BVecType = ext_vector_t; @@ -213,28 +296,34 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 static constexpr index_t kCM1PerLane = 4; // c_vec += a_vec * b_vec - CK_TILE_DEVICE void - operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const { + DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16bf16_1k", Ctrl) + { #if defined(__gfx90a__) || defined(__gfx94__) - c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); + c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); #elif defined(__gfx908__) - static_for<0, 2, 1>{}([&](auto k) { - c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16( - reinterpret_cast&>(a_vec) - .template get_as>()[number{}], - reinterpret_cast&>(b_vec) - .template get_as>()[number{}], - c_vec, - 0, - 0, - 0); - }); + static_for<0, 2, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); #else - ignore = c_vec; - ignore = a_vec; - ignore = b_vec; + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; #endif + } } // c_vec = a_vec * b_vec @@ -258,20 +347,21 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 }); return c_vec; #else - ignore = a_vec; - ignore = b_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; return CVecType{0.f}; #endif } }; // FP8 -template +template struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base { - using ADataType = AType_; - using BDataType = BType_; - using CDataType = float; + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = AType_; + using BDataType = BType_; + using CDataType = float; using AVecType = ext_vector_t; using BVecType = ext_vector_t; @@ -292,38 +382,120 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base static constexpr index_t kCM1PerLane = 4; // c_vec += a_vec * b_vec - CK_TILE_DEVICE void - operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const { + if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv) + { + if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "v", "v", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "v", "v", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "v", "v", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "v", "v", "v") + } + } + else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa) + { + if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "a", "a", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "a", "a", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "a", "a", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "a", "a", "v") + } + } + else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav) + { + if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "a", "v", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "a", "v", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "a", "v", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "a", "v", "v") + } + } + else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva) + { + if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "v", "a", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "v", "a", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "v", "a", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "v", "a", "v") + } + } + else + { #if defined(__gfx94__) - if constexpr(std::is_same_v && std::is_same_v) - c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( - bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); - else if constexpr(std::is_same_v && std::is_same_v) - c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8( - bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); - else if constexpr(std::is_same_v && std::is_same_v) - c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8( - bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); - else if constexpr(std::is_same_v && std::is_same_v) - c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8( - bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); #elif defined(__gfx908__) || defined(__gfx90a__) - static_for<0, 8, 1>{}([&](auto k) { - float a_f32 = - type_convert(reinterpret_cast&>(a_vec) - .template get_as()[number{}]); - float b_f32 = - type_convert(reinterpret_cast&>(b_vec) - .template get_as()[number{}]); - - c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0); - }); + static_for<0, 8, 1>{}([&](auto k) { + float a_f32 = + type_convert(reinterpret_cast&>(a_vec) + .template get_as()[number{}]); + float b_f32 = + type_convert(reinterpret_cast&>(b_vec) + .template get_as()[number{}]); + + c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0); + }); #else - ignore = c_vec; - ignore = a_vec; - ignore = b_vec; + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; #endif + } } // c_vec = a_vec * b_vec @@ -356,20 +528,97 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base }); return c_vec; #else - ignore = a_vec; - ignore = b_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; return CVecType{0.f}; #endif } }; +template using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8 = - WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; + WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; + +template using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8 = - WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; + WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; + +template using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8 = - WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; + WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; + +template using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8 = - WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; + WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; + +// int8 +template +struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = int8_t; + using BDataType = int8_t; + using CDataType = int32_t; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 32; + static constexpr index_t kN = 32; + static constexpr index_t kK = 16; + + static constexpr index_t kAMLane = 32; + static constexpr index_t kBNLane = 32; + static constexpr index_t kABKLane = 2; + static constexpr index_t kABKPerLane = 8; + + static constexpr index_t kCMLane = 2; + static constexpr index_t kCNLane = 32; + static constexpr index_t kCM0PerLane = 4; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + DISPATCH_MFMA_CTRL_("v_mfma_i32_32x32x16_i8", Ctrl) + else + { +#if defined(__gfx94__) + c_vec = __builtin_amdgcn_mfma_i32_32x32x8i8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); +#elif defined(__gfx908__) || defined(__gfx90a__) + static_for<0, 8, 1>{}([&](auto k) { + float a_f32 = + type_convert(reinterpret_cast&>(a_vec) + .template get_as()[number{}]); + float b_f32 = + type_convert(reinterpret_cast&>(b_vec) + .template get_as()[number{}]); + + c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0); + }); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif + } + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + CVecType c_vec{0}; + operator()(c_vec, a_vec, b_vec); + return c_vec; + } +}; + +#undef DISPATCH_MFMA_ } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 4183d9cb95..99cd5d787e 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -21,40 +21,40 @@ struct WarpGemmMfmaDispatcher; // clang-format off // fp16 -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; }; // bf16 -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; }; // fp8 -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; }; -template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; }; // clang-format on } // namespace impl diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp index eb9dbf127d..182d023a00 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp @@ -31,11 +31,21 @@ struct WarpGemmImpl using BWarpTensor = static_distributed_tensor; using CWarpTensor = static_distributed_tensor; - CK_TILE_DEVICE void operator()(CWarpTensor& c, const AWarpTensor& a, const BWarpTensor& b) const + CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { - using AVec = ext_vector_t; - using BVec = ext_vector_t; - using CVec = ext_vector_t; + return WarpGemmAttribute_::get_num_of_access(); + } + + template + CK_TILE_DEVICE void + operator()(CTensor& c, const ATensor& a, const BTensor& b, bool_constant = {}) const + { + static_assert(detail::is_similiar_distributed_tensor_v && + detail::is_similiar_distributed_tensor_v && + detail::is_similiar_distributed_tensor_v); + using AVec = ext_vector_t; + using BVec = ext_vector_t; + using CVec = ext_vector_t; constexpr auto I0 = number<0>{}; @@ -44,18 +54,49 @@ struct WarpGemmImpl auto c_vec = c.get_thread_buffer().template get_as()[I0]; // c_vec += a_vec * b_vec - WarpGemmAttribute{}(c_vec, a_vec, b_vec); + WarpGemmAttribute{}(c_vec, a_vec, b_vec, bool_constant{}); c.get_thread_buffer().template set_as(I0, c_vec); } - CK_TILE_DEVICE auto operator()(const AWarpTensor& a, const BWarpTensor& b) const + template + CK_TILE_DEVICE void operator()(CTensor& c, + const ATensor& a, + const BTensor& b, + number, + bool_constant = {}) const { - CWarpTensor c; + using AVec = ext_vector_t; + using BVec = ext_vector_t; + using CVec = ext_vector_t; + + constexpr auto I0 = number<0>{}; - using AVec = ext_vector_t; - using BVec = ext_vector_t; - using CVec = ext_vector_t; + const auto a_vec = a.get_thread_buffer().template get_as()[I0]; + const auto b_vec = b.get_thread_buffer().template get_as()[I0]; + auto c_vec = c.get_thread_buffer().template get_as()[I0]; + + // c_vec += a_vec * b_vec + WarpGemmAttribute{}(c_vec, a_vec, b_vec, number{}, bool_constant{}); + + c.get_thread_buffer().template set_as(I0, c_vec); + } + + template + CK_TILE_DEVICE auto operator()(const ATensor& a, const BTensor& b) const + { + using CTensor = CWarpTensor; + static_assert(detail::is_similiar_distributed_tensor_v && + detail::is_similiar_distributed_tensor_v); + CTensor c; + + using AVec = ext_vector_t; + using BVec = ext_vector_t; + using CVec = ext_vector_t; constexpr auto I0 = number<0>{}; diff --git a/include/ck_tile/ops/moe_sorting.hpp b/include/ck_tile/ops/moe_sorting.hpp deleted file mode 100644 index b74607f061..0000000000 --- a/include/ck_tile/ops/moe_sorting.hpp +++ /dev/null @@ -1,11 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp" -#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp" -#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp" -#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp" -#include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/tensor_layout.hpp"