diff --git a/README.md b/README.md index 7af8d7db..e6c9ef4e 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ We recommend starting from [Setting up the environment](https://cairo-lang.org/d # Installation instructions You should be able to download the python package zip file directly from -[github](https://github.com/starkware-libs/cairo-lang/releases/tag/v0.13.2) +[github](https://github.com/starkware-libs/cairo-lang/releases/tag/v0.13.3) and install it using ``pip``. See [Setting up the environment](https://cairo-lang.org/docs/quickstart.html). @@ -54,7 +54,7 @@ Once the docker image is built, you can fetch the python package zip file using: ```bash > container_id=$(docker create cairo) -> docker cp ${container_id}:/app/cairo-lang-0.13.2.zip . +> docker cp ${container_id}:/app/cairo-lang-0.13.3.zip . > docker rm -v ${container_id} ``` diff --git a/src/starkware/cairo/common/BUILD b/src/starkware/cairo/common/BUILD index 299c9320..b3b0be46 100644 --- a/src/starkware/cairo/common/BUILD +++ b/src/starkware/cairo/common/BUILD @@ -9,6 +9,7 @@ cairo_library( "bitwise.cairo", "bool.cairo", "cairo_builtins.cairo", + "copy_indices.cairo", "default_dict.cairo", "dict.cairo", "dict_access.cairo", @@ -22,6 +23,7 @@ cairo_library( "invoke.cairo", "keccak.cairo", "keccak_state.cairo", + "log2_ceil.cairo", "math.cairo", "math_cmp.cairo", "memcpy.cairo", diff --git a/src/starkware/cairo/common/copy_indices.cairo b/src/starkware/cairo/common/copy_indices.cairo new file mode 100644 index 00000000..6a0e9810 --- /dev/null +++ b/src/starkware/cairo/common/copy_indices.cairo @@ -0,0 +1,36 @@ +// Copies len field elements from src to dst at the given indices. +// I.e., dst = [src[i] for i in indices]. +func copy_indices(dst: felt*, src: felt*, indices: felt*, len: felt) { + struct LoopFrame { + dst: felt*, + indices: felt*, + } + + if (len == 0) { + return (); + } + + %{ vm_enter_scope({'n': ids.len}) %} + tempvar frame = LoopFrame(dst=dst, indices=indices); + + loop: + let frame = [cast(ap - LoopFrame.SIZE, LoopFrame*)]; + assert [frame.dst] = src[[frame.indices]]; + + let continue_copying = [ap]; + // Reserve space for continue_copying. + let next_frame = cast(ap + 1, LoopFrame*); + next_frame.dst = frame.dst + 1, ap++; + next_frame.indices = frame.indices + 1, ap++; + %{ + n -= 1 + ids.continue_copying = 1 if n > 0 else 0 + %} + static_assert next_frame + LoopFrame.SIZE == ap + 1; + jmp loop if continue_copying != 0, ap++; + // Assert that the loop executed len times. + len = cast(next_frame.indices, felt) - cast(indices, felt); + + %{ vm_exit_scope() %} + return (); +} diff --git a/src/starkware/cairo/common/log2_ceil.cairo b/src/starkware/cairo/common/log2_ceil.cairo new file mode 100644 index 00000000..7f17606b --- /dev/null +++ b/src/starkware/cairo/common/log2_ceil.cairo @@ -0,0 +1,28 @@ +from starkware.cairo.common.math import assert_in_range, assert_not_zero +from starkware.cairo.common.pow import pow + +// Returns the ceil value of the log2 of the given value. +// Enforces that 1 <= value <= RANGE_CHECK_BOUND. +func log2_ceil{range_check_ptr}(value: felt) -> felt { + alloc_locals; + assert_not_zero(value); + if (value == 1) { + return 0; + } + + local res; + %{ + from starkware.python.math_utils import log2_ceil + ids.res = log2_ceil(ids.value) + %} + + // Verify that 1 <= 2**(res - 1) < value <= 2**res <= RANGE_CHECK_BOUND. + // The RANGE_CHECK_BOUND bound is required by the `assert_in_range` function. + assert_in_range(res, 1, 128 + 1); + let (lower_bound) = pow(2, res - 1); + let min = lower_bound + 1; + let max = 2 * lower_bound; + assert_in_range(value, min, max + 1); + + return res; +} diff --git a/src/starkware/cairo/lang/BUILD b/src/starkware/cairo/lang/BUILD index 47d9027b..80a5aa17 100644 --- a/src/starkware/cairo/lang/BUILD +++ b/src/starkware/cairo/lang/BUILD @@ -10,6 +10,7 @@ py_library( ], data = [ "@" + CAIRO_COMPILER_ARCHIVE, + "//src/starkware/starknet/core/os:starknet_os_program_cairo_lib", ], visibility = ["//visibility:public"], deps = [ diff --git a/src/starkware/cairo/lang/VERSION b/src/starkware/cairo/lang/VERSION index 9beb74d4..288adf53 100644 --- a/src/starkware/cairo/lang/VERSION +++ b/src/starkware/cairo/lang/VERSION @@ -1 +1 @@ -0.13.2 +0.13.3 diff --git a/src/starkware/cairo/lang/ide/vscode-cairo/package.json b/src/starkware/cairo/lang/ide/vscode-cairo/package.json index d165ea5a..eb29a186 100644 --- a/src/starkware/cairo/lang/ide/vscode-cairo/package.json +++ b/src/starkware/cairo/lang/ide/vscode-cairo/package.json @@ -2,7 +2,7 @@ "name": "cairo", "displayName": "CairoZero", "description": "Support Cairo syntax", - "version": "0.13.2", + "version": "0.13.3", "engines": { "vscode": "^1.30.0" }, diff --git a/src/starkware/starknet/business_logic/execution/execute_entry_point.py b/src/starkware/starknet/business_logic/execution/execute_entry_point.py index c377a4f8..672b15aa 100644 --- a/src/starkware/starknet/business_logic/execution/execute_entry_point.py +++ b/src/starkware/starknet/business_logic/execution/execute_entry_point.py @@ -40,7 +40,7 @@ from starkware.starknet.core.os.deprecated_syscall_handler import DeprecatedBlSyscallHandler from starkware.starknet.core.os.syscall_handler import BusinessLogicSyscallHandler from starkware.starknet.definitions import fields -from starkware.starknet.definitions.constants import GasCost +from starkware.starknet.definitions.constants import VERSIONED_CONSTANTS, GasCost from starkware.starknet.definitions.error_codes import StarknetErrorCode from starkware.starknet.definitions.execution_mode import ExecutionMode from starkware.starknet.definitions.general_config import ( @@ -133,7 +133,7 @@ def sync_execute_for_testing( state=state, resources_manager=ExecutionResourcesManager.empty(), tx_execution_context=TransactionExecutionContext.create_for_testing( - n_steps=general_config.invoke_tx_max_n_steps + n_steps=VERSIONED_CONSTANTS.invoke_tx_max_n_steps, ), general_config=general_config, support_reverted=support_reverted, @@ -146,10 +146,13 @@ async def execute_for_testing( resources_manager: Optional[ExecutionResourcesManager] = None, tx_execution_context: Optional[TransactionExecutionContext] = None, execution_mode: ExecutionMode = ExecutionMode.EXECUTE, + invoke_tx_max_n_steps: Optional[int] = None, ) -> CallInfo: + if invoke_tx_max_n_steps is None: + invoke_tx_max_n_steps = VERSIONED_CONSTANTS.invoke_tx_max_n_steps if tx_execution_context is None: tx_execution_context = TransactionExecutionContext.create_for_testing( - n_steps=general_config.invoke_tx_max_n_steps, execution_mode=execution_mode + n_steps=invoke_tx_max_n_steps, execution_mode=execution_mode ) if resources_manager is None: diff --git a/src/starkware/starknet/business_logic/fact_state/BUILD b/src/starkware/starknet/business_logic/fact_state/BUILD index 729ce2c6..451d6f7d 100644 --- a/src/starkware/starknet/business_logic/fact_state/BUILD +++ b/src/starkware/starknet/business_logic/fact_state/BUILD @@ -15,6 +15,7 @@ py_library( "//src/starkware/cairo/lang/vm:cairo_vm_crypto_lib", "//src/starkware/python:starkware_python_utils_lib", "//src/starkware/starknet/business_logic/state:starknet_business_logic_state_lib", + "//src/starkware/starknet/core/aggregator:cairo_aggregator_lib", "//src/starkware/starknet/definitions:starknet_definitions_lib", "//src/starkware/starknet/definitions:starknet_general_config_lib", "//src/starkware/starkware_utils:starkware_config_utils_lib", diff --git a/src/starkware/starknet/business_logic/fact_state/state.py b/src/starkware/starknet/business_logic/fact_state/state.py index 9d01c0e6..e6e1d908 100644 --- a/src/starkware/starknet/business_logic/fact_state/state.py +++ b/src/starkware/starknet/business_logic/fact_state/state.py @@ -35,6 +35,7 @@ from starkware.starknet.business_logic.state.state import CachedState from starkware.starknet.business_logic.state.state_api import StateReader from starkware.starknet.business_logic.state.state_api_objects import BlockInfo +from starkware.starknet.core.aggregator.output_parser import ContractChanges, OsStateDiff from starkware.starknet.definitions import constants, fields from starkware.starknet.definitions.data_availability_mode import DataAvailabilityMode from starkware.starknet.definitions.general_config import StarknetGeneralConfig @@ -479,14 +480,6 @@ def backward_compatibility_before_data_availability_modes( return data - async def write(self, storage: Storage) -> bytes: - """ - Writes an entry containing the state diff to the storage under its hash, as a fact object. - """ - hash_value = self.calculate_hash() - await self.set(storage=storage, suffix=hash_value) - return hash_value - @classmethod def empty(cls, block_info: BlockInfo): """ @@ -500,6 +493,32 @@ def empty(cls, block_info: BlockInfo): block_info=block_info, ) + async def write(self, storage: Storage, batch_id: int) -> bytes: + """ + Writes the state diff to the storage under the given batch_id. + Returns the key suffix (serialized batch_id). + """ + suffix = str(batch_id).encode("ascii") + await self.set(storage=storage, suffix=suffix) + return suffix + + @classmethod + def create_l1_da_mode( + cls, + address_to_class_hash: Mapping[int, int], + nonces: Mapping[int, int], + storage_updates: Mapping[int, Mapping[int, int]], + declared_classes: Mapping[int, int], + block_info: BlockInfo, + ) -> "StateDiff": + return cls( + address_to_class_hash=address_to_class_hash, + nonces={DataAvailabilityMode.L1: nonces}, + storage_updates={DataAvailabilityMode.L1: storage_updates}, + declared_classes=declared_classes, + block_info=block_info, + ) + @classmethod def from_cached_state(cls, cached_state: CachedState) -> "StateDiff": state_cache = cached_state.cache @@ -582,6 +601,91 @@ async def commit( block_info=self.block_info, ) + def get_os_encoded_length(self) -> int: + """ + Returns the length of the OS encoded representation of the state diff. + See src/starkware/starknet/core/os/state/output.cairo. + """ + return len(self.to_os_state_diff().encode()) + + def get_marginal_os_encoded_length(self, previous_state_diff: Optional["StateDiff"]) -> int: + """ + Returns the marginal addition of self to the given state diff's length. + + E.g., the following are equivalent: + * (a + b).get_os_encoded_length() + * a.get_os_encoded_length() + b.get_marginal_os_encoded_length(a) + """ + if previous_state_diff is None: + return self.get_os_encoded_length() + + pre_squash_size = previous_state_diff.get_os_encoded_length() + post_squash_size = previous_state_diff.squash(other=self).get_os_encoded_length() + return post_squash_size - pre_squash_size + + def to_os_state_diff(self) -> OsStateDiff: + self.assert_l1_da_mode() + nonces = self.nonces.get(DataAvailabilityMode.L1, {}) + storage_updates = self.storage_updates.get(DataAvailabilityMode.L1, {}) + modified_contracts = sorted( + self.address_to_class_hash.keys() | nonces.keys() | storage_updates.keys() + ) + return OsStateDiff( + contracts=[ + ContractChanges( + addr=addr, + new_nonce=nonces.get(addr, None), + new_class_hash=self.address_to_class_hash.get(addr, None), + storage_changes={ + key: (None, value) for key, value in storage_updates.get(addr, {}).items() + }, + # Only relevant for `full_output` mode. + prev_nonce=None, + prev_class_hash=None, + ) + for addr in modified_contracts + ], + classes={ + class_hash: (None, compiled_class_hash) + for class_hash, compiled_class_hash in self.declared_classes.items() + }, + ) + + @classmethod + def from_os_state_diff(cls, os_state_diff: OsStateDiff, block_info: BlockInfo) -> "StateDiff": + contracts = os_state_diff.contracts + return cls.create_l1_da_mode( + address_to_class_hash={ + contract.addr: contract.new_class_hash + for contract in contracts + if contract.new_class_hash is not None + }, + nonces={ + contract.addr: contract.new_nonce + for contract in contracts + if contract.new_nonce is not None + }, + storage_updates={ + contract.addr: { + key: value for key, (_prev, value) in contract.storage_changes.items() + } + for contract in contracts + if len(contract.storage_changes) > 0 + }, + declared_classes={ + class_hash: compiled_class_hash + for class_hash, (_prev, compiled_class_hash) in os_state_diff.classes.items() + }, + block_info=block_info, + ) + + def assert_l1_da_mode(self): + supported_da_modes = {DataAvailabilityMode.L1} + unsupported_da_modes = ( + set(self.nonces.keys() | self.storage_updates.keys()) - supported_da_modes + ) + assert len(unsupported_da_modes) == 0, f"Unsupported DA modes: {unsupported_da_modes}." + @marshmallow_dataclass.dataclass(frozen=True) class DeprecatedStateDiff(EverestStateDiff, DBObject): diff --git a/src/starkware/starknet/business_logic/state/state_api_objects.py b/src/starkware/starknet/business_logic/state/state_api_objects.py index f142868a..1fec8b43 100644 --- a/src/starkware/starknet/business_logic/state/state_api_objects.py +++ b/src/starkware/starknet/business_logic/state/state_api_objects.py @@ -92,7 +92,7 @@ class BlockInfo(ValidatedMarshmallowDataclass): # The sequencer address of this block. sequencer_address: Optional[int] = field(metadata=fields.optional_sequencer_address_metadata) - # The version of Starknet system (e.g., "0.13.2"). + # The version of Starknet system (e.g., "0.13.3"). starknet_version: Optional[str] = field(metadata=fields.starknet_version_metadata) # Indicates whether to use KZG commitment scheme for the block's Data Avilability. @@ -105,12 +105,12 @@ def rename_old_gas_price_fields( return rename_old_gas_price_fields(data=data) @classmethod - def empty(cls, sequencer_address: Optional[int]) -> "BlockInfo": + def empty(cls, sequencer_address: Optional[int], block_number: int = -1) -> "BlockInfo": """ Returns an empty BlockInfo object; i.e., the one before the first in the chain. """ return cls( - block_number=-1, + block_number=block_number, block_timestamp=0, # As gas prices must be non-zero, just use 1 for all prices. l1_gas_price=ResourcePrice(price_in_wei=1, price_in_fri=1), diff --git a/src/starkware/starknet/business_logic/transaction/deprecated_objects.py b/src/starkware/starknet/business_logic/transaction/deprecated_objects.py index 1221f03c..a5f5fed2 100644 --- a/src/starkware/starknet/business_logic/transaction/deprecated_objects.py +++ b/src/starkware/starknet/business_logic/transaction/deprecated_objects.py @@ -282,7 +282,8 @@ def run_validate_entrypoint( resources_manager=resources_manager, general_config=general_config, tx_execution_context=self.get_execution_context( - n_steps=general_config.validate_max_n_steps, execution_mode=ExecutionMode.VALIDATE + n_steps=general_config.get_validate_max_n_steps(), + execution_mode=ExecutionMode.VALIDATE, ), ) @@ -336,7 +337,8 @@ def charge_fee( general_config=general_config, state=state, tx_execution_context=self.get_execution_context( - n_steps=general_config.invoke_tx_max_n_steps, execution_mode=ExecutionMode.EXECUTE + n_steps=constants.VERSIONED_CONSTANTS.invoke_tx_max_n_steps, + execution_mode=ExecutionMode.EXECUTE, ), actual_fee=actual_fee, ) @@ -885,7 +887,8 @@ def run_constructor_entrypoint( resources_manager=resources_manager, general_config=general_config, tx_execution_context=self.get_execution_context( - n_steps=general_config.validate_max_n_steps, execution_mode=ExecutionMode.VALIDATE + n_steps=general_config.get_validate_max_n_steps(), + execution_mode=ExecutionMode.VALIDATE, ), ) @@ -1098,7 +1101,7 @@ def invoke_constructor( signature=[], max_fee=0, nonce=0, - n_steps=general_config.invoke_tx_max_n_steps, + n_steps=constants.VERSIONED_CONSTANTS.invoke_tx_max_n_steps, version=self.version, execution_mode=ExecutionMode.EXECUTE, ) @@ -1384,7 +1387,7 @@ def run_execute_entrypoint( resources_manager=resources_manager, general_config=general_config, tx_execution_context=self.get_execution_context( - n_steps=general_config.invoke_tx_max_n_steps, + n_steps=constants.VERSIONED_CONSTANTS.invoke_tx_max_n_steps, execution_mode=ExecutionMode.EXECUTE, ), ) @@ -1510,7 +1513,7 @@ def _apply_specific_concurrent_changes( resources_manager=resources_manager, general_config=general_config, tx_execution_context=self.get_execution_context( - n_steps=general_config.invoke_tx_max_n_steps + n_steps=constants.VERSIONED_CONSTANTS.invoke_tx_max_n_steps, ), ) diff --git a/src/starkware/starknet/core/aggregator/BUILD b/src/starkware/starknet/core/aggregator/BUILD index b1adf299..1f02c289 100644 --- a/src/starkware/starknet/core/aggregator/BUILD +++ b/src/starkware/starknet/core/aggregator/BUILD @@ -14,7 +14,7 @@ py_library( "//src/starkware/cairo/lang/vm:cairo_relocatable_lib", "//src/starkware/cairo/lang/vm:cairo_vm_lib", "//src/starkware/python:starkware_python_utils_lib", - "//src/starkware/starknet/core/os:kzg_manager_lib", + "//src/starkware/starknet/core/os/data_availability:compression", "//src/starkware/starknet/definitions:starknet_definitions_lib", ], ) @@ -29,6 +29,9 @@ cairo_binary( "--debug_info_with_source", ], compiled_program_name = "aggregator.json", + hint_deps = [ + "//src/starkware/starknet/core/os:kzg_manager_lib", + ], main = "main.cairo", deps = [ "//src/starkware/starknet/core/os:output", @@ -43,7 +46,7 @@ pytest_test( "aggregator_test.py", ], data = [ - ":aggregator.json", + ":aggregator", "//src/starkware/starknet/core/os/data_availability:bls_field", ], deps = [ @@ -56,6 +59,7 @@ pytest_test( "//src/starkware/cairo/lang/compiler:cairo_compile_lib", "//src/starkware/cairo/lang/vm:cairo_vm_lib", "//src/starkware/python:starkware_python_test_utils_lib", + "//src/starkware/starknet/core/os/data_availability:compression", ], ) diff --git a/src/starkware/starknet/core/aggregator/aggregator_test.py b/src/starkware/starknet/core/aggregator/aggregator_test.py index 919420f3..c664513b 100644 --- a/src/starkware/starknet/core/aggregator/aggregator_test.py +++ b/src/starkware/starknet/core/aggregator/aggregator_test.py @@ -1,6 +1,7 @@ +import itertools import os from enum import Enum, auto -from typing import List, Optional, Tuple +from typing import List, Optional, Sequence, Tuple import pytest @@ -16,12 +17,15 @@ from starkware.cairo.lang.vm.vm_exceptions import VmException from starkware.python.test_utils import maybe_raises from starkware.starknet.core.aggregator.output_parser import ( + N_UPDATES_SMALL_PACKING_BOUND, ContractChanges, OsOutput, + OsStateDiff, TaskOutput, parse_bootloader_output, ) from starkware.starknet.core.aggregator.utils import OsOutputToCairo +from starkware.starknet.core.os.data_availability.compression import compress # Dummy values for the test. OS_PROGRAM_HASH = 0x7E0B89C77D0003C05511B9F0E1416F1328C2132E41E056B2EF3BC950135360F @@ -102,16 +106,8 @@ def aggregator_program() -> Program: return Program.loads(data=open(AGGREGATOR_COMPILED_PATH).read()) -def contract_header_packed_word( - n_updates: int, prev_nonce: int, new_nonce: int, class_updated: int, full_output: bool -) -> int: - """ - Returns the second word of the contract header. - """ - if full_output: - return n_updates + prev_nonce * 2**64 + new_nonce * 2**128 + class_updated * 2**192 - else: - return n_updates + new_nonce * 2**64 + class_updated * 2**128 +def remove_none_values(array: Sequence[Optional[int]]) -> List[int]: + return [x for x in array if x is not None] class FailureModifier(Enum): @@ -126,7 +122,7 @@ class FailureModifier(Enum): def block0_output(full_output: bool): - res = [ + partial_res_with_nones = [ # initial_root. ROOT0, # final_root. @@ -152,11 +148,14 @@ def block0_output(full_output: bool): # Messages to L2. len(MSG_TO_L2_0), *MSG_TO_L2_0, + ] + partial_res = remove_none_values(partial_res_with_nones) + da_with_nones = [ # Number of contracts. 2, # Contract addr. CONTRACT_ADDR0, - contract_header_packed_word( + ContractChanges.encode_header( n_updates=3, prev_nonce=0, new_nonce=1, class_updated=1, full_output=full_output ), # Class hash. @@ -175,7 +174,7 @@ def block0_output(full_output: bool): # Contract whose block0 changes are fully reverted by block1. # Contract addr. CONTRACT_ADDR1, - contract_header_packed_word( + ContractChanges.encode_header( n_updates=1, prev_nonce=10, new_nonce=10, class_updated=1, full_output=full_output ), # Class hash. @@ -195,12 +194,15 @@ def block0_output(full_output: bool): COMPILED_CLASS_HASH1_0 if full_output else None, COMPILED_CLASS_HASH1_1, ] - return [x for x in res if x is not None] + da = remove_none_values(da_with_nones) + if not full_output: + da = compress(data=da) + return partial_res + da def block1_output(full_output: bool, modifier: FailureModifier = FailureModifier.NONE): maybe_wrong = lambda x, modifier0: x + (10 if modifier == modifier0 else 0) - res = [ + partial_res_with_nones = [ # initial_root. maybe_wrong(ROOT1, FailureModifier.ROOT), # final_root. @@ -226,11 +228,14 @@ def block1_output(full_output: bool, modifier: FailureModifier = FailureModifier # Messages to L2. len(MSG_TO_L2_1), *MSG_TO_L2_1, + ] + partial_res = remove_none_values(partial_res_with_nones) + da_with_nones = [ # Number of contracts. 3, # Contract addr. CONTRACT_ADDR0, - contract_header_packed_word( + ContractChanges.encode_header( n_updates=2, prev_nonce=1, new_nonce=2, class_updated=0, full_output=full_output ), # Class hash. @@ -246,7 +251,7 @@ def block1_output(full_output: bool, modifier: FailureModifier = FailureModifier # Contract whose block0 changes are fully reverted by block1. # Contract addr. CONTRACT_ADDR1, - contract_header_packed_word( + ContractChanges.encode_header( n_updates=1, prev_nonce=10, new_nonce=10, class_updated=1, full_output=full_output ), # Class hash. @@ -259,8 +264,12 @@ def block1_output(full_output: bool, modifier: FailureModifier = FailureModifier # Contract that only appears in this block (block1). # Contract addr. CONTRACT_ADDR2, - contract_header_packed_word( - n_updates=1, prev_nonce=7, new_nonce=8, class_updated=0, full_output=full_output + ContractChanges.encode_header( + n_updates=1 + N_UPDATES_SMALL_PACKING_BOUND, + prev_nonce=7, + new_nonce=8, + class_updated=0, + full_output=full_output, ), # Class hash. CLASS_HASH2_0 if full_output else None, @@ -269,6 +278,15 @@ def block1_output(full_output: bool, modifier: FailureModifier = FailureModifier STORAGE_KEY4, STORAGE_VALUE4_0 if full_output else None, STORAGE_VALUE4_2, + # Write 256 values to test contract header packing with a large number of updates. + *itertools.chain.from_iterable( + ( + STORAGE_KEY4 + 1 + i, + 0 if full_output else None, + 1, + ) + for i in range(N_UPDATES_SMALL_PACKING_BOUND) + ), # Number of classes. 1, CLASS_HASH0_0, @@ -279,7 +297,10 @@ def block1_output(full_output: bool, modifier: FailureModifier = FailureModifier ), COMPILED_CLASS_HASH0_2, ] - return [x for x in res if x is not None] + da = remove_none_values(da_with_nones) + if not full_output: + da = compress(data=da) + return partial_res + da def combined_output(full_output: bool, use_kzg_da: bool = False): @@ -315,16 +336,16 @@ def combined_output(full_output: bool, use_kzg_da: bool = False): *MSG_TO_L2_1, *([] if use_kzg_da else da), ] - return [x for x in res if x is not None] + return remove_none_values(res) def combined_output_da(full_output: bool): - res = [ + res_with_nones = [ # Number of contracts. 3 if full_output else 2, # Contract addr. CONTRACT_ADDR0, - contract_header_packed_word( + ContractChanges.encode_header( n_updates=4, prev_nonce=0, new_nonce=2, class_updated=1, full_output=full_output ), # Class hash. @@ -346,7 +367,7 @@ def combined_output_da(full_output: bool): # Contract addr. CONTRACT_ADDR1 if full_output else None, ( - contract_header_packed_word( + ContractChanges.encode_header( n_updates=0, prev_nonce=10, new_nonce=10, class_updated=0, full_output=full_output ) if full_output @@ -357,8 +378,12 @@ def combined_output_da(full_output: bool): CLASS_HASH1_0 if full_output else None, # Contract addr. CONTRACT_ADDR2, - contract_header_packed_word( - n_updates=1, prev_nonce=7, new_nonce=8, class_updated=0, full_output=full_output + ContractChanges.encode_header( + n_updates=1 + N_UPDATES_SMALL_PACKING_BOUND, + prev_nonce=7, + new_nonce=8, + class_updated=0, + full_output=full_output, ), # Class hash. CLASS_HASH2_0 if full_output else None, @@ -367,6 +392,14 @@ def combined_output_da(full_output: bool): STORAGE_KEY4, STORAGE_VALUE4_0 if full_output else None, STORAGE_VALUE4_2, + *itertools.chain.from_iterable( + ( + STORAGE_KEY4 + 1 + i, + 0 if full_output else None, + 1, + ) + for i in range(N_UPDATES_SMALL_PACKING_BOUND) + ), # Number of classes. 2, # Class updates. @@ -377,7 +410,10 @@ def combined_output_da(full_output: bool): COMPILED_CLASS_HASH1_0 if full_output else None, COMPILED_CLASS_HASH1_1, ] - return [x for x in res if x is not None] + res = remove_none_values(res_with_nones) + if not full_output: + return compress(data=res) + return res def combined_kzg_info(da: List[int]) -> List[int]: @@ -423,52 +459,54 @@ def test_output_parser(full_output: bool): full_output=1 if full_output else 0, messages_to_l1=MSG_TO_L1_0, messages_to_l2=MSG_TO_L2_0, - contracts=[ - ContractChanges( - addr=CONTRACT_ADDR0, - prev_nonce=0 if full_output else None, - new_nonce=1, - prev_class_hash=CLASS_HASH0_0 if full_output else None, - new_class_hash=CLASS_HASH0_1, - storage_changes={ - STORAGE_KEY0: ( - STORAGE_VALUE0_0 if full_output else None, - STORAGE_VALUE0_1, - ), - STORAGE_KEY1: ( - STORAGE_VALUE1_0 if full_output else None, - STORAGE_VALUE1_1, - ), - STORAGE_KEY2: ( - STORAGE_VALUE2_0 if full_output else None, - STORAGE_VALUE2_1, - ), - }, - ), - ContractChanges( - addr=CONTRACT_ADDR1, - prev_nonce=10 if full_output else None, - new_nonce=10, - prev_class_hash=CLASS_HASH1_0 if full_output else None, - new_class_hash=CLASS_HASH1_1, - storage_changes={ - STORAGE_KEY0: ( - STORAGE_VALUE0_0 if full_output else None, - STORAGE_VALUE0_1, - ), - }, - ), - ], - classes={ - CLASS_HASH0_0: ( - COMPILED_CLASS_HASH0_0 if full_output else None, - COMPILED_CLASS_HASH0_1, - ), - CLASS_HASH1_0: ( - COMPILED_CLASS_HASH1_0 if full_output else None, - COMPILED_CLASS_HASH1_1, - ), - }, + state_diff=OsStateDiff( + contracts=[ + ContractChanges( + addr=CONTRACT_ADDR0, + prev_nonce=0 if full_output else None, + new_nonce=1, + prev_class_hash=CLASS_HASH0_0 if full_output else None, + new_class_hash=CLASS_HASH0_1, + storage_changes={ + STORAGE_KEY0: ( + STORAGE_VALUE0_0 if full_output else None, + STORAGE_VALUE0_1, + ), + STORAGE_KEY1: ( + STORAGE_VALUE1_0 if full_output else None, + STORAGE_VALUE1_1, + ), + STORAGE_KEY2: ( + STORAGE_VALUE2_0 if full_output else None, + STORAGE_VALUE2_1, + ), + }, + ), + ContractChanges( + addr=CONTRACT_ADDR1, + prev_nonce=10 if full_output else None, + new_nonce=10 if full_output else None, + prev_class_hash=CLASS_HASH1_0 if full_output else None, + new_class_hash=CLASS_HASH1_1, + storage_changes={ + STORAGE_KEY0: ( + STORAGE_VALUE0_0 if full_output else None, + STORAGE_VALUE0_1, + ), + }, + ), + ], + classes={ + CLASS_HASH0_0: ( + COMPILED_CLASS_HASH0_0 if full_output else None, + COMPILED_CLASS_HASH0_1, + ), + CLASS_HASH1_0: ( + COMPILED_CLASS_HASH1_0 if full_output else None, + COMPILED_CLASS_HASH1_1, + ), + }, + ), ), ), TaskOutput( @@ -486,57 +524,63 @@ def test_output_parser(full_output: bool): full_output=1 if full_output else 0, messages_to_l1=MSG_TO_L1_1, messages_to_l2=MSG_TO_L2_1, - contracts=[ - ContractChanges( - addr=CONTRACT_ADDR0, - prev_nonce=1 if full_output else None, - new_nonce=2, - prev_class_hash=CLASS_HASH0_1 if full_output else None, - new_class_hash=CLASS_HASH0_1 if full_output else None, - storage_changes={ - STORAGE_KEY0: ( - STORAGE_VALUE0_1 if full_output else None, - STORAGE_VALUE0_2, - ), - STORAGE_KEY3: ( - STORAGE_VALUE3_0 if full_output else None, - STORAGE_VALUE3_2, - ), - }, - ), - ContractChanges( - addr=CONTRACT_ADDR1, - prev_nonce=10 if full_output else None, - new_nonce=10, - prev_class_hash=CLASS_HASH1_1 if full_output else None, - new_class_hash=CLASS_HASH1_0, - storage_changes={ - STORAGE_KEY0: ( - STORAGE_VALUE0_1 if full_output else None, - STORAGE_VALUE0_0, - ), - }, - ), - ContractChanges( - addr=CONTRACT_ADDR2, - prev_nonce=7 if full_output else None, - new_nonce=8, - prev_class_hash=CLASS_HASH2_0 if full_output else None, - new_class_hash=CLASS_HASH2_0 if full_output else None, - storage_changes={ - STORAGE_KEY4: ( - STORAGE_VALUE4_0 if full_output else None, - STORAGE_VALUE4_2, - ), - }, - ), - ], - classes={ - CLASS_HASH0_0: ( - COMPILED_CLASS_HASH0_1 if full_output else None, - COMPILED_CLASS_HASH0_2, - ), - }, + state_diff=OsStateDiff( + contracts=[ + ContractChanges( + addr=CONTRACT_ADDR0, + prev_nonce=1 if full_output else None, + new_nonce=2, + prev_class_hash=CLASS_HASH0_1 if full_output else None, + new_class_hash=CLASS_HASH0_1 if full_output else None, + storage_changes={ + STORAGE_KEY0: ( + STORAGE_VALUE0_1 if full_output else None, + STORAGE_VALUE0_2, + ), + STORAGE_KEY3: ( + STORAGE_VALUE3_0 if full_output else None, + STORAGE_VALUE3_2, + ), + }, + ), + ContractChanges( + addr=CONTRACT_ADDR1, + prev_nonce=10 if full_output else None, + new_nonce=10 if full_output else None, + prev_class_hash=CLASS_HASH1_1 if full_output else None, + new_class_hash=CLASS_HASH1_0, + storage_changes={ + STORAGE_KEY0: ( + STORAGE_VALUE0_1 if full_output else None, + STORAGE_VALUE0_0, + ), + }, + ), + ContractChanges( + addr=CONTRACT_ADDR2, + prev_nonce=7 if full_output else None, + new_nonce=8, + prev_class_hash=CLASS_HASH2_0 if full_output else None, + new_class_hash=CLASS_HASH2_0 if full_output else None, + storage_changes={ + STORAGE_KEY4: ( + STORAGE_VALUE4_0 if full_output else None, + STORAGE_VALUE4_2, + ), + **{ + STORAGE_KEY4 + 1 + i: (0 if full_output else None, 1) + for i in range(N_UPDATES_SMALL_PACKING_BOUND) + }, + }, + ), + ], + classes={ + CLASS_HASH0_0: ( + COMPILED_CLASS_HASH0_1 if full_output else None, + COMPILED_CLASS_HASH0_2, + ), + }, + ), ), ), ] diff --git a/src/starkware/starknet/core/aggregator/output_parser.py b/src/starkware/starknet/core/aggregator/output_parser.py index d791c0de..0c754c0c 100644 --- a/src/starkware/starknet/core/aggregator/output_parser.py +++ b/src/starkware/starknet/core/aggregator/output_parser.py @@ -2,8 +2,14 @@ import itertools from typing import Dict, Iterator, List, Optional, Tuple +from starkware.starknet.core.os.data_availability.compression import compress, decompress from starkware.starknet.definitions.constants import OsOutputConstant +N_UPDATES_BOUND = 2**64 +N_UPDATES_SMALL_PACKING_BOUND = 2**8 +NONCE_BOUND = 2**64 +FLAG_BOUND = 2**1 + @dataclasses.dataclass class ContractChanges: @@ -13,17 +19,116 @@ class ContractChanges: # The address of the contract. addr: int - # The previous nonce of the contract (for account contracts, optional). + # The previous nonce of the contract (for account contracts, if full output). prev_nonce: Optional[int] - # The new nonce of the contract (for account contracts). - new_nonce: int - # The previous class hash (if changed). + # The new nonce of the contract (for account contracts, if changed or full output). + new_nonce: Optional[int] + # The previous class hash (if full output). prev_class_hash: Optional[int] - # The new class hash (if changed). + # The new class hash (if changed or full output). new_class_hash: Optional[int] # A map from storage key to its prev value (optional) and new value. storage_changes: Dict[int, Tuple[Optional[int], int]] + def encode(self, full_output: bool = False) -> List[int]: + """ + Returns the OS encoding of the contract diff. + """ + was_class_updated = self.prev_class_hash != self.new_class_hash + header_packed_word = self.encode_header( + n_updates=len(self.storage_changes), + prev_nonce=self.prev_nonce, + new_nonce=self.new_nonce, + class_updated=was_class_updated, + full_output=full_output, + ) + res = [self.addr, header_packed_word] + + if full_output: + assert self.prev_class_hash is not None, "Prev class_hash is missing with full_output." + assert self.new_class_hash is not None, "New class_hash is missing with full_output." + res += [self.prev_class_hash, self.new_class_hash] + else: + if was_class_updated: + assert self.new_class_hash is not None + res.append(self.new_class_hash) + + res += encode_key_value_pairs(self.storage_changes) + return res + + @staticmethod + def encode_header( + n_updates: int, + prev_nonce: Optional[int], + new_nonce: Optional[int], + class_updated: int, + full_output: bool, + ) -> int: + """ + Returns the encoded contract header word. + """ + if full_output: + assert prev_nonce is not None, "Prev nonce is missing with full_output." + assert new_nonce is not None, "New nonce is missing with full_output." + packed_nonces = prev_nonce * NONCE_BOUND + new_nonce + else: + if new_nonce is None or prev_nonce == new_nonce: + # The nonce was not changed. + packed_nonces = 0 + else: + packed_nonces = new_nonce + + is_n_updates_small = n_updates < N_UPDATES_SMALL_PACKING_BOUND + n_updates_bound = N_UPDATES_SMALL_PACKING_BOUND if is_n_updates_small else N_UPDATES_BOUND + + header_packed_word = packed_nonces + header_packed_word = header_packed_word * n_updates_bound + n_updates + header_packed_word = header_packed_word * FLAG_BOUND + int(is_n_updates_small) + header_packed_word = header_packed_word * FLAG_BOUND + int(class_updated) + return header_packed_word + + +@dataclasses.dataclass(frozen=True) +class OsStateDiff: + """ + Represents the state diff. + """ + + # Contracts that were changed. + contracts: List[ContractChanges] + # Classes that were declared. A map from class hash to previous (optional) and new + # compiled class hash. + classes: Dict[int, Tuple[Optional[int], int]] + + def encode(self, full_output: bool = False) -> List[int]: + """ + Returns the OS encoding of the state diff. + """ + state_diff = [ + len(self.contracts), + *list(itertools.chain(*(contract.encode() for contract in self.contracts))), + len(self.classes), + *encode_key_value_pairs(self.classes), + ] + if not full_output: + return compress(data=state_diff) + + return state_diff + + +def encode_key_value_pairs(d: Dict[int, Tuple[Optional[int], int]]) -> List[int]: + """ + Encodes a dictionary of the following format: {key: (optional_prev_value, new_value)}. + """ + res = [] + for key, (prev_value, new_value) in sorted(d.items()): + res.append(key) + if prev_value is not None: + res.append(prev_value) + res.append(new_value) + + return res + @dataclasses.dataclass class OsOutput: @@ -55,11 +160,8 @@ class OsOutput: messages_to_l1: List[int] # Messages from L1 to L2. messages_to_l2: List[int] - # The list of contracts that were changed. - contracts: Optional[List[ContractChanges]] - # The list of classes that were declared. A map from class hash to previous (optional) and new - # compiled class hash. - classes: Optional[Dict[int, Tuple[Optional[int], int]]] + # The state diff. + state_diff: Optional[OsStateDiff] @dataclasses.dataclass @@ -127,26 +229,11 @@ def parse_os_output(output_iter: Iterator[int]) -> OsOutput: messages_to_l2_segment_size = next(output_iter) messages_to_l2 = list(itertools.islice(output_iter, messages_to_l2_segment_size)) - contracts: Optional[List[ContractChanges]] - if use_kzg_da == 0: - # Contract changes. - n_contracts = next(output_iter) - contracts = [] - for _ in range(n_contracts): - contracts.append( - parse_contract_changes(output_iter=output_iter, full_output=full_output) - ) - - # Class changes. - n_classes = next(output_iter) - classes = {} - for _ in range(n_classes): - class_hash = next(output_iter) - prev_compiled_class_hash = next(output_iter) if full_output else None - new_compiled_class_hash = next(output_iter) - classes[class_hash] = (prev_compiled_class_hash, new_compiled_class_hash) - else: - contracts = classes = None + state_diff = ( + parse_os_state_diff(output_iter=output_iter, full_output=full_output) + if use_kzg_da == 0 + else None + ) return OsOutput( initial_root=initial_root, @@ -161,24 +248,59 @@ def parse_os_output(output_iter: Iterator[int]) -> OsOutput: full_output=full_output_int, messages_to_l1=messages_to_l1, messages_to_l2=messages_to_l2, - contracts=contracts, - classes=classes, + state_diff=state_diff, ) +def parse_os_state_diff(output_iter: Iterator[int], full_output: bool) -> OsStateDiff: + """ + Parses the state diff. + """ + if not full_output: + state_diff = decompress(compressed=output_iter) + output_iter = itertools.chain(iter(state_diff), output_iter) + + # Contract changes. + n_contracts = next(output_iter) + contracts = [] + for _ in range(n_contracts): + contracts.append(parse_contract_changes(output_iter=output_iter, full_output=full_output)) + + # Class changes. + n_classes = next(output_iter) + classes = {} + for _ in range(n_classes): + class_hash = next(output_iter) + prev_compiled_class_hash = next(output_iter) if full_output else None + new_compiled_class_hash = next(output_iter) + classes[class_hash] = (prev_compiled_class_hash, new_compiled_class_hash) + + return OsStateDiff(contracts=contracts, classes=classes) + + def parse_contract_changes(output_iter: Iterator[int], full_output: bool) -> ContractChanges: """ Parses contract changes. """ addr = next(output_iter) - class_nonce_n_changes = next(output_iter) - class_nonce, n_changes = divmod(class_nonce_n_changes, 2**64) + nonce_n_changes_two_flags = next(output_iter) + + # Parse flags. + nonce_n_changes_one_flag, class_updated = divmod(nonce_n_changes_two_flags, FLAG_BOUND) + nonce_n_changes, is_n_updates_small = divmod(nonce_n_changes_one_flag, FLAG_BOUND) + + # Parse n_changes. + n_updates_bound = N_UPDATES_SMALL_PACKING_BOUND if is_n_updates_small else N_UPDATES_BOUND + nonce, n_changes = divmod(nonce_n_changes, n_updates_bound) + + # Parse nonces. + prev_nonce: Optional[int] + new_nonce: Optional[int] if full_output: - class_nonce, prev_nonce = divmod(class_nonce, 2**64) + prev_nonce, new_nonce = divmod(nonce, NONCE_BOUND) else: prev_nonce = None - class_updated, new_nonce = divmod(class_nonce, 2**64) - assert class_updated in [0, 1], f"Invalid contract header: {class_nonce_n_changes}" + new_nonce = None if nonce == 0 else nonce if full_output: prev_class_hash = next(output_iter) diff --git a/src/starkware/starknet/core/aggregator/program_hash.json b/src/starkware/starknet/core/aggregator/program_hash.json index d11b7351..9ea9329f 100644 --- a/src/starkware/starknet/core/aggregator/program_hash.json +++ b/src/starkware/starknet/core/aggregator/program_hash.json @@ -1,4 +1,4 @@ { - "program_hash": "0x52e9e4e95cedaf218b0a99bf0044c95335a4da09b2472ebb87d804804b55db9", - "program_hash_with_aggregator_prefix": "0x29134351e8694cf55b54addda8b66eb7614791c3f6e98098e3e37b8e8592926" + "program_hash": "0x51a9bd6e54ba7f11f86c28f9174da39b2652c13ba563665078f40b4ca3a27b", + "program_hash_with_aggregator_prefix": "0x8ef7e2afc1754c5a0a3ca5891c1b1b91db899670a1685c630b9715aee5cd0" } diff --git a/src/starkware/starknet/core/aggregator/utils.py b/src/starkware/starknet/core/aggregator/utils.py index 6c3e485e..ba9f3aec 100644 --- a/src/starkware/starknet/core/aggregator/utils.py +++ b/src/starkware/starknet/core/aggregator/utils.py @@ -56,10 +56,12 @@ def process_os_output( ptr=messages_to_l2_start, arg=os_output.messages_to_l2 ) + state_diff = os_output.state_diff + assert state_diff is not None, "Missing state diff information." + # Handle contract state changes. storage_dict: List[MaybeRelocatable] = [] - assert os_output.contracts is not None, "Missing contract changes information." - for contract in os_output.contracts: + for contract in state_diff.contracts: if contract.addr in self._inner_storage: state_entry = self._inner_storage[contract.addr] else: @@ -86,6 +88,7 @@ def process_os_output( storage_changes.append(new_value) assert contract.new_class_hash is not None, "Missing new class hash." + assert contract.new_nonce is not None, "Missing new nonce." state_entry.add_state_entry( segments=segments, class_hash=contract.new_class_hash, @@ -103,8 +106,7 @@ def process_os_output( # Handle compiled class changes. class_dict = [] - assert os_output.classes is not None, "Missing class changes information." - for class_hash, (prev_compiled_hash, new_compiled_hash) in os_output.classes.items(): + for class_hash, (prev_compiled_hash, new_compiled_hash) in state_diff.classes.items(): assert prev_compiled_hash is not None, "Missing previous compiled class hash." class_dict.append(class_hash) class_dict.append(prev_compiled_hash) diff --git a/src/starkware/starknet/core/os/BUILD b/src/starkware/starknet/core/os/BUILD index a1661037..8b2ba4e4 100644 --- a/src/starkware/starknet/core/os/BUILD +++ b/src/starkware/starknet/core/os/BUILD @@ -29,13 +29,13 @@ cairo_library( srcs = [ "block_context.cairo", "output.cairo", - "//src/starkware/starknet/core/os/data_availability:commitment.cairo", ], hint_deps = [ "//src/starkware/python:starkware_python_utils_lib", ], deps = [ - "//src/starkware/starknet/core/os/data_availability:bls_field", + "//src/starkware/starknet/core/os/data_availability:cairo_compression", + "//src/starkware/starknet/core/os/data_availability:commitment", "//src/starkware/starknet/core/os/state:starknet_os_state_lib", ], ) diff --git a/src/starkware/starknet/core/os/data_availability/BUILD b/src/starkware/starknet/core/os/data_availability/BUILD index 10b480f4..c13e011a 100644 --- a/src/starkware/starknet/core/os/data_availability/BUILD +++ b/src/starkware/starknet/core/os/data_availability/BUILD @@ -61,3 +61,29 @@ cairo_library( "//src/starkware/cairo/common:cairo_common_cairo_lib", ], ) + +py_library( + name = "compression", + srcs = [ + "compression.py", + ], + visibility = ["//visibility:public"], + deps = [ + "//src/starkware/cairo/lang:cairo_constants_lib", + "//src/starkware/python:starkware_python_utils_lib", + ], +) + +cairo_library( + name = "cairo_compression", + srcs = [ + "compression.cairo", + ], + hint_deps = [ + ":compression", + ], + visibility = ["//visibility:public"], + deps = [ + "//src/starkware/cairo/common:cairo_common_cairo_lib", + ], +) diff --git a/src/starkware/starknet/core/os/data_availability/compression.cairo b/src/starkware/starknet/core/os/data_availability/compression.cairo new file mode 100644 index 00000000..4dd76646 --- /dev/null +++ b/src/starkware/starknet/core/os/data_availability/compression.cairo @@ -0,0 +1,432 @@ +from starkware.cairo.common.alloc import alloc +from starkware.cairo.common.copy_indices import copy_indices +from starkware.cairo.common.dict import dict_new, dict_read, dict_squash, dict_update, dict_write +from starkware.cairo.common.dict_access import DictAccess +from starkware.cairo.common.log2_ceil import log2_ceil +from starkware.cairo.common.math import assert_in_range, unsigned_div_rem +from starkware.cairo.common.memcpy import memcpy +from starkware.cairo.common.pow import pow + +const COMPRESSION_VERSION = 0; + +// Holds the number of elements per unique value bucket. +struct UniqueValueBucketLengths { + n_252_bit_elms: felt, + n_125_bit_elms: felt, + n_83_bit_elms: felt, + n_62_bit_elms: felt, + n_31_bit_elms: felt, + n_15_bit_elms: felt, +} + +// Holds decoding info such as the length of each unique values bucket. +struct Header { + version: felt, + // Total data length before compression. + data_len: felt, + unique_value_bucket_lengths: UniqueValueBucketLengths, + // Number of elements in the special bucket that holds pointers of repeating vaules. + n_repeating_values: felt, +} + +// The number of buckets, which includes the unique value buckets and the repeating value bucket. +const TOTAL_N_BUCKETS = UniqueValueBucketLengths.SIZE + 1; + +// Number of bits for each field of the header. +const HEADER_ELM_N_BITS = 20; +// Number of bits encoding each element (per bucket). +const BUCKET_125_N_BITS = 125; +const BUCKET_83_N_BITS = 83; +const BUCKET_62_N_BITS = 62; +const BUCKET_31_N_BITS = 31; +const BUCKET_15_N_BITS = 15; + +// Maximum number of bits that can be packed in one felt. +const MAX_N_BITS_PER_FELT = 251; + +// (Max) Number of elements packed in one felt (per bucket). +const BUCKET_125_N_ELMS_PER_FELT = 2; +const BUCKET_83_N_ELMS_PER_FELT = 3; +const BUCKET_62_N_ELMS_PER_FELT = 4; +const BUCKET_31_N_ELMS_PER_FELT = 8; +const BUCKET_15_N_ELMS_PER_FELT = 16; + +// Compresses the given data into `compressed_dst`. +// Format (packed in felts - see `unpack` functions): +// - Header: felt containing decoding info such as the total data length and the length of each +// bucket (see Header). +// - Buckets: +// - Unique value buckets, concatenated by the order in the header. +// - Repeating value pointers. +// - Bucket indices: bucket index per element - each is the index of the bucket containing the +// corresponding uncompressed element. +// +// The buckets preserve the insertion order. +// Thus, to decompress the data: +// - Unpack the info (unique_values, repeating_value_pointers, bucket_index_per_elm). +// - Build the repeating_values bucket: +// `[unique_values[i] for i in repeating_value_pointers]` +// - Let: +// `all_values = unique_values + repeating_values` +// - Calculate the initial bucket offsets. +// - Reconstruct the data: +// `[all_values[next(bucket_offsets[bucket_index])] for bucket_index in bucket_index_per_elm]` +// where `next()` returns the current value and increments it by 1. +// +// Note: a malicious prover might use a non-optimal compression. +func compress{range_check_ptr, compressed_dst: felt*}(data_start: felt*, data_end: felt*) { + // Guess the compression. + %{ + from starkware.starknet.core.os.data_availability.compression import compress + data = memory.get_range_as_ints(addr=ids.data_start, size=ids.data_end - ids.data_start) + segments.write_arg(ids.compressed_dst, compress(data)) + %} + // Verify the guess by decompressing it onto the original data array. + let (decompressed_end) = decompress{compressed=compressed_dst}(decompressed_dst=data_start); + + // Ensure the entire data was reconstructed. + assert decompressed_end = data_end; + return (); +} + +// Decompresses `compressed` into `decompressed_dst`. +// Returns the decompressed array end. +func decompress{range_check_ptr, compressed: felt*}(decompressed_dst: felt*) -> ( + decompressed_end: felt* +) { + alloc_locals; + let header = unpack_header(); + with_attr error_message("Unsupported compression version.") { + assert header.version = COMPRESSION_VERSION; + } + // Unpack and build `all_values`, which is a concatenation of the unique and repeating values. + let (all_values) = alloc(); + let (n_unique_values) = unpack_unique_values(header=header, unique_values_dst=all_values); + unpack_repeating_values( + n_repeating_values=header.n_repeating_values, + n_unique_values=n_unique_values, + unique_values=all_values, + repeating_values_dst=&all_values[n_unique_values], + ); + + let bucket_index_per_elm = unpack_bucket_index_per_elm(header=header); + + // Reconstruct the data into `decompressed_dst`. + let data_dst = decompressed_dst; + with data_dst { + reconstruct_data( + header=header, all_values=all_values, bucket_index_per_elm=bucket_index_per_elm + ); + } + return (decompressed_end=data_dst); +} + +// Unpacks the first felt of `compressed` into a `Header` struct. +func unpack_header{range_check_ptr, compressed: felt*}() -> Header* { + alloc_locals; + let (local header: Header*) = alloc(); + static_assert Header.SIZE * HEADER_ELM_N_BITS == 180; // <= 251 bits. + unpack_felt( + packed_felt=compressed[0], + elm_bound=2 ** HEADER_ELM_N_BITS, + n_elms=Header.SIZE, + decompressed_dst=cast(header, felt*), + ); + let compressed = &compressed[1]; + return header; +} + +// Unpacks the unique value buckets from `compressed` into `unique_values_dst`. +func unpack_unique_values{range_check_ptr, compressed: felt*}( + header: Header*, unique_values_dst: felt* +) -> (n_unique_values: felt) { + alloc_locals; + let decompressed_dst_start = unique_values_dst; + let decompressed_dst = decompressed_dst_start; + + let bucket_lengths = header.unique_value_bucket_lengths; + static_assert UniqueValueBucketLengths.SIZE == 6; + + // Unpack the 252-bit bucket using memcpy. + local n_252_bit_elms = bucket_lengths.n_252_bit_elms; + memcpy(dst=decompressed_dst, src=compressed, len=n_252_bit_elms); + let compressed = &compressed[n_252_bit_elms]; + let decompressed_dst = &decompressed_dst[n_252_bit_elms]; + + with decompressed_dst { + static_assert BUCKET_125_N_BITS * BUCKET_125_N_ELMS_PER_FELT == 250; // <= 251 bits. + unpack_felts( + n_elms=bucket_lengths.n_125_bit_elms, + elm_bound=2 ** BUCKET_125_N_BITS, + n_elms_per_felt=BUCKET_125_N_ELMS_PER_FELT, + ); + static_assert BUCKET_83_N_BITS * BUCKET_83_N_ELMS_PER_FELT == 249; // <= 251 bits. + unpack_felts( + n_elms=bucket_lengths.n_83_bit_elms, + elm_bound=2 ** BUCKET_83_N_BITS, + n_elms_per_felt=BUCKET_83_N_ELMS_PER_FELT, + ); + static_assert BUCKET_62_N_BITS * BUCKET_62_N_ELMS_PER_FELT == 248; // <= 251 bits. + unpack_felts( + n_elms=bucket_lengths.n_62_bit_elms, + elm_bound=2 ** BUCKET_62_N_BITS, + n_elms_per_felt=BUCKET_62_N_ELMS_PER_FELT, + ); + static_assert BUCKET_31_N_BITS * BUCKET_31_N_ELMS_PER_FELT == 248; // <= 251 bits. + unpack_felts( + n_elms=bucket_lengths.n_31_bit_elms, + elm_bound=2 ** BUCKET_31_N_BITS, + n_elms_per_felt=BUCKET_31_N_ELMS_PER_FELT, + ); + static_assert BUCKET_15_N_BITS * BUCKET_15_N_ELMS_PER_FELT == 240; // <= 251 bits. + unpack_felts( + n_elms=bucket_lengths.n_15_bit_elms, + elm_bound=2 ** BUCKET_15_N_BITS, + n_elms_per_felt=BUCKET_15_N_ELMS_PER_FELT, + ); + } + return (n_unique_values=decompressed_dst - decompressed_dst_start); +} + +// Unpacks the repeating value pointers from `compressed`, and writes the actual +// (repeating) values to `repeating_values_dst`. +func unpack_repeating_values{range_check_ptr, compressed: felt*}( + n_repeating_values: felt, + n_unique_values: felt, + unique_values: felt*, + repeating_values_dst: felt*, +) { + alloc_locals; + let pointers = unpack_repeating_value_pointers( + n_repeating_values=n_repeating_values, n_unique_values=n_unique_values + ); + // Reconstruct the repeating values. + // Note that `unpack_repeating_value_pointers` guarantees that each pointer is in the + // unique_values array range. + copy_indices( + dst=repeating_values_dst, src=unique_values, indices=pointers, len=n_repeating_values + ); + return (); +} + +// Unpacks the repeating value pointers from `compressed`. +// Each pointer points to a value in the unique_values and corresponds to the original data element +// at the same position. +// The function guarantees that: each pointer is in range [0, n_unique_values). +func unpack_repeating_value_pointers{range_check_ptr, compressed: felt*}( + n_repeating_values: felt, n_unique_values: felt +) -> felt* { + alloc_locals; + let (local pointers: felt*) = alloc(); + + // The pointer bound (unlike the fixed bucket bounds) is dynamically set as the number of + // unique values. + let pointer_bound = n_unique_values; + let n_elms_per_felt = get_n_elms_per_felt(elm_bound=pointer_bound); + + let decompressed_dst = pointers; + with decompressed_dst { + unpack_felts( + n_elms=n_repeating_values, elm_bound=pointer_bound, n_elms_per_felt=n_elms_per_felt + ); + } + return pointers; +} + +// Unpacks the uncompressed-data bucket indices from `compressed`. +// Each index is of the bucket containing the corresponding data element. +// The function guarantees that: each pointer is in range [0, TOTAL_N_BUCKETS). +func unpack_bucket_index_per_elm{range_check_ptr, compressed: felt*}(header: Header*) -> felt* { + alloc_locals; + let (local bucket_index_per_elm: felt*) = alloc(); + let n_elms_per_felt = get_n_elms_per_felt(elm_bound=TOTAL_N_BUCKETS); + + let decompressed_dst = bucket_index_per_elm; + with decompressed_dst { + unpack_felts( + n_elms=header.data_len, elm_bound=TOTAL_N_BUCKETS, n_elms_per_felt=n_elms_per_felt + ); + } + return bucket_index_per_elm; +} + +// Reconstructs the data into `data_dst`. +func reconstruct_data{range_check_ptr, data_dst: felt*}( + header: Header*, all_values: felt*, bucket_index_per_elm: felt* +) { + alloc_locals; + // Calculate the initial offset (in `all_values`) of each bucket. + // Unique value buckets. + let bucket0_offset = 0; + local bucket1_offset = bucket0_offset + header.unique_value_bucket_lengths.n_252_bit_elms; + local bucket2_offset = bucket1_offset + header.unique_value_bucket_lengths.n_125_bit_elms; + local bucket3_offset = bucket2_offset + header.unique_value_bucket_lengths.n_83_bit_elms; + local bucket4_offset = bucket3_offset + header.unique_value_bucket_lengths.n_62_bit_elms; + local bucket5_offset = bucket4_offset + header.unique_value_bucket_lengths.n_31_bit_elms; + // Repeating values bucket. + local bucket6_offset = bucket5_offset + header.unique_value_bucket_lengths.n_15_bit_elms; + + // Create a dictionary from bucket index (0, 1, ..., 6) to the current offset in `all_values`. + %{ initial_dict = {bucket_index: 0 for bucket_index in range(ids.TOTAL_N_BUCKETS)} %} + let (local dict_ptr_start) = dict_new(); + let dict_ptr = dict_ptr_start; + with dict_ptr { + // Initialize the bucket offsets. + static_assert TOTAL_N_BUCKETS == 7; + dict_write(key=0, new_value=bucket0_offset); + dict_write(key=1, new_value=bucket1_offset); + dict_write(key=2, new_value=bucket2_offset); + dict_write(key=3, new_value=bucket3_offset); + dict_write(key=4, new_value=bucket4_offset); + dict_write(key=5, new_value=bucket5_offset); + dict_write(key=6, new_value=bucket6_offset); + + // Reconstruct the data. + reconstruct_data_inner( + data_len=header.data_len, + all_values=all_values, + bucket_index_per_elm=bucket_index_per_elm, + ); + + // Verify there was no out-of-bound access to `all_values` array by checking the bucket + // offset final values. + dict_update(key=0, prev_value=bucket1_offset, new_value=bucket1_offset); + dict_update(key=1, prev_value=bucket2_offset, new_value=bucket2_offset); + dict_update(key=2, prev_value=bucket3_offset, new_value=bucket3_offset); + dict_update(key=3, prev_value=bucket4_offset, new_value=bucket4_offset); + dict_update(key=4, prev_value=bucket5_offset, new_value=bucket5_offset); + dict_update(key=5, prev_value=bucket6_offset, new_value=bucket6_offset); + tempvar all_values_len = bucket6_offset + header.n_repeating_values; + dict_update(key=6, prev_value=all_values_len, new_value=all_values_len); + } + // Verify the dict reads by squashing the updates. + // Note that there is no need to verify the initial values: + // the dict keys are contained in [0, 1, ... TOTAL_N_BUCKETS - 1] since `unpack_pointers` + // guarantees that each pointer is in this range, and they were all set explicitly above. + dict_squash(dict_accesses_start=dict_ptr_start, dict_accesses_end=dict_ptr); + return (); +} + +// A helper for `reconstruct_data`. +// The given dict_ptr holds the bucket offsets. +func reconstruct_data_inner{dict_ptr: DictAccess*, data_dst: felt*}( + data_len: felt, all_values: felt*, bucket_index_per_elm: felt* +) { + if (data_len == 0) { + return (); + } + + let bucket_index = bucket_index_per_elm[0]; + // Guess the offset to the all_values array - it is validated by the `dict_update` below. + tempvar prev_offset; + %{ + dict_tracker = __dict_manager.get_tracker(ids.dict_ptr) + ids.prev_offset = dict_tracker.data[ids.bucket_index] + %} + + // Advance the bucket offset. + dict_update(key=bucket_index, prev_value=prev_offset, new_value=prev_offset + 1); + + assert data_dst[0] = all_values[prev_offset]; + let data_dst = &data_dst[1]; + + return reconstruct_data_inner( + data_len=data_len - 1, all_values=all_values, bucket_index_per_elm=&bucket_index_per_elm[1] + ); +} + +// Returns the number of elements (smaller than `elm_bound`) that can be packed in one felt. +// The result will satisfy: max(log2(elm_bound), 1) * n_elms_per_felt <= 251. +// +// Note: this calculation may return a sub-optimal result when `elm_bound` is not a power of two: +// it returns: 251 // max(log2_ceil(elm_bound), 1). +func get_n_elms_per_felt{range_check_ptr}(elm_bound: felt) -> felt { + alloc_locals; + // If elm_bound is 0 or 1, return MAX_N_BITS_PER_FELT. + if (elm_bound * (elm_bound - 1) == 0) { + return MAX_N_BITS_PER_FELT; + } + let n_bits_per_elm = log2_ceil(value=elm_bound); + let (n_elms_per_felt, _) = unsigned_div_rem(value=MAX_N_BITS_PER_FELT, div=n_bits_per_elm); + return n_elms_per_felt; +} + +// Unpacks an array of `n_elms` from `compressed`, +// packed in `ceil(n_elms/n_elms_per_felt)` felts, into `decompressed_dst`. +// +// Assumptions: +// - elm_bound is in range [0, 2**128). +// - elm_bound ** n_elms_per_felt <= 2**251. +func unpack_felts{range_check_ptr, compressed: felt*, decompressed_dst: felt*}( + n_elms: felt, elm_bound: felt, n_elms_per_felt: felt +) { + alloc_locals; + let (n_full_felts, local n_remaining_elms) = unsigned_div_rem( + value=n_elms, div=n_elms_per_felt + ); + unpack_felts_given_n_packed_felts( + n_packed_felts=n_full_felts, elm_bound=elm_bound, n_elms_per_felt=n_elms_per_felt + ); + if (n_remaining_elms != 0) { + unpack_felts_given_n_packed_felts( + n_packed_felts=1, elm_bound=elm_bound, n_elms_per_felt=n_remaining_elms + ); + return (); + } + return (); +} + +// Unpacks `n_packed_felts` from `compressed` into `decompressed_dst`, +// where each packed felt contains `n_elms_per_felt` elements. +// Assumptions: see `unpack_felts`. +func unpack_felts_given_n_packed_felts{range_check_ptr, compressed: felt*, decompressed_dst: felt*}( + n_packed_felts: felt, elm_bound: felt, n_elms_per_felt: felt +) { + if (n_packed_felts == 0) { + return (); + } + + unpack_felt( + packed_felt=compressed[0], + elm_bound=elm_bound, + n_elms=n_elms_per_felt, + decompressed_dst=decompressed_dst, + ); + let compressed = &compressed[1]; + let decompressed_dst = &decompressed_dst[n_elms_per_felt]; + return unpack_felts_given_n_packed_felts( + n_packed_felts=n_packed_felts - 1, elm_bound=elm_bound, n_elms_per_felt=n_elms_per_felt + ); +} + +// Unpacks `n_elms` from the given felt into `decompressed_dst`. +// The first element is at the least significant bits. +// The function guarantees that: packed_felt < elm_bound ** n_elms. +// Assumptions: see `unpack_felts`. +func unpack_felt{range_check_ptr}( + packed_felt: felt, elm_bound: felt, n_elms: felt, decompressed_dst: felt* +) { + if (n_elms == 0) { + // Verify that there are no more elements to unpack. + // This check also ensures that the initial `packed_felt` is equal to + // current0 + current1 * bound + current2 * bound**2 + ... + current(n-1) * bound**(n-1). + assert packed_felt = 0; + return (); + } + + %{ memory[ids.decompressed_dst] = ids.packed_felt % ids.elm_bound %} + tempvar current = decompressed_dst[0]; + + // Verify element is in range [0, elm_bound). + assert [range_check_ptr] = current; + assert [range_check_ptr + 1] = elm_bound - current - 1; + let range_check_ptr = range_check_ptr + 2; + + let packed_suffix = (packed_felt - current) / elm_bound; + return unpack_felt( + packed_felt=packed_suffix, + elm_bound=elm_bound, + n_elms=n_elms - 1, + decompressed_dst=&decompressed_dst[1], + ); +} diff --git a/src/starkware/starknet/core/os/data_availability/compression.py b/src/starkware/starknet/core/os/data_availability/compression.py new file mode 100644 index 00000000..334043f3 --- /dev/null +++ b/src/starkware/starknet/core/os/data_availability/compression.py @@ -0,0 +1,288 @@ +import itertools +from itertools import count +from typing import Dict, Iterator, List, Tuple + +from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME +from starkware.python.math_utils import log2_ceil +from starkware.python.utils import div_ceil, iter_blockify, safe_zip + +COMPRESSION_VERSION = 0 + +# Max number of bits that can be packed in a single felt. +MAX_N_BITS = 251 + +# Number of bits encoding each element (per bucket). +N_BITS_PER_BUCKET = [252, 125, 83, 62, 31, 15] +N_UNIQUE_VALUE_BUCKETS = len(N_BITS_PER_BUCKET) +# Number of buckets, including the repeating values bucket. +TOTAL_N_BUCKETS = N_UNIQUE_VALUE_BUCKETS + 1 + +# Version, data length, bucket lengths. +HEADER_LEN = 1 + 1 + TOTAL_N_BUCKETS + +HEADER_ELM_N_BITS = 20 +HEADER_ELM_BOUND = 2**HEADER_ELM_N_BITS + + +class UniqueValueBucket: + """ + A set-like data structure that preserves the insertion order. + Holds values of `n_bits` bit length or less. + """ + + def __init__(self, n_bits: int): + self.n_bits = n_bits + # A mapping from value to its insertion order. + self._value_to_index: Dict[int, int] = {} + + def __contains__(self, value: int) -> bool: + return value in self._value_to_index + + def __len__(self) -> int: + return len(self._value_to_index) + + def add(self, value: int): + if value in self: + return + next_index = len(self._value_to_index) + self._value_to_index[value] = next_index + + def get_index(self, value: int) -> int: + return self._value_to_index[value] + + def pack_in_felts(self) -> List[int]: + # The values should be sorted by the insertion order, but this is given for free since + # Python dict preserves order. + values = list(self._value_to_index.keys()) + return pack_in_felts(elms=values, elm_bound=2**self.n_bits) + + +class CompressionSet: + """ + A utility class for compression. + Used to manage and store the unique values in seperate buckets according to their bit length. + """ + + def __init__(self, n_bits_per_bucket: List[int]): + self._buckets = [UniqueValueBucket(n_bits=n_bits) for n_bits in n_bits_per_bucket] + # Index by the given order. + indexed_buckets = list(enumerate(self._buckets)) + # Sort by the number of bits (cached for the `update` function). + self._sorted_buckets = sorted( + indexed_buckets, key=lambda indexed_bucket: indexed_bucket[1].n_bits + ) + + # A special bucket that holds locations of the unique values in the buckets, in the + # following form: (bucket_index, index_in_bucket). + # Each corresponds to a repeating value and is the location of the first (unique) copy. + self._repeating_value_locations: List[Tuple[int, int]] = [] + + # Maps each item to the bucket it was assigned, including the repeating values bucket. + self._bucket_index_per_elm: List[int] = [] + + self.finalized = False + + @property + def repeating_values_bucket_index(self) -> int: + return len(self._buckets) + + def update(self, values: List[int]): + assert not self.finalized, "Cannot add values after finalizing." + for value in values: + for bucket_index, bucket in self._sorted_buckets: + if value.bit_length() <= bucket.n_bits: + if value in bucket: + # Repeated value; add the location of the first added copy. + self._repeating_value_locations.append( + (bucket_index, bucket.get_index(value)) + ) + self._bucket_index_per_elm.append(self.repeating_values_bucket_index) + else: + # First appearance of this value. + bucket.add(value) + self._bucket_index_per_elm.append(bucket_index) + break + else: + raise ValueError(f"{value} is too large.") + + def get_unique_value_bucket_lengths(self) -> List[int]: + return [len(bucket) for bucket in self._buckets] + + def get_repeating_value_bucket_length(self) -> int: + return len(self._repeating_value_locations) + + def get_repeating_value_pointers(self) -> List[int]: + """ + Returns a list of pointers corresponding to the repeating values. + The pointers point to the chained unique value buckets. + """ + assert self.finalized, "Cannot get pointers before finalizing." + unique_value_bucket_lengths = self.get_unique_value_bucket_lengths() + bucket_offsets = get_bucket_offsets(bucket_lengths=unique_value_bucket_lengths) + return [ + bucket_offsets[bucket_index] + index_in_bucket + for bucket_index, index_in_bucket in self._repeating_value_locations + ] + + def get_bucket_index_per_elm(self) -> List[int]: + """ + Returns the bucket indices of the added values. + """ + assert self.finalized, "Cannot get bucket_index_per_elm before finalizing." + return self._bucket_index_per_elm + + def pack_unique_values(self) -> List[int]: + """ + Packs the unique value buckets and chains them. + """ + assert self.finalized, "Cannot pack before finalizing." + return list(itertools.chain(*(bucket.pack_in_felts() for bucket in self._buckets))) + + def finalize(self): + self.finalized = True + + +def compress(data: List[int]) -> List[int]: + """ + Compresses the given data. + The result is a list of felts. + """ + assert len(data) < HEADER_ELM_BOUND, "Data is too long." + compression_set = CompressionSet(n_bits_per_bucket=N_BITS_PER_BUCKET) + compression_set.update(data) + compression_set.finalize() + + bucket_index_per_elm = compression_set.get_bucket_index_per_elm() + + unique_value_bucket_lengths = compression_set.get_unique_value_bucket_lengths() + n_unique_values = sum(unique_value_bucket_lengths) + header = [ + COMPRESSION_VERSION, + len(data), + *unique_value_bucket_lengths, + compression_set.get_repeating_value_bucket_length(), + ] + packed_header = pack_in_felt(elms=header, elm_bound=HEADER_ELM_BOUND) + packed_repeating_value_pointers = pack_in_felts( + elms=compression_set.get_repeating_value_pointers(), elm_bound=n_unique_values + ) + packed_bucket_index_per_elm = pack_in_felts( + elms=bucket_index_per_elm, elm_bound=TOTAL_N_BUCKETS + ) + + return [ + packed_header, + *compression_set.pack_unique_values(), + *packed_repeating_value_pointers, + *packed_bucket_index_per_elm, + ] + + +def decompress(compressed: Iterator[int]) -> List[int]: + """ + Decompresses the given compressed data. + """ + + def unpack_chunk(n_elms: int, elm_bound: int) -> List[int]: + n_packed_felts = div_ceil(n_elms, get_n_elms_per_felt(elm_bound)) + compressed_chunk = list(itertools.islice(compressed, n_packed_felts)) + return unpack_felts(compressed=compressed_chunk, elm_bound=elm_bound, n_elms=n_elms) + + header = unpack_chunk(n_elms=HEADER_LEN, elm_bound=HEADER_ELM_BOUND) + # Unpack header. + version = header[0] + assert version == COMPRESSION_VERSION, f"Unsupported compression version {version}." + data_len = header[1] + unique_value_bucket_lengths = header[2 : 2 + N_UNIQUE_VALUE_BUCKETS] + (n_repeating_values,) = header[2 + N_UNIQUE_VALUE_BUCKETS :] + + # Unpack buckets: unique values and repeating values. + unique_values = list( + itertools.chain( + *( + unpack_chunk(n_elms=bucket_length, elm_bound=2**n_bits) + for bucket_length, n_bits in safe_zip( + unique_value_bucket_lengths, N_BITS_PER_BUCKET + ) + ) + ) + ) + repeating_value_pointers = unpack_chunk(n_elms=n_repeating_values, elm_bound=len(unique_values)) + repeating_values = [unique_values[i] for i in repeating_value_pointers] + all_values = unique_values + repeating_values + + # Unpack the bucket indices. + bucket_index_per_elm = unpack_chunk(n_elms=data_len, elm_bound=TOTAL_N_BUCKETS) + + # Get the starting position of each bucket. + all_bucket_lengths = [*unique_value_bucket_lengths, n_repeating_values] + bucket_offsets = get_bucket_offsets(bucket_lengths=all_bucket_lengths) + + # Reconstruct the data. + bucket_offset_iterators = [count(start=offset) for offset in bucket_offsets] + return [ + all_values[next(bucket_offset_iterators[bucket_index])] + for bucket_index in bucket_index_per_elm + ] + + +def pack_in_felts(elms: List[int], elm_bound: int) -> List[int]: + """ + Packs the given elements in multiple felts. + """ + assert all(elm < elm_bound for elm in elms), "Element out of bound." + return [ + pack_in_felt(elms=chunk, elm_bound=elm_bound) + for chunk in iter_blockify(elms, chunk_size=get_n_elms_per_felt(elm_bound)) + ] + + +def unpack_felts(compressed: List[int], elm_bound: int, n_elms: int) -> List[int]: + """ + Unpacks the given packed felts into an array of `n_elms` elements. + """ + n_elms_per_felt = get_n_elms_per_felt(elm_bound) + res = itertools.chain( + *( + unpack_felt(packed_felt=packed_felt, elm_bound=elm_bound, n_elms=n_elms_per_felt) + for packed_felt in compressed + ) + ) + # Remove trailing zeros. + return list(res)[:n_elms] + + +def pack_in_felt(elms: List[int], elm_bound: int) -> int: + """ + Packs the given elements in a single felt. + The first element is at the least significant bits. + """ + res = sum(elm * (elm_bound**i) for i, elm in enumerate(elms)) + assert res < DEFAULT_PRIME, "Out of bound packing." + return res + + +def unpack_felt(packed_felt: int, elm_bound: int, n_elms: int) -> List[int]: + res = [] + for _ in range(n_elms): + packed_felt, current = divmod(packed_felt, elm_bound) + res.append(current) + + assert packed_felt == 0 + return res + + +def get_n_elms_per_felt(elm_bound: int) -> int: + if elm_bound <= 1: + return MAX_N_BITS + if elm_bound > 2**MAX_N_BITS: + return 1 + + return MAX_N_BITS // log2_ceil(elm_bound) + + +def get_bucket_offsets(bucket_lengths: List[int]) -> List[int]: + """ + Returns the starting position of each bucket given their lengths. + """ + return [sum(bucket_lengths[:i]) for i in range(len(bucket_lengths))] diff --git a/src/starkware/starknet/core/os/output.cairo b/src/starkware/starknet/core/os/output.cairo index 2a399685..4427700f 100644 --- a/src/starkware/starknet/core/os/output.cairo +++ b/src/starkware/starknet/core/os/output.cairo @@ -8,6 +8,7 @@ from starkware.starknet.core.os.data_availability.commitment import ( Uint384, compute_os_kzg_commitment_info, ) +from starkware.starknet.core.os.data_availability.compression import compress from starkware.starknet.core.os.state.commitment import CommitmentUpdate from starkware.starknet.core.os.state.output import ( output_contract_class_da_changes, @@ -74,12 +75,15 @@ func serialize_os_output{range_check_ptr, poseidon_ptr: PoseidonBuiltin*, output local use_kzg_da = os_output.header.use_kzg_da; local full_output = os_output.header.full_output; + let compress_state_updates = 1 - full_output; // Compute the data availability segment. local state_updates_start: felt*; let state_updates_ptr = state_updates_start; %{ - if ids.use_kzg_da: + # `use_kzg_da` is used in a hint in `process_data_availability`. + use_kzg_da = ids.use_kzg_da + if use_kzg_da or ids.compress_state_updates: ids.state_updates_start = segments.add() else: # Assign a temporary segment, to be relocated into the output segment. @@ -104,9 +108,15 @@ func serialize_os_output{range_check_ptr, poseidon_ptr: PoseidonBuiltin*, output serialize_output_header(os_output_header=os_output.header); + let (local da_start, local da_end) = process_data_availability( + state_updates_start=state_updates_start, + state_updates_end=state_updates_ptr, + compress_state_updates=compress_state_updates, + ); + if (use_kzg_da != 0) { let os_kzg_commitment_info = compute_os_kzg_commitment_info( - state_updates_start=state_updates_start, state_updates_end=state_updates_ptr + state_updates_start=da_start, state_updates_end=da_end ); serialize_os_kzg_commitment_info(os_kzg_commitment_info=os_kzg_commitment_info); tempvar poseidon_ptr = poseidon_ptr; @@ -126,9 +136,7 @@ func serialize_os_output{range_check_ptr, poseidon_ptr: PoseidonBuiltin*, output ); if (use_kzg_da == 0) { - serialize_data_availability( - state_updates_start=state_updates_start, state_updates_end=state_updates_ptr - ); + serialize_data_availability(da_start=da_start, da_end=da_end); } return (); @@ -211,14 +219,36 @@ func serialize_os_kzg_commitment_info{output_ptr: felt*}( return (); } -func serialize_data_availability{output_ptr: felt*}( - state_updates_start: felt*, state_updates_end: felt* -) { - let da_start = output_ptr; +// Returns the final data-availability to output. +func process_data_availability{range_check_ptr}( + state_updates_start: felt*, state_updates_end: felt*, compress_state_updates: felt +) -> (da_start: felt*, da_end: felt*) { + if (compress_state_updates == 0) { + return (da_start=state_updates_start, da_end=state_updates_end); + } + + alloc_locals; + + // Output a compression of the state updates. + local compressed_start: felt*; + %{ + if use_kzg_da: + ids.compressed_start = segments.add() + else: + # Assign a temporary segment, to be relocated into the output segment. + ids.compressed_start = segments.add_temp_segment() + %} + let compressed_dst = compressed_start; + with compressed_dst { + compress(data_start=state_updates_start, data_end=state_updates_end); + } + return (da_start=compressed_start, da_end=compressed_dst); +} - // Relocate 'state_updates_segment' to the correct place in the output segment. - relocate_segment(src_ptr=state_updates_start, dest_ptr=output_ptr); - let output_ptr = state_updates_end; +func serialize_data_availability{output_ptr: felt*}(da_start: felt*, da_end: felt*) { + // Relocate data availability segment to the correct place in the output segment. + relocate_segment(src_ptr=da_start, dest_ptr=output_ptr); + let output_ptr = da_end; %{ from starkware.python.math_utils import div_ceil diff --git a/src/starkware/starknet/core/os/program_hash.json b/src/starkware/starknet/core/os/program_hash.json index 8c3a1dbc..4b29234a 100644 --- a/src/starkware/starknet/core/os/program_hash.json +++ b/src/starkware/starknet/core/os/program_hash.json @@ -1,3 +1,3 @@ { - "program_hash": "0x1e324682835e60c4779a683b32713504aed894fd73842f7d05b18e7bd29cd70" + "program_hash": "0x54d3603ed14fb897d0925c48f26330ea9950bd4ca95746dad4f7f09febffe0d" } diff --git a/src/starkware/starknet/core/os/state/output.cairo b/src/starkware/starknet/core/os/state/output.cairo index f1df5cb3..af4e31b0 100644 --- a/src/starkware/starknet/core/os/state/output.cairo +++ b/src/starkware/starknet/core/os/state/output.cairo @@ -11,18 +11,24 @@ from starkware.starknet.core.os.state.commitment import StateEntry // * The contract address (1 word). // * 1 word with the following info: // * A flag indicating whether the class hash was updated, -// * the number of entry updates, +// * A flag indicating whether the number of updates is small (< 256), +// * the number of entry updates (packed according to the previous flag), +// * the new nonce (if `full_output` is used or if it was updated), // * the old nonce (if `full_output` is used), -// * and the new nonce: -// +-------+-----------+-----------+ LSB -// | flag | new nonce | n_updates | -// | 1 bit | 64 bits | 64 bits | -// +-------+-----------+-----------+ +// +-------------+----------------+------------+ LSB +// | n_updates | n_updates_flag | class_flag | +// | 8 or 64 bit | 1 bit | 1 bit | +// +-------------+----------------+------------+ +// OR (if the nonce was updated) +// +-----------+-------------+----------------+------------+ LSB +// | new_nonce | n_updates | n_updates_flag | class_flag | +// | 64 bits | 8 or 64 bit | 1 bit | 1 bit | +// +-----------+-------------+----------------+------------+ // OR (if `full_output` is used) -// +-------+-----------+-----------+-----------+ LSB -// | flag | new nonce | old nonce | n_updates | -// | 1 bit | 64 bits | 64 bits | 64 bits | -// +-------+-----------+-----------+-----------+ +// +-----------+-----------+-------------+----------------+------------+ LSB +// | old_nonce | new_nonce | n_updates | n_updates_flag | class_flag | +// | 64 bits | 64 bits | 8 or 64 bit | 1 bit | 1 bit | +// +-----------+-----------+-------------+----------------+------------+ // // * The old class hash for this contract (1 word, if `full_output` is used). // * The new class hash for this contract (1 word, if it was updated or `full_output` is used). @@ -40,6 +46,8 @@ from starkware.starknet.core.os.state.commitment import StateEntry // A bound on the number of contract state entry updates in a contract. const N_UPDATES_BOUND = 2 ** 64; +// Number of updates that is lower than this bound will be packed more efficiently in the header. +const N_UPDATES_SMALL_PACKING_BOUND = 2 ** 8; // A bound on the nonce of a contract. const NONCE_BOUND = 2 ** 64; @@ -139,28 +147,28 @@ func output_contract_state{range_check_ptr, state_updates_ptr: felt*}( ) { alloc_locals; - // Make room for number of state updates. - let output_n_updates = [state_updates_ptr]; + // Make room for number of modified contracts. + let output_n_modified_contracts = [state_updates_ptr]; let state_updates_ptr = state_updates_ptr + 1; - let n_actual_state_changes = 0; + let n_modified_contracts = 0; - with n_actual_state_changes { + with n_modified_contracts { output_contract_state_inner( n_contract_state_changes=n_contract_state_changes, state_changes=contract_state_changes_start, full_output=full_output, ); } - // Write number of state updates. - assert output_n_updates = n_actual_state_changes; + // Write number of modified contracts. + assert output_n_modified_contracts = n_modified_contracts; return (); } // Helper function for `output_contract_state()`. // -// Increases `n_actual_state_changes` by the number of contracts with state changes. -func output_contract_state_inner{range_check_ptr, state_updates_ptr: felt*, n_actual_state_changes}( +// Increases `n_modified_contracts` by the number of contracts with state changes. +func output_contract_state_inner{range_check_ptr, state_updates_ptr: felt*, n_modified_contracts}( n_contract_state_changes: felt, state_changes: DictAccess*, full_output: felt ) { if (n_contract_state_changes == 0) { @@ -221,27 +229,45 @@ func output_contract_state_inner{range_check_ptr, state_updates_ptr: felt*, n_ac assert contract_header[0] = state_changes.key; // Write the second word of the header. - // Write 'was class update' flag. - let value = was_class_updated; - // Write the new nonce. + // Handle the nonce. assert_nn_le(new_state_nonce, NONCE_BOUND - 1); - let value = value * NONCE_BOUND + new_state_nonce; - // Write the old nonce (if `full_output` is used). if (full_output == 0) { - tempvar value = value; + if (prev_state_nonce != new_state_nonce) { + tempvar value = new_state_nonce; + } else { + tempvar value = 0; + } tempvar range_check_ptr = range_check_ptr; } else { + // Full output - write the new and old nonces. assert_nn_le(prev_state_nonce, NONCE_BOUND - 1); - tempvar value = value * NONCE_BOUND + prev_state_nonce; + tempvar value = prev_state_nonce * NONCE_BOUND + new_state_nonce; tempvar range_check_ptr = range_check_ptr; } + // Write the number of updates. - assert_nn_le(n_actual_updates, N_UPDATES_BOUND - 1); - let value = value * N_UPDATES_BOUND + n_actual_updates; + local is_n_updates_small; + %{ ids.is_n_updates_small = ids.n_actual_updates < ids.N_UPDATES_SMALL_PACKING_BOUND %} + // Verify that the guessed value is 0 or 1. + assert is_n_updates_small * is_n_updates_small = is_n_updates_small; + if (is_n_updates_small != 0) { + tempvar n_updates_bound = N_UPDATES_SMALL_PACKING_BOUND; + } else { + tempvar n_updates_bound = N_UPDATES_BOUND; + } + assert_nn_le(n_actual_updates, n_updates_bound - 1); + let value = value * n_updates_bound + n_actual_updates; + + // Write 'is_n_updates_small' flag. + let value = value * 2 + is_n_updates_small; + + // Write 'was class updated' flag. + let value = value * 2 + was_class_updated; + assert contract_header[1] = value; let state_updates_ptr = cast(storage_diff, felt*); - let n_actual_state_changes = n_actual_state_changes + 1; + let n_modified_contracts = n_modified_contracts + 1; return output_contract_state_inner( n_contract_state_changes=n_contract_state_changes - 1, diff --git a/src/starkware/starknet/definitions/BUILD b/src/starkware/starknet/definitions/BUILD index 4dc9002e..e8734136 100644 --- a/src/starkware/starknet/definitions/BUILD +++ b/src/starkware/starknet/definitions/BUILD @@ -41,6 +41,7 @@ py_library( srcs = [ "chain_ids.py", "general_config.py", + "overridable_versioned_constants.py", ], data = [ "general_config.yml", @@ -53,7 +54,6 @@ py_library( "//src/starkware/python:starkware_python_utils_lib", "//src/starkware/starkware_utils:starkware_config_utils_lib", "//src/starkware/starkware_utils:starkware_dataclasses_utils_lib", - requirement("marshmallow"), requirement("marshmallow_dataclass"), ], ) diff --git a/src/starkware/starknet/definitions/chain_ids.py b/src/starkware/starknet/definitions/chain_ids.py index c5b5678d..3f413bed 100644 --- a/src/starkware/starknet/definitions/chain_ids.py +++ b/src/starkware/starknet/definitions/chain_ids.py @@ -1,7 +1,10 @@ from enum import Enum -from typing import Set +from typing import Dict, Set from starkware.python.utils import from_bytes +from starkware.starknet.definitions.overridable_versioned_constants import ( + OverridableVersionedConstants, +) FEE_TOKEN_ADDRESS = 0x04718F5A0FC34CC1AF16A1CDEE98FFB20C31F5CD61D6AB07201858F4287C938D DEPRECATED_FEE_TOKEN_ADDRESS = 0x49D36570D4E46F48E99674BD3FCC84644DDD6B96F7C741B1562B82F9E004DC7 @@ -25,3 +28,5 @@ def is_private(self) -> bool: CHAIN_ID_TO_DEPRECATED_FEE_TOKEN_ADDRESS = { chain_enum: DEPRECATED_FEE_TOKEN_ADDRESS for chain_enum in StarknetChainId } + +CHAIN_ID_TO_PRIVATE_VERSIONED_CONSTANTS: Dict[StarknetChainId, OverridableVersionedConstants] = {} diff --git a/src/starkware/starknet/definitions/constants.py b/src/starkware/starknet/definitions/constants.py index e127d159..f83c5c9e 100644 --- a/src/starkware/starknet/definitions/constants.py +++ b/src/starkware/starknet/definitions/constants.py @@ -110,6 +110,8 @@ L1_TO_L2_MSG_HEADER_SIZE = 5 L2_TO_L1_MSG_HEADER_SIZE = 3 CLASS_UPDATE_SIZE = 1 +# Header, unique values (at least one felt), pointers (at least one felt). +COMPRESSED_DA_SEGMENT_MIN_LENGTH = 3 # OS reserved contract addresses. ORIGIN_ADDRESS = 0 diff --git a/src/starkware/starknet/definitions/error_codes.py b/src/starkware/starknet/definitions/error_codes.py index 371bf5e6..7b3ed02c 100644 --- a/src/starkware/starknet/definitions/error_codes.py +++ b/src/starkware/starknet/definitions/error_codes.py @@ -78,6 +78,7 @@ class StarknetErrorCode(ErrorCode): TRANSACTION_LIMIT_EXCEEDED = auto() TRANSACTION_NOT_FOUND = auto() UNAUTHORIZED_ACTION_ON_VALIDATE = auto() + UNAUTHORIZED_DECLARE = auto() UNAUTHORIZED_ENTRY_POINT_FOR_INVOKE = auto() UNDECLARED_CLASS = auto() UNEXPECTED_FAILURE = auto() diff --git a/src/starkware/starknet/definitions/fields.py b/src/starkware/starknet/definitions/fields.py index 73900b33..a2ac8ea5 100644 --- a/src/starkware/starknet/definitions/fields.py +++ b/src/starkware/starknet/definitions/fields.py @@ -30,6 +30,7 @@ BytesAsHex, EnumField, FrozenDictField, + StrictOptionalInteger, StrictRequiredInteger, VariadicLengthTupleField, ) @@ -530,7 +531,7 @@ def validate_resource_bounds(resource_bounds: ResourceBoundsMapping) -> bool: ) -optional_state_diff_hash_metadata = dict( +optional_state_diff_key_metadata = dict( marshmallow_field=BytesAsHex(required=False, load_default=None) ) @@ -571,12 +572,10 @@ def validate_resource_bounds(resource_bounds: ResourceBoundsMapping) -> bool: # General config. -invoke_tx_n_steps_metadata = dict( - marshmallow_field=StrictRequiredInteger(validate=validate_non_negative("invoke_tx_n_steps")) -) - -validate_n_steps_metadata = dict( - marshmallow_field=StrictRequiredInteger(validate=validate_non_negative("validate_n_steps")) +validate_max_n_steps_override_metadata = dict( + marshmallow_field=StrictOptionalInteger( + validate=validate_non_negative("validate_max_n_steps_override", allow_none=True) + ) ) gas_price = dict( diff --git a/src/starkware/starknet/definitions/general_config.py b/src/starkware/starknet/definitions/general_config.py index cc8fa344..b8ea5417 100644 --- a/src/starkware/starknet/definitions/general_config.py +++ b/src/starkware/starknet/definitions/general_config.py @@ -1,17 +1,21 @@ -import json import os from dataclasses import field from typing import Optional -import marshmallow.fields as mfields import marshmallow_dataclass from services.everest.definitions.general_config import EverestGeneralConfig from starkware.cairo.lang.instances import dynamic_instance -from starkware.python.utils import from_bytes, get_build_dir_path +from starkware.python.utils import from_bytes from starkware.starknet.definitions import fields -from starkware.starknet.definitions.chain_ids import StarknetChainId +from starkware.starknet.definitions.chain_ids import ( + CHAIN_ID_TO_PRIVATE_VERSIONED_CONSTANTS, + StarknetChainId, +) from starkware.starknet.definitions.constants import VERSIONED_CONSTANTS +from starkware.starknet.definitions.overridable_versioned_constants import ( + OverridableVersionedConstants, +) from starkware.starkware_utils.config_base import Config, load_config from starkware.starkware_utils.marshmallow_dataclass_fields import ( RequiredBoolean, @@ -19,7 +23,6 @@ load_int_value, ) -PRIVATE_VERSIONED_CONSTANTS_DIR = "src/starkware/starknet/definitions/private_versioned_constants" GENERAL_CONFIG_FILE_NAME = "general_config.yml" DOCKER_GENERAL_CONFIG_PATH = os.path.join("/", GENERAL_CONFIG_FILE_NAME) GENERAL_CONFIG_PATH = os.path.join(os.path.dirname(__file__), GENERAL_CONFIG_FILE_NAME) @@ -65,14 +68,14 @@ DEFAULT_USE_KZG_DA = True # Given in units of wei. -DEFAULT_DEPRECATED_L1_GAS_PRICE = 10**11 -DEFAULT_DEPRECATED_L1_DATA_GAS_PRICE = 10**5 +DEFAULT_DEPRECATED_L1_GAS_PRICE = 10**9 +DEFAULT_DEPRECATED_L1_DATA_GAS_PRICE = 1 DEFAULT_ETH_IN_FRI = 10**21 DEFAULT_MIN_FRI_L1_GAS_PRICE = 10**6 -DEFAULT_MAX_FRI_L1_GAS_PRICE = 10**18 -DEFAULT_MIN_FRI_L1_DATA_GAS_PRICE = 10**5 -DEFAULT_MAX_FRI_L1_DATA_GAS_PRICE = 10**17 +DEFAULT_MAX_FRI_L1_GAS_PRICE = 10**21 +DEFAULT_MIN_FRI_L1_DATA_GAS_PRICE = 1 +DEFAULT_MAX_FRI_L1_DATA_GAS_PRICE = 10**21 # Configuration schema definition. @@ -130,28 +133,12 @@ class StarknetGeneralConfig(EverestGeneralConfig): gas_price_bounds: GasPriceBounds = field(default_factory=GasPriceBounds) - invoke_tx_max_n_steps: int = field( - metadata=fields.invoke_tx_n_steps_metadata, - default=VERSIONED_CONSTANTS.invoke_tx_max_n_steps, - ) - # IMPORTANT: when editing this in production, make sure to only decrease the value. # Increasing it in production may cause issue to nodes during execution, so only increase it # during a new release. - validate_max_n_steps: int = field( - metadata=fields.validate_n_steps_metadata, default=VERSIONED_CONSTANTS.validate_max_n_steps - ) - - private_versioned_constants_path: Optional[str] = field( - metadata=additional_metadata( - marshmallow_field=mfields.String(allow_none=True), - description=( - "If not None, overrides the default versioned constants. Must be a name of a valid " - "versioned constants file, in the blockifier format. The file must be located in " - "the private versioned constants directory." - ), - ), - default=None, + # This value should not be used directly, Use `get_validate_max_n_steps`. + validate_max_n_steps_override: Optional[int] = field( + metadata=fields.validate_max_n_steps_override_metadata, default=None ) # The default price of one ETH (10**18 Wei) in STRK units. Used in case of oracle failure. @@ -209,33 +196,24 @@ def min_fri_l1_data_gas_price(self) -> int: def max_fri_l1_data_gas_price(self) -> int: return self.gas_price_bounds.max_fri_l1_data_gas_price - def get_optional_private_versioned_constants( - self, - ) -> Optional[str]: - """ - Returns the private versioned constants file's contents, if a filename is configured. - No need to parse the contents, as the file is already in the blockifier format. - """ - path = self.private_versioned_constants_path - if path is None: - return None - - # Assert the file exists. - relative_path = os.path.join(PRIVATE_VERSIONED_CONSTANTS_DIR, path) - absolute_path = get_build_dir_path(rel_path=relative_path) - assert os.path.isfile( - absolute_path - ), f"Invalid path to private versioned constants. {absolute_path=}." - - with open(absolute_path) as f: - private_versioned_constants = f.read() - - # Assert the string is in valid JSON format. - try: - json.loads(private_versioned_constants) - except json.JSONDecodeError as e: - raise ValueError( - f"private versioned constants file is not a valid JSON. file path: {absolute_path}." - ) from e - - return private_versioned_constants + def get_private_versioned_constants(self) -> Optional[OverridableVersionedConstants]: + return CHAIN_ID_TO_PRIVATE_VERSIONED_CONSTANTS.get(self.chain_id) + + def get_validate_max_n_steps(self) -> int: + if self.validate_max_n_steps_override is not None: + return self.validate_max_n_steps_override + + private_versioned_constants = self.get_private_versioned_constants() + if private_versioned_constants is not None: + if private_versioned_constants.validate_max_n_steps is not None: + return private_versioned_constants.validate_max_n_steps + + return VERSIONED_CONSTANTS.validate_max_n_steps + + def get_invoke_tx_max_n_steps(self) -> int: + private_versioned_constants = self.get_private_versioned_constants() + if private_versioned_constants is not None: + if private_versioned_constants.invoke_tx_max_n_steps is not None: + return private_versioned_constants.invoke_tx_max_n_steps + + return VERSIONED_CONSTANTS.invoke_tx_max_n_steps diff --git a/src/starkware/starknet/definitions/general_config.yml b/src/starkware/starknet/definitions/general_config.yml index 8d94f060..4ecc624f 100644 --- a/src/starkware/starknet/definitions/general_config.yml +++ b/src/starkware/starknet/definitions/general_config.yml @@ -7,10 +7,9 @@ gas_price_bounds: min_fri_l1_gas_price: 30000000000 min_wei_l1_data_gas_price: 100000 min_wei_l1_gas_price: 10000000000 -invoke_tx_max_n_steps: 1000000 -sequencer_address: '0x795488c127693ffb36733cc054f9e2be39241a794a4877dc8fc1dbe52750488' +sequencer_address: '0x31c641e041f8d25997985b0efe68d0c5ce89d418ca9a127ae043aebed6851c5' starknet_os_config: chain_id: 1536727068981429685321 - deprecated_fee_token_address: '0x4201b1ca8320dd248a9c18aeae742db64b0ea0f6f1f5a4c72ddb6b725e16316' - fee_token_address: '0x3c0178a04b4d297c884e366bb0a2ec0f682188ba604d71955134a203d4f2adc' -validate_max_n_steps: 1000000 + deprecated_fee_token_address: '0x5195ba458d98a8d5a390afa87e199566e473d1124c07a3c57bf19813255ac41' + fee_token_address: '0x7ce4aa542d72a82662cda96b147da9b041ecf8c61f67ef657f3bbb852fc698f' +validate_max_n_steps_override: 1000000 diff --git a/src/starkware/starknet/definitions/overridable_versioned_constants.py b/src/starkware/starknet/definitions/overridable_versioned_constants.py new file mode 100644 index 00000000..71739ace --- /dev/null +++ b/src/starkware/starknet/definitions/overridable_versioned_constants.py @@ -0,0 +1,11 @@ +import dataclasses +from typing import Optional + + +# Should only include versioned constants which are both overridable AND have a need for an +# override. +@dataclasses.dataclass +class OverridableVersionedConstants: + max_calldata_length: Optional[int] = None + invoke_tx_max_n_steps: Optional[int] = None + validate_max_n_steps: Optional[int] = None diff --git a/src/starkware/starknet/definitions/versioned_constants.json b/src/starkware/starknet/definitions/versioned_constants.json index e362db48..82194eeb 100644 --- a/src/starkware/starknet/definitions/versioned_constants.json +++ b/src/starkware/starknet/definitions/versioned_constants.json @@ -11,7 +11,7 @@ 1 ], "gas_per_code_byte": [ - 875, + 32, 1000 ], "gas_per_data_felt": [ diff --git a/src/starkware/starknet/solidity/IStarknetMessaging.sol b/src/starkware/starknet/solidity/IStarknetMessaging.sol index ec41e58b..9226ecb9 100644 --- a/src/starkware/starknet/solidity/IStarknetMessaging.sol +++ b/src/starkware/starknet/solidity/IStarknetMessaging.sol @@ -15,6 +15,26 @@ interface IStarknetMessaging is IStarknetMessagingEvents { */ function l1ToL2Messages(bytes32 msgHash) external view returns (uint256); + /** + Returns the hash of an L1 -> L2 message. + */ + function l1ToL2MsgHash( + address fromAddress, + uint256 toAddress, + uint256 selector, + uint256[] calldata payload, + uint256 nonce + ) external pure returns (bytes32); + + /** + Returns the hash of an L2 -> L1 message. + */ + function l2ToL1MsgHash( + uint256 fromAddress, + address toAddress, + uint256[] calldata payload + ) external pure returns (bytes32); + /** Sends a message to an L2 contract. This function is payable, the payed amount is the message fee. diff --git a/src/starkware/starknet/solidity/Starknet.sol b/src/starkware/starknet/solidity/Starknet.sol index 859c4c5d..9ecbd94e 100644 --- a/src/starkware/starknet/solidity/Starknet.sol +++ b/src/starkware/starknet/solidity/Starknet.sol @@ -157,7 +157,7 @@ contract Starknet is return 0; } - function validateInitData(bytes calldata data) internal view override { + function validateInitData(bytes calldata data) internal pure override { require(data.length == 7 * 32, "ILLEGAL_INIT_DATA_SIZE"); uint256 programHash_ = abi.decode(data[:32], (uint256)); require(programHash_ != 0, "BAD_INITIALIZATION"); @@ -189,6 +189,7 @@ contract Starknet is */ function verifyKzgProofs(uint256[] calldata programOutputSlice, bytes[] calldata kzgProofs) internal + view { require(programOutputSlice.length >= 2, "KZG_SEGMENT_TOO_SHORT"); bytes32 z = bytes32(programOutputSlice[StarknetOutput.KZG_Z_OFFSET]); diff --git a/src/starkware/starknet/solidity/StarknetGovernance.sol b/src/starkware/starknet/solidity/StarknetGovernance.sol index 9e807197..3ba93003 100644 --- a/src/starkware/starknet/solidity/StarknetGovernance.sol +++ b/src/starkware/starknet/solidity/StarknetGovernance.sol @@ -9,7 +9,7 @@ contract StarknetGovernance is Governance { /* Returns the GovernanceInfoStruct associated with the governance tag. */ - function getGovernanceInfo() internal view override returns (GovernanceInfoStruct storage gub) { + function getGovernanceInfo() internal pure override returns (GovernanceInfoStruct storage gub) { bytes32 location = keccak256(abi.encodePacked(STARKNET_GOVERNANCE_INFO_TAG)); assembly { gub.slot := location diff --git a/src/starkware/starknet/solidity/StarknetMessaging.sol b/src/starkware/starknet/solidity/StarknetMessaging.sol index ebf353a7..6a120231 100644 --- a/src/starkware/starknet/solidity/StarknetMessaging.sol +++ b/src/starkware/starknet/solidity/StarknetMessaging.sol @@ -83,18 +83,19 @@ contract StarknetMessaging is IStarknetMessaging { } /** - Returns the hash of an L1 -> L2 message from msg.sender. + Returns the hash of an L1 -> L2 message. */ - function getL1ToL2MsgHash( + function l1ToL2MsgHash( + address fromAddress, uint256 toAddress, uint256 selector, uint256[] calldata payload, uint256 nonce - ) internal view returns (bytes32) { + ) public pure returns (bytes32) { return keccak256( abi.encodePacked( - uint256(uint160(msg.sender)), + uint256(uint160(fromAddress)), toAddress, nonce, selector, @@ -104,6 +105,20 @@ contract StarknetMessaging is IStarknetMessaging { ); } + /** + Returns the hash of an L2 -> L1 message. + */ + function l2ToL1MsgHash( + uint256 fromAddress, + address toAddress, + uint256[] calldata payload + ) public pure returns (bytes32) { + return + keccak256( + abi.encodePacked(fromAddress, uint256(uint160(toAddress)), payload.length, payload) + ); + } + /** Sends a message to an L2 contract. */ @@ -117,7 +132,7 @@ contract StarknetMessaging is IStarknetMessaging { uint256 nonce = l1ToL2MessageNonce(); NamedStorage.setUintValue(L1L2_MESSAGE_NONCE_TAG, nonce + 1); emit LogMessageToL2(msg.sender, toAddress, selector, payload, nonce, msg.value); - bytes32 msgHash = getL1ToL2MsgHash(toAddress, selector, payload, nonce); + bytes32 msgHash = l1ToL2MsgHash(msg.sender, toAddress, selector, payload, nonce); // Note that the inclusion of the unique nonce in the message hash implies that // l1ToL2Messages()[msgHash] was not accessed before. l1ToL2Messages()[msgHash] = msg.value + 1; @@ -134,10 +149,7 @@ contract StarknetMessaging is IStarknetMessaging { override returns (bytes32) { - bytes32 msgHash = keccak256( - abi.encodePacked(fromAddress, uint256(uint160(msg.sender)), payload.length, payload) - ); - + bytes32 msgHash = l2ToL1MsgHash(fromAddress, msg.sender, payload); require(l2ToL1Messages()[msgHash] > 0, "INVALID_MESSAGE_TO_CONSUME"); emit ConsumedMessageToL1(fromAddress, msg.sender, payload); l2ToL1Messages()[msgHash] -= 1; @@ -151,7 +163,7 @@ contract StarknetMessaging is IStarknetMessaging { uint256 nonce ) external override returns (bytes32) { emit MessageToL2CancellationStarted(msg.sender, toAddress, selector, payload, nonce); - bytes32 msgHash = getL1ToL2MsgHash(toAddress, selector, payload, nonce); + bytes32 msgHash = l1ToL2MsgHash(msg.sender, toAddress, selector, payload, nonce); uint256 msgFeePlusOne = l1ToL2Messages()[msgHash]; require(msgFeePlusOne > 0, "NO_MESSAGE_TO_CANCEL"); l1ToL2MessageCancellations()[msgHash] = block.timestamp; @@ -168,7 +180,7 @@ contract StarknetMessaging is IStarknetMessaging { // Note that the message hash depends on msg.sender, which prevents one contract from // cancelling another contract's message. // Trying to do so will result in NO_MESSAGE_TO_CANCEL. - bytes32 msgHash = getL1ToL2MsgHash(toAddress, selector, payload, nonce); + bytes32 msgHash = l1ToL2MsgHash(msg.sender, toAddress, selector, payload, nonce); uint256 msgFeePlusOne = l1ToL2Messages()[msgHash]; require(msgFeePlusOne != 0, "NO_MESSAGE_TO_CANCEL"); diff --git a/src/starkware/starknet/solidity/StarknetOperator.sol b/src/starkware/starknet/solidity/StarknetOperator.sol index 3ad59aab..898993ac 100644 --- a/src/starkware/starknet/solidity/StarknetOperator.sol +++ b/src/starkware/starknet/solidity/StarknetOperator.sol @@ -7,7 +7,7 @@ import "starkware/solidity/libraries/NamedStorage8.sol"; abstract contract StarknetOperator is Operator { string constant OPERATORS_MAPPING_TAG = "STARKNET_1.0_ROLES_OPERATORS_MAPPING_TAG"; - function getOperators() internal view override returns (mapping(address => bool) storage) { + function getOperators() internal pure override returns (mapping(address => bool) storage) { return NamedStorage.addressToBoolMapping(OPERATORS_MAPPING_TAG); } } diff --git a/src/starkware/starknet/solidity/StarknetState.sol b/src/starkware/starknet/solidity/StarknetState.sol index d328ae49..714ee5b2 100644 --- a/src/starkware/starknet/solidity/StarknetState.sol +++ b/src/starkware/starknet/solidity/StarknetState.sol @@ -24,7 +24,10 @@ library StarknetState { and validate that we have the expected block number at the end. This function must be called at the beginning of the updateState transaction. */ - function checkPrevBlockNumber(State storage state, uint256[] calldata starknetOutput) internal { + function checkPrevBlockNumber(State storage state, uint256[] calldata starknetOutput) + internal + view + { uint256 expectedPrevBlockNumber; if (state.blockNumber == -1) { expectedPrevBlockNumber = 0x800000000000011000000000000000000000000000000000000000000000000; @@ -41,7 +44,10 @@ library StarknetState { Validates that the current block number is the new block number. This is used to protect against re-entrancy attacks. */ - function checkNewBlockNumber(State storage state, uint256[] calldata starknetOutput) internal { + function checkNewBlockNumber(State storage state, uint256[] calldata starknetOutput) + internal + view + { require( uint256(state.blockNumber) == starknetOutput[StarknetOutput.NEW_BLOCK_NUMBER_OFFSET], "REENTRANCY_FAILURE"