Skip to content

Commit

Permalink
feat: increase leaf size
Browse files Browse the repository at this point in the history
  • Loading branch information
GCdePaula committed Jul 30, 2024
1 parent 4f9ee32 commit 4ff6d68
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 44 deletions.
23 changes: 9 additions & 14 deletions src/AccessLogs.sol
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import "./UArchConstants.sol";
library AccessLogs {
using Buffer for Buffer.Context;
using Memory for Memory.AlignedSize;
using Memory for Memory.PhysicalAddress;

struct Context {
bytes32 currentRootHash;
Expand Down Expand Up @@ -85,7 +86,7 @@ library AccessLogs {
returns (bytes32)
{
Memory.Region memory r =
Memory.regionFromStride(readStride, Memory.alignedSizeFromLog2(2));
Memory.regionFromStride(readStride, Memory.alignedSizeFromLog2(0));
return readRegion(a, r);
}

Expand All @@ -95,16 +96,13 @@ library AccessLogs {
) internal pure returns (uint64) {
bytes32 readData = a.buffer.consumeBytes32();
bytes32 valHash = keccak256(abi.encodePacked(readData));
Memory.PhysicalAddress leafAddress = Memory.PhysicalAddress.wrap(
Memory.PhysicalAddress.unwrap(readAddress) & ~uint64(31)
);
uint64 offset = Memory.PhysicalAddress.unwrap(readAddress)
- Memory.PhysicalAddress.unwrap(leafAddress);

Memory.PhysicalAddress leafAddress = readAddress.truncateToLeaf();
uint64 offset = readAddress.minus(leafAddress);
bytes8 readValue = bytes8(readData << (offset << 3));

bytes32 expectedValHash =
readLeaf(a, Memory.strideFromWordAddress(leafAddress));
readLeaf(a, Memory.strideFromLeafAddress(leafAddress));

require(valHash == expectedValHash, "Read value doesn't match");
return machineWordToSolidityUint64(readValue);
Expand Down Expand Up @@ -136,7 +134,7 @@ library AccessLogs {
bytes32 newHash
) internal pure {
Memory.Region memory r =
Memory.regionFromStride(writeStride, Memory.alignedSizeFromLog2(2));
Memory.regionFromStride(writeStride, Memory.alignedSizeFromLog2(0));
writeRegion(a, r, newHash);
}

Expand All @@ -146,12 +144,9 @@ library AccessLogs {
uint64 newValue
) internal pure {
bytes32 writtenData = a.buffer.consumeBytes32();
Memory.PhysicalAddress leafAddress = Memory.PhysicalAddress.wrap(
Memory.PhysicalAddress.unwrap(writeAddress) & ~uint64(31)
);
uint64 offset = Memory.PhysicalAddress.unwrap(writeAddress)
- Memory.PhysicalAddress.unwrap(leafAddress);

Memory.PhysicalAddress leafAddress = writeAddress.truncateToLeaf();
uint64 offset = writeAddress.minus(leafAddress);
uint64 expectedNewValue =
machineWordToSolidityUint64(bytes8(writtenData << (offset << 3)));

Expand All @@ -162,7 +157,7 @@ library AccessLogs {

writeLeaf(
a,
Memory.strideFromWordAddress(leafAddress),
Memory.strideFromLeafAddress(leafAddress),
keccak256(abi.encodePacked(writtenData))
);
}
Expand Down
9 changes: 5 additions & 4 deletions src/Buffer.sol
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,13 @@ library Buffer {
) internal pure returns (bytes32, uint8) {
// require that multiplier makes sense!
uint8 logOfSize = region.alignedSize.log2();
require(logOfSize <= LOG2RANGE, "Cannot be bigger than the tree itself");
require(
logOfSize <= Memory.LOG2_MAX_SIZE,
"Cannot be bigger than the tree itself"
);

uint64 stride = Memory.Stride.unwrap(region.stride);
uint8 nodesCount = LOG2RANGE - logOfSize;
uint8 nodesCount = Memory.LOG2_MAX_SIZE - logOfSize;

for (uint64 i = 0; i < nodesCount; i++) {
Buffer.Context memory siblings =
Expand Down Expand Up @@ -113,8 +116,6 @@ library Buffer {
return root;
}

uint8 constant LOG2RANGE = 61;

function isEven(uint64 x) private pure returns (bool) {
return x % 2 == 0;
}
Expand Down
44 changes: 33 additions & 11 deletions src/Memory.sol
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ pragma solidity ^0.8.0;
library Memory {
// Specifies a memory region and it's merkle hash.
// The size is given in the number of leaves in the tree,
// and therefore are word-sized.
// This means a `alignedSize` specifies a region the size of a word.
// and therefore are leaf-sized.
// This means a `alignedSize` specifies a region the size of a leaf.
// The address has to be aligned to a power-of-two.
// By using an stride, we can guarantee that the address is aligned.
// The address is given by `stride * (1 << log2s)`.
Expand All @@ -45,15 +45,14 @@ library Memory {
return regionFromStride(stride, alignedSize);
}

function regionFromWordAddress(PhysicalAddress startAddress)
function regionFromLeafAddress(PhysicalAddress startAddress)
internal
pure
returns (Region memory)
{
return regionFromPhysicalAddress(startAddress, alignedSizeFromLog2(0));
}

//
// Stride and PhysicalAddress
//
// When using memory address and a size in merkle trees to refer to a memory region,
Expand All @@ -74,23 +73,26 @@ library Memory {
) internal pure returns (Stride) {
uint64 s = alignedSize.size();
uint64 addr = PhysicalAddress.unwrap(startAddress);
// assert memory address is word-aligned (8-byte long)
assert(addr & 7 == 0);
uint64 position = PhysicalAddress.unwrap(startAddress) >> 3;

// assert memory address is leaf-aligned (32-byte long)
assert(addr & LEAF_MASK == 0);
uint64 position = PhysicalAddress.unwrap(startAddress) >> LOG2_LEAF;

// assert position and size are aligned
// position has to be a multiple of size
// equivalent to: size = 2^a, position = 2^b, position = size * 2^c, where c >= 0
assert(((s - 1) & position) == 0);
uint64 stride = position / s;

return Stride.wrap(stride);
}

function strideFromWordAddress(PhysicalAddress startAddress)
function strideFromLeafAddress(PhysicalAddress startAddress)
internal
pure
returns (Stride)
{
return strideFromPhysicalAddress(startAddress, alignedSizeFromLog2(2));
return strideFromPhysicalAddress(startAddress, alignedSizeFromLog2(0));
}

function validateStrideLength(Stride stride, AlignedSize alignedSize)
Expand All @@ -101,15 +103,17 @@ library Memory {
assert(Stride.unwrap(stride) * s < MAX_STRIDE);
}

//
// AlignedSize
//
// The size is given in the number of leaves in the tree,
// and therefore are word-sized; a size of one (or log2size zero) means one word long.

type AlignedSize is uint8;

uint64 constant MAX_SIZE = (1 << 61);
uint8 constant LOG2_LEAF = 5;
uint8 constant LOG2_MAX_SIZE = 64 - LOG2_LEAF;
uint64 constant LEAF_MASK = uint64(1 << LOG2_LEAF) - 1;
uint64 constant MAX_SIZE = uint64(1 << LOG2_MAX_SIZE);

using Memory for AlignedSize;

Expand All @@ -134,4 +138,22 @@ library Memory {
{
return PhysicalAddress.wrap(uint64Address);
}

function truncateToLeaf(PhysicalAddress addr)
internal
pure
returns (PhysicalAddress)
{
uint64 r = Memory.PhysicalAddress.unwrap(addr) & ~LEAF_MASK;
return Memory.PhysicalAddress.wrap(r);
}

function minus(PhysicalAddress lhs, PhysicalAddress rhs)
internal
pure
returns (uint64)
{
return Memory.PhysicalAddress.unwrap(lhs)
- Memory.PhysicalAddress.unwrap(rhs);
}
}
2 changes: 1 addition & 1 deletion src/UArchCompat.sol
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ library UArchCompat {
Memory.regionFromPhysicalAddress(
UArchConstants.RESET_POSITION.toPhysicalAddress(),
Memory.alignedSizeFromLog2(
UArchConstants.RESET_ALIGNED_SIZE - 3
UArchConstants.RESET_ALIGNED_SIZE - Memory.LOG2_LEAF
)
),
UArchConstants.PRESTINE_STATE
Expand Down
21 changes: 11 additions & 10 deletions templates/AccessLogs.sol.template
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import "./UArchConstants.sol";
library AccessLogs {
using Buffer for Buffer.Context;
using Memory for Memory.AlignedSize;
using Memory for Memory.PhysicalAddress;

struct Context {
bytes32 currentRootHash;
Expand Down Expand Up @@ -98,7 +99,7 @@ library AccessLogs {
returns (bytes32)
{
Memory.Region memory r =
Memory.regionFromStride(readStride, Memory.alignedSizeFromLog2(2));
Memory.regionFromStride(readStride, Memory.alignedSizeFromLog2(0));
return readRegion(a, r);
}

Expand All @@ -108,13 +109,13 @@ library AccessLogs {
) internal pure returns (uint64) {
bytes32 readData = a.buffer.consumeBytes32();
bytes32 valHash = keccak256(abi.encodePacked(readData));
Memory.PhysicalAddress leafAddress = Memory.PhysicalAddress.wrap(Memory.PhysicalAddress.unwrap(readAddress) & ~uint64(31));
uint64 offset = Memory.PhysicalAddress.unwrap(readAddress) - Memory.PhysicalAddress.unwrap(leafAddress);


Memory.PhysicalAddress leafAddress = readAddress.truncateToLeaf();
uint64 offset = readAddress.minus(leafAddress);
bytes8 readValue = bytes8(readData << (offset << 3));

bytes32 expectedValHash =
readLeaf(a, Memory.strideFromWordAddress(leafAddress));
readLeaf(a, Memory.strideFromLeafAddress(leafAddress));

require(valHash == expectedValHash, "Read value doesn't match");
return machineWordToSolidityUint64(readValue);
Expand Down Expand Up @@ -146,7 +147,7 @@ library AccessLogs {
bytes32 newHash
) internal pure {
Memory.Region memory r =
Memory.regionFromStride(writeStride, Memory.alignedSizeFromLog2(2));
Memory.regionFromStride(writeStride, Memory.alignedSizeFromLog2(0));
writeRegion(a, r, newHash);
}

Expand All @@ -156,16 +157,16 @@ library AccessLogs {
uint64 newValue
) internal pure {
bytes32 writtenData = a.buffer.consumeBytes32();
Memory.PhysicalAddress leafAddress = Memory.PhysicalAddress.wrap(Memory.PhysicalAddress.unwrap(writeAddress) & ~uint64(31));
uint64 offset = Memory.PhysicalAddress.unwrap(writeAddress) - Memory.PhysicalAddress.unwrap(leafAddress);


Memory.PhysicalAddress leafAddress = writeAddress.truncateToLeaf();
uint64 offset = writeAddress.minus(leafAddress);
uint64 expectedNewValue = machineWordToSolidityUint64(bytes8(writtenData << (offset << 3)));

require(newValue == expectedNewValue, "Access log value does not contain the expected written value");

writeLeaf(
a,
Memory.strideFromWordAddress(leafAddress),
Memory.strideFromLeafAddress(leafAddress),
keccak256(abi.encodePacked(writtenData))
);
}
Expand Down
4 changes: 2 additions & 2 deletions test/AccessLogs.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ contract AccessLogsTest is Test {
buffer.offset = 0;
(bytes32 root,) = buffer.peekRoot(
Memory.regionFromStride(
Memory.strideFromWordAddress(position.toPhysicalAddress()),
Memory.alignedSizeFromLog2(2)
Memory.strideFromLeafAddress(position.toPhysicalAddress()),
Memory.alignedSizeFromLog2(0)
),
drive
);
Expand Down
4 changes: 2 additions & 2 deletions test/Memory.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ contract MemoryTest is Test {

function testStrideAlignment() public {
for (uint128 paddr = 8; paddr <= (1 << 63); paddr *= 2) {
for (uint8 l = 0; ((1 << l) <= (paddr >> 3)); ++l) {
for (uint8 l = 0; ((1 << l) <= (paddr >> Memory.LOG2_LEAF)); ++l) {
uint64(paddr).toPhysicalAddress().strideFromPhysicalAddress(
Memory.alignedSizeFromLog2(l)
);
Expand All @@ -39,7 +39,7 @@ contract MemoryTest is Test {
Memory.alignedSizeFromLog2(l)
);

if ((1 << l) == (paddr >> 3)) {
if ((1 << l) == (paddr >> Memory.LOG2_LEAF)) {
// address has to be aligned with stride size
vm.expectRevert();
uint64(paddr + paddr / 2).toPhysicalAddress()
Expand Down

0 comments on commit 4ff6d68

Please sign in to comment.