diff --git a/src/halmos/sevm.py b/src/halmos/sevm.py index 93044120..a9758424 100644 --- a/src/halmos/sevm.py +++ b/src/halmos/sevm.py @@ -69,6 +69,10 @@ 264: Function("f_evm_bvurem_264", BitVecSort264, BitVecSort264, BitVecSort264), 512: Function("f_evm_bvurem_512", BitVecSort512, BitVecSort512, BitVecSort512), } +f_mul = { + 256: Function("f_evm_bvmul", BitVecSort256, BitVecSort256, BitVecSort256), + 512: Function("f_evm_bvmul_512", BitVecSort512, BitVecSort512, BitVecSort512), +} f_sdiv = Function("f_evm_bvsdiv", BitVecSort256, BitVecSort256, BitVecSort256) f_smod = Function("f_evm_bvsrem", BitVecSort256, BitVecSort256, BitVecSort256) f_exp = Function("f_evm_exp", BitVecSort256, BitVecSort256, BitVecSort256) @@ -1337,6 +1341,10 @@ def mk_mod(self, ex: Exec, x: Any, y: Any) -> Any: # ex.path.append(Or(y == con(0), ULT(term, y))) # (x % y) < y if y != 0 return term + def mk_mul(self, ex: Exec, x: Any, y: Any) -> Any: + term = f_mul[x.size()](x, y) + return term + def arith(self, ex: Exec, op: int, w1: Word, w2: Word) -> Word: w1 = b2i(w1) w2 = b2i(w2) @@ -1348,7 +1356,35 @@ def arith(self, ex: Exec, op: int, w1: Word, w2: Word) -> Word: return w1 - w2 if op == EVM.MUL: - return w1 * w2 + if is_bv_value(w1) and is_bv_value(w2): + return w1 * w2 + + if is_bv_value(w1): + i1: int = w1.as_long() + if i1 == 0: + return w1 + + if i1 == 1: + return w2 + + if is_power_of_two(i1): + return w2 << int(math.log(i1, 2)) + + if is_bv_value(w2): + i2: int = w2.as_long() + if i2 == 0: + return w2 + + if i2 == 1: + return w1 + + if is_power_of_two(i2): + return w1 << int(math.log(i2, 2)) + + if is_bv_value(w1) or is_bv_value(w2): + return w1 * w2 + + return self.mk_mul(ex, w1, w2) if op == EVM.DIV: div_for_overflow_check = self.div_xy_y(w1, w2) @@ -1440,7 +1476,7 @@ def arith(self, ex: Exec, op: int, w1: Word, w2: Word) -> Word: if i2 <= self.options.smt_exp_by_const: exp = w1 for _ in range(i2 - 1): - exp = exp * w1 + exp = self.arith(ex, EVM.MUL, exp, w1) return exp return f_exp(w1, w2) diff --git a/tests/test_sevm.py b/tests/test_sevm.py index 2aa17dac..d8111ed3 100644 --- a/tests/test_sevm.py +++ b/tests/test_sevm.py @@ -9,6 +9,7 @@ from halmos.sevm import ( con, Contract, + f_mul, f_div, f_sdiv, f_mod, @@ -146,7 +147,7 @@ def byte_of(i, x): [ (o(EVM.PUSH0), [], con(0)), (o(EVM.ADD), [x, y], x + y), - (o(EVM.MUL), [x, y], x * y), + (o(EVM.MUL), [x, y], f_mul[x.size()](x, y)), (o(EVM.SUB), [x, y], x - y), (o(EVM.DIV), [x, y], f_div(x, y)), (o(EVM.DIV), [con(5), con(3)], con(1)), @@ -199,13 +200,13 @@ def byte_of(i, x): ( o(EVM.MULMOD), [x, y, con(2**3)], - ZeroExt(253, Extract(2, 0, ZeroExt(256, x) * ZeroExt(256, y))), + ZeroExt(253, Extract(2, 0, f_mul[512](ZeroExt(256, x), ZeroExt(256, y)))), ), ( o(EVM.MULMOD), [x, y, z], Extract( - 255, 0, f_mod[512](ZeroExt(256, x) * ZeroExt(256, y), ZeroExt(256, z)) + 255, 0, f_mod[512](f_mul[512](ZeroExt(256, x), ZeroExt(256, y)), ZeroExt(256, z)) ), ), (o(EVM.MULMOD), [con(10), con(10), con(8)], con(4)), @@ -221,7 +222,7 @@ def byte_of(i, x): (o(EVM.EXP), [x, y], f_exp(x, y)), (o(EVM.EXP), [x, con(0)], con(1)), (o(EVM.EXP), [x, con(1)], x), - (o(EVM.EXP), [x, con(2)], x * x), + (o(EVM.EXP), [x, con(2)], f_mul[x.size()](x, x)), (o(EVM.SIGNEXTEND), [con(0), y], SignExt(248, Extract(7, 0, y))), (o(EVM.SIGNEXTEND), [con(1), y], SignExt(240, Extract(15, 0, y))), (o(EVM.SIGNEXTEND), [con(30), y], SignExt(8, Extract(247, 0, y))),