diff --git a/src/ethereum_test_tools/tests/test_vm.py b/src/ethereum_test_tools/tests/test_vm.py index 6b0a60ec05..0ce958e69f 100644 --- a/src/ethereum_test_tools/tests/test_vm.py +++ b/src/ethereum_test_tools/tests/test_vm.py @@ -382,6 +382,23 @@ def test_macros(): pytest.param( Op.POP(Op.CALL(1, 2, 3, 4, 5, 6, 7)), 0, 0, 7, 0, id="POP(CALL(1, 2, 3, 4, 5, 6, 7))" ), + pytest.param(Op.ADD + Op.RETF, 2, 1, 2, 2, id="Op.ADD + Op.RETF"), + pytest.param( + (Op.ADD + Op.RETF).with_min_stack_height(2), + 2, + 1, + 2, + 2, + id="(Op.ADD + Op.RETF).with_min_stack_height(2)", + ), + pytest.param( + (Op.ADD + Op.RETF).with_min_stack_height(3), + 2, + 1, + 3, + 3, + id="(Op.ADD + Op.RETF).with_min_stack_height(3)", + ), ], ) def test_bytecode_properties( diff --git a/src/ethereum_test_tools/vm/bytecode.py b/src/ethereum_test_tools/vm/bytecode.py index a40678dd66..31f6f32bc5 100644 --- a/src/ethereum_test_tools/vm/bytecode.py +++ b/src/ethereum_test_tools/vm/bytecode.py @@ -137,15 +137,15 @@ def __add__(self, other: "Bytecode | int | None") -> "Bytecode": # Edge case for sum() function return self assert isinstance(other, Bytecode), "Can only concatenate Bytecode instances" - # Figure out the stack height after executing the two opcodes. + # Figure out the stack height after executing the two codes. a_pop, a_push = self.popped_stack_items, self.pushed_stack_items a_min, a_max = self.min_stack_height, self.max_stack_height b_pop, b_push = other.popped_stack_items, other.pushed_stack_items b_min, b_max = other.min_stack_height, other.max_stack_height a_out = a_min - a_pop + a_push - c_pop = max(0, a_pop + (b_pop - a_push)) - c_push = max(0, a_push + b_push - b_pop) + c_pop = a_pop + (b_pop - a_push) if b_pop > a_push else a_pop + c_push = a_push + b_push - b_pop if a_push + b_push > b_pop else b_push c_min = a_min if a_out >= b_min else (b_min - a_out) + a_min c_max = max(a_max + max(0, b_min - a_out), b_max + max(0, a_out - b_min)) @@ -180,6 +180,22 @@ def __mul__(self, other: int) -> "Bytecode": output += self return output + def with_min_stack_height(self, min_stack_height: int) -> "Bytecode": + """ + Set the minimum stack height required by the opcode. + """ + assert min_stack_height >= 0 + return ( + Bytecode( + b"", + popped_stack_items=0, + pushed_stack_items=0, + min_stack_height=min_stack_height, + max_stack_height=min_stack_height, + ) + + self + ) + def hex(self) -> str: """ Return the hexadecimal representation of the opcode byte representation.