diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index 913d7d6f7aa80d..22fd0f40a36b5c 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -27,6 +27,7 @@ #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/ir_adaptor/translator/translate.h" #include "paddle/fluid/ir_adaptor/translator/utils.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" #include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" #include "paddle/fluid/pir/dialect/operator/ir/api_builder.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" @@ -91,6 +92,20 @@ inline void SetProgramInt64Attr(std::shared_ptr program, attr_name, pir::Int64Attribute::get(pir::IrContext::Instance(), value)); } +std::string GetValueInfo(Value v) { + std::stringstream ss; + ss << "define_op_name=" << v.dyn_cast().owner()->name(); + ss << ", index=" << v.dyn_cast().index(); + ss << ", dtype=" << v.type(); + if (v.type().isa()) { + ss << ", place=" + << v.type() + .dyn_cast() + .place(); + } + return ss.str(); +} + void BindProgram(py::module *m) { py::class_> program(*m, "Program", R"DOC( Create Python Program. Program is an abstraction of model structure, divided into @@ -353,7 +368,14 @@ void BindValue(py::module *m) { return self.impl() == other.Value::impl(); }) .def("__hash__", - [](const Value &self) { return std::hash{}(self); }); + [](const Value &self) { return std::hash{}(self); }) + .def("__str__", [](const Value &self) -> py::str { + std::ostringstream print_stream; + print_stream << "Value("; + print_stream << GetValueInfo(self); + print_stream << ")"; + return print_stream.str(); + }); } void BindOpOperand(py::module *m) { @@ -472,6 +494,19 @@ void BindOpResult(py::module *m) { }) .def("__hash__", [](OpResult &self) { return std::hash{}(self); }) + .def("__str__", + [](OpResult &self) -> py::str { + std::ostringstream print_stream; + print_stream << "OpResult("; + print_stream << GetValueInfo(self); + if (GetOpResultBoolAttr(self, kAttrStopGradients)) { + print_stream << ", stop_gradient=True"; + } else { + print_stream << ", stop_gradient=False"; + } + print_stream << ")"; + return print_stream.str(); + }) .def( "get_defining_op", [](const OpResult &self) -> pir::Operation * { diff --git a/paddle/pir/core/ir_printer.cc b/paddle/pir/core/ir_printer.cc index 52c49be8121046..260d42e035e4d4 100644 --- a/paddle/pir/core/ir_printer.cc +++ b/paddle/pir/core/ir_printer.cc @@ -317,6 +317,11 @@ void Operation::Print(std::ostream& os) { printer.PrintOperation(this); } +void Value::Print(std::ostream& os) const { + IrPrinter printer(os); + printer.PrintValue(*this); +} + void Type::Print(std::ostream& os) const { BasicIrPrinter printer(os); printer.PrintType(*this); diff --git a/paddle/pir/core/value.h b/paddle/pir/core/value.h index 81a1717540e3db..00c7aa123746eb 100644 --- a/paddle/pir/core/value.h +++ b/paddle/pir/core/value.h @@ -72,6 +72,8 @@ class IR_API Value { OpOperand first_use() const; + void Print(std::ostream &os) const; + bool use_empty() const; bool HasOneUse() const; diff --git a/test/ir/new_ir/test_ir_pybind.py b/test/ir/new_ir/test_ir_pybind.py index 34aa4c90c873fb..b9a6fb92ac5482 100644 --- a/test/ir/new_ir/test_ir_pybind.py +++ b/test/ir/new_ir/test_ir_pybind.py @@ -103,6 +103,11 @@ def test_value(self): ) # test value == opresult self.assertEqual(add_op.operands_source()[0], matmul_op.results()[0]) + # test opresult print + self.assertTrue( + 'dtype=pd_op.tensor<4x4xf32>' + in add_op.operands_source()[0].__str__() + ) # test opresult == value self.assertEqual( add_op.operands()[0].source(), add_op.operands_source()[0] @@ -110,10 +115,13 @@ def test_value(self): # test opresult == opresult self.assertEqual(add_op.operands()[0].source(), matmul_op.results()[0]) + # test opresult print self.assertEqual( tanh_op.operands()[0].source().get_defining_op().name(), "pd_op.add" ) - + self.assertTrue( + 'pd_op.tensor<4x4xf32>' in tanh_op.operands()[0].source().__str__() + ) add_op.replace_all_uses_with(matmul_op.results()) self.assertEqual( tanh_op.operands()[0].source().get_defining_op().name(),