diff --git a/requirements.txt b/requirements.txt index 1ad3779..6a4f8fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ --f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -jaxlib==0.4.31 -jax -flax +jaxlib==0.4.36 +jax==0.4.36 +flax==0.10.2 cmake ninja clang-format diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt index 525132a..7a0be24 100644 --- a/src/common/CMakeLists.txt +++ b/src/common/CMakeLists.txt @@ -19,6 +19,7 @@ add_library(TTPJRTCommon "api_impl.cc" "platform.cc" "module_builder.cc" + "plugin_attributes.cc" ) add_dependencies(TTPJRTCommon tt-mlir loguru) target_include_directories(TTPJRTCommon PUBLIC diff --git a/src/common/api_impl.cc b/src/common/api_impl.cc index 2c8f817..fa1bf33 100644 --- a/src/common/api_impl.cc +++ b/src/common/api_impl.cc @@ -10,6 +10,7 @@ #include "common/api_impl.h" +// c++ standard library includes #include #include #include @@ -17,7 +18,9 @@ #include #include +// tt-xla includes #include "common/module_builder.h" +#include "common/plugin_attributes.h" #include "common/status.h" namespace tt::pjrt { @@ -1082,12 +1085,7 @@ void BindMonomorphicApi(PJRT_Api *api) { return nullptr; }; - api->PJRT_Plugin_Attributes = - +[](PJRT_Plugin_Attributes_Args *args) -> PJRT_Error * { - DLOG_F(LOG_DEBUG, "PJRT_Plugin_Attributes"); - args->num_attributes = 0; - return nullptr; - }; + api->PJRT_Plugin_Attributes = InitializePluginAttributes; // Bind by object types. BufferInstance::BindApi(api); @@ -1099,4 +1097,15 @@ void BindMonomorphicApi(PJRT_Api *api) { LoadedExecutableInstance::BindApi(api); } +PJRT_Error *InitializePluginAttributes(PJRT_Plugin_Attributes_Args *args) { + DLOG_F(LOG_DEBUG, "PJRT_Plugin_Attributes"); + + static std::unique_ptr s_plugin_attributes = + std::make_unique(); + args->attributes = s_plugin_attributes->getAttributes(); + args->num_attributes = s_plugin_attributes->getNumAttributes(); + + return nullptr; +} + } // namespace tt::pjrt diff --git a/src/common/api_impl.h b/src/common/api_impl.h index 988fdda..87c3e19 100644 --- a/src/common/api_impl.h +++ b/src/common/api_impl.h @@ -11,6 +11,7 @@ #ifndef TT_XLA_SRC_COMMON_API_IMPL_H_ #define TT_XLA_SRC_COMMON_API_IMPL_H_ +// c++ standard library includes #include #include #include @@ -22,10 +23,15 @@ #include #include +// PJRT C API includes +#include "xla/pjrt/c/pjrt_c_api.h" + +// tt-mlir includes +#include "tt/runtime/runtime.h" + +// tt-xla includes #include "common/module_builder.h" #include "common/platform.h" -#include "tt/runtime/runtime.h" -#include "xla/pjrt/c/pjrt_c_api.h" namespace tt::pjrt { @@ -397,6 +403,9 @@ class ClientInstance { // Binds all monomorphic API members and top-level API struct setup. void BindMonomorphicApi(PJRT_Api *api); +// Initializes and returns PJRT plugin attributes. +PJRT_Error *InitializePluginAttributes(PJRT_Plugin_Attributes_Args *args); + // Fully binds the PJRT_Api struct for all types. Polymorphic types must be // specified by template parameters. template diff --git a/src/common/module_builder.cc b/src/common/module_builder.cc index d099951..61b614b 100644 --- a/src/common/module_builder.cc +++ b/src/common/module_builder.cc @@ -99,8 +99,6 @@ ModuleBuilder::createVHLOModule(const std::string_view &code) { mlir::OwningOpRef vhlo_module = mlir::parseSourceString( llvm::StringRef(code.data(), code.size()), - // IR may be invalid because some fields may be using DenseElements - // instead of DenseArray. We rectify that below and verify after. mlir::ParserConfig{m_context.get(), /*verifyAfterParse=*/true}); if (!vhlo_module) { @@ -119,14 +117,7 @@ void ModuleBuilder::convertFromVHLOToSHLO( mlir::OwningOpRef &mlir_module) { mlir::PassManager vhlo_to_shlo_pm(mlir_module.get()->getName()); - // Converting VHLO to latest version to facilitate easier conversion to - // StableHLO. - mlir::stablehlo::VhloToVersionPassOptions vhlo_options; - vhlo_options.targetVersionOption = - mlir::vhlo::Version::getCurrentVersion().toString(); - vhlo_to_shlo_pm.addPass( - mlir::stablehlo::createVhloToVersionPass(vhlo_options)); - vhlo_to_shlo_pm.addPass(mlir::stablehlo::createVhloLegalizeToStablehloPass()); + mlir::stablehlo::createStablehloDeserializePipeline(vhlo_to_shlo_pm); if (mlir::failed(vhlo_to_shlo_pm.run(mlir_module.get()))) { DLOG_F(ERROR, "Failed to convert from VHLO to SHLO module"); diff --git a/src/common/plugin_attributes.cc b/src/common/plugin_attributes.cc new file mode 100644 index 0000000..6e5c945 --- /dev/null +++ b/src/common/plugin_attributes.cc @@ -0,0 +1,39 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "common/plugin_attributes.h" + +namespace tt::pjrt { + +StableHLOVersionAttribute::StableHLOVersionAttribute( + std::string_view version_name, mlir::vhlo::Version version_id) + : m_version_name(version_name) { + m_version_id[0] = version_id.getMajor(); + m_version_id[1] = version_id.getMinor(); + m_version_id[2] = version_id.getPatch(); +} + +PJRT_NamedValue StableHLOVersionAttribute::toNamedValue() const { + PJRT_NamedValue named_value; + named_value.struct_size = PJRT_NamedValue_STRUCT_SIZE; + named_value.extension_start = nullptr; + named_value.name = m_version_name.data(); + named_value.name_size = m_version_name.size(); + named_value.type = PJRT_NamedValue_Type::PJRT_NamedValue_kInt64List; + named_value.int64_array_value = m_version_id; + named_value.value_size = c_version_id_size; + + return named_value; +} + +PJRTPluginAttributes::PJRTPluginAttributes() + : m_stablehlo_current_version("stablehlo_current_version", + mlir::vhlo::Version::getCurrentVersion()), + m_stablehlo_minimum_version("stablehlo_minimum_version", + mlir::vhlo::Version::getMinimumVersion()) { + m_attributes.emplace_back(m_stablehlo_current_version.toNamedValue()); + m_attributes.emplace_back(m_stablehlo_minimum_version.toNamedValue()); +} + +} // namespace tt::pjrt diff --git a/src/common/plugin_attributes.h b/src/common/plugin_attributes.h new file mode 100644 index 0000000..419cc00 --- /dev/null +++ b/src/common/plugin_attributes.h @@ -0,0 +1,66 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TT_XLA_SRC_COMMON_PLUGIN_ATTRIBUTES_H_ +#define TT_XLA_SRC_COMMON_PLUGIN_ATTRIBUTES_H_ + +// c++ standard library includes +#include +#include + +// stablehlo includes +#include "stablehlo/dialect/Version.h" + +// PJRT C API includes +#include "xla/pjrt/c/pjrt_c_api.h" + +namespace tt::pjrt { + +// Models StableHLO version attributes. +class StableHLOVersionAttribute { +public: + StableHLOVersionAttribute(std::string_view version_name, + mlir::vhlo::Version version_id); + + PJRT_NamedValue toNamedValue() const; + +private: + static constexpr std::size_t c_version_id_size = 3; + + std::string_view m_version_name; + + std::int64_t m_version_id[c_version_id_size]; +}; + +// Container for PJRT plugin attributes that frameworks can read. +class PJRTPluginAttributes { +public: + PJRTPluginAttributes(); + + const PJRT_NamedValue *getAttributes() const { return m_attributes.data(); } + + std::size_t getNumAttributes() const { return m_attributes.size(); } + +private: + std::vector m_attributes; + + // Attribute for the current StableHLO version of the plugin. + // If a PJRT plugin has an attribute for `stablehlo_current_version` then JAX + // will precisely downgrade the IR to the plugin's version. Without the + // attribute JAX uses 12 weeks IR downgrade, meaning newer features can't be + // used or integrated for several months. Similarly an older than 12w plugin + // will have more stability if it lets JAX know its precise version so it + // downgrades more than 12w. Note that support >12w isn't guaranteed by JAX + // but historically has been fairly stable. + StableHLOVersionAttribute m_stablehlo_current_version; + + // Attribute for the minimum supported StableHLO version of the plugin. + // Requires frameworks to upgrade the IR to at least this version, and to not + // downgrade the IR below this version. + StableHLOVersionAttribute m_stablehlo_minimum_version; +}; + +} // namespace tt::pjrt + +#endif // TT_XLA_SRC_COMMON_PLUGIN_ATTRIBUTES_H_