From 39812db5e0d8de45c7898170b44448a72256313e Mon Sep 17 00:00:00 2001 From: Marko Rakita Date: Thu, 28 Nov 2024 09:58:02 +0100 Subject: [PATCH] Refactor module builder and fix issue 41 (#87) --- src/common/CMakeLists.txt | 1 + src/common/api_impl.cc | 22 ++-- src/common/api_impl.h | 7 +- src/common/module_builder.cc | 212 +++++++++++++++++++++++------------ src/common/module_builder.h | 67 ++++++++--- tests/TTIR/test_basic_ops.py | 1 - tests/TTIR/test_mnist.py | 2 - 7 files changed, 209 insertions(+), 103 deletions(-) diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt index 34ac6a2..525132a 100644 --- a/src/common/CMakeLists.txt +++ b/src/common/CMakeLists.txt @@ -49,6 +49,7 @@ StablehloReferenceProcess StablehloReferenceOps StablehloPasses ChloOps +Version VhloOps VhloTypes StablehloOps diff --git a/src/common/api_impl.cc b/src/common/api_impl.cc index 8684839..f4b66b7 100644 --- a/src/common/api_impl.cc +++ b/src/common/api_impl.cc @@ -9,8 +9,7 @@ // https://llvm.org/LICENSE.txt #include "common/api_impl.h" -#include "common/module_builder.h" -#include "common/status.h" + #include #include #include @@ -18,6 +17,9 @@ #include #include +#include "common/module_builder.h" +#include "common/status.h" + namespace tt::pjrt { std::pair @@ -673,22 +675,24 @@ tt_pjrt_status ClientInstance::PopulateDevices() { PJRT_Error *ClientInstance::Compile(const PJRT_Program *program, LoadedExecutableInstance **out_executable) { DLOG_F(LOG_DEBUG, "ClientInstance::Compile"); - std::string_view format(program->format, program->format_size); + std::string_view code(program->code, program->code_size); + std::string_view format(program->format, program->format_size); - if (!context_.has_value()) { - context_.emplace(); + tt_pjrt_status status = module_builder_->buildModule(code, format); + if (!tt_pjrt_status_is_ok(status)) { + return MakeError(status); } - module_builder_->BuildModule(code, format, *context_); auto executable = std::make_unique( *this, - new ExecutableImage(module_builder_->GetBinary(), + new ExecutableImage(module_builder_->getBinary(), std::string(program->code, program->code_size), - module_builder_->get_num_inputs(), - module_builder_->get_num_outputs()), + module_builder_->getNumInputs(), + module_builder_->getNumOutputs()), addressable_devices_); *out_executable = executable.release(); + return nullptr; } diff --git a/src/common/api_impl.h b/src/common/api_impl.h index d27a2bd..988fdda 100644 --- a/src/common/api_impl.h +++ b/src/common/api_impl.h @@ -8,8 +8,8 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // https://llvm.org/LICENSE.txt -#ifndef IREE_PJRT_PLUGIN_PJRT_COMMON_API_IMPL_H_ -#define IREE_PJRT_PLUGIN_PJRT_COMMON_API_IMPL_H_ +#ifndef TT_XLA_SRC_COMMON_API_IMPL_H_ +#define TT_XLA_SRC_COMMON_API_IMPL_H_ #include #include @@ -378,7 +378,6 @@ class ClientInstance { std::vector addressable_devices_; std::unique_ptr module_builder_; - std::optional context_; // Synchronization. // We keep one global execution timeline across all devices. The management @@ -432,4 +431,4 @@ static void BindApi(PJRT_Api *api) { } // namespace tt::pjrt -#endif // IREE_PJRT_PLUGIN_PJRT_COMMON_API_IMPL_H_ +#endif // TT_XLA_SRC_COMMON_API_IMPL_H_ diff --git a/src/common/module_builder.cc b/src/common/module_builder.cc index 31f9b28..d099951 100644 --- a/src/common/module_builder.cc +++ b/src/common/module_builder.cc @@ -4,55 +4,48 @@ // #include "common/module_builder.h" -#include "status.h" +// c++ standard library includes #include #include +// loguru includes +#include "loguru/loguru.hpp" + +// llvm mlir includes #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/Extensions/AllExtensions.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MLProgram/IR/MLProgram.h" -#include "mlir/IR/MLIRContext.h" - -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/OwningOpRef.h" // from @llvm-project -#include "mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/Parser/Parser.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project -#include "stablehlo/dialect/ChloOps.h" // from @stablehlo -#include "stablehlo/dialect/Register.h" // from @stablehlo -#include "stablehlo/dialect/Serialization.h" // from @stablehlo -#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo -#include "stablehlo/transforms/Passes.h" // from @stablehlo -#include "ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h" -#include "ttmlir/RegisterAll.h" - +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" + +// stablehlo includes +#include "stablehlo/dialect/Register.h" +#include "stablehlo/dialect/Version.h" +#include "stablehlo/transforms/Passes.h" + +// tt-mlir includes #define TTMLIR_ENABLE_STABLEHLO +#include "tt/runtime/runtime.h" #include "ttmlir/Conversion/StableHLOToTTIR/StableHLOToTTIR.h" +#include "ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h" #include "ttmlir/Dialect/TTIR/Transforms/Passes.h" #include "ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h" #include "ttmlir/Dialect/TTNN/Transforms/Passes.h" +#include "ttmlir/RegisterAll.h" #include "ttmlir/Target/TTNN/TTNNToFlatbuffer.h" -#include "loguru/loguru.hpp" -#include "tt/runtime/runtime.h" namespace tt::pjrt { -void ModuleBuilder::BuildModule(std::string_view code, std::string_view format, - mlir::MLIRContext &context) { - DLOG_F(LOG_DEBUG, "ModuleBuilder::BuildModule"); +ModuleBuilder::ModuleBuilder() + : m_status(tt_pjrt_status::kSuccess), m_num_inputs(0), m_num_outputs(0) { + m_context = std::make_unique(); - int log_level = loguru::g_stderr_verbosity; - // Register all the required dialects. + // Register all the required dialects and passes. mlir::DialectRegistry registry; registry.insert(); @@ -62,73 +55,152 @@ void ModuleBuilder::BuildModule(std::string_view code, std::string_view format, mlir::tt::registerAllDialects(registry); mlir::stablehlo::registerAllDialects(registry); + mlir::func::registerAllExtensions(registry); mlir::tt::registerAllExtensions(registry); - context.appendDialectRegistry(registry); + mlir::tt::ttir::registerPasses(); + mlir::tt::ttnn::registerPasses(); + + m_context->appendDialectRegistry(registry); +} + +tt_pjrt_status ModuleBuilder::buildModule(const std::string_view &code, + const std::string_view &format) { + DLOG_F(LOG_DEBUG, "ModuleBuilder::buildModule"); + + mlir::OwningOpRef mlir_module = createVHLOModule(code); + if (!tt_pjrt_status_is_ok(m_status)) { + return m_status; + } + + convertFromVHLOToSHLO(mlir_module); + if (!tt_pjrt_status_is_ok(m_status)) { + return m_status; + } + + convertFromSHLOToTTIR(mlir_module); + if (!tt_pjrt_status_is_ok(m_status)) { + return m_status; + } + + convertFromTTIRToTTNN(mlir_module); + if (!tt_pjrt_status_is_ok(m_status)) { + return m_status; + } + + createFlatbufferBinary(mlir_module); + + return m_status; +} - mlir::OwningOpRef mlir_module = +mlir::OwningOpRef +ModuleBuilder::createVHLOModule(const std::string_view &code) { + mlir::OwningOpRef vhlo_module = mlir::parseSourceString( llvm::StringRef(code.data(), code.size()), // IR may be invalid because some fields may be using DenseElements // instead of DenseArray. We rectify that below and verify after. - mlir::ParserConfig{&context, /*verifyAfterParse=*/true}); - DLOG_F(LOG_DEBUG, "VHLO Module"); - if (log_level > 0) - mlir_module->dump(); + mlir::ParserConfig{m_context.get(), /*verifyAfterParse=*/true}); - mlir::PassManager vhlo_pm(mlir_module.get()->getName()); - vhlo_pm.addPass(mlir::stablehlo::createVhloLegalizeToStablehloPass()); - // Run the pass manager. - if (mlir::failed(vhlo_pm.run(mlir_module.get()))) { - throw std::runtime_error("Failed to run VHLO->SHLO pipeline."); + if (!vhlo_module) { + DLOG_F(ERROR, "Failed to create VHLO module from the input program code"); + m_status = tt_pjrt_status::kInternal; + return nullptr; } - DLOG_F(LOG_DEBUG, "SHLO Module"); - if (log_level > 0) - mlir_module->dump(); - mlir::tt::ttir::registerPasses(); - mlir::tt::ttnn::registerPasses(); + DLOG_F(LOG_DEBUG, "VHLO Module:"); + print_module(vhlo_module); + + return vhlo_module; +} + +void ModuleBuilder::convertFromVHLOToSHLO( + mlir::OwningOpRef &mlir_module) { + mlir::PassManager vhlo_to_shlo_pm(mlir_module.get()->getName()); + + // Converting VHLO to latest version to facilitate easier conversion to + // StableHLO. + mlir::stablehlo::VhloToVersionPassOptions vhlo_options; + vhlo_options.targetVersionOption = + mlir::vhlo::Version::getCurrentVersion().toString(); + vhlo_to_shlo_pm.addPass( + mlir::stablehlo::createVhloToVersionPass(vhlo_options)); + vhlo_to_shlo_pm.addPass(mlir::stablehlo::createVhloLegalizeToStablehloPass()); + + if (mlir::failed(vhlo_to_shlo_pm.run(mlir_module.get()))) { + DLOG_F(ERROR, "Failed to convert from VHLO to SHLO module"); + m_status = tt_pjrt_status::kInternal; + return; + } + + DLOG_F(LOG_DEBUG, "SHLO Module:"); + print_module(mlir_module); +} +void ModuleBuilder::convertFromSHLOToTTIR( + mlir::OwningOpRef &mlir_module) { // Implicit nesting required to call the stablehlo.composite --> func.call // conversion. - mlir::PassManager shlo_pm(mlir_module.get()->getName(), - mlir::PassManager::Nesting::Implicit); + mlir::PassManager shlo_to_ttir_pm(mlir_module.get()->getName(), + mlir::PassManager::Nesting::Implicit); + mlir::tt::ttir::StableHLOToTTIRPipelineOptions shlo_options; shlo_options.arithDialectConversionsEnabled = true; shlo_options.removeDeadValuesEnabled = true; shlo_options.legalizeCompositeToCallEnabled = true; - mlir::tt::ttir::createStableHLOToTTIRPipeline(shlo_pm, shlo_options); - // Run the pass manager. - if (mlir::failed(shlo_pm.run(mlir_module.get()))) { - throw std::runtime_error("Failed to run SHLO->TTIR pipeline."); + mlir::tt::ttir::createStableHLOToTTIRPipeline(shlo_to_ttir_pm, shlo_options); + + if (mlir::failed(shlo_to_ttir_pm.run(mlir_module.get()))) { + DLOG_F(ERROR, "Failed to convert from SHLO to TTIR module"); + m_status = tt_pjrt_status::kInternal; + return; } - DLOG_F(LOG_DEBUG, "TTIR Module"); - if (log_level > 0) - mlir_module->dump(); - mlir::PassManager pm(mlir_module.get()->getName()); + DLOG_F(LOG_DEBUG, "TTIR Module:"); + print_module(mlir_module); +} + +void ModuleBuilder::convertFromTTIRToTTNN( + mlir::OwningOpRef &mlir_module) { + mlir::PassManager ttir_to_ttnn_pm(mlir_module.get()->getName()); + mlir::tt::ttnn::TTIRToTTNNBackendPipelineOptions options; - mlir::tt::ttnn::createTTIRToTTNNBackendPipeline(pm, options); + mlir::tt::ttnn::createTTIRToTTNNBackendPipeline(ttir_to_ttnn_pm, options); // Run the pass manager. - if (mlir::failed(pm.run(mlir_module.get()))) { - throw std::runtime_error("Failed to run TTIR->TTNN pipeline."); + if (mlir::failed(ttir_to_ttnn_pm.run(mlir_module.get()))) { + DLOG_F(ERROR, "Failed to convert from TTIR to TTNN module"); + m_status = tt_pjrt_status::kInternal; + return; } - DLOG_F(LOG_DEBUG, "TTNN Module"); - if (log_level > 0) - mlir_module->dump(); - binary_ptr_ = mlir::tt::ttnn::ttnnToFlatbuffer(mlir_module.get()); + DLOG_F(LOG_DEBUG, "TTNN Module:"); + print_module(mlir_module); +} + +void ModuleBuilder::createFlatbufferBinary( + const mlir::OwningOpRef &mlir_module) { + m_flatbuffer_binary = mlir::tt::ttnn::ttnnToFlatbuffer(mlir_module.get()); + + if (m_flatbuffer_binary == nullptr) { + DLOG_F(ERROR, "Failed to generate flatbuffer binary"); + m_status = tt_pjrt_status::kInternal; + return; + } + + tt::runtime::Binary runtime_binary_handle(m_flatbuffer_binary); + m_num_inputs = runtime_binary_handle.getProgramInputs(0).size(); + m_num_outputs = runtime_binary_handle.getProgramOutputs(0).size(); +} - if (binary_ptr_ == nullptr) { - throw std::runtime_error("Failed to generate flatbuffer binary."); +void ModuleBuilder::print_module( + mlir::OwningOpRef &mlir_module) { + if (loguru::g_stderr_verbosity < LOG_DEBUG) { + return; } - binary_ = std::make_unique(binary_ptr_); - num_outputs_ = binary_->getProgramOutputs(0).size(); - num_inputs_ = binary_->getProgramInputs(0).size(); - return; + mlir_module->dump(); } } // namespace tt::pjrt diff --git a/src/common/module_builder.h b/src/common/module_builder.h index 5b70568..70c753a 100644 --- a/src/common/module_builder.h +++ b/src/common/module_builder.h @@ -3,40 +3,73 @@ // SPDX-License-Identifier: Apache-2.0 // -#ifndef IREE_PJRT_PLUGIN_PJRT_COMMON_MODULE_BUILDER_H_ -#define IREE_PJRT_PLUGIN_PJRT_COMMON_MODULE_BUILDER_H_ +#ifndef TT_XLA_SRC_COMMON_MODULE_BUILDER_H_ +#define TT_XLA_SRC_COMMON_MODULE_BUILDER_H_ +// c++ standard library includes #include #include +// llvm mlir includes #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" -#include "tt/runtime/runtime.h" + +// tt-xla includes +#include "status.h" namespace tt::pjrt { class ModuleBuilder { public: - ModuleBuilder() = default; - ~ModuleBuilder() = default; + ModuleBuilder(); + + tt_pjrt_status buildModule(const std::string_view &code, + const std::string_view &format); - size_t get_num_inputs() { return num_inputs_; }; - size_t get_num_outputs() { return num_outputs_; }; - unsigned int get_code_size() { return code_size_; }; + std::shared_ptr getBinary() const { return m_flatbuffer_binary; } - std::shared_ptr GetBinary() { return binary_ptr_; } + size_t getNumInputs() const { return m_num_inputs; }; - void BuildModule(std::string_view code, std::string_view format, - mlir::MLIRContext &context); + size_t getNumOutputs() const { return m_num_outputs; }; private: - size_t num_inputs_ = 0; - size_t num_outputs_ = 0; - unsigned int code_size_ = 0; - std::shared_ptr binary_ptr_; - std::unique_ptr binary_; + // Creates VHLO module from the input program code. + mlir::OwningOpRef + createVHLOModule(const std::string_view &code); + + // Converts VHLO module to StableHLO module. + void convertFromVHLOToSHLO(mlir::OwningOpRef &mlir_module); + + // Converts StableHLO module to TTIR module. + void convertFromSHLOToTTIR(mlir::OwningOpRef &mlir_module); + + // Converts TTIR module to TTNN module. + void convertFromTTIRToTTNN(mlir::OwningOpRef &mlir_module); + + // Creates flatbuffer binary from the built TTNN module. + void + createFlatbufferBinary(const mlir::OwningOpRef &mlir_module); + + // Prints module to console for debug purposes. + static void print_module(mlir::OwningOpRef &mlir_module); + + // MLIR context handle. + std::unique_ptr m_context; + + // Flatbuffer binary handle. + std::shared_ptr m_flatbuffer_binary; + + // Number of binary program inputs. + size_t m_num_inputs; + + // Number of binary program outputs. + size_t m_num_outputs; + + // Holds status of the last builder action. + tt_pjrt_status m_status; }; } // namespace tt::pjrt -#endif // IREE_PJRT_PLUGIN_PJRT_COMMON_MODULE_BUILDER_H_ +#endif // TT_XLA_SRC_COMMON_MODULE_BUILDER_H_ diff --git a/tests/TTIR/test_basic_ops.py b/tests/TTIR/test_basic_ops.py index 58dff4e..37116f5 100644 --- a/tests/TTIR/test_basic_ops.py +++ b/tests/TTIR/test_basic_ops.py @@ -108,7 +108,6 @@ def module_div(a, b): verify_module(module_div, [(3, 3, 3), (3, 3, 3)], required_atol=35e-2) -@pytest.mark.skip("VHLO Legalization failed.") def test_dot_general_op(): def module_dot_general(a, b): return jnp.dot(a, b) diff --git a/tests/TTIR/test_mnist.py b/tests/TTIR/test_mnist.py index 310a721..64f862f 100644 --- a/tests/TTIR/test_mnist.py +++ b/tests/TTIR/test_mnist.py @@ -9,7 +9,6 @@ from infrastructure import verify_module -@pytest.mark.skip("VHLO Legalization failed.") def test_matmul(): def module_matmul(a, b): return jnp.matmul(a, b) @@ -17,7 +16,6 @@ def module_matmul(a, b): verify_module(module_matmul, [(32, 32), (32, 32)], required_atol=3e-2) -@pytest.mark.skip("VHLO Legalization failed.") def test_matmul_with_bias(): def module_matmul(a, b, bias): return jnp.matmul(a, b) + bias