Skip to content

Commit

Permalink
Removed Asserts in Python Bindings (#1545)
Browse files Browse the repository at this point in the history
* Removed Asserts in Python Bindings

* Changed explorer use of asserted binding
  • Loading branch information
vprajapati-tt authored Dec 10, 2024
1 parent f2f2e97 commit f0e03af
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 18 deletions.
15 changes: 9 additions & 6 deletions python/TTModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,17 @@ void populateTTModule(py::module &m) {
.withElementType(unwrap(ctx), unwrap(elementType));
})
.def("getLayout",
[](MlirType &type) {
assert(isa<RankedTensorType>(
unwrap(type))); // Make sure that this is operating on a
// RankedTensorType object
[](MlirType &type) -> std::variant<tt::MetalLayoutAttr, py::object> {
// Make sure that this is operating on a RankedTensorType object
if (not isa<RankedTensorType>(unwrap(type))) {
return py::none();
}
RankedTensorType tensor =
mlir::cast<RankedTensorType>(unwrap(type));
assert(tensor.getEncoding()); // Make sure that this Tensor has an
// encoding value
// Make sure that this Tensor has an encoding value
if (not tensor.getEncoding()) {
return py::none();
}
tt::MetalLayoutAttr layout =
mlir::cast<tt::MetalLayoutAttr>(tensor.getEncoding());
return layout;
Expand Down
16 changes: 9 additions & 7 deletions python/TTNNModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,14 @@ void populateTTNNModule(py::module &m) {
.def_property_readonly(
"memref",
[](tt::ttnn::TTNNLayoutAttr self) { return wrap(self.getMemref()); })
.def_property_readonly(
"memory_layout_as_int", [](tt::ttnn::TTNNLayoutAttr self) {
if (!self.getMemLayout()) {
assert(false && "Memory layout is not set");
}
return static_cast<uint32_t>(self.getMemLayout().getValue());
});
.def_property_readonly("memory_layout_as_int",
[](tt::ttnn::TTNNLayoutAttr self)
-> std::variant<uint32_t, py::object> {
if (!self.getMemLayout()) {
return py::none();
}
return static_cast<uint32_t>(
self.getMemLayout().getValue());
});
}
} // namespace mlir::ttmlir::python
12 changes: 7 additions & 5 deletions tools/explorer/tt_adapter/src/tt_adapter/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,12 +398,14 @@ def parse_ttnn_ttnn_layout(attr):
layout = ttnn.ir.TTNNLayoutAttr.maybe_downcast(attr)
result = []
result.append(graph_builder.KeyValue(key="linear", value=str(layout.linear)))
result.append(
graph_builder.KeyValue(
key="memory_layout",
value=str(ttnn.TensorMemoryLayout(layout.memory_layout_as_int)),
memory_layout = layout.memory_layout_as_int
if memory_layout is not None:
result.append(
graph_builder.KeyValue(
key="memory_layout",
value=str(ttnn.TensorMemoryLayout(memory_layout)),
)
)
)
result.append(
graph_builder.KeyValue(
key="grid_shape", value="x".join(map(str, layout.grid_attr.shape))
Expand Down

0 comments on commit f0e03af

Please sign in to comment.