Skip to content

Commit

Permalink
More rules for IfExpr in pass/simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
roastduck committed Jan 18, 2024
1 parent c182842 commit 0906acd
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 3 deletions.
6 changes: 3 additions & 3 deletions include/math/linear.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ template <class T> struct LinearExpr {
// `LinearExpr`s are the same, but std::map is too slow. So, we are using
// std::vector and sort each factor by its hash
std::vector<Scale<T>> coeff_;
T bias_;
T bias_ = 0;

bool isConst() const { return coeff_.empty(); }

Expand Down Expand Up @@ -126,8 +126,8 @@ bool hasIdenticalCoeff(const LinearExpr<T> &lhs, const LinearExpr<T> &rhs) {
* directinos
*/
template <class T>
requires std::integral<T> || std::floating_point<T>
Expr lin2expr(const LinearExpr<T> &lin) {
requires std::integral<T> || std::floating_point<T>
Expr lin2expr(const LinearExpr<T> &lin) {
Expr b = makeIntConst(lin.bias_);

for (auto &&item : lin.coeff_) {
Expand Down
27 changes: 27 additions & 0 deletions src/pass/simplify.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <algorithm>
#include <unordered_set>

#include <analyze/analyze_linear.h>
#include <analyze/as_dnf.h>
#include <except.h>
#include <math/min_max.h>
Expand Down Expand Up @@ -808,6 +809,32 @@ Expr SimplifyPass::visit(const IfExpr &_op) {
// intrinsic(..., type_a) : intrinsic(..., type_b)`.
}

auto thenLin = linear(op->thenCase_), elseLin = linear(op->elseCase_);
LinearExpr<int64_t> common;
for (size_t i = 0, j = 0, m = thenLin.coeff_.size(),
n = elseLin.coeff_.size();
i < m && j < n; i++) {
// LinearExpr<...>::coeff_ is sorted by hash
while (j < n &&
elseLin.coeff_[j].a_->hash() < thenLin.coeff_[i].a_->hash()) {
j++;
}
if (thenLin.coeff_[i].k_ == elseLin.coeff_[j].k_ &&
HashComparator{}(thenLin.coeff_[i].a_, elseLin.coeff_[j].a_)) {
common.coeff_.emplace_back(thenLin.coeff_[i]);
thenLin.coeff_[i++].k_ = elseLin.coeff_[j++].k_ = 0;

Check notice

Code scanning / CodeQL

For loop variable changed in body Note

Loop counters should not be modified in the body of the
loop
.
}
}
if (thenLin.bias_ == elseLin.bias_) {
common.bias_ = thenLin.bias_;
thenLin.bias_ = elseLin.bias_ = 0;
}
if (!common.coeff_.empty() || common.bias_ != 0) {
return makeAdd(
lin2expr(common),
makeIfExpr(op->cond_, lin2expr(thenLin), lin2expr(elseLin)));
}

return op;
}

Expand Down
21 changes: 21 additions & 0 deletions test/20.pass/test_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,27 @@ def test_if_expr_in_cond(p):
assert std.match(ast)


@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify])
def test_sink_if_expr_into_linear_expression(p):
with ft.VarDef([("a", (), "int32", "input", "cpu"),
("b", (), "int32", "input", "cpu"),
("c", (), "int32", "input", "cpu"),
("y", (), "int32", "output", "cpu")]) as (a, b, c, y):
y[...] = ft.if_then_else(c < 0, 3 * a, 3 * a + 2 * b)
ast = ft.pop_ast(verbose=True)
ast = p(ast)
print(ast)

with ft.VarDef([("a", (), "int32", "input", "cpu"),
("b", (), "int32", "input", "cpu"),
("c", (), "int32", "input", "cpu"),
("y", (), "int32", "output", "cpu")]) as (a, b, c, y):
y[...] = 3 * a + ft.if_then_else(c < 0, 0, 2 * b)
std = ft.pop_ast()

assert std.match(ast)


@pytest.mark.parametrize('p', [ft.pb_simplify, ft.simplify, ft.z3_simplify])
def test_accessible_after_writing_if(p):
with ft.VarDef([("x", (4,), "int32", "inout", "cpu"),
Expand Down

0 comments on commit 0906acd

Please sign in to comment.