From 4ff6d68f863a011d3b2ae5c0087ee4cce7a463db Mon Sep 17 00:00:00 2001 From: Gabriel Coutinho de Paula Date: Thu, 25 Jul 2024 22:04:34 -0300 Subject: [PATCH] feat: increase leaf size --- src/AccessLogs.sol | 23 +++++++--------- src/Buffer.sol | 9 ++++--- src/Memory.sol | 44 +++++++++++++++++++++++-------- src/UArchCompat.sol | 2 +- templates/AccessLogs.sol.template | 21 ++++++++------- test/AccessLogs.t.sol | 4 +-- test/Memory.t.sol | 4 +-- 7 files changed, 63 insertions(+), 44 deletions(-) diff --git a/src/AccessLogs.sol b/src/AccessLogs.sol index 50d7de34..19c6e887 100644 --- a/src/AccessLogs.sol +++ b/src/AccessLogs.sol @@ -24,6 +24,7 @@ import "./UArchConstants.sol"; library AccessLogs { using Buffer for Buffer.Context; using Memory for Memory.AlignedSize; + using Memory for Memory.PhysicalAddress; struct Context { bytes32 currentRootHash; @@ -85,7 +86,7 @@ library AccessLogs { returns (bytes32) { Memory.Region memory r = - Memory.regionFromStride(readStride, Memory.alignedSizeFromLog2(2)); + Memory.regionFromStride(readStride, Memory.alignedSizeFromLog2(0)); return readRegion(a, r); } @@ -95,16 +96,13 @@ library AccessLogs { ) internal pure returns (uint64) { bytes32 readData = a.buffer.consumeBytes32(); bytes32 valHash = keccak256(abi.encodePacked(readData)); - Memory.PhysicalAddress leafAddress = Memory.PhysicalAddress.wrap( - Memory.PhysicalAddress.unwrap(readAddress) & ~uint64(31) - ); - uint64 offset = Memory.PhysicalAddress.unwrap(readAddress) - - Memory.PhysicalAddress.unwrap(leafAddress); + Memory.PhysicalAddress leafAddress = readAddress.truncateToLeaf(); + uint64 offset = readAddress.minus(leafAddress); bytes8 readValue = bytes8(readData << (offset << 3)); bytes32 expectedValHash = - readLeaf(a, Memory.strideFromWordAddress(leafAddress)); + readLeaf(a, Memory.strideFromLeafAddress(leafAddress)); require(valHash == expectedValHash, "Read value doesn't match"); return machineWordToSolidityUint64(readValue); @@ -136,7 +134,7 @@ library AccessLogs { bytes32 newHash ) internal pure { Memory.Region memory r = - Memory.regionFromStride(writeStride, Memory.alignedSizeFromLog2(2)); + Memory.regionFromStride(writeStride, Memory.alignedSizeFromLog2(0)); writeRegion(a, r, newHash); } @@ -146,12 +144,9 @@ library AccessLogs { uint64 newValue ) internal pure { bytes32 writtenData = a.buffer.consumeBytes32(); - Memory.PhysicalAddress leafAddress = Memory.PhysicalAddress.wrap( - Memory.PhysicalAddress.unwrap(writeAddress) & ~uint64(31) - ); - uint64 offset = Memory.PhysicalAddress.unwrap(writeAddress) - - Memory.PhysicalAddress.unwrap(leafAddress); + Memory.PhysicalAddress leafAddress = writeAddress.truncateToLeaf(); + uint64 offset = writeAddress.minus(leafAddress); uint64 expectedNewValue = machineWordToSolidityUint64(bytes8(writtenData << (offset << 3))); @@ -162,7 +157,7 @@ library AccessLogs { writeLeaf( a, - Memory.strideFromWordAddress(leafAddress), + Memory.strideFromLeafAddress(leafAddress), keccak256(abi.encodePacked(writtenData)) ); } diff --git a/src/Buffer.sol b/src/Buffer.sol index 9b212418..adeb73e8 100644 --- a/src/Buffer.sol +++ b/src/Buffer.sol @@ -81,10 +81,13 @@ library Buffer { ) internal pure returns (bytes32, uint8) { // require that multiplier makes sense! uint8 logOfSize = region.alignedSize.log2(); - require(logOfSize <= LOG2RANGE, "Cannot be bigger than the tree itself"); + require( + logOfSize <= Memory.LOG2_MAX_SIZE, + "Cannot be bigger than the tree itself" + ); uint64 stride = Memory.Stride.unwrap(region.stride); - uint8 nodesCount = LOG2RANGE - logOfSize; + uint8 nodesCount = Memory.LOG2_MAX_SIZE - logOfSize; for (uint64 i = 0; i < nodesCount; i++) { Buffer.Context memory siblings = @@ -113,8 +116,6 @@ library Buffer { return root; } - uint8 constant LOG2RANGE = 61; - function isEven(uint64 x) private pure returns (bool) { return x % 2 == 0; } diff --git a/src/Memory.sol b/src/Memory.sol index 1597ef62..17a40f21 100644 --- a/src/Memory.sol +++ b/src/Memory.sol @@ -18,8 +18,8 @@ pragma solidity ^0.8.0; library Memory { // Specifies a memory region and it's merkle hash. // The size is given in the number of leaves in the tree, - // and therefore are word-sized. - // This means a `alignedSize` specifies a region the size of a word. + // and therefore are leaf-sized. + // This means a `alignedSize` specifies a region the size of a leaf. // The address has to be aligned to a power-of-two. // By using an stride, we can guarantee that the address is aligned. // The address is given by `stride * (1 << log2s)`. @@ -45,7 +45,7 @@ library Memory { return regionFromStride(stride, alignedSize); } - function regionFromWordAddress(PhysicalAddress startAddress) + function regionFromLeafAddress(PhysicalAddress startAddress) internal pure returns (Region memory) @@ -53,7 +53,6 @@ library Memory { return regionFromPhysicalAddress(startAddress, alignedSizeFromLog2(0)); } - // // Stride and PhysicalAddress // // When using memory address and a size in merkle trees to refer to a memory region, @@ -74,23 +73,26 @@ library Memory { ) internal pure returns (Stride) { uint64 s = alignedSize.size(); uint64 addr = PhysicalAddress.unwrap(startAddress); - // assert memory address is word-aligned (8-byte long) - assert(addr & 7 == 0); - uint64 position = PhysicalAddress.unwrap(startAddress) >> 3; + + // assert memory address is leaf-aligned (32-byte long) + assert(addr & LEAF_MASK == 0); + uint64 position = PhysicalAddress.unwrap(startAddress) >> LOG2_LEAF; + // assert position and size are aligned // position has to be a multiple of size // equivalent to: size = 2^a, position = 2^b, position = size * 2^c, where c >= 0 assert(((s - 1) & position) == 0); uint64 stride = position / s; + return Stride.wrap(stride); } - function strideFromWordAddress(PhysicalAddress startAddress) + function strideFromLeafAddress(PhysicalAddress startAddress) internal pure returns (Stride) { - return strideFromPhysicalAddress(startAddress, alignedSizeFromLog2(2)); + return strideFromPhysicalAddress(startAddress, alignedSizeFromLog2(0)); } function validateStrideLength(Stride stride, AlignedSize alignedSize) @@ -101,7 +103,6 @@ library Memory { assert(Stride.unwrap(stride) * s < MAX_STRIDE); } - // // AlignedSize // // The size is given in the number of leaves in the tree, @@ -109,7 +110,10 @@ library Memory { type AlignedSize is uint8; - uint64 constant MAX_SIZE = (1 << 61); + uint8 constant LOG2_LEAF = 5; + uint8 constant LOG2_MAX_SIZE = 64 - LOG2_LEAF; + uint64 constant LEAF_MASK = uint64(1 << LOG2_LEAF) - 1; + uint64 constant MAX_SIZE = uint64(1 << LOG2_MAX_SIZE); using Memory for AlignedSize; @@ -134,4 +138,22 @@ library Memory { { return PhysicalAddress.wrap(uint64Address); } + + function truncateToLeaf(PhysicalAddress addr) + internal + pure + returns (PhysicalAddress) + { + uint64 r = Memory.PhysicalAddress.unwrap(addr) & ~LEAF_MASK; + return Memory.PhysicalAddress.wrap(r); + } + + function minus(PhysicalAddress lhs, PhysicalAddress rhs) + internal + pure + returns (uint64) + { + return Memory.PhysicalAddress.unwrap(lhs) + - Memory.PhysicalAddress.unwrap(rhs); + } } diff --git a/src/UArchCompat.sol b/src/UArchCompat.sol index a6c45479..94697042 100644 --- a/src/UArchCompat.sol +++ b/src/UArchCompat.sol @@ -104,7 +104,7 @@ library UArchCompat { Memory.regionFromPhysicalAddress( UArchConstants.RESET_POSITION.toPhysicalAddress(), Memory.alignedSizeFromLog2( - UArchConstants.RESET_ALIGNED_SIZE - 3 + UArchConstants.RESET_ALIGNED_SIZE - Memory.LOG2_LEAF ) ), UArchConstants.PRESTINE_STATE diff --git a/templates/AccessLogs.sol.template b/templates/AccessLogs.sol.template index dc2742df..2c152a6a 100644 --- a/templates/AccessLogs.sol.template +++ b/templates/AccessLogs.sol.template @@ -35,6 +35,7 @@ import "./UArchConstants.sol"; library AccessLogs { using Buffer for Buffer.Context; using Memory for Memory.AlignedSize; + using Memory for Memory.PhysicalAddress; struct Context { bytes32 currentRootHash; @@ -98,7 +99,7 @@ library AccessLogs { returns (bytes32) { Memory.Region memory r = - Memory.regionFromStride(readStride, Memory.alignedSizeFromLog2(2)); + Memory.regionFromStride(readStride, Memory.alignedSizeFromLog2(0)); return readRegion(a, r); } @@ -108,13 +109,13 @@ library AccessLogs { ) internal pure returns (uint64) { bytes32 readData = a.buffer.consumeBytes32(); bytes32 valHash = keccak256(abi.encodePacked(readData)); - Memory.PhysicalAddress leafAddress = Memory.PhysicalAddress.wrap(Memory.PhysicalAddress.unwrap(readAddress) & ~uint64(31)); - uint64 offset = Memory.PhysicalAddress.unwrap(readAddress) - Memory.PhysicalAddress.unwrap(leafAddress); - + + Memory.PhysicalAddress leafAddress = readAddress.truncateToLeaf(); + uint64 offset = readAddress.minus(leafAddress); bytes8 readValue = bytes8(readData << (offset << 3)); bytes32 expectedValHash = - readLeaf(a, Memory.strideFromWordAddress(leafAddress)); + readLeaf(a, Memory.strideFromLeafAddress(leafAddress)); require(valHash == expectedValHash, "Read value doesn't match"); return machineWordToSolidityUint64(readValue); @@ -146,7 +147,7 @@ library AccessLogs { bytes32 newHash ) internal pure { Memory.Region memory r = - Memory.regionFromStride(writeStride, Memory.alignedSizeFromLog2(2)); + Memory.regionFromStride(writeStride, Memory.alignedSizeFromLog2(0)); writeRegion(a, r, newHash); } @@ -156,16 +157,16 @@ library AccessLogs { uint64 newValue ) internal pure { bytes32 writtenData = a.buffer.consumeBytes32(); - Memory.PhysicalAddress leafAddress = Memory.PhysicalAddress.wrap(Memory.PhysicalAddress.unwrap(writeAddress) & ~uint64(31)); - uint64 offset = Memory.PhysicalAddress.unwrap(writeAddress) - Memory.PhysicalAddress.unwrap(leafAddress); - + + Memory.PhysicalAddress leafAddress = writeAddress.truncateToLeaf(); + uint64 offset = writeAddress.minus(leafAddress); uint64 expectedNewValue = machineWordToSolidityUint64(bytes8(writtenData << (offset << 3))); require(newValue == expectedNewValue, "Access log value does not contain the expected written value"); writeLeaf( a, - Memory.strideFromWordAddress(leafAddress), + Memory.strideFromLeafAddress(leafAddress), keccak256(abi.encodePacked(writtenData)) ); } diff --git a/test/AccessLogs.t.sol b/test/AccessLogs.t.sol index cd0c1e93..66c278fd 100644 --- a/test/AccessLogs.t.sol +++ b/test/AccessLogs.t.sol @@ -180,8 +180,8 @@ contract AccessLogsTest is Test { buffer.offset = 0; (bytes32 root,) = buffer.peekRoot( Memory.regionFromStride( - Memory.strideFromWordAddress(position.toPhysicalAddress()), - Memory.alignedSizeFromLog2(2) + Memory.strideFromLeafAddress(position.toPhysicalAddress()), + Memory.alignedSizeFromLog2(0) ), drive ); diff --git a/test/Memory.t.sol b/test/Memory.t.sol index 8503c99c..360f58ec 100644 --- a/test/Memory.t.sol +++ b/test/Memory.t.sol @@ -28,7 +28,7 @@ contract MemoryTest is Test { function testStrideAlignment() public { for (uint128 paddr = 8; paddr <= (1 << 63); paddr *= 2) { - for (uint8 l = 0; ((1 << l) <= (paddr >> 3)); ++l) { + for (uint8 l = 0; ((1 << l) <= (paddr >> Memory.LOG2_LEAF)); ++l) { uint64(paddr).toPhysicalAddress().strideFromPhysicalAddress( Memory.alignedSizeFromLog2(l) ); @@ -39,7 +39,7 @@ contract MemoryTest is Test { Memory.alignedSizeFromLog2(l) ); - if ((1 << l) == (paddr >> 3)) { + if ((1 << l) == (paddr >> Memory.LOG2_LEAF)) { // address has to be aligned with stride size vm.expectRevert(); uint64(paddr + paddr / 2).toPhysicalAddress()