From 36cd00758a28801680f498d74af4caf4b8041606 Mon Sep 17 00:00:00 2001 From: Shizhi Tang Date: Sat, 4 May 2024 23:17:33 +0800 Subject: [PATCH] [Frontend] Refactor operator overloading Now all operator overloadings are defined in expr.py, decoupling expr.py and FFI. --- ffi/expr.cc | 122 ++++------- python/freetensor/core/expr.py | 372 +++++++++++++++++---------------- 2 files changed, 235 insertions(+), 259 deletions(-) diff --git a/ffi/expr.cc b/ffi/expr.cc index bd88fe045..3cfaa8709 100644 --- a/ffi/expr.cc +++ b/ffi/expr.cc @@ -186,90 +186,6 @@ void init_ffi_ast_expr(py::module_ &m) { .def(py::init([](int64_t val) { return makeIntConst(val); })) .def(py::init([](float val) { return makeFloatConst(val); })) .def(py::init([](const FrontendVar &var) { return var.asLoad(); })) - .def( - "__add__", - [](const Expr &lhs, const Expr &rhs) { return makeAdd(lhs, rhs); }, - py::is_operator()) - .def( - "__radd__", - [](const Expr &rhs, const Expr &lhs) { return makeAdd(lhs, rhs); }, - py::is_operator()) - .def( - "__sub__", - [](const Expr &lhs, const Expr &rhs) { return makeSub(lhs, rhs); }, - py::is_operator()) - .def( - "__rsub__", - [](const Expr &rhs, const Expr &lhs) { return makeSub(lhs, rhs); }, - py::is_operator()) - .def( - "__mul__", - [](const Expr &lhs, const Expr &rhs) { return makeMul(lhs, rhs); }, - py::is_operator()) - .def( - "__rmul__", - [](const Expr &rhs, const Expr &lhs) { return makeMul(lhs, rhs); }, - py::is_operator()) - .def( - "__truediv__", - [](const Expr &lhs, const Expr &rhs) { - return makeRealDiv(lhs, rhs); - }, - py::is_operator()) - .def( - "__rtruediv__", - [](const Expr &rhs, const Expr &lhs) { - return makeRealDiv(lhs, rhs); - }, - py::is_operator()) - .def( - "__floordiv__", - [](const Expr &lhs, const Expr &rhs) { - return makeFloorDiv(lhs, rhs); - }, - py::is_operator()) - .def( - "__rfloordiv__", - [](const Expr &rhs, const Expr &lhs) { - return makeFloorDiv(lhs, rhs); - }, - py::is_operator()) - .def( - "__mod__", - [](const Expr &lhs, const Expr &rhs) { return makeMod(lhs, rhs); }, - py::is_operator()) - .def( - "__rmod__", - [](const Expr &rhs, const Expr &lhs) { return makeMod(lhs, rhs); }, - py::is_operator()) - .def( - "__lt__", - [](const Expr &lhs, const Expr &rhs) { return makeLT(lhs, rhs); }, - py::is_operator()) - .def( - "__le__", - [](const Expr &lhs, const Expr &rhs) { return makeLE(lhs, rhs); }, - py::is_operator()) - .def( - "__gt__", - [](const Expr &lhs, const Expr &rhs) { return makeGT(lhs, rhs); }, - py::is_operator()) - .def( - "__ge__", - [](const Expr &lhs, const Expr &rhs) { return makeGE(lhs, rhs); }, - py::is_operator()) - .def( - "__eq__", - [](const Expr &lhs, const Expr &rhs) { return makeEQ(lhs, rhs); }, - py::is_operator()) - .def( - "__ne__", - [](const Expr &lhs, const Expr &rhs) { return makeNE(lhs, rhs); }, - py::is_operator()) - .def( - "__neg__", - [](const Expr &expr) { return makeSub(makeIntConst(0), expr); }, - py::is_operator()) .def_property_readonly( "dtype", [](const Expr &op) -> DataType { return op->dtype(); }); py::implicitly_convertible(); @@ -293,6 +209,10 @@ void init_ffi_ast_expr(py::module_ &m) { [](const std::string &_1, const std::vector &_2, const DataType _3) { return makeLoad(_1, _2, _3); }, "var"_a, "indices"_a, "load_type"_a); + m.def( + "makeMod", + [](const Expr &_1, const Expr &_2) { return makeMod(_1, _2); }, "lhs"_a, + "rhs"_a); m.def( "makeRemainder", [](const Expr &_1, const Expr &_2) { return makeRemainder(_1, _2); }, @@ -353,6 +273,22 @@ void init_ffi_ast_expr(py::module_ &m) { "makeCast", [](const Expr &_1, const DataType &_2) { return makeCast(_1, _2); }, "expr"_a, "dtype"_a); + m.def( + "makeAdd", + [](const Expr &_1, const Expr &_2) { return makeAdd(_1, _2); }, "lhs"_a, + "rhs"_a); + m.def( + "makeSub", + [](const Expr &_1, const Expr &_2) { return makeSub(_1, _2); }, "lhs"_a, + "rhs"_a); + m.def( + "makeMul", + [](const Expr &_1, const Expr &_2) { return makeMul(_1, _2); }, "lhs"_a, + "rhs"_a); + m.def( + "makeRealDiv", + [](const Expr &_1, const Expr &_2) { return makeRealDiv(_1, _2); }, + "lhs"_a, "rhs"_a); m.def( "makeFloorDiv", [](const Expr &_1, const Expr &_2) { return makeFloorDiv(_1, _2); }, @@ -367,6 +303,24 @@ void init_ffi_ast_expr(py::module_ &m) { return makeRoundTowards0Div(_1, _2); }, "lhs"_a, "rhs"_a); + m.def( + "makeLT", [](const Expr &_1, const Expr &_2) { return makeLT(_1, _2); }, + "lhs"_a, "rhs"_a); + m.def( + "makeLE", [](const Expr &_1, const Expr &_2) { return makeLE(_1, _2); }, + "lhs"_a, "rhs"_a); + m.def( + "makeGT", [](const Expr &_1, const Expr &_2) { return makeGT(_1, _2); }, + "lhs"_a, "rhs"_a); + m.def( + "makeGE", [](const Expr &_1, const Expr &_2) { return makeGE(_1, _2); }, + "lhs"_a, "rhs"_a); + m.def( + "makeEQ", [](const Expr &_1, const Expr &_2) { return makeEQ(_1, _2); }, + "lhs"_a, "rhs"_a); + m.def( + "makeNE", [](const Expr &_1, const Expr &_2) { return makeNE(_1, _2); }, + "lhs"_a, "rhs"_a); m.def( "makeIntrinsic", [](const std::string &_1, const std::vector &_2, diff --git a/python/freetensor/core/expr.py b/python/freetensor/core/expr.py index 925b48c4c..6926b2485 100644 --- a/python/freetensor/core/expr.py +++ b/python/freetensor/core/expr.py @@ -9,7 +9,7 @@ 'VarRef', 'VarRefFromVarDef', 'VarVersionRef', 'add', 'sub', 'mul', 'truediv', 'floordiv', 'ceildiv', 'round_towards_0_div', 'mod', 'remainder', 'min', 'max', 'l_and', 'l_or', 'lt', 'le', 'gt', 'ge', 'eq', 'ne', 'l_not', - 'abs', 'sqrt', 'exp', 'ln', 'square', 'sigmoid', 'sin', 'cos', 'tan', + 'neg', 'abs', 'sqrt', 'exp', 'ln', 'square', 'sigmoid', 'sin', 'cos', 'tan', 'tanh', 'floor', 'ceil', 'unbound', 'if_then_else', 'cast', 'intrinsic', 'any', 'load_at_version', 'ndim', 'shape', 'dtype', 'mtype' ] @@ -17,7 +17,6 @@ import collections import builtins import math -from numbers import Number from typing import Sequence, Union from dataclasses import dataclass @@ -187,18 +186,6 @@ def _parse_key(self, key): ffiIdx.append(ffi.FrontendVarIdx(idx)) return ffiIdx - def __add__(self, other): - if self.ndim > 0: - from .. import libop - return libop.add(self, other) - return self.as_load() + other - - def __radd__(self, other): - if self.ndim > 0: - from .. import libop - return libop.add(other, self) - return other + self.as_load() - def __iadd__(self, other): if self.ndim > 0: from .. import libop @@ -209,18 +196,6 @@ def __iadd__(self, other): self.as_reduce_to(ffi.ReduceOp.Add, top.get_metadata(), other)) return AlreadyMadeReduceTo - def __sub__(self, other): - if self.ndim > 0: - from .. import libop - return libop.sub(self, other) - return self.as_load() - other - - def __rsub__(self, other): - if self.ndim > 0: - from .. import libop - return libop.sub(other, self) - return other - self.as_load() - def __isub__(self, other): if self.ndim > 0: from .. import libop @@ -231,18 +206,6 @@ def __isub__(self, other): self.as_reduce_to(ffi.ReduceOp.Add, top.get_metadata(), -other)) return AlreadyMadeReduceTo - def __mul__(self, other): - if self.ndim > 0: - from .. import libop - return libop.mul(self, other) - return self.as_load() * other - - def __rmul__(self, other): - if self.ndim > 0: - from .. import libop - return libop.mul(other, self) - return other * self.as_load() - def __imul__(self, other): if self.ndim > 0: from .. import libop @@ -253,18 +216,6 @@ def __imul__(self, other): self.as_reduce_to(ffi.ReduceOp.Mul, top.get_metadata(), other)) return AlreadyMadeReduceTo - def __truediv__(self, other): - if self.ndim > 0: - from .. import libop - return libop.truediv(self, other) - return self.as_load() / other - - def __rtruediv__(self, other): - if self.ndim > 0: - from .. import libop - return libop.truediv(other, self) - return other / self.as_load() - def __itruediv__(self, other): if self.ndim > 0: from .. import libop @@ -275,18 +226,6 @@ def __itruediv__(self, other): self.as_reduce_to(ffi.ReduceOp.Mul, top.get_metadata(), 1. / other)) return AlreadyMadeReduceTo - def __floordiv__(self, other): - if self.ndim > 0: - from .. import libop - return libop.floordiv(self, other) - return self.as_load() // other - - def __rfloordiv__(self, other): - if self.ndim > 0: - from .. import libop - return libop.floordiv(other, self) - return other // self.as_load() - def __ifloordiv__(self, other): if self.ndim > 0: from .. import libop @@ -294,18 +233,6 @@ def __ifloordiv__(self, other): return AlreadyMadeReduceTo return NotImplemented # Fallback to x = x // y - def __mod__(self, other): - if self.ndim > 0: - from .. import libop - return libop.mod(self, other) - return self.as_load() % other - - def __rmod__(self, other): - if self.ndim > 0: - from .. import libop - return libop.mod(other, self) - return other % self.as_load() - def __imod__(self, other): if self.ndim > 0: from .. import libop @@ -313,42 +240,6 @@ def __imod__(self, other): return AlreadyMadeReduceTo return NotImplemented # Fallback to x = x % y - def __lt__(self, other): - if self.ndim > 0: - from .. import libop - return libop.lt(self, other) - return self.as_load() < other - - def __le__(self, other): - if self.ndim > 0: - from .. import libop - return libop.le(self, other) - return self.as_load() <= other - - def __gt__(self, other): - if self.ndim > 0: - from .. import libop - return libop.gt(self, other) - return self.as_load() > other - - def __ge__(self, other): - if self.ndim > 0: - from .. import libop - return libop.ge(self, other) - return self.as_load() >= other - - def __eq__(self, other): - if self.ndim > 0: - from .. import libop - return libop.eq(self, other) - return self.as_load() == other - - def __ne__(self, other): - if self.ndim > 0: - from .. import libop - return libop.ne(self, other) - return self.as_load() != other - def __neg__(self): if self.ndim > 0: from .. import libop @@ -421,10 +312,14 @@ def __init__(self, indices, True) -def _istensor(x): +def _is_tensor(x): return isinstance(x, VarRef) and x.ndim > 0 +def _is_runtime_scalar(x): + return isinstance(x, VarRef) and x.ndim == 0 or isinstance(x, ffi.Expr) + + ###################################### # Binary Operators @@ -448,6 +343,11 @@ def add(lhs, rhs): VarRef or Number The sum ''' + if _is_tensor(lhs) or _is_tensor(rhs): + from .. import libop + return libop.add(lhs, rhs) + if _is_runtime_scalar(lhs) or _is_runtime_scalar(rhs): + return ffi.makeAdd(lhs, rhs) return lhs + rhs @@ -470,6 +370,11 @@ def sub(lhs, rhs): VarRef or Number The difference ''' + if _is_tensor(lhs) or _is_tensor(rhs): + from .. import libop + return libop.sub(lhs, rhs) + if _is_runtime_scalar(lhs) or _is_runtime_scalar(rhs): + return ffi.makeSub(lhs, rhs) return lhs - rhs @@ -492,6 +397,11 @@ def mul(lhs, rhs): VarRef or Number The product ''' + if _is_tensor(lhs) or _is_tensor(rhs): + from .. import libop + return libop.mul(lhs, rhs) + if _is_runtime_scalar(lhs) or _is_runtime_scalar(rhs): + return ffi.makeMul(lhs, rhs) return lhs * rhs @@ -514,6 +424,11 @@ def truediv(lhs, rhs): VarRef or Number The quotient ''' + if _is_tensor(lhs) or _is_tensor(rhs): + from .. import libop + return libop.truediv(lhs, rhs) + if _is_runtime_scalar(lhs) or _is_runtime_scalar(rhs): + return ffi.makeRealDiv(lhs, rhs) return lhs / rhs @@ -539,6 +454,11 @@ def floordiv(lhs, rhs): VarRef or Number The quotient ''' + if _is_tensor(lhs) or _is_tensor(rhs): + from .. import libop + return libop.floordiv(lhs, rhs) + if _is_runtime_scalar(lhs) or _is_runtime_scalar(rhs): + return ffi.makeFloorDiv(lhs, rhs) return lhs // rhs @@ -563,12 +483,14 @@ def ceildiv(lhs, rhs): VarRef or Number The quotient ''' - if _istensor(lhs) or _istensor(rhs): + if _is_tensor(lhs) or _is_tensor(rhs): from .. import libop return libop.ceildiv(lhs, rhs) + if _is_runtime_scalar(lhs) or _is_runtime_scalar(rhs): + return ffi.makeCeilDiv(lhs, rhs) if type(lhs) is int and type(rhs) is int: return lhs // rhs + (lhs % rhs > 0) - return ffi.makeCeilDiv(lhs, rhs) + raise TypeError("ceildiv is only supported for integer operands") def round_towards_0_div(lhs, rhs): @@ -594,17 +516,23 @@ def round_towards_0_div(lhs, rhs): VarRef or Number The quotient ''' - if _istensor(lhs) or _istensor(rhs): + if _is_tensor(lhs) or _is_tensor(rhs): from .. import libop return libop.round_towards_0_div(lhs, rhs) - return ffi.makeRoundTowards0Div(lhs, rhs) + if _is_runtime_scalar(lhs) or _is_runtime_scalar(rhs): + return ffi.makeRoundTowards0Div(lhs, rhs) + if type(lhs) is int and type(rhs) is int: + return a // b if (a < 0) == (b < 0) else -(-a // b) + raise TypeError( + "round_towards_0_div is only supported for integer operands") def mod(lhs, rhs): ''' `lhs` modulus `rhs` - The result is always non-negative (following Python convention, instead of C). + This function follows floored division (following Python convention, instead of C). + See https://en.wikipedia.org/wiki/Modulo for details. This function is recommended over `remainder`, as it enjoys more optimizations For scalar operands, it emit an expression node in AST. For non-scalar operands, @@ -622,6 +550,11 @@ def mod(lhs, rhs): VarRef or Number The modulo ''' + if _is_tensor(lhs) or _is_tensor(rhs): + from .. import libop + return libop.mod(lhs, rhs) + if _is_runtime_scalar(lhs) or _is_runtime_scalar(rhs): + return ffi.makeMod(lhs, rhs) return lhs % rhs @@ -629,7 +562,8 @@ def remainder(lhs, rhs): ''' Remainder of `lhs` dividing `rhs` - The result can be positive or negative (following C convention, instead of Python). + This function follows truncated division (following C convention, instead of Python). + See https://en.wikipedia.org/wiki/Modulo for details. End users are encouraged to use `lhs % rhs` instead, which follows Python convetion, and enjoys better optimization in FreeTensor @@ -648,10 +582,14 @@ def remainder(lhs, rhs): VarRef or Number The remainder ''' - if _istensor(lhs) or _istensor(rhs): + if _is_tensor(lhs) or _is_tensor(rhs): from .. import libop return libop.remainder(lhs, rhs) - return ffi.makeRemainder(lhs, rhs) + if _is_runtime_scalar(lhs) or _is_runtime_scalar(rhs): + return ffi.makeRemainder(lhs, rhs) + if type(lhs) is int and type(rhs) is int: + return lhs - rhs * round_towards_0_div(lhs, rhs) + raise TypeError("remainder is only supported for integer operands") def min(lhs, rhs): @@ -673,12 +611,12 @@ def min(lhs, rhs): VarRef or Number The minimum ''' - if _istensor(lhs) or _istensor(rhs): + if _is_tensor(lhs) or _is_tensor(rhs): from .. import libop return libop.min(lhs, rhs) - if isinstance(lhs, Number) and isinstance(rhs, Number): - return builtins.min(lhs, rhs) - return ffi.makeMin(lhs, rhs) + if _is_runtime_scalar(lhs) or _is_runtime_scalar(rhs): + return ffi.makeMin(lhs, rhs) + return builtins.min(lhs, rhs) def max(lhs, rhs): @@ -700,12 +638,12 @@ def max(lhs, rhs): VarRef or Number The maximum ''' - if _istensor(lhs) or _istensor(rhs): + if _is_tensor(lhs) or _is_tensor(rhs): from .. import libop return libop.max(lhs, rhs) - if isinstance(lhs, Number) and isinstance(rhs, Number): - return builtins.max(lhs, rhs) - return ffi.makeMax(lhs, rhs) + if _is_runtime_scalar(lhs) or _is_runtime_scalar(rhs): + return ffi.makeMax(lhs, rhs) + return builtins.max(lhs, rhs) def l_and(lhs, rhs): @@ -729,13 +667,14 @@ def l_and(lhs, rhs): VarRef or Number The logical and ''' - if _istensor(lhs) or _istensor(rhs): + if _is_tensor(lhs) or _is_tensor(rhs): from .. import libop return libop.l_and(lhs, rhs) + if _is_runtime_scalar(lhs) or _is_runtime_scalar(rhs): + return ffi.makeLAnd(lhs, rhs) if type(lhs) is bool and type(rhs) is bool: return lhs and rhs - else: - return ffi.makeLAnd(lhs, rhs) + raise TypeError("l_and is only supported for boolean operands") def l_or(lhs, rhs): @@ -759,13 +698,14 @@ def l_or(lhs, rhs): VarRef or Number The logical or ''' - if _istensor(lhs) or _istensor(rhs): + if _is_tensor(lhs) or _is_tensor(rhs): from .. import libop return libop.l_or(lhs, rhs) + if _is_runtime_scalar(lhs) or _is_runtime_scalar(rhs): + return ffi.makeLOr(lhs, rhs) if type(lhs) is bool and type(rhs) is bool: return lhs or rhs - else: - return ffi.makeLOr(lhs, rhs) + raise TypeError("l_or is only supported for boolean operands") def lt(lhs, rhs): @@ -787,6 +727,11 @@ def lt(lhs, rhs): VarRef or Number The comparison ''' + if _is_tensor(lhs) or _is_tensor(rhs): + from .. import libop + return libop.lt(lhs, rhs) + if _is_runtime_scalar(lhs) or _is_runtime_scalar(rhs): + return ffi.makeLT(lhs, rhs) return lhs < rhs @@ -809,6 +754,11 @@ def le(lhs, rhs): VarRef or Number The comparison ''' + if _is_tensor(lhs) or _is_tensor(rhs): + from .. import libop + return libop.le(lhs, rhs) + if _is_runtime_scalar(lhs) or _is_runtime_scalar(rhs): + return ffi.makeLE(lhs, rhs) return lhs <= rhs @@ -831,6 +781,11 @@ def gt(lhs, rhs): VarRef or Number The comparison ''' + if _is_tensor(lhs) or _is_tensor(rhs): + from .. import libop + return libop.gt(lhs, rhs) + if _is_runtime_scalar(lhs) or _is_runtime_scalar(rhs): + return ffi.makeGT(lhs, rhs) return lhs > rhs @@ -853,6 +808,11 @@ def ge(lhs, rhs): VarRef or Number The comparison ''' + if _is_tensor(lhs) or _is_tensor(rhs): + from .. import libop + return libop.ge(lhs, rhs) + if _is_runtime_scalar(lhs) or _is_runtime_scalar(rhs): + return ffi.makeGE(lhs, rhs) return lhs >= rhs @@ -875,6 +835,11 @@ def eq(lhs, rhs): VarRef or Number The comparison ''' + if _is_tensor(lhs) or _is_tensor(rhs): + from .. import libop + return libop.eq(lhs, rhs) + if _is_runtime_scalar(lhs) or _is_runtime_scalar(rhs): + return ffi.makeEQ(lhs, rhs) return lhs == rhs @@ -897,6 +862,11 @@ def ne(lhs, rhs): VarRef or Number The comparison ''' + if _is_tensor(lhs) or _is_tensor(rhs): + from .. import libop + return libop.ne(lhs, rhs) + if _is_runtime_scalar(lhs) or _is_runtime_scalar(rhs): + return ffi.makeNE(lhs, rhs) return lhs != rhs @@ -921,13 +891,39 @@ def l_not(expr): VarRef or Number The logical not ''' - if _istensor(expr): + if _is_tensor(expr): from .. import libop return libop.l_not(expr) + if _is_runtime_scalar(expr): + return ffi.makeLNot(expr) if type(expr) is bool: return not expr - else: - return ffi.makeLNot(expr) + raise TypeError("l_not is only supported for boolean operands") + + +def neg(expr): + ''' + Negation + + For scalar operands, it emit an expression node in AST. For non-scalar operands, + it calls libop.neg + + Parameters + ---------- + expr : VarRef or Number + The operand + + Returns + ------- + VarRef or Number + The negation + ''' + if _is_tensor(expr): + from .. import libop + return libop.neg(expr) + if _is_runtime_scalar(expr): + return ffi.makeSub(0, expr) + return -expr def abs(expr): @@ -947,12 +943,12 @@ def abs(expr): VarRef or Number The absolute value ''' - if _istensor(expr): + if _is_tensor(expr): from .. import libop return libop.abs(expr) - if isinstance(expr, Number): - return builtins.abs(expr) - return ffi.makeAbs(expr) + if _is_runtime_scalar(expr): + return ffi.makeAbs(expr) + return builtins.abs(expr) def sqrt(expr): @@ -972,12 +968,12 @@ def sqrt(expr): VarRef or Number The square root ''' - if _istensor(expr): + if _is_tensor(expr): from .. import libop return libop.sqrt(expr) - if isinstance(expr, Number): - return math.sqrt(expr) - return ffi.makeSqrt(expr) + if _is_runtime_scalar(expr): + return ffi.makeSqrt(expr) + return math.sqrt(expr) def exp(expr): @@ -997,12 +993,12 @@ def exp(expr): VarRef or Number The exponent ''' - if _istensor(expr): + if _is_tensor(expr): from .. import libop return libop.exp(expr) - if isinstance(expr, Number): - return math.exp(expr) - return ffi.makeExp(expr) + if _is_runtime_scalar(expr): + return ffi.makeExp(expr) + return math.exp(expr) def ln(expr): @@ -1022,12 +1018,12 @@ def ln(expr): VarRef or Number The exponent ''' - if _istensor(expr): + if _is_tensor(expr): from .. import libop return libop.ln(expr) - if isinstance(expr, Number): - return math.log(expr) # Defaults to ln without the base - return ffi.makeLn(expr) + if _is_runtime_scalar(expr): + return ffi.makeLn(expr) + return math.log(expr) # Defaults to ln without the base def square(expr): @@ -1047,12 +1043,12 @@ def square(expr): VarRef or Number The square ''' - if _istensor(expr): + if _is_tensor(expr): from .. import libop return libop.square(expr) - if isinstance(expr, Number): - return expr * expr - return ffi.makeSquare(expr) + if _is_runtime_scalar(expr): + return ffi.makeSquare(expr) + return expr * expr def sigmoid(expr): @@ -1072,7 +1068,7 @@ def sigmoid(expr): VarRef or Number The result ''' - if _istensor(expr): + if _is_tensor(expr): from .. import libop return libop.sigmoid(expr) return ffi.makeSigmoid(expr) @@ -1095,12 +1091,12 @@ def sin(expr): VarRef or Number The result ''' - if _istensor(expr): + if _is_tensor(expr): from .. import libop return libop.sin(expr) - if isinstance(expr, Number): - return math.sin(expr) - return ffi.makeSin(expr) + if _is_runtime_scalar(expr): + return ffi.makeSin(expr) + return math.sin(expr) def cos(expr): @@ -1120,12 +1116,12 @@ def cos(expr): VarRef or Number The result ''' - if _istensor(expr): + if _is_tensor(expr): from .. import libop return libop.cos(expr) - if isinstance(expr, Number): - return math.cos(expr) - return ffi.makeCos(expr) + if _is_runtime_scalar(expr): + return ffi.makeCos(expr) + return math.cos(expr) def tan(expr): @@ -1145,12 +1141,12 @@ def tan(expr): VarRef or Number The result ''' - if _istensor(expr): + if _is_tensor(expr): from .. import libop return libop.tan(expr) - if isinstance(expr, Number): - return math.tan(expr) - return ffi.makeTan(expr) + if _is_runtime_scalar(expr): + return ffi.makeTan(expr) + return math.tan(expr) def tanh(expr): @@ -1170,12 +1166,12 @@ def tanh(expr): VarRef or Number The result ''' - if _istensor(expr): + if _is_tensor(expr): from .. import libop return libop.tanh(expr) - if isinstance(expr, Number): - return math.tanh(expr) - return ffi.makeTanh(expr) + if _is_runtime_scalar(expr): + return ffi.makeTanh(expr) + return math.tanh(expr) def floor(expr): @@ -1195,7 +1191,7 @@ def floor(expr): VarRef or Number The result ''' - if _istensor(expr): + if _is_tensor(expr): from .. import libop return libop.floor(expr) return ffi.makeFloor(expr) @@ -1218,7 +1214,7 @@ def ceil(expr): VarRef or Number The result ''' - if _istensor(expr): + if _is_tensor(expr): from .. import libop return libop.ceil(expr) return ffi.makeCeil(expr) @@ -1394,3 +1390,29 @@ def mtype(var): return var.mtype else: return 'byvalue' + + +def register_operators(cls): + cls.__add__ = add + cls.__radd__ = lambda self, other: add(other, self) + cls.__sub__ = sub + cls.__rsub__ = lambda self, other: sub(other, self) + cls.__mul__ = mul + cls.__rmul__ = lambda self, other: mul(other, self) + cls.__truediv__ = truediv + cls.__rtruediv__ = lambda self, other: truediv(other, self) + cls.__floordiv__ = floordiv + cls.__rfloordiv__ = lambda self, other: floordiv(other, self) + cls.__mod__ = mod + cls.__rmod__ = lambda self, other: mod(other, self) + cls.__lt__ = lt + cls.__le__ = le + cls.__gt__ = gt + cls.__ge__ = ge + cls.__eq__ = eq + cls.__ne__ = ne + cls.__neg__ = neg + + +register_operators(ffi.FrontendVar) +register_operators(ffi.Expr)