Skip to content

Commit

Permalink
add comments
Browse files Browse the repository at this point in the history
Signed-off-by: Nikolaj Bjorner <[email protected]>
  • Loading branch information
NikolajBjorner committed Nov 25, 2024
1 parent 4559b23 commit 7ed185a
Show file tree
Hide file tree
Showing 3 changed files with 235 additions and 49 deletions.
111 changes: 82 additions & 29 deletions src/ast/sls/sls_arith_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ namespace sls {
template<typename num_t>
bool arith_base<num_t>::find_lin_moves(sat::literal lit) {
m_updates.reset();
auto* ineq = atom(lit.var());
auto* ineq = get_ineq(lit.var());
num_t a, b;
if (!ineq)
return false;
Expand Down Expand Up @@ -582,7 +582,7 @@ namespace sls {
num_t d(1), d2;
bool first = true;
for (auto a : ctx.get_clause(cl)) {
auto const* ineq = atom(a.var());
auto const* ineq = get_ineq(a.var());
if (!ineq)
continue;
d2 = dtt(a.sign(), *ineq);
Expand All @@ -601,7 +601,7 @@ namespace sls {
num_t d(1), d2;
bool first = true;
for (auto lit : ctx.get_clause(cl)) {
auto const* ineq = atom(lit.var());
auto const* ineq = get_ineq(lit.var());
if (!ineq)
continue;
d2 = dtt(lit.sign(), *ineq, v, new_value);
Expand Down Expand Up @@ -667,8 +667,8 @@ namespace sls {
}

buffer<sat::bool_var> to_flip;
for (auto const& [coeff, bv] : vi.m_bool_vars) {
auto& ineq = *atom(bv);
for (auto const& [coeff, bv] : vi.m_ineqs) {
auto& ineq = *get_ineq(bv);
bool old_sign = sign(bv);
sat::literal lit(bv, old_sign);
SASSERT(ctx.is_true(lit));
Expand All @@ -684,9 +684,9 @@ namespace sls {
m_last_var = v;

for (auto bv : to_flip) {
if (dtt(sign(bv), *atom(bv)) != 0)
if (dtt(sign(bv), *get_ineq(bv)) != 0)
ctx.flip(bv);
SASSERT(dtt(sign(bv), *atom(bv)) == 0);
SASSERT(dtt(sign(bv), *get_ineq(bv)) == 0);
}

IF_VERBOSE(10, verbose_stream() << "new value eh " << mk_bounded_pp(e, m) << "\n");
Expand Down Expand Up @@ -933,12 +933,12 @@ namespace sls {
template<typename num_t>
void arith_base<num_t>::init_bool_var(sat::bool_var bv) {
expr* e = ctx.atom(bv);
if (m_bool_vars.get(bv, nullptr))
if (m_ineqs.get(bv, nullptr))
return;
if (!e)
return;
expr* x, * y;
m_bool_vars.reserve(bv + 1);
m_ineqs.reserve(bv + 1);
if (a.is_le(e, x, y) || a.is_ge(e, y, x)) {
auto& ineq = new_ineq(ineq_kind::LE, num_t(0));
add_args(ineq, x, num_t(1));
Expand All @@ -963,8 +963,8 @@ namespace sls {
add_args(ineq, y, num_t(-1));
init_ineq(bv, ineq);
}
else if (m.is_distinct(e) && a.is_int_real(to_app(e)->get_arg(0))) {
NOT_IMPLEMENTED_YET();
else if (is_distinct(e)) {
verbose_stream() << "distinct " << mk_pp(e, m) << "\n";
}
else if (a.is_is_int(e, x))
{
Expand Down Expand Up @@ -1004,7 +1004,7 @@ namespace sls {
// compute the value of the linear term, and accumulate non-linear sub-terms
i.m_args_value = i.m_coeff;
for (auto const& [coeff, v] : i.m_args) {
m_vars[v].m_bool_vars.push_back({ coeff, bv });
m_vars[v].m_ineqs.push_back({ coeff, bv });
i.m_args_value += coeff * value(v);
if (is_mul(v)) {
auto const& [w, monomial] = get_mul(v);
Expand Down Expand Up @@ -1044,21 +1044,28 @@ namespace sls {
}

// attach i to bv
m_bool_vars.set(bv, &i);
m_ineqs.set(bv, &i);
}

template<typename num_t>
void arith_base<num_t>::init_bool_var_assignment(sat::bool_var v) {
auto* ineq = atom(v);
auto* ineq = get_ineq(v);
if (ineq && ineq->is_true() != ctx.is_true(v))
ctx.flip(v);
if (is_distinct(ctx.atom(v)) && eval_distinct(ctx.atom(v)) != ctx.is_true(v))
ctx.flip(v);
}

template<typename num_t>
void arith_base<num_t>::propagate_literal(sat::literal lit) {
if (!ctx.is_true(lit))
return;
auto const* ineq = atom(lit.var());
expr* e = ctx.atom(lit.var());
if (is_distinct(e) && eval_distinct(e) != ctx.is_true(lit)) {
repair_distinct(e);
return;
}
auto const* ineq = get_ineq(lit.var());
if (!ineq)
return;
if (ineq->is_true() != lit.sign())
Expand Down Expand Up @@ -1136,7 +1143,7 @@ namespace sls {
void arith_base<num_t>::repair_up(app* e) {
if (m.is_bool(e)) {
auto v = ctx.atom2bool_var(e);
auto const* ineq = atom(v);
auto const* ineq = get_ineq(v);
if (ineq && ineq->is_true() != ctx.is_true(v))
ctx.flip(v);
return;
Expand Down Expand Up @@ -1333,7 +1340,7 @@ namespace sls {
template<typename num_t>
void arith_base<num_t>::initialize_unit(sat::literal lit) {
init_bool_var(lit.var());
auto* ineq = atom(lit.var());
auto* ineq = get_ineq(lit.var());
if (!ineq)
return;

Expand Down Expand Up @@ -1623,11 +1630,11 @@ namespace sls {
double arith_base<num_t>::compute_score(var_t x, num_t const& delta) {
int result = 0;
int breaks = 0;
for (auto const& [coeff, bv] : m_vars[x].m_bool_vars) {
for (auto const& [coeff, bv] : m_vars[x].m_ineqs) {
bool old_sign = sign(bv);
auto lit = sat::literal(bv, old_sign);
auto dtt_old = dtt(old_sign, *atom(bv));
auto dtt_new = dtt(old_sign, *atom(bv), coeff, delta);
auto dtt_old = dtt(old_sign, *get_ineq(bv));
auto dtt_new = dtt(old_sign, *get_ineq(bv), coeff, delta);
#if 1
if (dtt_new == 0 && dtt_old != 0)
result += 1;
Expand Down Expand Up @@ -1711,7 +1718,7 @@ namespace sls {
template<typename num_t>
bool arith_base<num_t>::find_nl_moves(sat::literal lit) {
m_updates.reset();
auto* ineq = atom(lit.var());
auto* ineq = get_ineq(lit.var());
num_t a, b;
if (!ineq)
return false;
Expand Down Expand Up @@ -1766,7 +1773,7 @@ namespace sls {
template<typename num_t>
bool arith_base<num_t>::find_reset_moves(sat::literal lit) {
m_updates.reset();
auto* ineq = atom(lit.var());
auto* ineq = get_ineq(lit.var());
num_t a, b;
if (!ineq)
return false;
Expand Down Expand Up @@ -1892,7 +1899,7 @@ namespace sls {
template<typename num_t>
void arith_base<num_t>::check_ineqs() {
for (unsigned bv = 0; bv < ctx.num_bool_vars(); ++bv) {
auto const* ineq = atom(bv);
auto const* ineq = get_ineq(bv);
if (!ineq)
continue;
num_t d = dtt(sign(bv), *ineq);
Expand All @@ -1918,6 +1925,45 @@ namespace sls {
mk_term(arg);
}

template<typename num_t>
bool arith_base<num_t>::is_distinct(expr* e) {
return m.is_distinct(e) &&
to_app(e)->get_num_args() > 0 &&
a.is_int_real(to_app(e)->get_arg(0));
}

template<typename num_t>
bool arith_base<num_t>::eval_distinct(expr* e) {
auto const& args = *to_app(e);
for (unsigned i = 0; i < args.get_num_args(); ++i)
for (unsigned j = i + 1; j < args.get_num_args(); ++j) {
auto v1 = mk_term(args.get_arg(i));
auto v2 = mk_term(args.get_arg(j));
if (value(v1) == value(v2))
return false;
}
return true;
}

template<typename num_t>
void arith_base<num_t>::repair_distinct(expr* e) {
auto const& args = *to_app(e);
for (unsigned i = 0; i < args.get_num_args(); ++i)
for (unsigned j = i + 1; j < args.get_num_args(); ++j) {
auto v1 = mk_term(args.get_arg(i));
auto v2 = mk_term(args.get_arg(j));
if (value(v1) == value(v2)) {
auto new_value = value(v1) + num_t(1);
if (new_value == value(v2))
new_value += num_t(1);
if (!is_fixed(v2))
update(v2, new_value);
else if (!is_fixed(v1))
update(v1, new_value);
}
}
}

template<typename num_t>
bool arith_base<num_t>::set_value(expr* e, expr* v) {
if (!a.is_int_real(e))
Expand Down Expand Up @@ -1956,7 +2002,14 @@ namespace sls {
for (auto lit : clause.m_clause) {
if (!ctx.is_true(lit))
continue;
auto ineq = atom(lit.var());
if (is_distinct(ctx.atom(lit.var()))) {
if (eval_distinct(ctx.atom(lit.var())) != lit.sign()) {
sat = true;
break;
}
continue;
}
auto ineq = get_ineq(lit.var());
if (!ineq) {
sat = true;
break;
Expand All @@ -1972,7 +2025,7 @@ namespace sls {
verbose_stream() << clause << "\n";
for (auto lit : clause.m_clause) {
verbose_stream() << lit << " (" << ctx.is_true(lit) << ") ";
auto ineq = atom(lit.var());
auto ineq = get_ineq(lit.var());
if (!ineq)
continue;
verbose_stream() << *ineq << "\n";
Expand Down Expand Up @@ -2069,9 +2122,9 @@ namespace sls {
out << " ";
}

if (!vi.m_bool_vars.empty()) {
if (!vi.m_ineqs.empty()) {
out << " bool: ";
for (auto [c, bv] : vi.m_bool_vars)
for (auto [c, bv] : vi.m_ineqs)
out << c << "@" << bv << " ";
}
return out;
Expand All @@ -2080,7 +2133,7 @@ namespace sls {
template<typename num_t>
std::ostream& arith_base<num_t>::display(std::ostream& out) const {
for (unsigned v = 0; v < ctx.num_bool_vars(); ++v) {
auto ineq = atom(v);
auto ineq = get_ineq(v);
if (ineq)
out << v << ": " << *ineq << "\n";
}
Expand Down Expand Up @@ -2176,7 +2229,7 @@ namespace sls {
template<typename num_t>
void arith_base<num_t>::invariant() {
for (unsigned v = 0; v < ctx.num_bool_vars(); ++v) {
auto ineq = atom(v);
auto ineq = get_ineq(v);
if (ineq)
invariant(*ineq);
}
Expand Down
9 changes: 6 additions & 3 deletions src/ast/sls/sls_arith_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ namespace sls {
var_sort m_sort;
arith_op_kind m_op = arith_op_kind::LAST_ARITH_OP;
unsigned m_def_idx = UINT_MAX;
vector<std::pair<num_t, sat::bool_var>> m_bool_vars;
vector<std::pair<num_t, sat::bool_var>> m_ineqs;
unsigned_vector m_muls;
unsigned_vector m_adds;
optional<bound> m_lo, m_hi;
Expand Down Expand Up @@ -159,7 +159,7 @@ namespace sls {

stats m_stats;
config m_config;
scoped_ptr_vector<ineq> m_bool_vars;
scoped_ptr_vector<ineq> m_ineqs;
vector<var_info> m_vars;
vector<mul_def> m_muls;
vector<add_def> m_adds;
Expand All @@ -181,6 +181,9 @@ namespace sls {

unsigned get_num_vars() const { return m_vars.size(); }

bool is_distinct(expr* e);
bool eval_distinct(expr* e);
void repair_distinct(expr* e);
bool eval_is_correct(var_t v);
bool repair_mul(mul_def const& md);
bool repair_add(add_def const& ad);
Expand Down Expand Up @@ -219,7 +222,7 @@ namespace sls {
// double reward(sat::literal lit);

bool sign(sat::bool_var v) const { return !ctx.is_true(sat::literal(v, false)); }
ineq* atom(sat::bool_var bv) const { return m_bool_vars.get(bv, nullptr); }
ineq* get_ineq(sat::bool_var bv) const { return m_ineqs.get(bv, nullptr); }
num_t dtt(bool sign, ineq const& ineq) const { return dtt(sign, ineq.m_args_value, ineq); }
num_t dtt(bool sign, num_t const& args_value, ineq const& ineq) const;
num_t dtt(bool sign, ineq const& ineq, var_t v, num_t const& new_value) const;
Expand Down
Loading

0 comments on commit 7ed185a

Please sign in to comment.