Skip to content

Commit

Permalink
add rewrites for hoisting constants from ite expressions
Browse files Browse the repository at this point in the history
Signed-off-by: Nikolaj Bjorner <[email protected]>
  • Loading branch information
NikolajBjorner committed Nov 22, 2024
1 parent b4e768c commit 5025c3c
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 2 deletions.
66 changes: 66 additions & 0 deletions src/ast/rewriter/arith_rewriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,72 @@ bool arith_rewriter::is_arith_term(expr * n) const {
return n->get_kind() == AST_APP && to_app(n)->get_family_id() == get_fid();
}

br_status arith_rewriter::mk_ite_core(expr* c, expr* t, expr* e, expr_ref & result) {
numeral v1, v2;
bool is_int;
bool is_num1 = m_util.is_numeral(t, v1, is_int);
bool is_num2 = m_util.is_numeral(e, v2, is_int);
if (is_num1 && is_num2 && v1 == 0 && v2 != 1) {
result = m_util.mk_mul(e, m.mk_ite(c, t, m_util.mk_numeral(rational(1), is_int)));
return BR_DONE;
}
if (is_num1 && is_num2 && v2 == 0 && v1 != 1) {
result = m_util.mk_mul(t, m.mk_ite(c, m_util.mk_numeral(rational(1), is_int), e));
return BR_DONE;
}
if (is_num1 && is_num2 && is_int && gcd(v1, v2) != 1) {
auto g = gcd(v1, v2);
if (g > 0 && v1 < 0 && v2 < 0)
g = -g;

result = m_util.mk_numeral(g, is_int);
result = m_util.mk_mul(result, m.mk_ite(c, m_util.mk_numeral(v1/g, true), m_util.mk_numeral(v2/g, true)));
return BR_REWRITE2;
}
if (is_num1 && is_num2 && v1 != 0 && v2 != 0 && v1 != v2) {
if (v1 > v2)
result = m_util.mk_add(e, m.mk_ite(c, m_util.mk_numeral(v1 - v2, is_int), m_util.mk_numeral(rational::zero(), is_int)));
else
result = m_util.mk_add(e, m.mk_ite(c, m_util.mk_numeral(rational::zero(), is_int), m_util.mk_numeral(v2 - v1, is_int)));
return BR_DONE;
}
expr* x, *y;
if (is_num1 && m_util.is_mul(e, x, y) && m_util.is_numeral(x, v2, is_int) && v2 != 0) {
if (v1 == 0) {
result = m_util.mk_mul(x, m.mk_ite(c, t, y));
return BR_DONE;
}
if (is_int && divides(v2, v1)) {
result = m_util.mk_mul(x, m.mk_ite(c, m_util.mk_numeral(v1/v2, true), y));
return BR_DONE;
}

}
if (is_num2 && m_util.is_mul(t, x, y) && m_util.is_numeral(x, v1, is_int) && v1 != 0) {
if (v2 == 0) {
result = m_util.mk_mul(x, m.mk_ite(c, y, e));
return BR_DONE;
}
if (is_int && divides(v1, v2)) {
result = m_util.mk_mul(x, m.mk_ite(c, y, m_util.mk_numeral(v2/v1, true)));
return BR_DONE;
}

}
if (is_num1 && m_util.is_add(e, x, y) && m_util.is_numeral(x, v2, is_int)) {
result = m_util.mk_add(x, m.mk_ite(c, m_util.mk_numeral(v1 - v2, is_int), y));
return BR_REWRITE2;
}
if (is_num2 && m_util.is_add(t, x, y) && m_util.is_numeral(x, v1, is_int)) {
result = m_util.mk_add(x, m.mk_ite(c, y, m_util.mk_numeral(v2 - v1, is_int)));
return BR_REWRITE2;
}



return BR_FAILED;
}

br_status arith_rewriter::mk_eq_core(expr * arg1, expr * arg2, expr_ref & result) {
br_status st = BR_FAILED;
if (m_eq2ineq) {
Expand Down
1 change: 1 addition & 0 deletions src/ast/rewriter/arith_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class arith_rewriter : public poly_rewriter<arith_rewriter_core> {
br_status mk_lt_core(expr * arg1, expr * arg2, expr_ref & result);
br_status mk_ge_core(expr * arg1, expr * arg2, expr_ref & result);
br_status mk_gt_core(expr * arg1, expr * arg2, expr_ref & result);
br_status mk_ite_core(expr* c, expr* t, expr* e, expr_ref & result);

br_status mk_add_core(unsigned num_args, expr * const * args, expr_ref & result);
br_status mk_mul_core(unsigned num_args, expr * const * args, expr_ref & result);
Expand Down
2 changes: 2 additions & 0 deletions src/ast/rewriter/poly_rewriter_def.h
Original file line number Diff line number Diff line change
Expand Up @@ -1017,7 +1017,9 @@ bool poly_rewriter<Config>::hoist_ite(expr_ref& e) {
++i;
}
if (!pinned.empty()) {
TRACE("poly_rewriter", tout << e << "\n");
e = mk_add_app(adds.size(), adds.data());
TRACE("poly_rewriter", tout << e << "\n");
return true;
}
return false;
Expand Down
10 changes: 8 additions & 2 deletions src/ast/rewriter/th_rewriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ struct th_rewriter_cfg : public default_rewriter_cfg {
family_id s_fid = args[1]->get_sort()->get_family_id();
if (s_fid == m_bv_rw.get_fid())
st = m_bv_rw.mk_ite_core(args[0], args[1], args[2], result);
if (st == BR_FAILED && s_fid == m_a_rw.get_fid())
st = m_a_rw.mk_ite_core(args[0], args[1], args[2], result);
CTRACE("th_rewriter_step", st != BR_FAILED, tout << result << "\n");
if (st != BR_FAILED)
return st;
}
Expand All @@ -197,7 +200,9 @@ struct th_rewriter_cfg : public default_rewriter_cfg {
return st;
}

return m_b_rw.mk_app_core(f, num, args, result);
st = m_b_rw.mk_app_core(f, num, args, result);
CTRACE("th_rewriter_step", st != BR_FAILED, tout << result << "\n");
return st;
}
if (fid == m_a_rw.get_fid() && OP_LE == f->get_decl_kind() && m_seq_rw.u().has_seq()) {
st = m_seq_rw.mk_le_core(args[0], args[1], result);
Expand Down Expand Up @@ -315,7 +320,7 @@ struct th_rewriter_cfg : public default_rewriter_cfg {
return pull_ite_core<true>(f, to_app(args[1]), to_app(args[0]), result);
}
family_id fid = f->get_family_id();
if (num == 2 && (fid == m().get_basic_family_id() || fid == m_a_rw.get_fid() || fid == m_bv_rw.get_fid())) {
if (num == 2 && (fid == m().get_basic_family_id() || fid == m_bv_rw.get_fid())) {
// (f v3 (ite c v1 v2)) --> (ite v (f v3 v1) (f v3 v2))
if (m().is_value(args[0]) && is_ite_value_tree(args[1]))
return pull_ite_core<true>(f, to_app(args[1]), to_app(args[0]), result);
Expand Down Expand Up @@ -554,6 +559,7 @@ struct th_rewriter_cfg : public default_rewriter_cfg {
result = m().mk_app(f_prime, common, m().mk_ite(c, new_t, new_e));
else
result = m().mk_app(f_prime, m().mk_ite(c, new_t, new_e), common);
TRACE("push_ite", tout << result << "\n";);
return BR_DONE;
}
TRACE("push_ite", tout << "failed\n";);
Expand Down

0 comments on commit 5025c3c

Please sign in to comment.