From 4cb6badf3484aed7a0ef2ce9aa5ae83a798de063 Mon Sep 17 00:00:00 2001 From: Daejun Park Date: Tue, 10 Oct 2023 17:24:28 -0700 Subject: [PATCH] feat: smt solving refinement --- examples/simple/test/Vault.t.sol | 5 +- src/halmos/__main__.py | 46 ++--------- src/halmos/parser.py | 22 ------ src/halmos/sevm.py | 127 ++++++++++++------------------- tests/test/Math.t.sol | 6 +- tests/test/SignedDiv.t.sol | 1 - tests/test/Solver.t.sol | 2 +- 7 files changed, 59 insertions(+), 150 deletions(-) diff --git a/examples/simple/test/Vault.t.sol b/examples/simple/test/Vault.t.sol index fddbae8e..3d674e20 100644 --- a/examples/simple/test/Vault.t.sol +++ b/examples/simple/test/Vault.t.sol @@ -26,8 +26,8 @@ contract VaultTest is SymTest { vault.setTotalShares(svm.createUint256("S1")); } - // NOTE: currently timeout when --smt-div is enabled, while producing invalid counterexamples when --smt-div is not given - function prove_deposit(uint assets) public { + /// @custom:halmos --solver-timeout-assertion 10000 + function check_deposit(uint assets) public { uint A1 = vault.totalAssets(); uint S1 = vault.totalShares(); @@ -40,7 +40,6 @@ contract VaultTest is SymTest { assert(A1 * S2 <= A2 * S1); // no counterexample } - /// @custom:halmos --smt-div function check_mint(uint shares) public { uint A1 = vault.totalAssets(); uint S1 = vault.totalShares(); diff --git a/src/halmos/__main__.py b/src/halmos/__main__.py index f3117bce..d1e68bbc 100644 --- a/src/halmos/__main__.py +++ b/src/halmos/__main__.py @@ -689,17 +689,12 @@ def run( if is_valid: print(red(f"Counterexample: {render_model(model)}")) counterexamples.append(model) - elif args.print_potential_counterexample: + else: warn( COUNTEREXAMPLE_INVALID, f"Counterexample (potentially invalid): {render_model(model)}", ) counterexamples.append(model) - else: - warn( - COUNTEREXAMPLE_INVALID, - f"Counterexample (potentially invalid): (not displayed, use --print-potential-counterexample)", - ) else: warn(COUNTEREXAMPLE_UNKNOWN, f"Counterexample: {result}") @@ -958,11 +953,6 @@ def solve(query: str, args: Namespace) -> Tuple[CheckSatResult, Model]: def gen_model_from_sexpr(fn_args: GenModelArgs) -> ModelWithContext: args, idx, sexpr = fn_args.args, fn_args.idx, fn_args.sexpr res, model = solve(sexpr, args) -# solver = SolverFor("QF_AUFBV", ctx=Context()) -# solver.set(timeout=args.solver_timeout_assertion) -# solver.from_string(sexpr) -# res = solver.check() -# model = solver.model() if res == sat else None if res == sat and not is_model_valid(model): res, model = solve(refine(sexpr), args) @@ -980,37 +970,24 @@ def refine(query: str) -> str: # replace uninterpreted abstraction with actual symbols for assertion solving # TODO: replace `(evm_bvudiv x y)` with `(ite (= y (_ bv0 256)) (_ bv0 256) (bvudiv x y))` # as bvudiv is undefined when y = 0; also similarly for evm_bvurem - return re.sub(r"(\(\s*)evm_(bv[a-z]+)(_[0-9]+)?\b", r"\1\2", query) + query = re.sub(r"(\(\s*)evm_(bv[a-z]+)(_[0-9]+)?\b", r"\1\2", query) + return re.sub(r"\(\s*declare-fun\s+evm_(bv[a-z]+)(_[0-9]+)?\b", r"(declare-fun dummy_\1\2", query) def gen_model(args: Namespace, idx: int, ex: Exec) -> ModelWithContext: if args.verbose >= 1: print(f"Checking path condition (path id: {idx+1})") - model = None - ex.solver.set(timeout=args.solver_timeout_assertion) res = ex.solver.check() - if res == sat: - model = ex.solver.model() - - if res == unknown and args.solver_fresh: - if args.verbose >= 1: - print(f" Checking again with a fresh solver") - res, model = solve(ex.solver.to_smt2(), args) + model = ex.solver.model() if res == sat else None if res == sat and not is_model_valid(model): if args.verbose >= 1: print(f" Checking again with refinement") res, model = solve(refine(ex.solver.to_smt2()), args) - # sol2 = SolverFor("QF_AUFBV", ctx=Context()) - # sol2.set(timeout=args.solver_timeout_assertion) - # sol2.from_string(refine(ex.solver.to_smt2())) - # res = sol2.check() - # if res == sat: - # model = sol2.model() - - if is_unknown(res, model) and args.solver_subprocess: + + if args.solver_subprocess and is_unknown(res, model): if args.verbose >= 1: print(f" Checking again in an external process") fname = f"/tmp/{uuid.uuid4().hex}.smt2" @@ -1071,10 +1048,6 @@ def package_result( def is_model_valid(model: AnyModel) -> bool: for decl in model: - if str(decl) == "evm_bvudiv" and str(model[decl]) == "[else -> bvudiv_i(Var(0x0), Var(0x1))]": - continue - if str(decl) == "evm_bvsdiv" and str(model[decl]) == "[else -> bvsdiv_i(Var(0x0), Var(0x1))]": - continue if str(decl).startswith("evm_"): return False return True @@ -1103,13 +1076,6 @@ def mk_options(args: Namespace) -> Dict: "verbose": args.verbose, "debug": args.debug, "log": args.log, - "add": not args.no_smt_add, - "sub": not args.no_smt_sub, - "mul": not args.no_smt_mul, - "div": args.smt_div, - "mod": args.smt_mod, - "divByConst": args.smt_div_by_const, - "modByConst": args.smt_mod_by_const, "expByConst": args.smt_exp_by_const, "timeout": args.solver_timeout_branching, "sym_jump": args.symbolic_jump, diff --git a/src/halmos/parser.py b/src/halmos/parser.py index b3720a23..768d2823 100644 --- a/src/halmos/parser.py +++ b/src/halmos/parser.py @@ -166,23 +166,6 @@ def mk_arg_parser() -> argparse.ArgumentParser: # smt solver options group_solver = parser.add_argument_group("Solver options") - group_solver.add_argument( - "--no-smt-add", action="store_true", help="do not interpret `+`" - ) - group_solver.add_argument( - "--no-smt-sub", action="store_true", help="do not interpret `-`" - ) - group_solver.add_argument( - "--no-smt-mul", action="store_true", help="do not interpret `*`" - ) - group_solver.add_argument("--smt-div", action="store_true", help="interpret `/`") - group_solver.add_argument("--smt-mod", action="store_true", help="interpret `mod`") - group_solver.add_argument( - "--smt-div-by-const", action="store_true", help="interpret division by constant" - ) - group_solver.add_argument( - "--smt-mod-by-const", action="store_true", help="interpret constant modulo" - ) group_solver.add_argument( "--smt-exp-by-const", metavar="N", @@ -248,10 +231,5 @@ def mk_arg_parser() -> argparse.ArgumentParser: group_experimental.add_argument( "--symbolic-jump", action="store_true", help="support symbolic jump destination" ) - group_experimental.add_argument( - "--print-potential-counterexample", - action="store_true", - help="print potentially invalid counterexamples", - ) return parser diff --git a/src/halmos/sevm.py b/src/halmos/sevm.py index 4db80c0d..84ff6754 100644 --- a/src/halmos/sevm.py +++ b/src/halmos/sevm.py @@ -1516,133 +1516,100 @@ def mk_mod(self, ex: Exec, x: Any, y: Any) -> Any: def arith(self, ex: Exec, op: int, w1: Word, w2: Word) -> Word: w1 = b2i(w1) w2 = b2i(w2) + if op == EVM.ADD: - if self.options.get("add"): - return w1 + w2 - if is_bv_value(w1) and is_bv_value(w2): - return w1 + w2 - else: - return self.mk_add(w1, w2) - elif op == EVM.SUB: - if self.options.get("sub"): - return w1 - w2 - if is_bv_value(w1) and is_bv_value(w2): - return w1 - w2 - else: - return f_sub(w1, w2) - elif op == EVM.MUL: - if self.options.get("mul"): - return w1 * w2 - if is_bv_value(w1) and is_bv_value(w2): - return w1 * w2 - elif is_bv_value(w1): - i1: int = int(str(w1)) # must be concrete - if i1 == 0: - return w1 - elif is_power_of_two(i1): - return w2 << int(math.log(i1, 2)) - else: - return self.mk_mul(w1, w2) - elif is_bv_value(w2): - i2: int = int(str(w2)) # must be concrete - if i2 == 0: - return w2 - elif is_power_of_two(i2): - return w1 << int(math.log(i2, 2)) - else: - return self.mk_mul(w1, w2) - else: - return self.mk_mul(w1, w2) - elif op == EVM.DIV: + return w1 + w2 + + if op == EVM.SUB: + return w1 - w2 + + if op == EVM.MUL: + return w1 * w2 + + if op == EVM.DIV: div_for_overflow_check = self.div_xy_y(w1, w2) if div_for_overflow_check is not None: # xy/x or xy/y return div_for_overflow_check - # if self.options.get("div"): - # return UDiv(w1, w2) # unsigned div (bvudiv) + if is_bv_value(w1) and is_bv_value(w2): - return UDiv(w1, w2) - elif is_bv_value(w2): + return UDiv(w1, w2) # unsigned div (bvudiv) + + if is_bv_value(w2): # concrete denominator case i2: int = w2.as_long() if i2 == 0: return w2 - elif i2 == 1: + + if i2 == 1: return w1 - elif is_power_of_two(i2): + + if is_power_of_two(i2): return LShR(w1, int(math.log(i2, 2))) - # elif self.options.get("divByConst"): - # return UDiv(w1, w2) - else: - return self.mk_div(ex, w1, w2) - else: - return self.mk_div(ex, w1, w2) - elif op == EVM.MOD: - # if self.options.get("mod"): - # return URem(w1, w2) + + return self.mk_div(ex, w1, w2) + + if op == EVM.MOD: if is_bv_value(w1) and is_bv_value(w2): return URem(w1, w2) # bvurem - elif is_bv_value(w2): + + if is_bv_value(w2): i2: int = int(str(w2)) if i2 == 0 or i2 == 1: return con(0, w2.size()) - elif is_power_of_two(i2): + + if is_power_of_two(i2): bitsize = int(math.log(i2, 2)) return ZeroExt(w2.size() - bitsize, Extract(bitsize - 1, 0, w1)) - # elif self.options.get("modByConst"): - # return URem(w1, w2) - else: - return self.mk_mod(ex, w1, w2) - else: - return self.mk_mod(ex, w1, w2) - elif op == EVM.SDIV: - # if self.options.get("div"): - # return w1 / w2 # bvsdiv + + return self.mk_mod(ex, w1, w2) + + if op == EVM.SDIV: if is_bv_value(w1) and is_bv_value(w2): return w1 / w2 # bvsdiv if is_bv_value(w2): # concrete denominator case i2: int = w2.as_long() - if i2 == 0: return w2 # div by 0 is 0 if i2 == 1: return w1 # div by 1 is identity - # if self.options.get("divByConst"): - # return w1 / w2 # bvsdiv - # fall back to uninterpreted function :( return f_sdiv(w1, w2) - elif op == EVM.SMOD: + if op == EVM.SMOD: if is_bv_value(w1) and is_bv_value(w2): return SRem(w1, w2) # bvsrem # vs: w1 % w2 (bvsmod w1 w2) - else: - return f_smod(w1, w2) - elif op == EVM.EXP: + + # TODO: if is_bv_value(w2): + + return f_smod(w1, w2) + + if op == EVM.EXP: if is_bv_value(w1) and is_bv_value(w2): i1: int = int(str(w1)) # must be concrete i2: int = int(str(w2)) # must be concrete return con(i1**i2) - elif is_bv_value(w2): + + if is_bv_value(w2): i2: int = int(str(w2)) if i2 == 0: return con(1) - elif i2 == 1: + + if i2 == 1: return w1 - elif i2 <= self.options.get("expByConst"): + + if i2 <= self.options.get("expByConst"): exp = w1 for _ in range(i2 - 1): exp = exp * w1 return exp - else: - return f_exp(w1, w2) - else: - return f_exp(w1, w2) - else: - raise ValueError(op) + + return f_exp(w1, w2) + + raise ValueError(op) def arith2(self, ex: Exec, op: int, w1: Word, w2: Word, w3: Word) -> Word: w1 = b2i(w1) diff --git a/tests/test/Math.t.sol b/tests/test/Math.t.sol index 8aa2d348..a56f3255 100644 --- a/tests/test/Math.t.sol +++ b/tests/test/Math.t.sol @@ -10,8 +10,8 @@ contract MathTest { } } - // NOTE: currently timeout when --smt-div is enabled; producing invalid counterexamples when --smt-div is not given - function prove_deposit(uint a, uint A1, uint S1) public pure { + /// @custom:halmos --solver-timeout-assertion 10000 + function check_deposit(uint a, uint A1, uint S1) public pure { uint s = (a * S1) / A1; uint A2 = A1 + a; @@ -21,7 +21,7 @@ contract MathTest { assert(A1 * S2 <= A2 * S1); // no counterexample } - /// @custom:halmos --smt-div --solver-timeout-assertion=0 + /// @custom:halmos --solver-timeout-assertion 0 function check_mint(uint s, uint A1, uint S1) public pure { uint a = (s * A1) / S1; diff --git a/tests/test/SignedDiv.t.sol b/tests/test/SignedDiv.t.sol index 84742735..a8064bdb 100644 --- a/tests/test/SignedDiv.t.sol +++ b/tests/test/SignedDiv.t.sol @@ -74,7 +74,6 @@ abstract contract TestMulWad is Test { function createWadMul() internal virtual returns (WadMul); - /// @custom:halmos --smt-div function check_wadMul_solEquivalent(int256 x, int256 y) external { bytes memory encodedCall = abi.encodeWithSelector(WadMul.wadMul.selector, x, y); diff --git a/tests/test/Solver.t.sol b/tests/test/Solver.t.sol index 04a2402f..df3fdce6 100644 --- a/tests/test/Solver.t.sol +++ b/tests/test/Solver.t.sol @@ -3,7 +3,7 @@ pragma solidity >=0.8.0 <0.9.0; // from https://github.com/a16z/halmos/issues/57 -/// @custom:halmos --print-potential-counterexample --solver-timeout-assertion 10000 +/// @custom:halmos --solver-timeout-assertion 10000 contract SolverTest { function foo(uint x) public pure returns (uint) {