Skip to content

Commit

Permalink
Refactor module builder and fix issue 41 (#87)
Browse files Browse the repository at this point in the history
  • Loading branch information
mrakitaTT authored Nov 28, 2024
1 parent 36c8492 commit 39812db
Show file tree
Hide file tree
Showing 7 changed files with 209 additions and 103 deletions.
1 change: 1 addition & 0 deletions src/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ StablehloReferenceProcess
StablehloReferenceOps
StablehloPasses
ChloOps
Version
VhloOps
VhloTypes
StablehloOps
Expand Down
22 changes: 13 additions & 9 deletions src/common/api_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@
// https://llvm.org/LICENSE.txt

#include "common/api_impl.h"
#include "common/module_builder.h"
#include "common/status.h"

#include <cassert>
#include <cstring>
#include <iostream>
#include <optional>
#include <sstream>
#include <utility>

#include "common/module_builder.h"
#include "common/status.h"

namespace tt::pjrt {

std::pair<tt::target::DataType, size_t>
Expand Down Expand Up @@ -673,22 +675,24 @@ tt_pjrt_status ClientInstance::PopulateDevices() {
PJRT_Error *ClientInstance::Compile(const PJRT_Program *program,
LoadedExecutableInstance **out_executable) {
DLOG_F(LOG_DEBUG, "ClientInstance::Compile");
std::string_view format(program->format, program->format_size);

std::string_view code(program->code, program->code_size);
std::string_view format(program->format, program->format_size);

if (!context_.has_value()) {
context_.emplace();
tt_pjrt_status status = module_builder_->buildModule(code, format);
if (!tt_pjrt_status_is_ok(status)) {
return MakeError(status);
}
module_builder_->BuildModule(code, format, *context_);

auto executable = std::make_unique<LoadedExecutableInstance>(
*this,
new ExecutableImage(module_builder_->GetBinary(),
new ExecutableImage(module_builder_->getBinary(),
std::string(program->code, program->code_size),
module_builder_->get_num_inputs(),
module_builder_->get_num_outputs()),
module_builder_->getNumInputs(),
module_builder_->getNumOutputs()),
addressable_devices_);
*out_executable = executable.release();

return nullptr;
}

Expand Down
7 changes: 3 additions & 4 deletions src/common/api_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// https://llvm.org/LICENSE.txt

#ifndef IREE_PJRT_PLUGIN_PJRT_COMMON_API_IMPL_H_
#define IREE_PJRT_PLUGIN_PJRT_COMMON_API_IMPL_H_
#ifndef TT_XLA_SRC_COMMON_API_IMPL_H_
#define TT_XLA_SRC_COMMON_API_IMPL_H_

#include <atomic>
#include <iostream>
Expand Down Expand Up @@ -378,7 +378,6 @@ class ClientInstance {
std::vector<DeviceInstance *> addressable_devices_;

std::unique_ptr<ModuleBuilder> module_builder_;
std::optional<mlir::MLIRContext> context_;

// Synchronization.
// We keep one global execution timeline across all devices. The management
Expand Down Expand Up @@ -432,4 +431,4 @@ static void BindApi(PJRT_Api *api) {

} // namespace tt::pjrt

#endif // IREE_PJRT_PLUGIN_PJRT_COMMON_API_IMPL_H_
#endif // TT_XLA_SRC_COMMON_API_IMPL_H_
212 changes: 142 additions & 70 deletions src/common/module_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,55 +4,48 @@
//

#include "common/module_builder.h"
#include "status.h"

// c++ standard library includes
#include <cstdlib>
#include <iostream>

// loguru includes
#include "loguru/loguru.hpp"

// llvm mlir includes
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/IR/MLIRContext.h"

#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/OwningOpRef.h" // from @llvm-project
#include "mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/Parser/Parser.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "stablehlo/dialect/ChloOps.h" // from @stablehlo
#include "stablehlo/dialect/Register.h" // from @stablehlo
#include "stablehlo/dialect/Serialization.h" // from @stablehlo
#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo
#include "stablehlo/transforms/Passes.h" // from @stablehlo
#include "ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h"
#include "ttmlir/RegisterAll.h"

#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"

// stablehlo includes
#include "stablehlo/dialect/Register.h"
#include "stablehlo/dialect/Version.h"
#include "stablehlo/transforms/Passes.h"

// tt-mlir includes
#define TTMLIR_ENABLE_STABLEHLO
#include "tt/runtime/runtime.h"
#include "ttmlir/Conversion/StableHLOToTTIR/StableHLOToTTIR.h"
#include "ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h"
#include "ttmlir/Dialect/TTIR/Transforms/Passes.h"
#include "ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h"
#include "ttmlir/Dialect/TTNN/Transforms/Passes.h"
#include "ttmlir/RegisterAll.h"
#include "ttmlir/Target/TTNN/TTNNToFlatbuffer.h"

#include "loguru/loguru.hpp"
#include "tt/runtime/runtime.h"
namespace tt::pjrt {

void ModuleBuilder::BuildModule(std::string_view code, std::string_view format,
mlir::MLIRContext &context) {
DLOG_F(LOG_DEBUG, "ModuleBuilder::BuildModule");
ModuleBuilder::ModuleBuilder()
: m_status(tt_pjrt_status::kSuccess), m_num_inputs(0), m_num_outputs(0) {
m_context = std::make_unique<mlir::MLIRContext>();

int log_level = loguru::g_stderr_verbosity;
// Register all the required dialects.
// Register all the required dialects and passes.
mlir::DialectRegistry registry;

registry.insert<mlir::arith::ArithDialect>();
Expand All @@ -62,73 +55,152 @@ void ModuleBuilder::BuildModule(std::string_view code, std::string_view format,

mlir::tt::registerAllDialects(registry);
mlir::stablehlo::registerAllDialects(registry);

mlir::func::registerAllExtensions(registry);
mlir::tt::registerAllExtensions(registry);

context.appendDialectRegistry(registry);
mlir::tt::ttir::registerPasses();
mlir::tt::ttnn::registerPasses();

m_context->appendDialectRegistry(registry);
}

tt_pjrt_status ModuleBuilder::buildModule(const std::string_view &code,
const std::string_view &format) {
DLOG_F(LOG_DEBUG, "ModuleBuilder::buildModule");

mlir::OwningOpRef<mlir::ModuleOp> mlir_module = createVHLOModule(code);
if (!tt_pjrt_status_is_ok(m_status)) {
return m_status;
}

convertFromVHLOToSHLO(mlir_module);
if (!tt_pjrt_status_is_ok(m_status)) {
return m_status;
}

convertFromSHLOToTTIR(mlir_module);
if (!tt_pjrt_status_is_ok(m_status)) {
return m_status;
}

convertFromTTIRToTTNN(mlir_module);
if (!tt_pjrt_status_is_ok(m_status)) {
return m_status;
}

createFlatbufferBinary(mlir_module);

return m_status;
}

mlir::OwningOpRef<mlir::ModuleOp> mlir_module =
mlir::OwningOpRef<mlir::ModuleOp>
ModuleBuilder::createVHLOModule(const std::string_view &code) {
mlir::OwningOpRef<mlir::ModuleOp> vhlo_module =
mlir::parseSourceString<mlir::ModuleOp>(
llvm::StringRef(code.data(), code.size()),
// IR may be invalid because some fields may be using DenseElements
// instead of DenseArray. We rectify that below and verify after.
mlir::ParserConfig{&context, /*verifyAfterParse=*/true});
DLOG_F(LOG_DEBUG, "VHLO Module");
if (log_level > 0)
mlir_module->dump();
mlir::ParserConfig{m_context.get(), /*verifyAfterParse=*/true});

mlir::PassManager vhlo_pm(mlir_module.get()->getName());
vhlo_pm.addPass(mlir::stablehlo::createVhloLegalizeToStablehloPass());
// Run the pass manager.
if (mlir::failed(vhlo_pm.run(mlir_module.get()))) {
throw std::runtime_error("Failed to run VHLO->SHLO pipeline.");
if (!vhlo_module) {
DLOG_F(ERROR, "Failed to create VHLO module from the input program code");
m_status = tt_pjrt_status::kInternal;
return nullptr;
}
DLOG_F(LOG_DEBUG, "SHLO Module");
if (log_level > 0)
mlir_module->dump();

mlir::tt::ttir::registerPasses();
mlir::tt::ttnn::registerPasses();
DLOG_F(LOG_DEBUG, "VHLO Module:");
print_module(vhlo_module);

return vhlo_module;
}

void ModuleBuilder::convertFromVHLOToSHLO(
mlir::OwningOpRef<mlir::ModuleOp> &mlir_module) {
mlir::PassManager vhlo_to_shlo_pm(mlir_module.get()->getName());

// Converting VHLO to latest version to facilitate easier conversion to
// StableHLO.
mlir::stablehlo::VhloToVersionPassOptions vhlo_options;
vhlo_options.targetVersionOption =
mlir::vhlo::Version::getCurrentVersion().toString();
vhlo_to_shlo_pm.addPass(
mlir::stablehlo::createVhloToVersionPass(vhlo_options));
vhlo_to_shlo_pm.addPass(mlir::stablehlo::createVhloLegalizeToStablehloPass());

if (mlir::failed(vhlo_to_shlo_pm.run(mlir_module.get()))) {
DLOG_F(ERROR, "Failed to convert from VHLO to SHLO module");
m_status = tt_pjrt_status::kInternal;
return;
}

DLOG_F(LOG_DEBUG, "SHLO Module:");
print_module(mlir_module);
}

void ModuleBuilder::convertFromSHLOToTTIR(
mlir::OwningOpRef<mlir::ModuleOp> &mlir_module) {
// Implicit nesting required to call the stablehlo.composite --> func.call
// conversion.
mlir::PassManager shlo_pm(mlir_module.get()->getName(),
mlir::PassManager::Nesting::Implicit);
mlir::PassManager shlo_to_ttir_pm(mlir_module.get()->getName(),
mlir::PassManager::Nesting::Implicit);

mlir::tt::ttir::StableHLOToTTIRPipelineOptions shlo_options;
shlo_options.arithDialectConversionsEnabled = true;
shlo_options.removeDeadValuesEnabled = true;
shlo_options.legalizeCompositeToCallEnabled = true;
mlir::tt::ttir::createStableHLOToTTIRPipeline(shlo_pm, shlo_options);
// Run the pass manager.
if (mlir::failed(shlo_pm.run(mlir_module.get()))) {
throw std::runtime_error("Failed to run SHLO->TTIR pipeline.");
mlir::tt::ttir::createStableHLOToTTIRPipeline(shlo_to_ttir_pm, shlo_options);

if (mlir::failed(shlo_to_ttir_pm.run(mlir_module.get()))) {
DLOG_F(ERROR, "Failed to convert from SHLO to TTIR module");
m_status = tt_pjrt_status::kInternal;
return;
}
DLOG_F(LOG_DEBUG, "TTIR Module");
if (log_level > 0)
mlir_module->dump();

mlir::PassManager pm(mlir_module.get()->getName());
DLOG_F(LOG_DEBUG, "TTIR Module:");
print_module(mlir_module);
}

void ModuleBuilder::convertFromTTIRToTTNN(
mlir::OwningOpRef<mlir::ModuleOp> &mlir_module) {
mlir::PassManager ttir_to_ttnn_pm(mlir_module.get()->getName());

mlir::tt::ttnn::TTIRToTTNNBackendPipelineOptions options;
mlir::tt::ttnn::createTTIRToTTNNBackendPipeline(pm, options);
mlir::tt::ttnn::createTTIRToTTNNBackendPipeline(ttir_to_ttnn_pm, options);

// Run the pass manager.
if (mlir::failed(pm.run(mlir_module.get()))) {
throw std::runtime_error("Failed to run TTIR->TTNN pipeline.");
if (mlir::failed(ttir_to_ttnn_pm.run(mlir_module.get()))) {
DLOG_F(ERROR, "Failed to convert from TTIR to TTNN module");
m_status = tt_pjrt_status::kInternal;
return;
}
DLOG_F(LOG_DEBUG, "TTNN Module");
if (log_level > 0)
mlir_module->dump();

binary_ptr_ = mlir::tt::ttnn::ttnnToFlatbuffer(mlir_module.get());
DLOG_F(LOG_DEBUG, "TTNN Module:");
print_module(mlir_module);
}

void ModuleBuilder::createFlatbufferBinary(
const mlir::OwningOpRef<mlir::ModuleOp> &mlir_module) {
m_flatbuffer_binary = mlir::tt::ttnn::ttnnToFlatbuffer(mlir_module.get());

if (m_flatbuffer_binary == nullptr) {
DLOG_F(ERROR, "Failed to generate flatbuffer binary");
m_status = tt_pjrt_status::kInternal;
return;
}

tt::runtime::Binary runtime_binary_handle(m_flatbuffer_binary);
m_num_inputs = runtime_binary_handle.getProgramInputs(0).size();
m_num_outputs = runtime_binary_handle.getProgramOutputs(0).size();
}

if (binary_ptr_ == nullptr) {
throw std::runtime_error("Failed to generate flatbuffer binary.");
void ModuleBuilder::print_module(
mlir::OwningOpRef<mlir::ModuleOp> &mlir_module) {
if (loguru::g_stderr_verbosity < LOG_DEBUG) {
return;
}

binary_ = std::make_unique<tt::runtime::Binary>(binary_ptr_);
num_outputs_ = binary_->getProgramOutputs(0).size();
num_inputs_ = binary_->getProgramInputs(0).size();
return;
mlir_module->dump();
}

} // namespace tt::pjrt
Loading

0 comments on commit 39812db

Please sign in to comment.