Skip to content

Commit

Permalink
Addressed comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ajakovljevicTT committed Dec 11, 2024
1 parent 2d2aeb4 commit f3f317f
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 34 deletions.
81 changes: 52 additions & 29 deletions src/common/module_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,39 +66,17 @@ ModuleBuilder::ModuleBuilder()
m_context->appendDialectRegistry(registry);
}

static bool isScalarType(mlir::Type type) {
if (mlir::isa<mlir::FloatType>(type) || mlir::isa<mlir::IntegerType>(type)) {
return true;
}
if (auto tensorType = mlir::dyn_cast<mlir::RankedTensorType>(type)) {
return tensorType.getRank() == 0;
}
return false;
}

void ModuleBuilder::collectOutputTypes(mlir::ModuleOp &&module) {
m_is_output_scalar.clear();
for (auto &op : module.getOps()) {
if (auto funcOp = mlir::dyn_cast<mlir::func::FuncOp>(op)) {
// We care only for return ops of public functions, as that are the ones
// that will produce results in the flatbuffer.
if (funcOp.isPublic()) {
funcOp.walk([&](mlir::Operation *op) {
if (auto returnOp = mlir::dyn_cast<mlir::func::ReturnOp>(op)) {
for (auto operand : returnOp->getOperands()) {
m_is_output_scalar.push_back(isScalarType(operand.getType()));
}
}
});
}
}
}
bool ModuleBuilder::isOutputScalar(int index) const {
assert(index < m_is_output_scalar.size() && "Index of output out of range.");
return m_is_output_scalar[index];
}

tt_pjrt_status ModuleBuilder::buildModule(const std::string_view &code,
const std::string_view &format) {
DLOG_F(LOG_DEBUG, "ModuleBuilder::buildModule");

m_status = tt_pjrt_status::kSuccess;

mlir::OwningOpRef<mlir::ModuleOp> mlir_module = createVHLOModule(code);
if (!tt_pjrt_status_is_ok(m_status)) {
return m_status;
Expand All @@ -109,6 +87,8 @@ tt_pjrt_status ModuleBuilder::buildModule(const std::string_view &code,
return m_status;
}

collectOutputTypes(mlir_module);

convertFromSHLOToTTIR(mlir_module);
if (!tt_pjrt_status_is_ok(m_status)) {
return m_status;
Expand Down Expand Up @@ -155,12 +135,49 @@ void ModuleBuilder::convertFromVHLOToSHLO(
return;
}

collectOutputTypes(mlir_module.get());

DLOG_F(LOG_DEBUG, "SHLO Module:");
printModule(mlir_module);
}

void ModuleBuilder::collectOutputTypes(
const mlir::OwningOpRef<mlir::ModuleOp> &module) {
DLOG_F(LOG_DEBUG, "ModuleBuilder::collectOutputTypes");

m_is_output_scalar.clear();

module.get().walk([&](mlir::Operation *op) {
mlir::func::FuncOp funcOp = mlir::dyn_cast<mlir::func::FuncOp>(op);
mlir::ModuleOp moduleOp = mlir::dyn_cast<mlir::ModuleOp>(op);

// We care only about return ops of public functions, as that are the ones
// that will produce results in the flatbuffer.
if (funcOp && funcOp.isPublic()) {
for (const mlir::Type &returnType :
funcOp.getFunctionType().getResults()) {
m_is_output_scalar.push_back(isScalarType(returnType));
}
}

if (moduleOp) {
if (moduleOp == module.get()) {
return;
}
collectOutputTypes(moduleOp);
}
});
}

bool ModuleBuilder::isScalarType(mlir::Type type) {
if (mlir::isa<mlir::FloatType>(type) || mlir::isa<mlir::IntegerType>(type)) {
return true;
}
if (mlir::RankedTensorType tensorType =
mlir::dyn_cast<mlir::RankedTensorType>(type)) {
return tensorType.getRank() == 0;
}
return false;
}

void ModuleBuilder::convertFromSHLOToTTIR(
mlir::OwningOpRef<mlir::ModuleOp> &mlir_module) {
// Implicit nesting required to call the stablehlo.composite --> func.call
Expand Down Expand Up @@ -215,6 +232,12 @@ void ModuleBuilder::createFlatbufferBinary(
tt::runtime::Binary runtime_binary_handle(m_flatbuffer_binary);
m_num_inputs = runtime_binary_handle.getProgramInputs(0).size();
m_num_outputs = runtime_binary_handle.getProgramOutputs(0).size();

if (m_num_outputs != m_is_output_scalar.size()) {
DLOG_F(ERROR, "The number of return types does not match ");
m_status = tt_pjrt_status::kInternal;
return;
}
}

void ModuleBuilder::printModule(
Expand Down
13 changes: 8 additions & 5 deletions src/common/module_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class ModuleBuilder {

size_t getNumOutputs() const { return m_num_outputs; };

bool isOutputScalar(int index) const { return m_is_output_scalar[index]; }
bool isOutputScalar(int index) const;

private:
// Creates VHLO module from the input program code.
Expand All @@ -43,6 +43,10 @@ class ModuleBuilder {
// Converts VHLO module to StableHLO module.
void convertFromVHLOToSHLO(mlir::OwningOpRef<mlir::ModuleOp> &mlir_module);

// Fills up the m_is_output_scalar array with information is the output type
// scalar or not.
void collectOutputTypes(const mlir::OwningOpRef<mlir::ModuleOp> &module);

// Converts StableHLO module to TTIR module.
void convertFromSHLOToTTIR(mlir::OwningOpRef<mlir::ModuleOp> &mlir_module);

Expand All @@ -53,13 +57,12 @@ class ModuleBuilder {
void
createFlatbufferBinary(const mlir::OwningOpRef<mlir::ModuleOp> &mlir_module);

// Fills up the m_is_output_scalar array with information is the output type
// scalar or not.
void collectOutputTypes(mlir::ModuleOp &&module);

// Prints module to console for debug purposes.
static void printModule(mlir::OwningOpRef<mlir::ModuleOp> &mlir_module);

// Checks if a particular type is scalar.
bool isScalarType(mlir::Type type);

// MLIR context handle.
std::unique_ptr<mlir::MLIRContext> m_context;

Expand Down

0 comments on commit f3f317f

Please sign in to comment.