-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add PJRT plugin attributes for StableHLO version (#106)
- Loading branch information
Showing
7 changed files
with
136 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "common/plugin_attributes.h" | ||
|
||
namespace tt::pjrt { | ||
|
||
StableHLOVersionAttribute::StableHLOVersionAttribute( | ||
std::string_view version_name, mlir::vhlo::Version version_id) | ||
: m_version_name(version_name) { | ||
m_version_id[0] = version_id.getMajor(); | ||
m_version_id[1] = version_id.getMinor(); | ||
m_version_id[2] = version_id.getPatch(); | ||
} | ||
|
||
PJRT_NamedValue StableHLOVersionAttribute::toNamedValue() const { | ||
PJRT_NamedValue named_value; | ||
named_value.struct_size = PJRT_NamedValue_STRUCT_SIZE; | ||
named_value.extension_start = nullptr; | ||
named_value.name = m_version_name.data(); | ||
named_value.name_size = m_version_name.size(); | ||
named_value.type = PJRT_NamedValue_Type::PJRT_NamedValue_kInt64List; | ||
named_value.int64_array_value = m_version_id; | ||
named_value.value_size = c_version_id_size; | ||
|
||
return named_value; | ||
} | ||
|
||
PJRTPluginAttributes::PJRTPluginAttributes() | ||
: m_stablehlo_current_version("stablehlo_current_version", | ||
mlir::vhlo::Version::getCurrentVersion()), | ||
m_stablehlo_minimum_version("stablehlo_minimum_version", | ||
mlir::vhlo::Version::getMinimumVersion()) { | ||
m_attributes.emplace_back(m_stablehlo_current_version.toNamedValue()); | ||
m_attributes.emplace_back(m_stablehlo_minimum_version.toNamedValue()); | ||
} | ||
|
||
} // namespace tt::pjrt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#ifndef TT_XLA_SRC_COMMON_PLUGIN_ATTRIBUTES_H_ | ||
#define TT_XLA_SRC_COMMON_PLUGIN_ATTRIBUTES_H_ | ||
|
||
// c++ standard library includes | ||
#include <string_view> | ||
#include <vector> | ||
|
||
// stablehlo includes | ||
#include "stablehlo/dialect/Version.h" | ||
|
||
// PJRT C API includes | ||
#include "xla/pjrt/c/pjrt_c_api.h" | ||
|
||
namespace tt::pjrt { | ||
|
||
// Models StableHLO version attributes. | ||
class StableHLOVersionAttribute { | ||
public: | ||
StableHLOVersionAttribute(std::string_view version_name, | ||
mlir::vhlo::Version version_id); | ||
|
||
PJRT_NamedValue toNamedValue() const; | ||
|
||
private: | ||
static constexpr std::size_t c_version_id_size = 3; | ||
|
||
std::string_view m_version_name; | ||
|
||
std::int64_t m_version_id[c_version_id_size]; | ||
}; | ||
|
||
// Container for PJRT plugin attributes that frameworks can read. | ||
class PJRTPluginAttributes { | ||
public: | ||
PJRTPluginAttributes(); | ||
|
||
const PJRT_NamedValue *getAttributes() const { return m_attributes.data(); } | ||
|
||
std::size_t getNumAttributes() const { return m_attributes.size(); } | ||
|
||
private: | ||
std::vector<PJRT_NamedValue> m_attributes; | ||
|
||
// Attribute for the current StableHLO version of the plugin. | ||
// If a PJRT plugin has an attribute for `stablehlo_current_version` then JAX | ||
// will precisely downgrade the IR to the plugin's version. Without the | ||
// attribute JAX uses 12 weeks IR downgrade, meaning newer features can't be | ||
// used or integrated for several months. Similarly an older than 12w plugin | ||
// will have more stability if it lets JAX know its precise version so it | ||
// downgrades more than 12w. Note that support >12w isn't guaranteed by JAX | ||
// but historically has been fairly stable. | ||
StableHLOVersionAttribute m_stablehlo_current_version; | ||
|
||
// Attribute for the minimum supported StableHLO version of the plugin. | ||
// Requires frameworks to upgrade the IR to at least this version, and to not | ||
// downgrade the IR below this version. | ||
StableHLOVersionAttribute m_stablehlo_minimum_version; | ||
}; | ||
|
||
} // namespace tt::pjrt | ||
|
||
#endif // TT_XLA_SRC_COMMON_PLUGIN_ATTRIBUTES_H_ |