Skip to content

Commit

Permalink
[Frontend] Refactor operator overloading (#617)
Browse files Browse the repository at this point in the history
* [Frontend] Refactor operator overloading

Now all operator overloadings are defined in expr.py, decoupling expr.py
and FFI.

* `register_operators` should be private
  • Loading branch information
roastduck authored May 5, 2024
1 parent 1b427cc commit 2456038
Show file tree
Hide file tree
Showing 2 changed files with 235 additions and 259 deletions.
122 changes: 38 additions & 84 deletions ffi/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,90 +186,6 @@ void init_ffi_ast_expr(py::module_ &m) {
.def(py::init([](int64_t val) { return makeIntConst(val); }))
.def(py::init([](float val) { return makeFloatConst(val); }))
.def(py::init([](const FrontendVar &var) { return var.asLoad(); }))
.def(
"__add__",
[](const Expr &lhs, const Expr &rhs) { return makeAdd(lhs, rhs); },
py::is_operator())
.def(
"__radd__",
[](const Expr &rhs, const Expr &lhs) { return makeAdd(lhs, rhs); },
py::is_operator())
.def(
"__sub__",
[](const Expr &lhs, const Expr &rhs) { return makeSub(lhs, rhs); },
py::is_operator())
.def(
"__rsub__",
[](const Expr &rhs, const Expr &lhs) { return makeSub(lhs, rhs); },
py::is_operator())
.def(
"__mul__",
[](const Expr &lhs, const Expr &rhs) { return makeMul(lhs, rhs); },
py::is_operator())
.def(
"__rmul__",
[](const Expr &rhs, const Expr &lhs) { return makeMul(lhs, rhs); },
py::is_operator())
.def(
"__truediv__",
[](const Expr &lhs, const Expr &rhs) {
return makeRealDiv(lhs, rhs);
},
py::is_operator())
.def(
"__rtruediv__",
[](const Expr &rhs, const Expr &lhs) {
return makeRealDiv(lhs, rhs);
},
py::is_operator())
.def(
"__floordiv__",
[](const Expr &lhs, const Expr &rhs) {
return makeFloorDiv(lhs, rhs);
},
py::is_operator())
.def(
"__rfloordiv__",
[](const Expr &rhs, const Expr &lhs) {
return makeFloorDiv(lhs, rhs);
},
py::is_operator())
.def(
"__mod__",
[](const Expr &lhs, const Expr &rhs) { return makeMod(lhs, rhs); },
py::is_operator())
.def(
"__rmod__",
[](const Expr &rhs, const Expr &lhs) { return makeMod(lhs, rhs); },
py::is_operator())
.def(
"__lt__",
[](const Expr &lhs, const Expr &rhs) { return makeLT(lhs, rhs); },
py::is_operator())
.def(
"__le__",
[](const Expr &lhs, const Expr &rhs) { return makeLE(lhs, rhs); },
py::is_operator())
.def(
"__gt__",
[](const Expr &lhs, const Expr &rhs) { return makeGT(lhs, rhs); },
py::is_operator())
.def(
"__ge__",
[](const Expr &lhs, const Expr &rhs) { return makeGE(lhs, rhs); },
py::is_operator())
.def(
"__eq__",
[](const Expr &lhs, const Expr &rhs) { return makeEQ(lhs, rhs); },
py::is_operator())
.def(
"__ne__",
[](const Expr &lhs, const Expr &rhs) { return makeNE(lhs, rhs); },
py::is_operator())
.def(
"__neg__",
[](const Expr &expr) { return makeSub(makeIntConst(0), expr); },
py::is_operator())
.def_property_readonly(
"dtype", [](const Expr &op) -> DataType { return op->dtype(); });
py::implicitly_convertible<int, ExprNode>();
Expand All @@ -293,6 +209,10 @@ void init_ffi_ast_expr(py::module_ &m) {
[](const std::string &_1, const std::vector<Expr> &_2,
const DataType _3) { return makeLoad(_1, _2, _3); },
"var"_a, "indices"_a, "load_type"_a);
m.def(
"makeMod",
[](const Expr &_1, const Expr &_2) { return makeMod(_1, _2); }, "lhs"_a,
"rhs"_a);
m.def(
"makeRemainder",
[](const Expr &_1, const Expr &_2) { return makeRemainder(_1, _2); },
Expand Down Expand Up @@ -353,6 +273,22 @@ void init_ffi_ast_expr(py::module_ &m) {
"makeCast",
[](const Expr &_1, const DataType &_2) { return makeCast(_1, _2); },
"expr"_a, "dtype"_a);
m.def(
"makeAdd",
[](const Expr &_1, const Expr &_2) { return makeAdd(_1, _2); }, "lhs"_a,
"rhs"_a);
m.def(
"makeSub",
[](const Expr &_1, const Expr &_2) { return makeSub(_1, _2); }, "lhs"_a,
"rhs"_a);
m.def(
"makeMul",
[](const Expr &_1, const Expr &_2) { return makeMul(_1, _2); }, "lhs"_a,
"rhs"_a);
m.def(
"makeRealDiv",
[](const Expr &_1, const Expr &_2) { return makeRealDiv(_1, _2); },
"lhs"_a, "rhs"_a);
m.def(
"makeFloorDiv",
[](const Expr &_1, const Expr &_2) { return makeFloorDiv(_1, _2); },
Expand All @@ -367,6 +303,24 @@ void init_ffi_ast_expr(py::module_ &m) {
return makeRoundTowards0Div(_1, _2);
},
"lhs"_a, "rhs"_a);
m.def(
"makeLT", [](const Expr &_1, const Expr &_2) { return makeLT(_1, _2); },
"lhs"_a, "rhs"_a);
m.def(
"makeLE", [](const Expr &_1, const Expr &_2) { return makeLE(_1, _2); },
"lhs"_a, "rhs"_a);
m.def(
"makeGT", [](const Expr &_1, const Expr &_2) { return makeGT(_1, _2); },
"lhs"_a, "rhs"_a);
m.def(
"makeGE", [](const Expr &_1, const Expr &_2) { return makeGE(_1, _2); },
"lhs"_a, "rhs"_a);
m.def(
"makeEQ", [](const Expr &_1, const Expr &_2) { return makeEQ(_1, _2); },
"lhs"_a, "rhs"_a);
m.def(
"makeNE", [](const Expr &_1, const Expr &_2) { return makeNE(_1, _2); },
"lhs"_a, "rhs"_a);
m.def(
"makeIntrinsic",
[](const std::string &_1, const std::vector<Expr> &_2,
Expand Down
Loading

0 comments on commit 2456038

Please sign in to comment.