Skip to content

Commit

Permalink
Add PJRT plugin attributes for StableHLO version (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
mrakitaTT authored Dec 10, 2024
1 parent 32418e6 commit 5fd6cc8
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 22 deletions.
7 changes: 3 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
jaxlib==0.4.31
jax
flax
jaxlib==0.4.36
jax==0.4.36
flax==0.10.2
cmake
ninja
clang-format
Expand Down
1 change: 1 addition & 0 deletions src/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ add_library(TTPJRTCommon
"api_impl.cc"
"platform.cc"
"module_builder.cc"
"plugin_attributes.cc"
)
add_dependencies(TTPJRTCommon tt-mlir loguru)
target_include_directories(TTPJRTCommon PUBLIC
Expand Down
21 changes: 15 additions & 6 deletions src/common/api_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@

#include "common/api_impl.h"

// c++ standard library includes
#include <cassert>
#include <cstring>
#include <iostream>
#include <optional>
#include <sstream>
#include <utility>

// tt-xla includes
#include "common/module_builder.h"
#include "common/plugin_attributes.h"
#include "common/status.h"

namespace tt::pjrt {
Expand Down Expand Up @@ -1082,12 +1085,7 @@ void BindMonomorphicApi(PJRT_Api *api) {
return nullptr;
};

api->PJRT_Plugin_Attributes =
+[](PJRT_Plugin_Attributes_Args *args) -> PJRT_Error * {
DLOG_F(LOG_DEBUG, "PJRT_Plugin_Attributes");
args->num_attributes = 0;
return nullptr;
};
api->PJRT_Plugin_Attributes = InitializePluginAttributes;

// Bind by object types.
BufferInstance::BindApi(api);
Expand All @@ -1099,4 +1097,15 @@ void BindMonomorphicApi(PJRT_Api *api) {
LoadedExecutableInstance::BindApi(api);
}

PJRT_Error *InitializePluginAttributes(PJRT_Plugin_Attributes_Args *args) {
DLOG_F(LOG_DEBUG, "PJRT_Plugin_Attributes");

static std::unique_ptr<PJRTPluginAttributes> s_plugin_attributes =
std::make_unique<PJRTPluginAttributes>();
args->attributes = s_plugin_attributes->getAttributes();
args->num_attributes = s_plugin_attributes->getNumAttributes();

return nullptr;
}

} // namespace tt::pjrt
13 changes: 11 additions & 2 deletions src/common/api_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#ifndef TT_XLA_SRC_COMMON_API_IMPL_H_
#define TT_XLA_SRC_COMMON_API_IMPL_H_

// c++ standard library includes
#include <atomic>
#include <iostream>
#include <memory>
Expand All @@ -22,10 +23,15 @@
#include <thread>
#include <vector>

// PJRT C API includes
#include "xla/pjrt/c/pjrt_c_api.h"

// tt-mlir includes
#include "tt/runtime/runtime.h"

// tt-xla includes
#include "common/module_builder.h"
#include "common/platform.h"
#include "tt/runtime/runtime.h"
#include "xla/pjrt/c/pjrt_c_api.h"

namespace tt::pjrt {

Expand Down Expand Up @@ -397,6 +403,9 @@ class ClientInstance {
// Binds all monomorphic API members and top-level API struct setup.
void BindMonomorphicApi(PJRT_Api *api);

// Initializes and returns PJRT plugin attributes.
PJRT_Error *InitializePluginAttributes(PJRT_Plugin_Attributes_Args *args);

// Fully binds the PJRT_Api struct for all types. Polymorphic types must be
// specified by template parameters.
template <typename PlatformTy, typename ClientInstanceTy>
Expand Down
11 changes: 1 addition & 10 deletions src/common/module_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,6 @@ ModuleBuilder::createVHLOModule(const std::string_view &code) {
mlir::OwningOpRef<mlir::ModuleOp> vhlo_module =
mlir::parseSourceString<mlir::ModuleOp>(
llvm::StringRef(code.data(), code.size()),
// IR may be invalid because some fields may be using DenseElements
// instead of DenseArray. We rectify that below and verify after.
mlir::ParserConfig{m_context.get(), /*verifyAfterParse=*/true});

if (!vhlo_module) {
Expand All @@ -119,14 +117,7 @@ void ModuleBuilder::convertFromVHLOToSHLO(
mlir::OwningOpRef<mlir::ModuleOp> &mlir_module) {
mlir::PassManager vhlo_to_shlo_pm(mlir_module.get()->getName());

// Converting VHLO to latest version to facilitate easier conversion to
// StableHLO.
mlir::stablehlo::VhloToVersionPassOptions vhlo_options;
vhlo_options.targetVersionOption =
mlir::vhlo::Version::getCurrentVersion().toString();
vhlo_to_shlo_pm.addPass(
mlir::stablehlo::createVhloToVersionPass(vhlo_options));
vhlo_to_shlo_pm.addPass(mlir::stablehlo::createVhloLegalizeToStablehloPass());
mlir::stablehlo::createStablehloDeserializePipeline(vhlo_to_shlo_pm);

if (mlir::failed(vhlo_to_shlo_pm.run(mlir_module.get()))) {
DLOG_F(ERROR, "Failed to convert from VHLO to SHLO module");
Expand Down
39 changes: 39 additions & 0 deletions src/common/plugin_attributes.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "common/plugin_attributes.h"

namespace tt::pjrt {

StableHLOVersionAttribute::StableHLOVersionAttribute(
std::string_view version_name, mlir::vhlo::Version version_id)
: m_version_name(version_name) {
m_version_id[0] = version_id.getMajor();
m_version_id[1] = version_id.getMinor();
m_version_id[2] = version_id.getPatch();
}

PJRT_NamedValue StableHLOVersionAttribute::toNamedValue() const {
PJRT_NamedValue named_value;
named_value.struct_size = PJRT_NamedValue_STRUCT_SIZE;
named_value.extension_start = nullptr;
named_value.name = m_version_name.data();
named_value.name_size = m_version_name.size();
named_value.type = PJRT_NamedValue_Type::PJRT_NamedValue_kInt64List;
named_value.int64_array_value = m_version_id;
named_value.value_size = c_version_id_size;

return named_value;
}

PJRTPluginAttributes::PJRTPluginAttributes()
: m_stablehlo_current_version("stablehlo_current_version",
mlir::vhlo::Version::getCurrentVersion()),
m_stablehlo_minimum_version("stablehlo_minimum_version",
mlir::vhlo::Version::getMinimumVersion()) {
m_attributes.emplace_back(m_stablehlo_current_version.toNamedValue());
m_attributes.emplace_back(m_stablehlo_minimum_version.toNamedValue());
}

} // namespace tt::pjrt
66 changes: 66 additions & 0 deletions src/common/plugin_attributes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TT_XLA_SRC_COMMON_PLUGIN_ATTRIBUTES_H_
#define TT_XLA_SRC_COMMON_PLUGIN_ATTRIBUTES_H_

// c++ standard library includes
#include <string_view>
#include <vector>

// stablehlo includes
#include "stablehlo/dialect/Version.h"

// PJRT C API includes
#include "xla/pjrt/c/pjrt_c_api.h"

namespace tt::pjrt {

// Models StableHLO version attributes.
class StableHLOVersionAttribute {
public:
StableHLOVersionAttribute(std::string_view version_name,
mlir::vhlo::Version version_id);

PJRT_NamedValue toNamedValue() const;

private:
static constexpr std::size_t c_version_id_size = 3;

std::string_view m_version_name;

std::int64_t m_version_id[c_version_id_size];
};

// Container for PJRT plugin attributes that frameworks can read.
class PJRTPluginAttributes {
public:
PJRTPluginAttributes();

const PJRT_NamedValue *getAttributes() const { return m_attributes.data(); }

std::size_t getNumAttributes() const { return m_attributes.size(); }

private:
std::vector<PJRT_NamedValue> m_attributes;

// Attribute for the current StableHLO version of the plugin.
// If a PJRT plugin has an attribute for `stablehlo_current_version` then JAX
// will precisely downgrade the IR to the plugin's version. Without the
// attribute JAX uses 12 weeks IR downgrade, meaning newer features can't be
// used or integrated for several months. Similarly an older than 12w plugin
// will have more stability if it lets JAX know its precise version so it
// downgrades more than 12w. Note that support >12w isn't guaranteed by JAX
// but historically has been fairly stable.
StableHLOVersionAttribute m_stablehlo_current_version;

// Attribute for the minimum supported StableHLO version of the plugin.
// Requires frameworks to upgrade the IR to at least this version, and to not
// downgrade the IR below this version.
StableHLOVersionAttribute m_stablehlo_minimum_version;
};

} // namespace tt::pjrt

#endif // TT_XLA_SRC_COMMON_PLUGIN_ATTRIBUTES_H_

0 comments on commit 5fd6cc8

Please sign in to comment.