diff --git a/include/math/linear.h b/include/math/linear.h index 6fb372964..0f4e354e5 100644 --- a/include/math/linear.h +++ b/include/math/linear.h @@ -25,7 +25,7 @@ template struct LinearExpr { // `LinearExpr`s are the same, but std::map is too slow. So, we are using // std::vector and sort each factor by its hash std::vector> coeff_; - T bias_; + T bias_ = 0; bool isConst() const { return coeff_.empty(); } @@ -126,8 +126,8 @@ bool hasIdenticalCoeff(const LinearExpr &lhs, const LinearExpr &rhs) { * directinos */ template -requires std::integral || std::floating_point - Expr lin2expr(const LinearExpr &lin) { + requires std::integral || std::floating_point +Expr lin2expr(const LinearExpr &lin) { Expr b = makeIntConst(lin.bias_); for (auto &&item : lin.coeff_) { diff --git a/src/pass/simplify.cc b/src/pass/simplify.cc index 4a2b2eb73..60b44cacd 100644 --- a/src/pass/simplify.cc +++ b/src/pass/simplify.cc @@ -1,6 +1,7 @@ #include #include +#include #include #include #include @@ -808,6 +809,32 @@ Expr SimplifyPass::visit(const IfExpr &_op) { // intrinsic(..., type_a) : intrinsic(..., type_b)`. } + auto thenLin = linear(op->thenCase_), elseLin = linear(op->elseCase_); + LinearExpr common; + for (size_t i = 0, j = 0, m = thenLin.coeff_.size(), + n = elseLin.coeff_.size(); + i < m && j < n; i++) { + // LinearExpr<...>::coeff_ is sorted by hash + while (j < n && + elseLin.coeff_[j].a_->hash() < thenLin.coeff_[i].a_->hash()) { + j++; + } + if (thenLin.coeff_[i].k_ == elseLin.coeff_[j].k_ && + HashComparator{}(thenLin.coeff_[i].a_, elseLin.coeff_[j].a_)) { + common.coeff_.emplace_back(thenLin.coeff_[i]); + thenLin.coeff_[i++].k_ = elseLin.coeff_[j++].k_ = 0; + } + } + if (thenLin.bias_ == elseLin.bias_) { + common.bias_ = thenLin.bias_; + thenLin.bias_ = elseLin.bias_ = 0; + } + if (!common.coeff_.empty() || common.bias_ != 0) { + return makeAdd( + lin2expr(common), + makeIfExpr(op->cond_, lin2expr(thenLin), lin2expr(elseLin))); + } + return op; } diff --git a/test/20.pass/test_simplify.py b/test/20.pass/test_simplify.py index c93d57125..9cb59ba33 100644 --- a/test/20.pass/test_simplify.py +++ b/test/20.pass/test_simplify.py @@ -937,6 +937,27 @@ def test_if_expr_in_cond(p): assert std.match(ast) +@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify]) +def test_sink_if_expr_into_linear_expression(p): + with ft.VarDef([("a", (), "int32", "input", "cpu"), + ("b", (), "int32", "input", "cpu"), + ("c", (), "int32", "input", "cpu"), + ("y", (), "int32", "output", "cpu")]) as (a, b, c, y): + y[...] = ft.if_then_else(c < 0, 3 * a, 3 * a + 2 * b) + ast = ft.pop_ast(verbose=True) + ast = p(ast) + print(ast) + + with ft.VarDef([("a", (), "int32", "input", "cpu"), + ("b", (), "int32", "input", "cpu"), + ("c", (), "int32", "input", "cpu"), + ("y", (), "int32", "output", "cpu")]) as (a, b, c, y): + y[...] = 3 * a + ft.if_then_else(c < 0, 0, 2 * b) + std = ft.pop_ast() + + assert std.match(ast) + + @pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify, ft.z3_simplify]) def test_accessible_after_writing_if(p): with ft.VarDef([("x", (4,), "int32", "inout", "cpu"),