-
Notifications
You must be signed in to change notification settings - Fork 222
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
gemm fp8 e4m3 #185
base: main
Are you sure you want to change the base?
gemm fp8 e4m3 #185
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are there multiple lines for each configuration? should be 1 line per case?
# cast to FP8 | ||
# structure: | ||
# | 1 bit sign | 4 bit exponent | 3 bit mantissa | | ||
a, b = a.to(torch.float8_e4m3fn), b.to(torch.float8_e4m3fn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Directly casting to FP8 is not a good option since it will cause huge performance drop when our input is not fp8 originally. A preferred way is to do dynamic (or static fp8 quantization) so you can extract a scale factor (in precision fp32) to conduct scaled matmul.
Note that: ultimately we wanna fuse the scaling part into the kernel as well so we can reduce the overhead of quantization and dequantization.
grid = (total_programs_mn, total_programs_k) | ||
|
||
c = torch.zeros((m, n), device=a.device, dtype=torch.float16) | ||
gemm_split_k_kernel_forward[grid]( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add some comments to explain the reason of using split_k implementation such as in which scenario it's preferred?
return LigerFP8GemmSplitKFunction.apply(a_fp8, b_fp8) | ||
|
||
def fwd_torch(): | ||
return torch.matmul(a_float, b_float) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comparing the speed/memory the fp8 kernel with torch matmul on fp32 is not a quite fair comparison. A better comparison would be to compare with torch._scaled_mm with fp8 matmul such as the example here: https://gist.github.com/malfet/7874d96b99670c3da83cbb779ab770c6
] | ||
) | ||
def bench_memory_gemm_split_k_fp8(m, k, n, provider, dtype, device="cuda"): | ||
a_fp8 = torch.randn((m, k), device=device, dtype=dtype).to(torch.float8_e4m3fn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto. Let's try to create bf16 input and compare the speed/memory of torch._scaled_mm v.s. the fp8 kernel and then compare the joint time of quant + dequant + matmul (with fp8 scaled factor) Thanks!
Thanks for the efforts! Provided some comments and please take a look and let me know if there're any questions and we can discuss more here or on discord. |
…into matmulfp8 merge main into matmulfp8
@@ -0,0 +1,302 @@ | |||
# adapted from: https://github.com/pytorch-labs/applied-ai/blob/main/kernels/triton/inference/fp8/splitk_gemm_fp8.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we adapt the BSD3 license header from the original repo? @ByronHsu
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can check with legal internally
@qingquansong can you take another look |
Hey @AndreSlavescu @ByronHsu considering this is a bit different from the original feature request, maybe we can change the PR description a bit to clarify this and also correct the test cases to compare with torch fp8 matmul? (or at least torch bf16 matmul rather than fp32 before checking in) Let me know your thoughts and happy to discuss more. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to run this using 8 x H100 and the test failed because of memory constraints:
FAILED test_gemm.py::test_gemm_split_k[dtype0-0.2-0.2-1024-1024-1024] - triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 327680, Hardware limit: 232448. Reducing block sizes or `num_stages` may help.
FAILED test_gemm.py::test_gemm_split_k[dtype0-0.2-0.2-1024-2048-1024] - triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 327680, Hardware limit: 232448. Reducing block sizes or `num_stages` may help.
Test case should dynamically run only configurations that the machine's hardware can handle.
acc = tl.zeros((block_m, block_n), dtype=tl.float32) | ||
|
||
for k_ in range(0, grid_k, step=2): | ||
k_remaining = k - k_ * (block_k * split_k) | ||
|
||
mask_a = offs_k[None, :] < k_remaining | ||
mask_b = offs_k[:, None] < k_remaining | ||
|
||
a = tl.load(a_ptrs, mask=mask_a, other=0.0) | ||
b = tl.load(b_ptrs, mask=mask_b, other=0.0) | ||
|
||
# fp8 input dot product (supported types: [fp8e4nv, fp8e5, fp8e4b15]) | ||
acc = tl.dot(a, b, acc) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you'd have to specify fp16 out_dtype for tl.dot as its fp32 by default.
This will make it work but make sure this does not impact numerical stability.
acc = tl.zeros((block_m, block_n), dtype=tl.float32) | |
for k_ in range(0, grid_k, step=2): | |
k_remaining = k - k_ * (block_k * split_k) | |
mask_a = offs_k[None, :] < k_remaining | |
mask_b = offs_k[:, None] < k_remaining | |
a = tl.load(a_ptrs, mask=mask_a, other=0.0) | |
b = tl.load(b_ptrs, mask=mask_b, other=0.0) | |
# fp8 input dot product (supported types: [fp8e4nv, fp8e5, fp8e4b15]) | |
acc = tl.dot(a, b, acc) | |
acc = tl.zeros((block_m, block_n), dtype=tl.float16) | |
for k_ in range(0, grid_k, step=2): | |
k_remaining = k - k_ * (block_k * split_k) | |
mask_a = offs_k[None, :] < k_remaining | |
mask_b = offs_k[:, None] < k_remaining | |
a = tl.load(a_ptrs, mask=mask_a, other=0.0) | |
b = tl.load(b_ptrs, mask=mask_b, other=0.0) | |
# fp8 input dot product (supported types: [fp8e4nv, fp8e5, fp8e4b15]) | |
acc = tl.dot(a, b, acc, out_dtype=tl.float16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried this previously, and I don't think its supported. They also don't list it as a param in the documentation, so my guess is that it's designed to be unmodified.
Summary
Implemented FP8 gemm with E4M3 representation for FP8.
Issue #65
Testing Done
tested square matrices of varying sizes (64, 256, 512, 1024, 2048) + non-square matrices of varying sizes and compared against torch matmul with appropriate casting for backward (torch.matmul doesn't support fp8_e4m3 dtype for backward).
FP8 gemm will only work on SM_89+
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence