-
-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add the fp8-quantized GeMM for dense linear layers #18
Open
sfc-gh-reyazda
wants to merge
9
commits into
pp-staging
Choose a base branch
from
fp8-fused-gemm
base: pp-staging
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+44
−20
Open
Changes from 3 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
4ce7c85
Add the fp8-quantized GeMM for dense linear layers
sfc-gh-reyazda 7a6384b
add enable_fused_kernel
sfc-gh-reyazda 1df2b4c
set the fused-parameters to true for easier testing
sfc-gh-reyazda d7b5f13
add fp16 quantized-gemm
sfc-gh-reyazda c54d3a8
fix dtype for the gemm output
sfc-gh-reyazda d47996d
remove autotuning and add the best configs
sfc-gh-reyazda fda9309
fix conversion
sfc-gh-reyazda 7f951d0
fix ckpt loading
sfc-gh-reyazda df0547e
remove fp8 kernel
sfc-gh-reyazda File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from vllm.model_executor.layers.fused_fp8.fused_fp8_gemm import matmul_fp8 | ||
|
||
__all__ = [ | ||
"matmul_fp8", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import torch | ||
import triton | ||
import triton.language as tl | ||
|
||
def get_autotune_config(): | ||
return [ | ||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, | ||
num_warps=8), | ||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, | ||
num_warps=4), | ||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, | ||
num_warps=4), | ||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, | ||
num_warps=4), | ||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, | ||
num_warps=4), | ||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, | ||
num_warps=4), | ||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, | ||
num_warps=2), | ||
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, | ||
num_warps=2), | ||
# Good config for fp8 inputs. | ||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, | ||
num_warps=8), | ||
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, | ||
num_warps=8), | ||
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, | ||
num_warps=4), | ||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, | ||
num_warps=4), | ||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, | ||
num_warps=4), | ||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, | ||
num_warps=4), | ||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, | ||
num_warps=4), | ||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, | ||
num_warps=4) | ||
] | ||
|
||
@triton.autotune( | ||
configs=get_autotune_config(), | ||
key=['M', 'N', 'K'], | ||
) | ||
@triton.jit | ||
def matmul_kernel_fp8( | ||
a_ptr, b_ptr, c_ptr, scale_ptr, | ||
M, N, K, | ||
stride_am, stride_ak, | ||
stride_bk, stride_bn, | ||
stride_cm, stride_cn, | ||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # | ||
GROUP_SIZE_M: tl.constexpr, # | ||
quantization_group_size: tl.constexpr | ||
): | ||
pid = tl.program_id(axis=0) | ||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) | ||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) | ||
num_pid_in_group = GROUP_SIZE_M * num_pid_n | ||
group_id = pid // num_pid_in_group | ||
first_pid_m = group_id * GROUP_SIZE_M | ||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) | ||
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) | ||
pid_n = (pid % num_pid_in_group) // group_size_m | ||
|
||
|
||
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M | ||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N | ||
offs_k = tl.arange(0, BLOCK_SIZE_K) | ||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) | ||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) | ||
|
||
b_ptrs_offset = offs_bn[None, :] * (stride_bn // quantization_group_size) | ||
|
||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) | ||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): | ||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) | ||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) | ||
scale = tl.load(scale_ptr + b_ptrs_offset) | ||
# Dequantize weight (fp8 -> bf16) | ||
b = ((b & 0x80) << 8) | ((b & 0x7f) << 4) | ||
b = (b + 0x3C00).to(tl.uint16) | ||
b = (b.to(tl.bfloat16, bitcast=True) * scale).to(tl.bfloat16) | ||
|
||
accumulator = tl.dot(a, b, accumulator) | ||
|
||
a_ptrs += BLOCK_SIZE_K * stride_ak | ||
b_ptrs += BLOCK_SIZE_K * stride_bk | ||
|
||
c = accumulator.to(tl.float16) | ||
|
||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | ||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) | ||
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] | ||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) | ||
tl.store(c_ptrs, c, mask=c_mask) | ||
|
||
def matmul_fp8(a, b, scale, quantization_group_size): | ||
assert a.shape[1] == b.shape[0], "Incompatible dimensions" | ||
assert a.is_contiguous(), "Matrix A must be contiguous" | ||
M, K = a.shape | ||
K, N = b.shape | ||
c = torch.empty((M, N), device=a.device, dtype=torch.float16) | ||
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) | ||
matmul_kernel_fp8[grid]( | ||
a, b, c, scale, # | ||
M, N, K, # | ||
a.stride(0), a.stride(1), # | ||
b.stride(0), b.stride(1), # | ||
c.stride(0), c.stride(1), # | ||
quantization_group_size=quantization_group_size | ||
) | ||
return c |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
orig
should beorig_shape
?