-
Notifications
You must be signed in to change notification settings - Fork 188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[float8nocompile]: Triton kernels for conversion to float8 dtypes for forward pass of Float8LinearNoCompile #1445
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1445
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 7c47b03 with merge base ec64182 (): NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
offs = tl.arange(0, BLOCK_SIZE) | ||
|
||
# get amax | ||
amax = tl.zeros([1], dtype=tl.float64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if I'm reading this code right, the amax
value will be calculated per a single cell in the grid - I would have expected the same amax to be used for all cells in the grid, there should be a synchronization step somewhere
for i in range(0, n_elements, BLOCK_SIZE): | ||
block_offs = (i * BLOCK_SIZE) + offs | ||
block_mask = block_offs < n_elements | ||
vals = tl.load(input_ptr + block_offs, mask=block_mask).to(input_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would have expected a single tl.load
per grid cell, if I'm understanding this code correctly it's both launching a grid and also iterating through every element of the tensor from each cell in the grid. I could be missing something
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I think I see the problem, I didn't update the grid dimensions when I changed the kernel implementation strategy. In this simple strategy there should only be 1 thread block / cell, which finds the global amax via loading one block at a time into SRAM in a loop as seen here.
My original strategy was to divide the tensor into some number of blocks, compute a local amax for each block in parallel, writing those to a shared buffer of some sort, then have a second step which computes the global amax from those local amaxes. However, I decided I wanted to try this simple approach first to just get the code functionally correct to start.
Unfortunately a SEV caused me to lose access to my devgpu but I will verify the fix tomorrow then make it more parallelized.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vkuzo I went back to the drawing board and rewrote this using the original strategy I had in mind (described in my previous comment) and it works now. It's a bit more complicated and requires breaking up the conversion process into separate kernels, but it works and uses a higher degree of parallelism so should be reasonably performant.
However, I am wondering if it would be more efficient for each block in the final _to_fp8()
kernel to locally compute the global scale from the previously computed block_amaxes
, rather than have the separate kernel _scale()
do it before hand and share that scale value everywhere. My reasoning is that all instances of the _to_fp8()
kernel will be waiting on the _scale()
kernel to finish before they can launch, and the memory access latency overhead of the _scale()
kernel moving data back and forth between HBM <-> SRAM will be slower than each instance of _to_fp8()
just computing scale locally.
If time allows I'd like to benchmark/profile both approaches.
5002a80
to
867538e
Compare
Adding @drisspg for review as well, curious to get your thoughts on this initial implementation of triton kernels for fp8 conversion |
torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py
Outdated
Show resolved
Hide resolved
torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One big thing here is that all your kernels assume contiguous inputs, either assert that or update to add for striding.
Thanks, added an assertion that input tensor is contiguous, as well as a unit test to validate this behavior. |
Side note here another fun way to do it: https://github.com/drisspg/transformer_nuggets/blob/a4c66bbeebaa479ad8b6ed82d7efbafa41b17260/transformer_nuggets/fp8/scaled_quant.py#L139 |
Summary
These changes include:
Test plan:
Next steps
torch.mm
(which wraps the cuBLAS float8 matmul routine) w/ Float8Tensors to perform the matmul needed for the linear layer