Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PJRT attribute for StableHLO version to tt-xla plugin #97

Closed
GleasonK opened this issue Dec 2, 2024 · 5 comments · Fixed by #106
Closed

Add PJRT attribute for StableHLO version to tt-xla plugin #97

GleasonK opened this issue Dec 2, 2024 · 5 comments · Fixed by #106
Assignees
Labels
community issue was filed by a community member (not TT)

Comments

@GleasonK
Copy link

GleasonK commented Dec 2, 2024

Hello- I was going to try building and contributing but hit some setup issues and ran out of time, so figure I'll at least give some code pointers to what I intended to add, should be fairly trivial.

If a PJRT plugin has an attribute for stablehlo_current_version, then JAX will precisely downgrade the IR to the plugin's version. Without the attribute, JAX uses 12w IR downgrade, meaning newer features can't be used or integrated for several months, similarly an older than 12w plugin will have more stability if it lets JAX know it's precise version so it downgrades more than 12w - note support >12w isn't guaranteed by JAX but historically has been fairly stable.

To add attr support we'll need an impl for PJRT_Plugin_Attributes:

api->PJRT_Plugin_Attributes =

A template for specifying these attrs can be found in:
https://github.com/openxla/xla/blob/4220451d1582817f80dac19a4760374b666cbce5/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc#L353

Implementation where the attrs are held statically so they can be safely accessed by frameworks here, the <0>/<1> is used to ensure that each string has a unique storage:
https://github.com/openxla/xla/blob/4220451d1582817f80dac19a4760374b666cbce5/xla/pjrt/c/pjrt_c_api_helpers.cc#L637

The stablehlo_current_version impl is commented out for a bug currently, but once JAX makes a new release (this week or next) that is the impl that should be used, so perhaps hold off on adding this until there is a >=0.4.36 release on https://github.com/jax-ml/jax/releases

@github-actions github-actions bot added the community issue was filed by a community member (not TT) label Dec 2, 2024
@GleasonK
Copy link
Author

GleasonK commented Dec 2, 2024

One other minor recommendation would be to use createStablehloDeserializePipeline in convertFromVHLOToSHLO, similar to:
https://github.com/openxla/xla/blob/e11e342f38ec6d9dbbd04408b179f68f939ebed7/xla/pjrt/mlir_to_hlo.cc#L271

That way if we ever need to have a patch fix or something else for compat, we can manage it centrally and tt-xla won't need to worry about it. Also note this is a no-op for non-VHLO modules.

And lastly, the following code comment is stale and can be deleted, we've done away with this patch fix and IR is verified after parse now (note the verify set to true on lines after):

// IR may be invalid because some fields may be using DenseElements

@mrakitaTT
Copy link
Contributor

Thank you so much for your recommendations @GleasonK, this is awesome input! I'll try to implement this tomorrow and I'll tag you on PR.

Please feel free to open other issues with any future recommendation or proposal that you might have, it is highly appreciated! 🙏

I was going to try building and contributing but hit some setup issues and ran out of time

I'd also be curious to hear about these setup issues, if you think it is something on our end that we should improve or document more please let me know.

@GleasonK
Copy link
Author

GleasonK commented Dec 2, 2024

I wasn't able to trivially figure out what TTMLIR_TOOLCHAIN_DIR should be set to based on the instructions on tt-xla/README.md, which prevented my build of tt-mlir. Probably could dig around and figure it out, but didn't have the time today :)

Edit, yeah a little digging would have unblocked me, user error as expected. This is mentioned in:
https://docs.tenstorrent.com/tt-mlir/build.html

I was on that page and tried to ctrl+f "TTMLIR_TOOLCHAIN_DIR", but that keyword isn't on the page, only /opt/ttmlir-toolchain is mentioned, but I didn't see it on my quick glance

@mrakitaTT
Copy link
Contributor

Ohh I see, we are mentioning TTMLIR_TOOLCHAIN_DIR before the tt-mlir environment steps where it is introduced, and the sentence is also unfortunately worded so you'd expect that it is something that you need to do before continuing with the setup. Will fix, thank you for the feedback @GleasonK!

@mrakitaTT
Copy link
Contributor

@GleasonK JAX 0.4.36 release landed, I've sent a PR #106 which should address your suggestions. I hope I've covered everything properly, if you find a time to take a look at PR please let me know. Thank you again for suggesting these changes!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
community issue was filed by a community member (not TT)
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants