Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gemm fp8 e4m3 #185

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open

gemm fp8 e4m3 #185

wants to merge 31 commits into from

Conversation

AndreSlavescu
Copy link
Contributor

@AndreSlavescu AndreSlavescu commented Aug 31, 2024

Summary

Implemented FP8 gemm with E4M3 representation for FP8.

Issue #65

Testing Done

tested square matrices of varying sizes (64, 256, 512, 1024, 2048) + non-square matrices of varying sizes and compared against torch matmul with appropriate casting for backward (torch.matmul doesn't support fp8_e4m3 dtype for backward).

FP8 gemm will only work on SM_89+

  • Hardware Type: RTX 4090
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are there multiple lines for each configuration? should be 1 line per case?

# cast to FP8
# structure:
# | 1 bit sign | 4 bit exponent | 3 bit mantissa |
a, b = a.to(torch.float8_e4m3fn), b.to(torch.float8_e4m3fn)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Directly casting to FP8 is not a good option since it will cause huge performance drop when our input is not fp8 originally. A preferred way is to do dynamic (or static fp8 quantization) so you can extract a scale factor (in precision fp32) to conduct scaled matmul.

Note that: ultimately we wanna fuse the scaling part into the kernel as well so we can reduce the overhead of quantization and dequantization.

grid = (total_programs_mn, total_programs_k)

c = torch.zeros((m, n), device=a.device, dtype=torch.float16)
gemm_split_k_kernel_forward[grid](
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add some comments to explain the reason of using split_k implementation such as in which scenario it's preferred?

return LigerFP8GemmSplitKFunction.apply(a_fp8, b_fp8)

def fwd_torch():
return torch.matmul(a_float, b_float)
Copy link
Collaborator

@qingquansong qingquansong Sep 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comparing the speed/memory the fp8 kernel with torch matmul on fp32 is not a quite fair comparison. A better comparison would be to compare with torch._scaled_mm with fp8 matmul such as the example here: https://gist.github.com/malfet/7874d96b99670c3da83cbb779ab770c6

]
)
def bench_memory_gemm_split_k_fp8(m, k, n, provider, dtype, device="cuda"):
a_fp8 = torch.randn((m, k), device=device, dtype=dtype).to(torch.float8_e4m3fn)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto. Let's try to create bf16 input and compare the speed/memory of torch._scaled_mm v.s. the fp8 kernel and then compare the joint time of quant + dequant + matmul (with fp8 scaled factor) Thanks!

@qingquansong
Copy link
Collaborator

Thanks for the efforts! Provided some comments and please take a look and let me know if there're any questions and we can discuss more here or on discord.

…into matmulfp8

merge main into matmulfp8
@@ -0,0 +1,302 @@
# adapted from: https://github.com/pytorch-labs/applied-ai/blob/main/kernels/triton/inference/fp8/splitk_gemm_fp8.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we adapt the BSD3 license header from the original repo? @ByronHsu

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can check with legal internally

src/liger_kernel/ops/utils.py Show resolved Hide resolved
@ByronHsu
Copy link
Collaborator

ByronHsu commented Sep 7, 2024

@qingquansong can you take another look

@qingquansong
Copy link
Collaborator

Hey @AndreSlavescu @ByronHsu considering this is a bit different from the original feature request, maybe we can change the PR description a bit to clarify this and also correct the test cases to compare with torch fp8 matmul? (or at least torch bf16 matmul rather than fp32 before checking in) Let me know your thoughts and happy to discuss more.

benchmark/benchmark_gemm_split_k_fp8_e4m3.py Outdated Show resolved Hide resolved
test/transformers/test_gemm_split_k_fp8_e4m3.py Outdated Show resolved Hide resolved

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to run this using 8 x H100 and the test failed because of memory constraints:

FAILED test_gemm.py::test_gemm_split_k[dtype0-0.2-0.2-1024-1024-1024] - triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 327680, Hardware limit: 232448. Reducing block sizes or `num_stages` may help.
FAILED test_gemm.py::test_gemm_split_k[dtype0-0.2-0.2-1024-2048-1024] - triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 327680, Hardware limit: 232448. Reducing block sizes or `num_stages` may help.

Test case should dynamically run only configurations that the machine's hardware can handle.

Comment on lines +71 to +83
acc = tl.zeros((block_m, block_n), dtype=tl.float32)

for k_ in range(0, grid_k, step=2):
k_remaining = k - k_ * (block_k * split_k)

mask_a = offs_k[None, :] < k_remaining
mask_b = offs_k[:, None] < k_remaining

a = tl.load(a_ptrs, mask=mask_a, other=0.0)
b = tl.load(b_ptrs, mask=mask_b, other=0.0)

# fp8 input dot product (supported types: [fp8e4nv, fp8e5, fp8e4b15])
acc = tl.dot(a, b, acc)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you'd have to specify fp16 out_dtype for tl.dot as its fp32 by default.
This will make it work but make sure this does not impact numerical stability.

Suggested change
acc = tl.zeros((block_m, block_n), dtype=tl.float32)
for k_ in range(0, grid_k, step=2):
k_remaining = k - k_ * (block_k * split_k)
mask_a = offs_k[None, :] < k_remaining
mask_b = offs_k[:, None] < k_remaining
a = tl.load(a_ptrs, mask=mask_a, other=0.0)
b = tl.load(b_ptrs, mask=mask_b, other=0.0)
# fp8 input dot product (supported types: [fp8e4nv, fp8e5, fp8e4b15])
acc = tl.dot(a, b, acc)
acc = tl.zeros((block_m, block_n), dtype=tl.float16)
for k_ in range(0, grid_k, step=2):
k_remaining = k - k_ * (block_k * split_k)
mask_a = offs_k[None, :] < k_remaining
mask_b = offs_k[:, None] < k_remaining
a = tl.load(a_ptrs, mask=mask_a, other=0.0)
b = tl.load(b_ptrs, mask=mask_b, other=0.0)
# fp8 input dot product (supported types: [fp8e4nv, fp8e5, fp8e4b15])
acc = tl.dot(a, b, acc, out_dtype=tl.float16)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried this previously, and I don't think its supported. They also don't list it as a param in the documentation, so my guess is that it's designed to be unmodified.

@qingquansong qingquansong self-requested a review September 13, 2024 20:39
@qingquansong qingquansong dismissed their stale review September 25, 2024 00:59

unblock experimental feature

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants