Skip to content

Commit

Permalink
[PIR] Print value info on python (PaddlePaddle#57471)
Browse files Browse the repository at this point in the history
* fix bug

* rewrite __str__ in value and opresult to print info

* fix bug

* change as reviewed comments

* change as reviewed comments

* fix print str
  • Loading branch information
chen2016013 authored and iosmers committed Sep 21, 2023
1 parent fa12e0f commit e244f06
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 2 deletions.
37 changes: 36 additions & 1 deletion paddle/fluid/pybind/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/ir_adaptor/translator/translate.h"
#include "paddle/fluid/ir_adaptor/translator/utils.h"
#include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h"
#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h"
#include "paddle/fluid/pir/dialect/operator/ir/api_builder.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
Expand Down Expand Up @@ -91,6 +92,20 @@ inline void SetProgramInt64Attr(std::shared_ptr<Program> program,
attr_name, pir::Int64Attribute::get(pir::IrContext::Instance(), value));
}

std::string GetValueInfo(Value v) {
std::stringstream ss;
ss << "define_op_name=" << v.dyn_cast<OpResult>().owner()->name();
ss << ", index=" << v.dyn_cast<OpResult>().index();
ss << ", dtype=" << v.type();
if (v.type().isa<paddle::dialect::AllocatedDenseTensorType>()) {
ss << ", place="
<< v.type()
.dyn_cast<paddle::dialect::AllocatedDenseTensorType>()
.place();
}
return ss.str();
}

void BindProgram(py::module *m) {
py::class_<Program, std::shared_ptr<Program>> program(*m, "Program", R"DOC(
Create Python Program. Program is an abstraction of model structure, divided into
Expand Down Expand Up @@ -353,7 +368,14 @@ void BindValue(py::module *m) {
return self.impl() == other.Value::impl();
})
.def("__hash__",
[](const Value &self) { return std::hash<pir::Value>{}(self); });
[](const Value &self) { return std::hash<pir::Value>{}(self); })
.def("__str__", [](const Value &self) -> py::str {
std::ostringstream print_stream;
print_stream << "Value(";
print_stream << GetValueInfo(self);
print_stream << ")";
return print_stream.str();
});
}

void BindOpOperand(py::module *m) {
Expand Down Expand Up @@ -472,6 +494,19 @@ void BindOpResult(py::module *m) {
})
.def("__hash__",
[](OpResult &self) { return std::hash<pir::Value>{}(self); })
.def("__str__",
[](OpResult &self) -> py::str {
std::ostringstream print_stream;
print_stream << "OpResult(";
print_stream << GetValueInfo(self);
if (GetOpResultBoolAttr(self, kAttrStopGradients)) {
print_stream << ", stop_gradient=True";
} else {
print_stream << ", stop_gradient=False";
}
print_stream << ")";
return print_stream.str();
})
.def(
"get_defining_op",
[](const OpResult &self) -> pir::Operation * {
Expand Down
5 changes: 5 additions & 0 deletions paddle/pir/core/ir_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,11 @@ void Operation::Print(std::ostream& os) {
printer.PrintOperation(this);
}

void Value::Print(std::ostream& os) const {
IrPrinter printer(os);
printer.PrintValue(*this);
}

void Type::Print(std::ostream& os) const {
BasicIrPrinter printer(os);
printer.PrintType(*this);
Expand Down
2 changes: 2 additions & 0 deletions paddle/pir/core/value.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class IR_API Value {

OpOperand first_use() const;

void Print(std::ostream &os) const;

bool use_empty() const;

bool HasOneUse() const;
Expand Down
10 changes: 9 additions & 1 deletion test/ir/new_ir/test_ir_pybind.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,25 @@ def test_value(self):
)
# test value == opresult
self.assertEqual(add_op.operands_source()[0], matmul_op.results()[0])
# test opresult print
self.assertTrue(
'dtype=pd_op.tensor<4x4xf32>'
in add_op.operands_source()[0].__str__()
)
# test opresult == value
self.assertEqual(
add_op.operands()[0].source(), add_op.operands_source()[0]
)
# test opresult == opresult
self.assertEqual(add_op.operands()[0].source(), matmul_op.results()[0])

# test opresult print
self.assertEqual(
tanh_op.operands()[0].source().get_defining_op().name(), "pd_op.add"
)

self.assertTrue(
'pd_op.tensor<4x4xf32>' in tanh_op.operands()[0].source().__str__()
)
add_op.replace_all_uses_with(matmul_op.results())
self.assertEqual(
tanh_op.operands()[0].source().get_defining_op().name(),
Expand Down

0 comments on commit e244f06

Please sign in to comment.