-
Notifications
You must be signed in to change notification settings - Fork 129
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CK_TILE] fused-moe first version (#1634)
* moe pipeline * update code * compile OK * update * update cpu reference * update pipeline_gemm0 * compiler ok * update pipeline * rename to ex pipeline * block-asm * update * update * update first gemm ok * compute correct * update file structure * update README * update * update * update code * update API * return unsupport case * add comment * update readme * update * uncomment * update * fix build err --------- Co-authored-by: valarLip <[email protected]>
- Loading branch information
1 parent
645fe81
commit 440e28b
Showing
66 changed files
with
8,067 additions
and
309 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
set(TILE_EXAPMLE_FUSED_MOE "tile_example_fused_moe") | ||
# not using add_example_executable() to add this target, since we don't want this to have | ||
# to be included in "make all/install/check" | ||
message("adding ${TILE_EXAPMLE_FUSED_MOE}") | ||
file(GLOB INSTANCE_SRCS instances/*.cpp) | ||
add_executable(${TILE_EXAPMLE_FUSED_MOE} EXCLUDE_FROM_ALL main.cpp) | ||
target_include_directories(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) | ||
target_sources(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${INSTANCE_SRCS}) | ||
|
||
set(TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS) | ||
|
||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations | ||
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) | ||
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -DCK_TILE_BUFFER_LOAD_AGPR=1) # TODO: enable load to a | ||
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=4) # rta | ||
# list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -mllvm -greedy-reverse-local-assignment=1) | ||
# list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) | ||
|
||
target_compile_options(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# fused-moe | ||
Implementing the fused-moe block operator using ck-tile. This is a scatter/gather-group-gemm based solution, similiar to that of [vllm moe](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py), but we introduce more kernel fusion to boost performance | ||
![](misc/moe-0.png) | ||
|
||
The benifit of this fused-moe: | ||
* 1.5~2x perf boost compared with current vllm solution | ||
* zero workspace to reduce memory footprint | ||
* much less kernel instance, easy to maintain | ||
|
||
# Implementation and feature support | ||
## moe-sorting | ||
this is a common pre-process step before the actual moe-gemm. The purpose is to transform the moe loop over from token-by-token to expert-by-expert, make sure very workgroup is working for a single expert (B matrix). Besides, we extend this op to do the zeroing of the output buffer(to be used for reduce buffer with atomic) | ||
|
||
## moe-gemm | ||
`moe-gemm` is a group-gemm based back-to-back gemm, where the row-id of input token comes from another buffer. Naive understanding of fused-moe is from token-by-token view as below picture: | ||
![](misc/moe-1.png) | ||
After `moe-sorting`, we can view this algorithm as expert-by-expert, as below: | ||
![](misc/moe-2.png) | ||
|
||
## optimization | ||
summary of the key design of this fused-moe operator: | ||
* fuse 2 group-gemm + activation + `topk-weight` multiply into single kernel, using atomic for 2nd gemm accumualation | ||
* fuse buffer-zeroing in `moe-sorgin`, user no longer need call extra torch.zero() for the out buffer | ||
* fused scatter-gather for row index(same as vllm) | ||
* pre-shuffle B matric(weight) to maximize memory throughput. input(activation) keep original layout `[batch, hidden]`. | ||
* extrem optimized pipeline using block-inline-asm(we call it `micro-kernel` or `uk`), while not breaking the *composable* design of ck | ||
|
||
## | ||
``` | ||
// [indexing implementation-1] | ||
// using M_a as constexpr block_size to partition all tokens into different slices | ||
// each slice map to one expert, and one expert can have multiple slices | ||
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5 | ||
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]] | ||
// tok-0 tok-1 tok-2 tok-3 tok-4 | ||
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number) | ||
// | ||
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]] | ||
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 | ||
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] | ||
// | ||
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1) | ||
// * this could be larger than actual, since actual tokens are on GPU | ||
// | ||
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] | ||
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -| | ||
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o] | ||
// | ||
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr | ||
// | ||
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5] | ||
// * length is (max_num_tokens_padded + block_size - 1) / block_size | ||
// | ||
// num_tokens_post_padded_ptr : [28] | ||
// num_sorted_tiles_ptr : [7] | ||
// | ||
// * different from vLLM | ||
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id | ||
// 2)need sorted_weight_ptr | ||
// 3) use num_sorted_tiles_ptr, already divided by M_a | ||
// | ||
// * below used for indexing | ||
// 1) sorted_token_ids_ptr [max_num_tokens_padded] | ||
// 2) sorted_weight_ptr | ||
// 3) sorted_expert_ids_ptr | ||
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one) | ||
// | ||
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
// SPDX-License-Identifier: MIT | ||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. | ||
|
||
#pragma once | ||
|
||
#include "fused_moesorting.hpp" | ||
#include "fused_moegemm.hpp" | ||
|
||
struct fused_moe_args | ||
{ | ||
const void* a_ptr; // [m, k], input token | ||
const void* a_scale_ptr; // [m, 1], token scale | ||
const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w]) | ||
const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w]) | ||
const void* g_scale_ptr; // [e, 1, n], gate(up) scale | ||
const void* d_scale_ptr; // [e, 1, k], down scale | ||
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input | ||
void* o_ptr; // [m, k], output token (no need to do zeroing) | ||
|
||
const void* topk_ids_ptr; // [tokens, topk] | ||
const void* topk_weight_ptr; // [tokens, topk] | ||
void* sorted_token_ids_ptr; // [max_num_tokens_padded] | ||
void* sorted_weight_ptr; // [max_num_tokens_padded] | ||
void* sorted_expert_ids_ptr; // [(max_num_tokens_padded + block_size - 1) / block_size] | ||
void* num_sorted_tiles_ptr; // [1] | ||
|
||
ck_tile::index_t block_m; // block_m, used to devide the input | ||
ck_tile::index_t hidden_size; // k | ||
ck_tile::index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2 | ||
ck_tile::index_t num_tokens; // input number of tokens for current iteration | ||
ck_tile::index_t num_experts; // number of groups | ||
ck_tile::index_t topk; // need this? | ||
|
||
ck_tile::index_t stride_token; // for input/output, stride for each row, should >= hidden_size | ||
}; | ||
|
||
// This is the public API, will be generated by script | ||
struct fused_moe_traits | ||
{ | ||
std::string prec_i; // input precision | ||
std::string prec_w; // weight precision | ||
std::string prec_o; // output precision | ||
std::string prec_st; // token scale data type | ||
std::string prec_sw; // weight scale data type | ||
std::string prec_sq; // smooth quant scale | ||
std::string prec_kw; // topk-weight data type | ||
int block_m; | ||
int gate_only; | ||
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant | ||
}; | ||
|
||
float fused_moe(fused_moe_traits, fused_moe_args, const ck_tile::stream_config&); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
// SPDX-License-Identifier: MIT | ||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. | ||
|
||
#pragma once | ||
|
||
#include "ck_tile/core.hpp" | ||
#include "ck_tile/host/kernel_launch.hpp" | ||
#include "ck_tile/ops/fused_moe.hpp" | ||
#include <string> | ||
|
||
// this is only a convenient structure for creating an example | ||
// this is not part of the host API | ||
template <typename I, typename W, typename O, typename ST, typename SW, typename SQ, typename KW> | ||
struct FusedMoeGemmTypeConfig; | ||
|
||
template <typename ST, typename SW, typename SQ, typename KW> | ||
struct FusedMoeGemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, ST, SW, SQ, KW> | ||
{ | ||
using ADataType = ck_tile::bf16_t; | ||
using GDataType = ck_tile::bf16_t; | ||
using DDataType = ck_tile::bf16_t; | ||
using AccDataType = float; | ||
using ODataType = ck_tile::bf16_t; | ||
using AScaleDataType = ck_tile::remove_cvref_t<ST>; | ||
using GScaleDataType = ck_tile::remove_cvref_t<SW>; | ||
using DScaleDataType = ck_tile::remove_cvref_t<SW>; | ||
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>; | ||
using TopkWeightDataType = ck_tile::remove_cvref_t<KW>; | ||
using IndexDataType = ck_tile::index_t; | ||
}; | ||
|
||
template <typename ST, typename SW, typename SQ, typename KW> | ||
struct FusedMoeGemmTypeConfig<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, ST, SW, SQ, KW> | ||
{ | ||
using ADataType = ck_tile::fp16_t; | ||
using GDataType = ck_tile::fp16_t; | ||
using DDataType = ck_tile::fp16_t; | ||
using AccDataType = float; | ||
using ODataType = ck_tile::fp16_t; | ||
using AScaleDataType = ck_tile::remove_cvref_t<ST>; | ||
using GScaleDataType = ck_tile::remove_cvref_t<SW>; | ||
using DScaleDataType = ck_tile::remove_cvref_t<SW>; | ||
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>; | ||
using TopkWeightDataType = ck_tile::remove_cvref_t<KW>; | ||
using IndexDataType = ck_tile::index_t; | ||
}; | ||
|
||
template <typename ST, typename SW, typename SQ, typename KW> | ||
struct FusedMoeGemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, ck_tile::bf16_t, ST, SW, SQ, KW> | ||
{ | ||
using ADataType = ck_tile::int8_t; | ||
using GDataType = ck_tile::int8_t; | ||
using DDataType = ck_tile::int8_t; | ||
using AccDataType = int32_t; | ||
using ODataType = ck_tile::bf16_t; | ||
using AScaleDataType = ck_tile::remove_cvref_t<ST>; | ||
using GScaleDataType = ck_tile::remove_cvref_t<SW>; | ||
using DScaleDataType = ck_tile::remove_cvref_t<SW>; | ||
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>; | ||
using TopkWeightDataType = ck_tile::remove_cvref_t<KW>; | ||
using IndexDataType = ck_tile::index_t; | ||
}; | ||
|
||
// runtime args | ||
struct fused_moegemm_args : public ck_tile::FusedMoeGemmHostArgs | ||
{ | ||
}; | ||
|
||
// This is the public API, will be generated by script | ||
struct fused_moegemm_traits | ||
{ | ||
std::string prec_i; // input precision | ||
std::string prec_w; // weight precision | ||
std::string prec_o; // output precision | ||
std::string prec_st; // token scale data type | ||
std::string prec_sw; // weight scale data type | ||
std::string prec_sq; // smooth quant scale | ||
std::string prec_kw; // topk-weight data type | ||
int block_m; | ||
int gate_only; | ||
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant | ||
}; | ||
|
||
float fused_moegemm(fused_moegemm_traits, fused_moegemm_args, const ck_tile::stream_config&); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
// SPDX-License-Identifier: MIT | ||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. | ||
|
||
#pragma once | ||
#include <string> | ||
#include "ck_tile/core.hpp" | ||
#include "ck_tile/host.hpp" | ||
#include "ck_tile/ops/fused_moe.hpp" | ||
|
||
struct fused_moesorting_trait | ||
{ | ||
std::string index_type; | ||
std::string weight_type; // currently always float | ||
}; | ||
|
||
struct fused_moesorting_args : public ck_tile::MoeSortingHostArgs | ||
{ | ||
}; | ||
|
||
float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s); |
Oops, something went wrong.