Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PJRT plugin attributes for StableHLO version #106

Merged
merged 1 commit into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
jaxlib==0.4.31
jax
flax
jaxlib==0.4.36
jax==0.4.36
flax==0.10.2
cmake
ninja
clang-format
Expand Down
1 change: 1 addition & 0 deletions src/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ add_library(TTPJRTCommon
"api_impl.cc"
"platform.cc"
"module_builder.cc"
"plugin_attributes.cc"
)
add_dependencies(TTPJRTCommon tt-mlir loguru)
target_include_directories(TTPJRTCommon PUBLIC
Expand Down
21 changes: 15 additions & 6 deletions src/common/api_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@

#include "common/api_impl.h"

// c++ standard library includes
#include <cassert>
#include <cstring>
#include <iostream>
#include <optional>
#include <sstream>
#include <utility>

// tt-xla includes
#include "common/module_builder.h"
#include "common/plugin_attributes.h"
#include "common/status.h"

namespace tt::pjrt {
Expand Down Expand Up @@ -1095,12 +1098,7 @@ void BindMonomorphicApi(PJRT_Api *api) {
return nullptr;
};

api->PJRT_Plugin_Attributes =
+[](PJRT_Plugin_Attributes_Args *args) -> PJRT_Error * {
DLOG_F(LOG_DEBUG, "PJRT_Plugin_Attributes");
args->num_attributes = 0;
return nullptr;
};
api->PJRT_Plugin_Attributes = InitializePluginAttributes;
mrakitaTT marked this conversation as resolved.
Show resolved Hide resolved

// Bind by object types.
BufferInstance::BindApi(api);
Expand All @@ -1112,4 +1110,15 @@ void BindMonomorphicApi(PJRT_Api *api) {
LoadedExecutableInstance::BindApi(api);
}

PJRT_Error *InitializePluginAttributes(PJRT_Plugin_Attributes_Args *args) {
DLOG_F(LOG_DEBUG, "PJRT_Plugin_Attributes");

static std::unique_ptr<PJRTPluginAttributes> s_plugin_attributes =
std::make_unique<PJRTPluginAttributes>();
args->attributes = s_plugin_attributes->getAttributes();
args->num_attributes = s_plugin_attributes->getNumAttributes();

return nullptr;
}

} // namespace tt::pjrt
13 changes: 11 additions & 2 deletions src/common/api_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#ifndef TT_XLA_SRC_COMMON_API_IMPL_H_
#define TT_XLA_SRC_COMMON_API_IMPL_H_

// c++ standard library includes
#include <atomic>
#include <iostream>
#include <memory>
Expand All @@ -22,10 +23,15 @@
#include <thread>
#include <vector>

// PJRT C API includes
#include "xla/pjrt/c/pjrt_c_api.h"

// tt-mlir includes
#include "tt/runtime/runtime.h"

// tt-xla includes
#include "common/module_builder.h"
#include "common/platform.h"
#include "tt/runtime/runtime.h"
#include "xla/pjrt/c/pjrt_c_api.h"

namespace tt::pjrt {

Expand Down Expand Up @@ -397,6 +403,9 @@ class ClientInstance {
// Binds all monomorphic API members and top-level API struct setup.
void BindMonomorphicApi(PJRT_Api *api);

// Initializes and returns PJRT plugin attributes.
PJRT_Error *InitializePluginAttributes(PJRT_Plugin_Attributes_Args *args);

// Fully binds the PJRT_Api struct for all types. Polymorphic types must be
// specified by template parameters.
template <typename PlatformTy, typename ClientInstanceTy>
Expand Down
11 changes: 1 addition & 10 deletions src/common/module_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,6 @@ ModuleBuilder::createVHLOModule(const std::string_view &code) {
mlir::OwningOpRef<mlir::ModuleOp> vhlo_module =
mlir::parseSourceString<mlir::ModuleOp>(
llvm::StringRef(code.data(), code.size()),
// IR may be invalid because some fields may be using DenseElements
// instead of DenseArray. We rectify that below and verify after.
mlir::ParserConfig{m_context.get(), /*verifyAfterParse=*/true});

if (!vhlo_module) {
Expand All @@ -119,14 +117,7 @@ void ModuleBuilder::convertFromVHLOToSHLO(
mlir::OwningOpRef<mlir::ModuleOp> &mlir_module) {
mlir::PassManager vhlo_to_shlo_pm(mlir_module.get()->getName());

// Converting VHLO to latest version to facilitate easier conversion to
// StableHLO.
mlir::stablehlo::VhloToVersionPassOptions vhlo_options;
vhlo_options.targetVersionOption =
mlir::vhlo::Version::getCurrentVersion().toString();
vhlo_to_shlo_pm.addPass(
mlir::stablehlo::createVhloToVersionPass(vhlo_options));
vhlo_to_shlo_pm.addPass(mlir::stablehlo::createVhloLegalizeToStablehloPass());
mlir::stablehlo::createStablehloDeserializePipeline(vhlo_to_shlo_pm);

if (mlir::failed(vhlo_to_shlo_pm.run(mlir_module.get()))) {
DLOG_F(ERROR, "Failed to convert from VHLO to SHLO module");
Expand Down
39 changes: 39 additions & 0 deletions src/common/plugin_attributes.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "common/plugin_attributes.h"

namespace tt::pjrt {

StableHLOVersionAttribute::StableHLOVersionAttribute(
std::string_view version_name, mlir::vhlo::Version version_id)
: m_version_name(version_name) {
m_version_id[0] = version_id.getMajor();
m_version_id[1] = version_id.getMinor();
m_version_id[2] = version_id.getPatch();
}

PJRT_NamedValue StableHLOVersionAttribute::toNamedValue() const {
PJRT_NamedValue named_value;
named_value.struct_size = PJRT_NamedValue_STRUCT_SIZE;
named_value.extension_start = nullptr;
named_value.name = m_version_name.data();
named_value.name_size = m_version_name.size();
named_value.type = PJRT_NamedValue_Type::PJRT_NamedValue_kInt64List;
named_value.int64_array_value = m_version_id;
named_value.value_size = c_version_id_size;

return named_value;
}

PJRTPluginAttributes::PJRTPluginAttributes()
: m_stablehlo_current_version("stablehlo_current_version",
mlir::vhlo::Version::getCurrentVersion()),
m_stablehlo_minimum_version("stablehlo_minimum_version",
mlir::vhlo::Version::getMinimumVersion()) {
m_attributes.emplace_back(m_stablehlo_current_version.toNamedValue());
m_attributes.emplace_back(m_stablehlo_minimum_version.toNamedValue());
}

} // namespace tt::pjrt
66 changes: 66 additions & 0 deletions src/common/plugin_attributes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TT_XLA_SRC_COMMON_PLUGIN_ATTRIBUTES_H_
#define TT_XLA_SRC_COMMON_PLUGIN_ATTRIBUTES_H_

// c++ standard library includes
#include <string_view>
#include <vector>

// stablehlo includes
#include "stablehlo/dialect/Version.h"

// PJRT C API includes
#include "xla/pjrt/c/pjrt_c_api.h"

namespace tt::pjrt {

// Models StableHLO version attributes.
class StableHLOVersionAttribute {
public:
StableHLOVersionAttribute(std::string_view version_name,
mlir::vhlo::Version version_id);

PJRT_NamedValue toNamedValue() const;

private:
static constexpr std::size_t c_version_id_size = 3;

std::string_view m_version_name;

std::int64_t m_version_id[c_version_id_size];
};

// Container for PJRT plugin attributes that frameworks can read.
class PJRTPluginAttributes {
public:
PJRTPluginAttributes();

const PJRT_NamedValue *getAttributes() const { return m_attributes.data(); }

std::size_t getNumAttributes() const { return m_attributes.size(); }

private:
std::vector<PJRT_NamedValue> m_attributes;

// Attribute for the current StableHLO version of the plugin.
// If a PJRT plugin has an attribute for `stablehlo_current_version` then JAX
// will precisely downgrade the IR to the plugin's version. Without the
// attribute JAX uses 12 weeks IR downgrade, meaning newer features can't be
// used or integrated for several months. Similarly an older than 12w plugin
// will have more stability if it lets JAX know its precise version so it
// downgrades more than 12w. Note that support >12w isn't guaranteed by JAX
// but historically has been fairly stable.
StableHLOVersionAttribute m_stablehlo_current_version;

// Attribute for the minimum supported StableHLO version of the plugin.
// Requires frameworks to upgrade the IR to at least this version, and to not
// downgrade the IR below this version.
StableHLOVersionAttribute m_stablehlo_minimum_version;
};

} // namespace tt::pjrt

#endif // TT_XLA_SRC_COMMON_PLUGIN_ATTRIBUTES_H_
Loading