-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add PJRT attribute for StableHLO version to tt-xla plugin #97
Comments
One other minor recommendation would be to use That way if we ever need to have a patch fix or something else for compat, we can manage it centrally and tt-xla won't need to worry about it. Also note this is a no-op for non-VHLO modules. And lastly, the following code comment is stale and can be deleted, we've done away with this patch fix and IR is verified after parse now (note the verify set to tt-xla/src/common/module_builder.cc Line 102 in d770c94
|
Thank you so much for your recommendations @GleasonK, this is awesome input! I'll try to implement this tomorrow and I'll tag you on PR. Please feel free to open other issues with any future recommendation or proposal that you might have, it is highly appreciated! 🙏
I'd also be curious to hear about these setup issues, if you think it is something on our end that we should improve or document more please let me know. |
I wasn't able to trivially figure out what Edit, yeah a little digging would have unblocked me, user error as expected. This is mentioned in: I was on that page and tried to ctrl+f "TTMLIR_TOOLCHAIN_DIR", but that keyword isn't on the page, only |
Ohh I see, we are mentioning |
Hello- I was going to try building and contributing but hit some setup issues and ran out of time, so figure I'll at least give some code pointers to what I intended to add, should be fairly trivial.
If a PJRT plugin has an attribute for
stablehlo_current_version
, then JAX will precisely downgrade the IR to the plugin's version. Without the attribute, JAX uses 12w IR downgrade, meaning newer features can't be used or integrated for several months, similarly an older than 12w plugin will have more stability if it lets JAX know it's precise version so it downgrades more than 12w - note support >12w isn't guaranteed by JAX but historically has been fairly stable.To add attr support we'll need an impl for
PJRT_Plugin_Attributes
:tt-xla/src/common/api_impl.cc
Line 1098 in d770c94
A template for specifying these attrs can be found in:
https://github.com/openxla/xla/blob/4220451d1582817f80dac19a4760374b666cbce5/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc#L353
Implementation where the attrs are held statically so they can be safely accessed by frameworks here, the
<0>/<1>
is used to ensure that each string has a unique storage:https://github.com/openxla/xla/blob/4220451d1582817f80dac19a4760374b666cbce5/xla/pjrt/c/pjrt_c_api_helpers.cc#L637
The
stablehlo_current_version
impl is commented out for a bug currently, but once JAX makes a new release (this week or next) that is the impl that should be used, so perhaps hold off on adding this until there is a >=0.4.36 release on https://github.com/jax-ml/jax/releasesThe text was updated successfully, but these errors were encountered: