Skip to content

Commit

Permalink
[lower_to_mlir] Use named_attrs from tt-forge ops instead of hardcodi…
Browse files Browse the repository at this point in the history
…ng attributes for each op, while lowering to mlir.
  • Loading branch information
dgolubovicTT committed Aug 8, 2024
1 parent eda3539 commit 8cc7c46
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 12 deletions.
38 changes: 28 additions & 10 deletions pybuda/csrc/passes/lower_to_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ using namespace tt;
/**
* @brief Implementation of TT-MLIR emission from the TTForge graph.
*/

class MLIRGenerator
{
public:
Expand Down Expand Up @@ -109,6 +110,25 @@ class MLIRGenerator
symbolTable_[node->name()] = {value, node};
}

// Convert a TTForge attribute to an MLIR attribute.
mlir::Attribute convert_to_mlir_attribute(const tt::BudaOpAttr& value) {
return std::visit([this](auto&& arg) -> mlir::Attribute {
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, std::string>) {
return builder_.getStringAttr(arg);
} else if constexpr (std::is_same_v<T, bool>) {
return builder_.getBoolAttr(arg);
} else if constexpr (std::is_same_v<T, int>) {
return builder_.getI32IntegerAttr(arg);
} else if constexpr (std::is_same_v<T, float>) {
return builder_.getF32FloatAttr(arg);
} else {
// If type not handled, throw an exception or handle it appropriately
throw std::runtime_error("Unhandled attribute type");
}
}, value);
}

/// Emit a new function in MLIR.
/// A function represents a set of TTForge operations that are executed to produce output results.
/// This function will generate the MLIR code for each TTForge operation in the graph and emit the return operation for the function.
Expand Down Expand Up @@ -204,15 +224,15 @@ class MLIRGenerator
::llvm::ArrayRef<::llvm::StringRef> operation_attributes = TTIROp::getAttributeNames();
for(auto attribute_name: operation_attributes)
{
if(attribute_name.equals("operand_constraints"))
if(attribute_name == "operand_constraints")
{
// Create operation constraint attributes
mlir::NamedAttribute operand_constraints_attribute = builder_.getNamedAttr(
"operand_constraints",
builder_.getArrayAttr(get_mlir_operand_constraint_attributes(graph, op_node)));
attributes.push_back(operand_constraints_attribute);
}
else if(attribute_name.equals(mlir::OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr()))
else if(attribute_name == mlir::OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr())
{
// Create operation segment sizes attributes
mlir::NamedAttribute operand_segment_sizes_attribute = builder_.getNamedAttr(
Expand All @@ -225,15 +245,13 @@ class MLIRGenerator
}
}

// Workaround for now, need to figure out how to handle this properly
if(op_node->op_name() == "softmax")
for(const auto & attribute: op_node->op_type().named_attrs)
{
log_info("Softmax");
int32_t dimension = std::get<int>(op_node->op_attrs()[0]);
mlir::NamedAttribute dimension_attribute = builder_.getNamedAttr(
"dimension",
builder_.getSI32IntegerAttr(dimension));
attributes.push_back(dimension_attribute);
// convert atribute to mlir atribute
auto mlir_atribute = convert_to_mlir_attribute(attribute.second);
mlir::NamedAttribute named_attribute = builder_.getNamedAttr(
attribute.first, mlir_atribute);
attributes.push_back(named_attribute);
}

auto op = builder_.create<TTIROp>(
Expand Down
4 changes: 2 additions & 2 deletions pybuda/pybuda/op/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def Softmax(
Tensor
Buda tensor
"""
return op("softmax", name, operandA, attrs=(dim, stable)).get_tensor()
return op("softmax", name, operandA, attrs=(dim, stable), dimension=dim).get_tensor()


def LogSoftmax(
Expand Down Expand Up @@ -82,7 +82,7 @@ def LogSoftmax(
Tensor
Buda tensor
"""
return op("log_softmax", name, operandA, attrs=(dim, stable)).get_tensor()
return op("log_softmax", name, operandA, attrs=(dim, stable), dimension=dim).get_tensor()

def Layernorm(
name: str,
Expand Down

0 comments on commit 8cc7c46

Please sign in to comment.