Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor module builder and fix issue 41 #87

Merged
merged 1 commit into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ StablehloReferenceProcess
StablehloReferenceOps
StablehloPasses
ChloOps
Version
VhloOps
VhloTypes
StablehloOps
Expand Down
22 changes: 13 additions & 9 deletions src/common/api_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@
// https://llvm.org/LICENSE.txt

#include "common/api_impl.h"
#include "common/module_builder.h"
#include "common/status.h"

#include <cassert>
#include <cstring>
#include <iostream>
#include <optional>
#include <sstream>
#include <utility>

#include "common/module_builder.h"
#include "common/status.h"

namespace tt::pjrt {

std::pair<tt::target::DataType, size_t>
Expand Down Expand Up @@ -673,22 +675,24 @@ tt_pjrt_status ClientInstance::PopulateDevices() {
PJRT_Error *ClientInstance::Compile(const PJRT_Program *program,
LoadedExecutableInstance **out_executable) {
DLOG_F(LOG_DEBUG, "ClientInstance::Compile");
std::string_view format(program->format, program->format_size);

std::string_view code(program->code, program->code_size);
std::string_view format(program->format, program->format_size);

if (!context_.has_value()) {
context_.emplace();
tt_pjrt_status status = module_builder_->buildModule(code, format);
if (!tt_pjrt_status_is_ok(status)) {
return MakeError(status);
}
module_builder_->BuildModule(code, format, *context_);

auto executable = std::make_unique<LoadedExecutableInstance>(
*this,
new ExecutableImage(module_builder_->GetBinary(),
new ExecutableImage(module_builder_->getBinary(),
std::string(program->code, program->code_size),
module_builder_->get_num_inputs(),
module_builder_->get_num_outputs()),
module_builder_->getNumInputs(),
module_builder_->getNumOutputs()),
addressable_devices_);
*out_executable = executable.release();

return nullptr;
}

Expand Down
7 changes: 3 additions & 4 deletions src/common/api_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// https://llvm.org/LICENSE.txt

#ifndef IREE_PJRT_PLUGIN_PJRT_COMMON_API_IMPL_H_
#define IREE_PJRT_PLUGIN_PJRT_COMMON_API_IMPL_H_
#ifndef TT_XLA_SRC_COMMON_API_IMPL_H_
#define TT_XLA_SRC_COMMON_API_IMPL_H_

#include <atomic>
#include <iostream>
Expand Down Expand Up @@ -378,7 +378,6 @@ class ClientInstance {
std::vector<DeviceInstance *> addressable_devices_;

std::unique_ptr<ModuleBuilder> module_builder_;
std::optional<mlir::MLIRContext> context_;

// Synchronization.
// We keep one global execution timeline across all devices. The management
Expand Down Expand Up @@ -432,4 +431,4 @@ static void BindApi(PJRT_Api *api) {

} // namespace tt::pjrt

#endif // IREE_PJRT_PLUGIN_PJRT_COMMON_API_IMPL_H_
#endif // TT_XLA_SRC_COMMON_API_IMPL_H_
212 changes: 142 additions & 70 deletions src/common/module_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,55 +4,48 @@
//

#include "common/module_builder.h"
#include "status.h"

// c++ standard library includes
#include <cstdlib>
#include <iostream>

// loguru includes
#include "loguru/loguru.hpp"

// llvm mlir includes
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/IR/MLIRContext.h"

#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/OwningOpRef.h" // from @llvm-project
#include "mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/Parser/Parser.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "stablehlo/dialect/ChloOps.h" // from @stablehlo
#include "stablehlo/dialect/Register.h" // from @stablehlo
#include "stablehlo/dialect/Serialization.h" // from @stablehlo
#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo
#include "stablehlo/transforms/Passes.h" // from @stablehlo
#include "ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h"
#include "ttmlir/RegisterAll.h"

#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"

// stablehlo includes
#include "stablehlo/dialect/Register.h"
#include "stablehlo/dialect/Version.h"
#include "stablehlo/transforms/Passes.h"

// tt-mlir includes
#define TTMLIR_ENABLE_STABLEHLO
mrakitaTT marked this conversation as resolved.
Show resolved Hide resolved
#include "tt/runtime/runtime.h"
#include "ttmlir/Conversion/StableHLOToTTIR/StableHLOToTTIR.h"
#include "ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h"
#include "ttmlir/Dialect/TTIR/Transforms/Passes.h"
#include "ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h"
#include "ttmlir/Dialect/TTNN/Transforms/Passes.h"
#include "ttmlir/RegisterAll.h"
#include "ttmlir/Target/TTNN/TTNNToFlatbuffer.h"

#include "loguru/loguru.hpp"
#include "tt/runtime/runtime.h"
namespace tt::pjrt {

void ModuleBuilder::BuildModule(std::string_view code, std::string_view format,
mlir::MLIRContext &context) {
DLOG_F(LOG_DEBUG, "ModuleBuilder::BuildModule");
ModuleBuilder::ModuleBuilder()
: m_status(tt_pjrt_status::kSuccess), m_num_inputs(0), m_num_outputs(0) {
m_context = std::make_unique<mlir::MLIRContext>();

int log_level = loguru::g_stderr_verbosity;
// Register all the required dialects.
// Register all the required dialects and passes.
mlir::DialectRegistry registry;

registry.insert<mlir::arith::ArithDialect>();
Expand All @@ -62,73 +55,152 @@ void ModuleBuilder::BuildModule(std::string_view code, std::string_view format,

mlir::tt::registerAllDialects(registry);
mlir::stablehlo::registerAllDialects(registry);

mlir::func::registerAllExtensions(registry);
mlir::tt::registerAllExtensions(registry);

context.appendDialectRegistry(registry);
mlir::tt::ttir::registerPasses();
mlir::tt::ttnn::registerPasses();

m_context->appendDialectRegistry(registry);
}

tt_pjrt_status ModuleBuilder::buildModule(const std::string_view &code,
const std::string_view &format) {
DLOG_F(LOG_DEBUG, "ModuleBuilder::buildModule");

mlir::OwningOpRef<mlir::ModuleOp> mlir_module = createVHLOModule(code);
if (!tt_pjrt_status_is_ok(m_status)) {
return m_status;
}

convertFromVHLOToSHLO(mlir_module);
if (!tt_pjrt_status_is_ok(m_status)) {
return m_status;
}

convertFromSHLOToTTIR(mlir_module);
if (!tt_pjrt_status_is_ok(m_status)) {
return m_status;
}

convertFromTTIRToTTNN(mlir_module);
if (!tt_pjrt_status_is_ok(m_status)) {
return m_status;
}

createFlatbufferBinary(mlir_module);

return m_status;
}

mlir::OwningOpRef<mlir::ModuleOp> mlir_module =
mlir::OwningOpRef<mlir::ModuleOp>
ModuleBuilder::createVHLOModule(const std::string_view &code) {
mlir::OwningOpRef<mlir::ModuleOp> vhlo_module =
mlir::parseSourceString<mlir::ModuleOp>(
llvm::StringRef(code.data(), code.size()),
// IR may be invalid because some fields may be using DenseElements
// instead of DenseArray. We rectify that below and verify after.
mlir::ParserConfig{&context, /*verifyAfterParse=*/true});
DLOG_F(LOG_DEBUG, "VHLO Module");
if (log_level > 0)
mlir_module->dump();
mlir::ParserConfig{m_context.get(), /*verifyAfterParse=*/true});

mlir::PassManager vhlo_pm(mlir_module.get()->getName());
vhlo_pm.addPass(mlir::stablehlo::createVhloLegalizeToStablehloPass());
// Run the pass manager.
if (mlir::failed(vhlo_pm.run(mlir_module.get()))) {
throw std::runtime_error("Failed to run VHLO->SHLO pipeline.");
if (!vhlo_module) {
DLOG_F(ERROR, "Failed to create VHLO module from the input program code");
m_status = tt_pjrt_status::kInternal;
return nullptr;
}
DLOG_F(LOG_DEBUG, "SHLO Module");
if (log_level > 0)
mlir_module->dump();

mlir::tt::ttir::registerPasses();
mlir::tt::ttnn::registerPasses();
DLOG_F(LOG_DEBUG, "VHLO Module:");
print_module(vhlo_module);
mrakitaTT marked this conversation as resolved.
Show resolved Hide resolved

return vhlo_module;
}

void ModuleBuilder::convertFromVHLOToSHLO(
mlir::OwningOpRef<mlir::ModuleOp> &mlir_module) {
mlir::PassManager vhlo_to_shlo_pm(mlir_module.get()->getName());

// Converting VHLO to latest version to facilitate easier conversion to
// StableHLO.
mlir::stablehlo::VhloToVersionPassOptions vhlo_options;
vhlo_options.targetVersionOption =
mlir::vhlo::Version::getCurrentVersion().toString();
vhlo_to_shlo_pm.addPass(
mlir::stablehlo::createVhloToVersionPass(vhlo_options));
vhlo_to_shlo_pm.addPass(mlir::stablehlo::createVhloLegalizeToStablehloPass());

if (mlir::failed(vhlo_to_shlo_pm.run(mlir_module.get()))) {
DLOG_F(ERROR, "Failed to convert from VHLO to SHLO module");
m_status = tt_pjrt_status::kInternal;
return;
}

DLOG_F(LOG_DEBUG, "SHLO Module:");
print_module(mlir_module);
}

void ModuleBuilder::convertFromSHLOToTTIR(
mlir::OwningOpRef<mlir::ModuleOp> &mlir_module) {
// Implicit nesting required to call the stablehlo.composite --> func.call
// conversion.
mlir::PassManager shlo_pm(mlir_module.get()->getName(),
mlir::PassManager::Nesting::Implicit);
mlir::PassManager shlo_to_ttir_pm(mlir_module.get()->getName(),
mlir::PassManager::Nesting::Implicit);

mlir::tt::ttir::StableHLOToTTIRPipelineOptions shlo_options;
shlo_options.arithDialectConversionsEnabled = true;
shlo_options.removeDeadValuesEnabled = true;
shlo_options.legalizeCompositeToCallEnabled = true;
mlir::tt::ttir::createStableHLOToTTIRPipeline(shlo_pm, shlo_options);
// Run the pass manager.
if (mlir::failed(shlo_pm.run(mlir_module.get()))) {
throw std::runtime_error("Failed to run SHLO->TTIR pipeline.");
mlir::tt::ttir::createStableHLOToTTIRPipeline(shlo_to_ttir_pm, shlo_options);

if (mlir::failed(shlo_to_ttir_pm.run(mlir_module.get()))) {
DLOG_F(ERROR, "Failed to convert from SHLO to TTIR module");
m_status = tt_pjrt_status::kInternal;
return;
}
DLOG_F(LOG_DEBUG, "TTIR Module");
if (log_level > 0)
mlir_module->dump();

mlir::PassManager pm(mlir_module.get()->getName());
DLOG_F(LOG_DEBUG, "TTIR Module:");
print_module(mlir_module);
}

void ModuleBuilder::convertFromTTIRToTTNN(
mlir::OwningOpRef<mlir::ModuleOp> &mlir_module) {
mlir::PassManager ttir_to_ttnn_pm(mlir_module.get()->getName());

mlir::tt::ttnn::TTIRToTTNNBackendPipelineOptions options;
mlir::tt::ttnn::createTTIRToTTNNBackendPipeline(pm, options);
mlir::tt::ttnn::createTTIRToTTNNBackendPipeline(ttir_to_ttnn_pm, options);

// Run the pass manager.
if (mlir::failed(pm.run(mlir_module.get()))) {
throw std::runtime_error("Failed to run TTIR->TTNN pipeline.");
if (mlir::failed(ttir_to_ttnn_pm.run(mlir_module.get()))) {
DLOG_F(ERROR, "Failed to convert from TTIR to TTNN module");
m_status = tt_pjrt_status::kInternal;
return;
}
DLOG_F(LOG_DEBUG, "TTNN Module");
if (log_level > 0)
mlir_module->dump();

binary_ptr_ = mlir::tt::ttnn::ttnnToFlatbuffer(mlir_module.get());
DLOG_F(LOG_DEBUG, "TTNN Module:");
print_module(mlir_module);
}

void ModuleBuilder::createFlatbufferBinary(
const mlir::OwningOpRef<mlir::ModuleOp> &mlir_module) {
m_flatbuffer_binary = mlir::tt::ttnn::ttnnToFlatbuffer(mlir_module.get());

if (m_flatbuffer_binary == nullptr) {
DLOG_F(ERROR, "Failed to generate flatbuffer binary");
m_status = tt_pjrt_status::kInternal;
return;
}

tt::runtime::Binary runtime_binary_handle(m_flatbuffer_binary);
m_num_inputs = runtime_binary_handle.getProgramInputs(0).size();
m_num_outputs = runtime_binary_handle.getProgramOutputs(0).size();
}

if (binary_ptr_ == nullptr) {
throw std::runtime_error("Failed to generate flatbuffer binary.");
void ModuleBuilder::print_module(
mlir::OwningOpRef<mlir::ModuleOp> &mlir_module) {
if (loguru::g_stderr_verbosity < LOG_DEBUG) {
return;
}

binary_ = std::make_unique<tt::runtime::Binary>(binary_ptr_);
num_outputs_ = binary_->getProgramOutputs(0).size();
num_inputs_ = binary_->getProgramInputs(0).size();
return;
mlir_module->dump();
}

} // namespace tt::pjrt
Loading
Loading