Skip to content

Commit

Permalink
feat: smt solving refinement
Browse files Browse the repository at this point in the history
  • Loading branch information
daejunpark committed Oct 11, 2023
1 parent a4e0c39 commit 4cb6bad
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 150 deletions.
5 changes: 2 additions & 3 deletions examples/simple/test/Vault.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ contract VaultTest is SymTest {
vault.setTotalShares(svm.createUint256("S1"));
}

// NOTE: currently timeout when --smt-div is enabled, while producing invalid counterexamples when --smt-div is not given
function prove_deposit(uint assets) public {
/// @custom:halmos --solver-timeout-assertion 10000
function check_deposit(uint assets) public {
uint A1 = vault.totalAssets();
uint S1 = vault.totalShares();

Expand All @@ -40,7 +40,6 @@ contract VaultTest is SymTest {
assert(A1 * S2 <= A2 * S1); // no counterexample
}

/// @custom:halmos --smt-div
function check_mint(uint shares) public {
uint A1 = vault.totalAssets();
uint S1 = vault.totalShares();
Expand Down
46 changes: 6 additions & 40 deletions src/halmos/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,17 +689,12 @@ def run(
if is_valid:
print(red(f"Counterexample: {render_model(model)}"))
counterexamples.append(model)
elif args.print_potential_counterexample:
else:
warn(
COUNTEREXAMPLE_INVALID,
f"Counterexample (potentially invalid): {render_model(model)}",
)
counterexamples.append(model)
else:
warn(
COUNTEREXAMPLE_INVALID,
f"Counterexample (potentially invalid): (not displayed, use --print-potential-counterexample)",
)
else:
warn(COUNTEREXAMPLE_UNKNOWN, f"Counterexample: {result}")

Expand Down Expand Up @@ -958,11 +953,6 @@ def solve(query: str, args: Namespace) -> Tuple[CheckSatResult, Model]:
def gen_model_from_sexpr(fn_args: GenModelArgs) -> ModelWithContext:
args, idx, sexpr = fn_args.args, fn_args.idx, fn_args.sexpr
res, model = solve(sexpr, args)
# solver = SolverFor("QF_AUFBV", ctx=Context())
# solver.set(timeout=args.solver_timeout_assertion)
# solver.from_string(sexpr)
# res = solver.check()
# model = solver.model() if res == sat else None

if res == sat and not is_model_valid(model):
res, model = solve(refine(sexpr), args)
Expand All @@ -980,37 +970,24 @@ def refine(query: str) -> str:
# replace uninterpreted abstraction with actual symbols for assertion solving
# TODO: replace `(evm_bvudiv x y)` with `(ite (= y (_ bv0 256)) (_ bv0 256) (bvudiv x y))`
# as bvudiv is undefined when y = 0; also similarly for evm_bvurem
return re.sub(r"(\(\s*)evm_(bv[a-z]+)(_[0-9]+)?\b", r"\1\2", query)
query = re.sub(r"(\(\s*)evm_(bv[a-z]+)(_[0-9]+)?\b", r"\1\2", query)
return re.sub(r"\(\s*declare-fun\s+evm_(bv[a-z]+)(_[0-9]+)?\b", r"(declare-fun dummy_\1\2", query)


def gen_model(args: Namespace, idx: int, ex: Exec) -> ModelWithContext:
if args.verbose >= 1:
print(f"Checking path condition (path id: {idx+1})")

model = None

ex.solver.set(timeout=args.solver_timeout_assertion)
res = ex.solver.check()
if res == sat:
model = ex.solver.model()

if res == unknown and args.solver_fresh:
if args.verbose >= 1:
print(f" Checking again with a fresh solver")
res, model = solve(ex.solver.to_smt2(), args)
model = ex.solver.model() if res == sat else None

if res == sat and not is_model_valid(model):
if args.verbose >= 1:
print(f" Checking again with refinement")
res, model = solve(refine(ex.solver.to_smt2()), args)
# sol2 = SolverFor("QF_AUFBV", ctx=Context())
# sol2.set(timeout=args.solver_timeout_assertion)
# sol2.from_string(refine(ex.solver.to_smt2()))
# res = sol2.check()
# if res == sat:
# model = sol2.model()

if is_unknown(res, model) and args.solver_subprocess:

if args.solver_subprocess and is_unknown(res, model):
if args.verbose >= 1:
print(f" Checking again in an external process")
fname = f"/tmp/{uuid.uuid4().hex}.smt2"
Expand Down Expand Up @@ -1071,10 +1048,6 @@ def package_result(

def is_model_valid(model: AnyModel) -> bool:
for decl in model:
if str(decl) == "evm_bvudiv" and str(model[decl]) == "[else -> bvudiv_i(Var(0x0), Var(0x1))]":
continue
if str(decl) == "evm_bvsdiv" and str(model[decl]) == "[else -> bvsdiv_i(Var(0x0), Var(0x1))]":
continue
if str(decl).startswith("evm_"):
return False
return True
Expand Down Expand Up @@ -1103,13 +1076,6 @@ def mk_options(args: Namespace) -> Dict:
"verbose": args.verbose,
"debug": args.debug,
"log": args.log,
"add": not args.no_smt_add,
"sub": not args.no_smt_sub,
"mul": not args.no_smt_mul,
"div": args.smt_div,
"mod": args.smt_mod,
"divByConst": args.smt_div_by_const,
"modByConst": args.smt_mod_by_const,
"expByConst": args.smt_exp_by_const,
"timeout": args.solver_timeout_branching,
"sym_jump": args.symbolic_jump,
Expand Down
22 changes: 0 additions & 22 deletions src/halmos/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,23 +166,6 @@ def mk_arg_parser() -> argparse.ArgumentParser:
# smt solver options
group_solver = parser.add_argument_group("Solver options")

group_solver.add_argument(
"--no-smt-add", action="store_true", help="do not interpret `+`"
)
group_solver.add_argument(
"--no-smt-sub", action="store_true", help="do not interpret `-`"
)
group_solver.add_argument(
"--no-smt-mul", action="store_true", help="do not interpret `*`"
)
group_solver.add_argument("--smt-div", action="store_true", help="interpret `/`")
group_solver.add_argument("--smt-mod", action="store_true", help="interpret `mod`")
group_solver.add_argument(
"--smt-div-by-const", action="store_true", help="interpret division by constant"
)
group_solver.add_argument(
"--smt-mod-by-const", action="store_true", help="interpret constant modulo"
)
group_solver.add_argument(
"--smt-exp-by-const",
metavar="N",
Expand Down Expand Up @@ -248,10 +231,5 @@ def mk_arg_parser() -> argparse.ArgumentParser:
group_experimental.add_argument(
"--symbolic-jump", action="store_true", help="support symbolic jump destination"
)
group_experimental.add_argument(
"--print-potential-counterexample",
action="store_true",
help="print potentially invalid counterexamples",
)

return parser
127 changes: 47 additions & 80 deletions src/halmos/sevm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1516,133 +1516,100 @@ def mk_mod(self, ex: Exec, x: Any, y: Any) -> Any:
def arith(self, ex: Exec, op: int, w1: Word, w2: Word) -> Word:
w1 = b2i(w1)
w2 = b2i(w2)

if op == EVM.ADD:
if self.options.get("add"):
return w1 + w2
if is_bv_value(w1) and is_bv_value(w2):
return w1 + w2
else:
return self.mk_add(w1, w2)
elif op == EVM.SUB:
if self.options.get("sub"):
return w1 - w2
if is_bv_value(w1) and is_bv_value(w2):
return w1 - w2
else:
return f_sub(w1, w2)
elif op == EVM.MUL:
if self.options.get("mul"):
return w1 * w2
if is_bv_value(w1) and is_bv_value(w2):
return w1 * w2
elif is_bv_value(w1):
i1: int = int(str(w1)) # must be concrete
if i1 == 0:
return w1
elif is_power_of_two(i1):
return w2 << int(math.log(i1, 2))
else:
return self.mk_mul(w1, w2)
elif is_bv_value(w2):
i2: int = int(str(w2)) # must be concrete
if i2 == 0:
return w2
elif is_power_of_two(i2):
return w1 << int(math.log(i2, 2))
else:
return self.mk_mul(w1, w2)
else:
return self.mk_mul(w1, w2)
elif op == EVM.DIV:
return w1 + w2

if op == EVM.SUB:
return w1 - w2

if op == EVM.MUL:
return w1 * w2

if op == EVM.DIV:
div_for_overflow_check = self.div_xy_y(w1, w2)
if div_for_overflow_check is not None: # xy/x or xy/y
return div_for_overflow_check
# if self.options.get("div"):
# return UDiv(w1, w2) # unsigned div (bvudiv)

if is_bv_value(w1) and is_bv_value(w2):
return UDiv(w1, w2)
elif is_bv_value(w2):
return UDiv(w1, w2) # unsigned div (bvudiv)

if is_bv_value(w2):
# concrete denominator case
i2: int = w2.as_long()
if i2 == 0:
return w2
elif i2 == 1:

if i2 == 1:
return w1
elif is_power_of_two(i2):

if is_power_of_two(i2):
return LShR(w1, int(math.log(i2, 2)))
# elif self.options.get("divByConst"):
# return UDiv(w1, w2)
else:
return self.mk_div(ex, w1, w2)
else:
return self.mk_div(ex, w1, w2)
elif op == EVM.MOD:
# if self.options.get("mod"):
# return URem(w1, w2)

return self.mk_div(ex, w1, w2)

if op == EVM.MOD:
if is_bv_value(w1) and is_bv_value(w2):
return URem(w1, w2) # bvurem
elif is_bv_value(w2):

if is_bv_value(w2):
i2: int = int(str(w2))
if i2 == 0 or i2 == 1:
return con(0, w2.size())
elif is_power_of_two(i2):

if is_power_of_two(i2):
bitsize = int(math.log(i2, 2))
return ZeroExt(w2.size() - bitsize, Extract(bitsize - 1, 0, w1))
# elif self.options.get("modByConst"):
# return URem(w1, w2)
else:
return self.mk_mod(ex, w1, w2)
else:
return self.mk_mod(ex, w1, w2)
elif op == EVM.SDIV:
# if self.options.get("div"):
# return w1 / w2 # bvsdiv

return self.mk_mod(ex, w1, w2)

if op == EVM.SDIV:
if is_bv_value(w1) and is_bv_value(w2):
return w1 / w2 # bvsdiv

if is_bv_value(w2):
# concrete denominator case
i2: int = w2.as_long()

if i2 == 0:
return w2 # div by 0 is 0

if i2 == 1:
return w1 # div by 1 is identity

# if self.options.get("divByConst"):
# return w1 / w2 # bvsdiv

# fall back to uninterpreted function :(
return f_sdiv(w1, w2)

elif op == EVM.SMOD:
if op == EVM.SMOD:
if is_bv_value(w1) and is_bv_value(w2):
return SRem(w1, w2) # bvsrem # vs: w1 % w2 (bvsmod w1 w2)
else:
return f_smod(w1, w2)
elif op == EVM.EXP:

# TODO: if is_bv_value(w2):

return f_smod(w1, w2)

if op == EVM.EXP:
if is_bv_value(w1) and is_bv_value(w2):
i1: int = int(str(w1)) # must be concrete
i2: int = int(str(w2)) # must be concrete
return con(i1**i2)
elif is_bv_value(w2):

if is_bv_value(w2):
i2: int = int(str(w2))
if i2 == 0:
return con(1)
elif i2 == 1:

if i2 == 1:
return w1
elif i2 <= self.options.get("expByConst"):

if i2 <= self.options.get("expByConst"):
exp = w1
for _ in range(i2 - 1):
exp = exp * w1
return exp
else:
return f_exp(w1, w2)
else:
return f_exp(w1, w2)
else:
raise ValueError(op)

return f_exp(w1, w2)

raise ValueError(op)

def arith2(self, ex: Exec, op: int, w1: Word, w2: Word, w3: Word) -> Word:
w1 = b2i(w1)
Expand Down
6 changes: 3 additions & 3 deletions tests/test/Math.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ contract MathTest {
}
}

// NOTE: currently timeout when --smt-div is enabled; producing invalid counterexamples when --smt-div is not given
function prove_deposit(uint a, uint A1, uint S1) public pure {
/// @custom:halmos --solver-timeout-assertion 10000
function check_deposit(uint a, uint A1, uint S1) public pure {
uint s = (a * S1) / A1;

uint A2 = A1 + a;
Expand All @@ -21,7 +21,7 @@ contract MathTest {
assert(A1 * S2 <= A2 * S1); // no counterexample
}

/// @custom:halmos --smt-div --solver-timeout-assertion=0
/// @custom:halmos --solver-timeout-assertion 0
function check_mint(uint s, uint A1, uint S1) public pure {
uint a = (s * A1) / S1;

Expand Down
1 change: 0 additions & 1 deletion tests/test/SignedDiv.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ abstract contract TestMulWad is Test {

function createWadMul() internal virtual returns (WadMul);

/// @custom:halmos --smt-div
function check_wadMul_solEquivalent(int256 x, int256 y) external {
bytes memory encodedCall = abi.encodeWithSelector(WadMul.wadMul.selector, x, y);

Expand Down
2 changes: 1 addition & 1 deletion tests/test/Solver.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ pragma solidity >=0.8.0 <0.9.0;

// from https://github.com/a16z/halmos/issues/57

/// @custom:halmos --print-potential-counterexample --solver-timeout-assertion 10000
/// @custom:halmos --solver-timeout-assertion 10000
contract SolverTest {

function foo(uint x) public pure returns (uint) {
Expand Down

0 comments on commit 4cb6bad

Please sign in to comment.