Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[float8nocompile]: Triton kernels for conversion to float8 dtypes for forward pass of Float8LinearNoCompile #1445

Merged
merged 3 commits into from
Dec 19, 2024

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Dec 19, 2024

Summary

These changes include:

  • Triton kernels for high-precision to fp8 conversion for the forward pass
  • Wrapper function which orchestrates the conversion kernels and returns the result as Float8Tensor
  • Unit test to compare production conversion path with this new path

Test plan:

  • Unit tests for these kernels are passing

Next steps

  • Wrap using torch.autograd.Function subclass w/ backward() passing through gradients unchanged, to make conversion to Float8Tensor differentiable
  • Use torch.mm (which wraps the cuBLAS float8 matmul routine) w/ Float8Tensors to perform the matmul needed for the linear layer

@danielvegamyhre danielvegamyhre added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Dec 19, 2024
Copy link

pytorch-bot bot commented Dec 19, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1445

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 7c47b03 with merge base ec64182 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 19, 2024
@danielvegamyhre danielvegamyhre changed the title [WIP] [float8nocompile]: Triton kernel for conversion to float8 dtypes [WIP] [float8nocompile]: Triton kernel for conversion to float8 dtypes (forward) Dec 19, 2024
@danielvegamyhre danielvegamyhre changed the title [WIP] [float8nocompile]: Triton kernel for conversion to float8 dtypes (forward) [WIP] [float8nocompile]: Triton kernel for conversion to float8 dtypes for forward pass Dec 19, 2024
offs = tl.arange(0, BLOCK_SIZE)

# get amax
amax = tl.zeros([1], dtype=tl.float64)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if I'm reading this code right, the amax value will be calculated per a single cell in the grid - I would have expected the same amax to be used for all cells in the grid, there should be a synchronization step somewhere

for i in range(0, n_elements, BLOCK_SIZE):
block_offs = (i * BLOCK_SIZE) + offs
block_mask = block_offs < n_elements
vals = tl.load(input_ptr + block_offs, mask=block_mask).to(input_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have expected a single tl.load per grid cell, if I'm understanding this code correctly it's both launching a grid and also iterating through every element of the tensor from each cell in the grid. I could be missing something

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Dec 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I think I see the problem, I didn't update the grid dimensions when I changed the kernel implementation strategy. In this simple strategy there should only be 1 thread block / cell, which finds the global amax via loading one block at a time into SRAM in a loop as seen here.

My original strategy was to divide the tensor into some number of blocks, compute a local amax for each block in parallel, writing those to a shared buffer of some sort, then have a second step which computes the global amax from those local amaxes. However, I decided I wanted to try this simple approach first to just get the code functionally correct to start.

Unfortunately a SEV caused me to lose access to my devgpu but I will verify the fix tomorrow then make it more parallelized.

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Dec 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vkuzo I went back to the drawing board and rewrote this using the original strategy I had in mind (described in my previous comment) and it works now. It's a bit more complicated and requires breaking up the conversion process into separate kernels, but it works and uses a higher degree of parallelism so should be reasonably performant.

However, I am wondering if it would be more efficient for each block in the final _to_fp8() kernel to locally compute the global scale from the previously computed block_amaxes, rather than have the separate kernel _scale() do it before hand and share that scale value everywhere. My reasoning is that all instances of the _to_fp8() kernel will be waiting on the _scale() kernel to finish before they can launch, and the memory access latency overhead of the _scale() kernel moving data back and forth between HBM <-> SRAM will be slower than each instance of _to_fp8() just computing scale locally.

If time allows I'd like to benchmark/profile both approaches.

@danielvegamyhre danielvegamyhre changed the title [WIP] [float8nocompile]: Triton kernel for conversion to float8 dtypes for forward pass [float8nocompile]: Triton kernel for conversion to float8 dtypes for forward pass Dec 19, 2024
@danielvegamyhre danielvegamyhre changed the title [float8nocompile]: Triton kernel for conversion to float8 dtypes for forward pass [float8nocompile]: Triton kernels for conversion to float8 dtypes for forward pass Dec 19, 2024
@danielvegamyhre danielvegamyhre changed the title [float8nocompile]: Triton kernels for conversion to float8 dtypes for forward pass [float8nocompile]: Triton kernels for conversion to float8 dtypes for forward pass of Float8LinearNoCompile Dec 19, 2024
@danielvegamyhre
Copy link
Contributor Author

Adding @drisspg for review as well, curious to get your thoughts on this initial implementation of triton kernels for fp8 conversion

Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One big thing here is that all your kernels assume contiguous inputs, either assert that or update to add for striding.

@danielvegamyhre
Copy link
Contributor Author

One big thing here is that all your kernels assume contiguous inputs, either assert that or update to add for striding.

Thanks, added an assertion that input tensor is contiguous, as well as a unit test to validate this behavior.

@drisspg
Copy link
Contributor

drisspg commented Dec 19, 2024

Side note here another fun way to do it: https://github.com/drisspg/transformer_nuggets/blob/a4c66bbeebaa479ad8b6ed82d7efbafa41b17260/transformer_nuggets/fp8/scaled_quant.py#L139

@danielvegamyhre danielvegamyhre merged commit e474839 into pytorch:main Dec 19, 2024
17 of 18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants