From 29e3930c213a30b932bbd2cdc149aa988d7e369e Mon Sep 17 00:00:00 2001 From: Oba Date: Tue, 23 Jul 2024 22:27:41 +0200 Subject: [PATCH] fix: assert value is 0 at the end of load_bytecode (#1290) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Time spent on this PR: ## Pull request type Please check the type of change your PR introduces: - [x] Bugfix - [ ] Feature - [ ] Code style update (formatting, renaming) - [ ] Refactoring (no functional changes, no api changes) - [ ] Build related changes - [ ] Documentation content changes - [ ] Other (please describe): ## What is the current behavior? Do not assert the value is 0 at the end of load_bytecode leading to prover being able to Resolves #1279, resolves #1293 ## What is the new behavior? assert value is 0 at the end of load_bytecode - - - This change is [Reviewable](https://reviewable.io/reviews/kkrt-labs/kakarot/1290) --------- Co-authored-by: Clément Walter --- src/kakarot/accounts/library.cairo | 6 + tests/conftest.py | 2 +- tests/fixtures/starknet.py | 41 ++++--- .../kakarot/accounts/test_account_contract.py | 114 +++++++++++------- tests/utils/hints.py | 24 ++++ 5 files changed, 123 insertions(+), 64 deletions(-) diff --git a/src/kakarot/accounts/library.cairo b/src/kakarot/accounts/library.cairo index 6a874ce27..99c6b944e 100644 --- a/src/kakarot/accounts/library.cairo +++ b/src/kakarot/accounts/library.cairo @@ -652,11 +652,17 @@ namespace Internals { jmp cond if remaining_bytes != 0; + with_attr error_message("Value is not empty") { + assert value = 0; + } let bytecode = cast([fp], felt*); return (bytecode=bytecode); cond: jmp body if count != 0; + with_attr error_message("Value is not empty") { + assert value = 0; + } jmp read; } } diff --git a/tests/conftest.py b/tests/conftest.py index 3d8c04274..b5e9b6e4b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -58,7 +58,7 @@ def seed(request): pytest_plugins = ["tests.fixtures.starknet"] settings.register_profile("ci", deadline=None, max_examples=1000) -settings.register_profile("dev", max_examples=10) +settings.register_profile("dev", deadline=None, max_examples=10) settings.register_profile("debug", max_examples=10, verbosity=Verbosity.verbose) settings.load_profile(os.getenv("HYPOTHESIS_PROFILE", "default")) logger.info(f"Using Hypothesis profile: {os.getenv('HYPOTHESIS_PROFILE', 'default')}") diff --git a/tests/fixtures/starknet.py b/tests/fixtures/starknet.py index 86e5fa29d..ccbde8251 100644 --- a/tests/fixtures/starknet.py +++ b/tests/fixtures/starknet.py @@ -72,15 +72,7 @@ def cairo_compile(path): @pytest.fixture(scope="module") -def cairo_run(request) -> list: - """ - Run the cairo program corresponding to the python test file at a given entrypoint with given program inputs as kwargs. - Returns the output of the cairo program put in the output memory segment. - - When --profile-cairo is passed, the cairo program is run with the tracer enabled and the resulting trace is dumped. - - Logic is mainly taken from starkware.cairo.lang.vm.cairo_run with minor updates like the addition of the output segment. - """ +def cairo_program(request) -> list: cairo_file = Path(request.node.fspath).with_suffix(".cairo") if not cairo_file.exists(): raise ValueError(f"Missing cairo file: {cairo_file}") @@ -89,23 +81,36 @@ def cairo_run(request) -> list: program = cairo_compile(cairo_file) stop = perf_counter() logger.info(f"{cairo_file} compiled in {stop - start:.2f}s") + return program + + +@pytest.fixture(scope="module") +def cairo_run(request, cairo_program) -> list: + """ + Run the cairo program corresponding to the python test file at a given entrypoint with given program inputs as kwargs. + Returns the output of the cairo program put in the output memory segment. + + When --profile-cairo is passed, the cairo program is run with the tracer enabled and the resulting trace is dumped. + + Logic is mainly taken from starkware.cairo.lang.vm.cairo_run with minor updates like the addition of the output segment. + """ def _factory(entrypoint, **kwargs) -> list: implicit_args = list( - program.identifiers.get_by_full_name( + cairo_program.identifiers.get_by_full_name( ScopedName(path=["__main__", entrypoint, "ImplicitArgs"]) ).members.keys() ) args = list( - program.identifiers.get_by_full_name( + cairo_program.identifiers.get_by_full_name( ScopedName(path=["__main__", entrypoint, "Args"]) ).members.keys() ) - return_data = program.identifiers.get_by_full_name( + return_data = cairo_program.identifiers.get_by_full_name( ScopedName(path=["__main__", entrypoint, "Return"]) ) # Fix builtins runner based on the implicit args since the compiler doesn't find them - program.builtins = [ + cairo_program.builtins = [ builtin # This list is extracted from the builtin runners # Builtins have to be declared in this order @@ -125,7 +130,7 @@ def _factory(entrypoint, **kwargs) -> list: memory = MemoryDict() runner = CairoRunner( - program=program, + program=cairo_program, layout=request.config.getoption("layout"), memory=memory, proof_mode=request.config.getoption("proof_mode"), @@ -165,7 +170,7 @@ def _factory(entrypoint, **kwargs) -> list: runner.execution_public_memory = list(range(len(stack))) runner.initialize_state( - entrypoint=program.identifiers.get_by_full_name( + entrypoint=cairo_program.identifiers.get_by_full_name( ScopedName(path=["__main__", entrypoint]) ).pc, stack=stack, @@ -178,7 +183,7 @@ def _factory(entrypoint, **kwargs) -> list: "syscall_handler": SyscallHandler(), }, static_locals={ - "debug_info": debug_info(program), + "debug_info": debug_info(cairo_program), "serde": serde, "Opcodes": Opcodes, }, @@ -234,7 +239,7 @@ def _factory(entrypoint, **kwargs) -> list: ) if request.config.getoption("profile_cairo"): tracer_data = TracerData( - program=program, + program=cairo_program, memory=runner.relocated_memory, trace=runner.relocated_trace, debug_info=runner.get_relocated_debug_info(), @@ -253,7 +258,7 @@ def _factory(entrypoint, **kwargs) -> list: write_binary_memory( fp, runner.relocated_memory, - math.ceil(program.prime.bit_length() / 8), + math.ceil(cairo_program.prime.bit_length() / 8), ) rc_min, rc_max = runner.get_perm_range_check_limits() diff --git a/tests/src/kakarot/accounts/test_account_contract.py b/tests/src/kakarot/accounts/test_account_contract.py index 750c86ef0..b71bc93dc 100644 --- a/tests/src/kakarot/accounts/test_account_contract.py +++ b/tests/src/kakarot/accounts/test_account_contract.py @@ -6,6 +6,8 @@ import rlp from eth_account.account import Account from eth_utils import keccak +from hypothesis import given, settings +from hypothesis.strategies import binary from starkware.starknet.public.abi import ( get_selector_from_name, get_storage_var_address, @@ -15,6 +17,7 @@ from tests.utils.constants import CHAIN_ID, TRANSACTION_GAS_LIMIT, TRANSACTIONS from tests.utils.errors import cairo_error from tests.utils.helpers import generate_random_private_key, rlp_encode_signed_data +from tests.utils.hints import patch_hint from tests.utils.syscall_handler import SyscallHandler from tests.utils.uint256 import int_to_uint256 @@ -92,7 +95,7 @@ def test_should_write_bytecode(self, cairo_run, bytecode): SyscallHandler.mock_storage.assert_has_calls(calls) class TestBytecode: - @pytest.fixture + def storage(self, bytecode): chunks = wrap(bytecode.hex(), 2 * 31) @@ -105,9 +108,9 @@ def _storage(address): return _storage - def test_should_read_bytecode(self, cairo_run, bytecode, storage): + def test_should_read_bytecode(self, cairo_run, bytecode): with patch.object( - SyscallHandler, "mock_storage", side_effect=storage + SyscallHandler, "mock_storage", side_effect=self.storage(bytecode) ) as mock_storage: output_len, output = cairo_run("test__bytecode") chunk_counts, remainder = divmod(len(bytecode), 31) @@ -116,6 +119,24 @@ def test_should_read_bytecode(self, cairo_run, bytecode, storage): mock_storage.assert_has_calls(calls) assert output[:output_len] == list(bytecode) + @given(bytecode=binary(min_size=1, max_size=400)) + @settings(max_examples=5) + def test_should_raise_when_read_bytecode_zellic_issue_1279( + self, cairo_program, cairo_run, bytecode + ): + with ( + patch_hint( + cairo_program, + "memory[ids.output] = res = (int(ids.value) % PRIME) % ids.base\nassert res < ids.bound, f'split_int(): Limb {res} is out of range.'", + "memory[ids.output] = res = (int(ids.value) % PRIME + 1) % ids.base\nassert res < ids.bound, f'split_int(): Limb {res} is out of range.'", + ), + patch.object( + SyscallHandler, "mock_storage", side_effect=self.storage(bytecode) + ), + ): + with cairo_error(message="Value is not empty"): + output_len, output = cairo_run("test__bytecode") + class TestNonce: @SyscallHandler.patch("Ownable_owner", 0xDEAD) def test_should_assert_only_owner(self, cairo_run): @@ -367,8 +388,9 @@ def test_should_raise_invalid_signature_for_invalid_chain_id_when_tx_type0_not_p signed.v, ] - with cairo_error(message="Invalid signature."), SyscallHandler.patch( - "Account_evm_address", address + with ( + cairo_error(message="Invalid signature."), + SyscallHandler.patch("Account_evm_address", address), ): cairo_run( "test__execute_from_outside", @@ -403,8 +425,9 @@ def test_should_raise_invalid_chain_id_tx_type_different_from_0( signed.v, ] - with cairo_error(message="Invalid chain id"), SyscallHandler.patch( - "Account_evm_address", address + with ( + cairo_error(message="Invalid chain id"), + SyscallHandler.patch("Account_evm_address", address), ): cairo_run( "test__execute_from_outside", @@ -429,8 +452,9 @@ def test_should_raise_invalid_nonce(self, cairo_run, transaction): encoded_unsigned_tx = rlp_encode_signed_data(transaction) tx_data = list(encoded_unsigned_tx) - with cairo_error(message="Invalid nonce"), SyscallHandler.patch( - "Account_evm_address", address + with ( + cairo_error(message="Invalid nonce"), + SyscallHandler.patch("Account_evm_address", address), ): cairo_run( "test__execute_from_outside", @@ -454,12 +478,10 @@ def test_raise_not_enough_ETH_balance(self, cairo_run, transaction): encoded_unsigned_tx = rlp_encode_signed_data(transaction) tx_data = list(encoded_unsigned_tx) - with cairo_error( - message="Not enough ETH to pay msg.value + max gas fees" - ), SyscallHandler.patch( - "Account_evm_address", address - ), SyscallHandler.patch( - "Account_nonce", transaction.get("nonce", 0) + with ( + cairo_error(message="Not enough ETH to pay msg.value + max gas fees"), + SyscallHandler.patch("Account_evm_address", address), + SyscallHandler.patch("Account_nonce", transaction.get("nonce", 0)), ): cairo_run( "test__execute_from_outside", @@ -486,12 +508,10 @@ def test_raise_transaction_gas_limit_too_high(self, cairo_run, transaction): encoded_unsigned_tx = rlp_encode_signed_data(transaction) tx_data = list(encoded_unsigned_tx) - with cairo_error( - message="Transaction gas_limit > Block gas_limit" - ), SyscallHandler.patch( - "Account_evm_address", address - ), SyscallHandler.patch( - "Account_nonce", transaction.get("nonce", 0) + with ( + cairo_error(message="Transaction gas_limit > Block gas_limit"), + SyscallHandler.patch("Account_evm_address", address), + SyscallHandler.patch("Account_nonce", transaction.get("nonce", 0)), ): cairo_run( "test__execute_from_outside", @@ -523,9 +543,11 @@ def test_raise_max_fee_per_gas_too_low(self, cairo_run, transaction): encoded_unsigned_tx = rlp_encode_signed_data(transaction) tx_data = list(encoded_unsigned_tx) - with cairo_error(message="Max fee per gas too low"), SyscallHandler.patch( - "Account_evm_address", address - ), SyscallHandler.patch("Account_nonce", transaction.get("nonce", 0)): + with ( + cairo_error(message="Max fee per gas too low"), + SyscallHandler.patch("Account_evm_address", address), + SyscallHandler.patch("Account_nonce", transaction.get("nonce", 0)), + ): cairo_run( "test__execute_from_outside", tx_data=tx_data, @@ -565,12 +587,10 @@ def test_raise_max_priority_fee_too_high(self, cairo_run): signed.v, ] - with cairo_error( - message="Max priority fee greater than max fee per gas" - ), SyscallHandler.patch( - "Account_evm_address", address - ), SyscallHandler.patch( - "Account_nonce", transaction["nonce"] + with ( + cairo_error(message="Max priority fee greater than max fee per gas"), + SyscallHandler.patch("Account_evm_address", address), + SyscallHandler.patch("Account_nonce", transaction["nonce"]), ): cairo_run( "test__execute_from_outside", @@ -606,15 +626,17 @@ def test_pass_authorized_pre_eip155_transaction(self, cairo_run): int.from_bytes(keccak(encoded_unsigned_tx), "big") ) - with SyscallHandler.patch( - "Account_evm_address", int(ARACHNID_PROXY_DEPLOYER, 16) - ), SyscallHandler.patch( - "Account_authorized_message_hashes", - tx_hash_low, - tx_hash_high, - 0x1, - ), SyscallHandler.patch( - "Account_nonce", 0 + with ( + SyscallHandler.patch( + "Account_evm_address", int(ARACHNID_PROXY_DEPLOYER, 16) + ), + SyscallHandler.patch( + "Account_authorized_message_hashes", + tx_hash_low, + tx_hash_high, + 0x1, + ), + SyscallHandler.patch("Account_nonce", 0), ): output_len, output = cairo_run( "test__execute_from_outside", @@ -661,9 +683,10 @@ def test_pass_all_transactions_types(self, cairo_run, seed, transaction): encoded_unsigned_tx = rlp_encode_signed_data(transaction) tx_data = list(encoded_unsigned_tx) - with SyscallHandler.patch( - "Account_evm_address", address - ), SyscallHandler.patch("Account_nonce", transaction.get("nonce", 0)): + with ( + SyscallHandler.patch("Account_evm_address", address), + SyscallHandler.patch("Account_nonce", transaction.get("nonce", 0)), + ): output_len, output = cairo_run( "test__execute_from_outside", tx_data=tx_data, @@ -713,9 +736,10 @@ def test_should_pass_all_data_len(self, cairo_run, bytecode): signed.v, ] - with SyscallHandler.patch( - "Account_evm_address", address - ), SyscallHandler.patch("Account_nonce", transaction.get("nonce", 0)): + with ( + SyscallHandler.patch("Account_evm_address", address), + SyscallHandler.patch("Account_nonce", transaction.get("nonce", 0)), + ): output_len, output = cairo_run( "test__execute_from_outside", tx_data=tx_data, diff --git a/tests/utils/hints.py b/tests/utils/hints.py index eb7c852f6..940066394 100644 --- a/tests/utils/hints.py +++ b/tests/utils/hints.py @@ -1,6 +1,9 @@ from collections import defaultdict +from contextlib import contextmanager +from unittest.mock import patch from starkware.cairo.common.dict import DictTracker +from starkware.cairo.lang.compiler.program import CairoHint def debug_info(program): @@ -45,3 +48,24 @@ def new_default_dict( current_ptr=base, ) return base + + +@contextmanager +def patch_hint(program, hint, new_hint): + patched_hints = { + k: [ + ( + hint_ + if hint_.code != hint + else CairoHint( + accessible_scopes=hint_.accessible_scopes, + flow_tracking_data=hint_.flow_tracking_data, + code=new_hint, + ) + ) + for hint_ in v + ] + for k, v in program.hints.items() + } + with patch.object(program, "hints", new=patched_hints): + yield