From 901da712cad49545c43f4eaea61eba42c354cddc Mon Sep 17 00:00:00 2001
From: Stephen Chen <20940639+stephenctw@users.noreply.github.com>
Date: Thu, 31 Oct 2024 13:53:48 +0800
Subject: [PATCH] feat(prt): complete rollups meta step
---
.../node/compute-runner/src/lib.rs | 2 +-
.../node/machine-runner/src/lib.rs | 17 +--
prt/client-lua/computation/commitment.lua | 21 ++-
prt/client-lua/computation/machine.lua | 107 ++++++++++-----
prt/client-lua/player/strategy.lua | 10 +-
prt/client-rs/src/db/compute_state_access.rs | 70 +++++++---
prt/client-rs/src/db/sql/compute_data.rs | 72 ++++++++++
prt/client-rs/src/db/sql/migrations.sql | 5 +
prt/client-rs/src/machine/commitment.rs | 4 -
.../src/machine/commitment_builder.rs | 9 +-
prt/client-rs/src/machine/instance.rs | 125 +++++++++++------
prt/client-rs/src/strategy/player.rs | 6 +-
prt/contracts/src/IDataProvider.sol | 1 +
.../tournament/abstracts/LeafTournament.sol | 126 +++++++-----------
.../src/tournament/libs/Commitment.sol | 4 +-
prt/contracts/test/Util.sol | 29 +++-
prt/tests/compute-rs/src/main.rs | 4 +-
prt/tests/compute/prt_compute.lua | 3 +-
.../runners/helpers/fake_commitment.lua | 9 +-
prt/tests/compute/runners/hero_runner.lua | 2 +-
prt/tests/compute/runners/sybil_runner.lua | 2 +-
21 files changed, 418 insertions(+), 210 deletions(-)
diff --git a/cartesi-rollups/node/compute-runner/src/lib.rs b/cartesi-rollups/node/compute-runner/src/lib.rs
index 5f037ffe..b9a167ae 100644
--- a/cartesi-rollups/node/compute-runner/src/lib.rs
+++ b/cartesi-rollups/node/compute-runner/src/lib.rs
@@ -43,7 +43,7 @@ where
.state_manager
.machine_state_hashes(last_sealed_epoch.epoch_number)?;
let mut player = Player::new(
- inputs.into_iter().map(|i| Input(i)).collect(),
+ Some(inputs.into_iter().map(|i| Input(i)).collect()),
leafs
.into_iter()
.map(|l| {
diff --git a/cartesi-rollups/node/machine-runner/src/lib.rs b/cartesi-rollups/node/machine-runner/src/lib.rs
index 8404deff..96024233 100644
--- a/cartesi-rollups/node/machine-runner/src/lib.rs
+++ b/cartesi-rollups/node/machine-runner/src/lib.rs
@@ -168,18 +168,11 @@ where
fn process_input(&mut self, data: &[u8]) -> Result<(), SM> {
// TODO: review caclulations
- let big_steps_in_stride = max_uint(LOG2_STRIDE - LOG2_UARCH_SPAN);
- let stride_count_in_input = max_uint(LOG2_EMULATOR_SPAN + LOG2_UARCH_SPAN - LOG2_STRIDE);
+ let big_steps_in_stride = 1 << (LOG2_STRIDE - LOG2_UARCH_SPAN);
+ let stride_count_in_input = 1 << (LOG2_EMULATOR_SPAN + LOG2_UARCH_SPAN - LOG2_STRIDE);
- // take snapshot and make it available to the compute client
- // the snapshot taken before input insersion is for log/proof generation
- self.snapshot(0)?;
self.feed_input(data)?;
- self.run_machine(1)?;
- // take snapshot and make it available to the compute client
- // the snapshot taken after insersion and step is for commitment builder
- self.snapshot(1)?;
- self.run_machine(big_steps_in_stride - 1)?;
+ self.run_machine(big_steps_in_stride)?;
let mut i: u64 = 0;
while !self.machine.read_iflags_y()? {
@@ -228,12 +221,12 @@ where
Ok(())
}
- fn snapshot(&self, offset: u64) -> Result<(), SM> {
+ fn take_snapshot(&self) -> Result<(), SM> {
// TODO: make sure "/rollups_data/{epoch_number}" exists
let snapshot_path = PathBuf::from(format!(
"/rollups_data/{}/{}",
self.epoch_number,
- self.next_input_index_in_epoch << LOG2_EMULATOR_SPAN + offset
+ self.next_input_index_in_epoch << LOG2_EMULATOR_SPAN
));
if !snapshot_path.exists() {
self.machine.store(&snapshot_path)?;
diff --git a/prt/client-lua/computation/commitment.lua b/prt/client-lua/computation/commitment.lua
index 0ef7825e..8bd1574f 100644
--- a/prt/client-lua/computation/commitment.lua
+++ b/prt/client-lua/computation/commitment.lua
@@ -35,10 +35,10 @@ local function run_uarch_span(machine)
end
local function build_small_machine_commitment(base_cycle, log2_stride_count, machine, snapshot_dir)
- local machine_state = machine:run(base_cycle)
+ local machine_state = machine:state()
if save_snapshot then
-- taking snapshot for leafs to save time in next level
- machine:snapshot(snapshot_dir, base_cycle)
+ machine:take_snapshot(snapshot_dir, base_cycle)
end
local initial_state = machine_state.root_hash
@@ -60,10 +60,10 @@ local function build_small_machine_commitment(base_cycle, log2_stride_count, mac
end
local function build_big_machine_commitment(base_cycle, log2_stride, log2_stride_count, machine, snapshot_dir)
- local machine_state = machine:run(base_cycle)
+ local machine_state = machine:state()
if save_snapshot then
-- taking snapshot for leafs to save time in next level
- machine:snapshot(snapshot_dir, base_cycle)
+ machine:take_snapshot(snapshot_dir, base_cycle)
end
local initial_state = machine_state.root_hash
@@ -88,9 +88,16 @@ local function build_big_machine_commitment(base_cycle, log2_stride, log2_stride
return initial_state, builder:build(initial_state)
end
-local function build_commitment(base_cycle, log2_stride, log2_stride_count, machine_path, snapshot_dir)
+local function build_commitment(base_cycle, log2_stride, log2_stride_count, machine_path, snapshot_dir, inputs)
local machine = Machine:new_from_path(machine_path)
machine:load_snapshot(snapshot_dir, base_cycle)
+ if inputs then
+ -- treat it as rollups
+ machine:run_with_inputs(base_cycle, inputs)
+ else
+ -- treat it as compute
+ machine:run(base_cycle)
+ end
if log2_stride >= consts.log2_uarch_span then
assert(
@@ -120,7 +127,7 @@ function CommitmentBuilder:new(machine_path, snapshot_dir, root_commitment)
return c
end
-function CommitmentBuilder:build(base_cycle, level, log2_stride, log2_stride_count)
+function CommitmentBuilder:build(base_cycle, level, log2_stride, log2_stride_count, inputs)
if not self.commitments[level] then
self.commitments[level] = {}
elseif self.commitments[level][base_cycle] then
@@ -128,7 +135,7 @@ function CommitmentBuilder:build(base_cycle, level, log2_stride, log2_stride_cou
end
local _, commitment = build_commitment(base_cycle, log2_stride, log2_stride_count, self.machine_path,
- self.snapshot_dir)
+ self.snapshot_dir, inputs)
self.commitments[level][base_cycle] = commitment
return commitment
end
diff --git a/prt/client-lua/computation/machine.lua b/prt/client-lua/computation/machine.lua
index 889f79ae..efc37a28 100644
--- a/prt/client-lua/computation/machine.lua
+++ b/prt/client-lua/computation/machine.lua
@@ -110,7 +110,7 @@ local function find_closest_snapshot(path, current_cycle, cycle)
return closest_dir
end
-function Machine:snapshot(snapshot_dir, cycle)
+function Machine:take_snapshot(snapshot_dir, cycle)
if helper.exists(snapshot_dir) then
local snapshot_path = snapshot_dir .. "/" .. tostring(cycle)
@@ -163,6 +163,34 @@ function Machine:run_uarch(ucycle)
self.ucycle = ucycle
end
+function Machine:run_with_inputs(cycle, inputs)
+ local input_mask = arithmetic.max_uint(consts.log2_emulator_span)
+ local current_input_index = self.cycle >> consts.log2_emulator_span
+
+ local next_input_index
+
+ if self.cycle & input_mask == 0 then
+ next_input_index = current_input_index
+ else
+ next_input_index = current_input_index + 1
+ end
+ local next_input_cycle = next_input_index << consts.log2_emulator_span
+
+ while next_input_cycle < cycle do
+ self:run(next_input_cycle)
+ local input = inputs[next_input_index]
+ if input then
+ self.machine:send_cmio_response(cartesi.machine.HTIF_YIELD_REASON_ADVANCE_STATE, input);
+ end
+
+ next_input_index = next_input_index + 1
+ next_input_cycle = next_input_index << consts.log2_emulator_span
+ end
+ self:run(cycle)
+
+ return self:state()
+end
+
function Machine:increment_uarch()
self.machine:run_uarch(self.ucycle + 1)
self.ucycle = self.ucycle + 1
@@ -200,18 +228,28 @@ local function ver(t, p, s)
return t
end
-local function encode_access_log(logs)
+local bint = require 'utils.bint' (256) -- use 256 bits integers
+
+local function encode_access_logs(logs, input)
local encoded = {}
- for _, a in ipairs(logs.accesses) do
- if a.log2_size == 3 then
- table.insert(encoded, a.read)
- else
- table.insert(encoded, a.read_hash)
- end
+ if input then
+ -- TODO: check #input is encoded as uint256
+ table.insert(encoded, bint(#input))
+ table.insert(encoded, input)
+ end
- for _, h in ipairs(a.sibling_hashes) do
- table.insert(encoded, h)
+ for _, log in ipairs(logs) do
+ for _, a in ipairs(log.accesses) do
+ if a.log2_size == 3 then
+ table.insert(encoded, a.read)
+ else
+ table.insert(encoded, a.read_hash)
+ end
+
+ for _, h in ipairs(a.sibling_hashes) do
+ table.insert(encoded, h)
+ end
end
end
@@ -223,38 +261,45 @@ local function encode_access_log(logs)
return '"' .. hex_data .. '"'
end
-function Machine.get_logs(path, snapshot_dir, cycle, ucycle, input)
+function Machine.get_logs(path, snapshot_dir, cycle, ucycle, inputs)
local machine = Machine:new_from_path(path)
machine:load_snapshot(snapshot_dir, cycle)
- local logs
+ local logs = {}
local log_type = { annotations = true, proofs = true }
- machine:run(cycle)
-
- local mask = 1 << consts.log2_emulator_span - 1;
- if cycle & mask == 0 and input then
- -- need to process input
- if ucycle == 0 then
- logs = machine.machine:log_send_cmio_response(cartesi.machine.HTIF_YIELD_REASON_ADVANCE_STATE, input,
- log_type
- )
- local step_logs = machine.machine:log_uarch_step(log_type)
- -- append step logs to cmio logs
- for _, log in ipairs(step_logs) do
- table.insert(logs, log)
+ local input = Hash.zero
+ if inputs then
+ -- treat it as rollups
+ machine:run_with_inputs(cycle, inputs)
+
+ local mask = arithmetic.max_uint(consts.log2_emulator_span);
+ local try_input = inputs[cycle >> consts.log2_emulator_span]
+ if cycle & mask == 0 and try_input then
+ input = try_input
+ -- need to process input
+ if ucycle == 0 then
+ -- need to log cmio
+ table.insert(logs,
+ machine.machine:log_send_cmio_response(cartesi.machine.HTIF_YIELD_REASON_ADVANCE_STATE, input,
+ log_type
+ ))
+ table.insert(logs, machine.machine:log_uarch_step(log_type))
+ return encode_access_logs(logs, input)
+ else
+ machine.machine:send_cmio_response(cartesi.machine.HTIF_YIELD_REASON_ADVANCE_STATE, input)
end
- return encode_access_log(logs)
- else
- machine.machine:send_cmio_response(cartesi.machine.HTIF_YIELD_REASON_ADVANCE_STATE, input)
end
+ else
+ -- treat it as compute
+ machine:run(cycle)
end
machine:run_uarch(ucycle)
if ucycle == consts.uarch_span then
- logs = machine.machine:log_uarch_reset(log_type)
+ table.insert(logs, machine.machine:log_uarch_reset(log_type))
else
- logs = machine.machine:log_uarch_step(log_type)
+ table.insert(logs, machine.machine:log_uarch_step(log_type))
end
- return encode_access_log(logs)
+ return encode_access_logs(logs, nil)
end
return Machine
diff --git a/prt/client-lua/player/strategy.lua b/prt/client-lua/player/strategy.lua
index 28bf4667..a86079bb 100644
--- a/prt/client-lua/player/strategy.lua
+++ b/prt/client-lua/player/strategy.lua
@@ -127,8 +127,8 @@ function HonestStrategy:_react_match(match, commitment, log)
local cycle = match.base_big_cycle
local ucycle = (match.leaf_cycle & constants.uarch_span):touinteger()
- local input = self.inputs[cycle >> constants.log2_emulator_span]
- local logs = Machine.get_logs(self.machine_path, self.commitment_builder.snapshot_dir, cycle, ucycle, input)
+ local logs = Machine.get_logs(self.machine_path, self.commitment_builder.snapshot_dir, cycle, ucycle,
+ self.inputs)
helper.log_full(self.sender.index, string.format(
"win leaf match in tournament %s of level %d for commitment %s",
@@ -281,7 +281,8 @@ function HonestStrategy:_react_tournament(tournament, log)
tournament.base_big_cycle,
tournament.level,
tournament.log2_stride,
- tournament.log2_stride_count
+ tournament.log2_stride_count,
+ self.inputs
)
table.insert(log.tournaments, tournament)
@@ -299,7 +300,8 @@ function HonestStrategy:_react_tournament(tournament, log)
tournament.parent.base_big_cycle,
tournament.parent.level,
tournament.parent.log2_stride,
- tournament.parent.log2_stride_count
+ tournament.parent.log2_stride_count,
+ self.inputs
)
if tournament_winner.commitment ~= old_commitment then
helper.log_full(self.sender.index, "player lost tournament")
diff --git a/prt/client-rs/src/db/compute_state_access.rs b/prt/client-rs/src/db/compute_state_access.rs
index 3cfb81dc..2d4be861 100644
--- a/prt/client-rs/src/db/compute_state_access.rs
+++ b/prt/client-rs/src/db/compute_state_access.rs
@@ -16,12 +16,11 @@ use std::{
#[derive(Debug, Serialize, Deserialize)]
pub struct InputsAndLeafs {
- #[serde(default)]
- inputs: Vec,
+ inputs: Option>,
leafs: Vec,
}
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Debug, Serialize, Deserialize, Default)]
pub struct Input(#[serde(with = "alloy_hex::serde")] pub Vec);
#[derive(Debug, Serialize, Deserialize)]
@@ -30,6 +29,7 @@ pub struct Leaf(#[serde(with = "alloy_hex::serde")] pub [u8; 32], pub u64);
#[derive(Debug)]
pub struct ComputeStateAccess {
connection: Mutex,
+ pub handle_rollups: bool,
pub work_path: PathBuf,
}
@@ -46,7 +46,7 @@ fn read_json_file(file_path: &Path) -> Result {
impl ComputeStateAccess {
pub fn new(
- inputs: Vec,
+ inputs: Option>,
leafs: Vec,
root_tournament: String,
compute_data_path: &str,
@@ -59,13 +59,16 @@ impl ComputeStateAccess {
let work_path = PathBuf::from(work_dir);
let db_path = work_path.join("db");
let no_create_flags = OpenFlags::default() & !OpenFlags::SQLITE_OPEN_CREATE;
+ let handle_rollups;
match Connection::open_with_flags(&db_path, no_create_flags) {
// database already exists, return it
Ok(connection) => {
+ handle_rollups = compute_data::handle_rollups(&connection)?;
return Ok(Self {
connection: Mutex::new(connection),
+ handle_rollups,
work_path,
- })
+ });
}
Err(_) => {
// create new database
@@ -77,17 +80,21 @@ impl ComputeStateAccess {
// prioritize json file over parameters
match read_json_file(&json_path) {
Ok(inputs_and_leafs) => {
+ handle_rollups = inputs_and_leafs.inputs.is_some();
+ compute_data::insert_handle_rollups(&connection, handle_rollups)?;
compute_data::insert_compute_data(
&connection,
- inputs_and_leafs.inputs.iter(),
+ inputs_and_leafs.inputs.unwrap_or_default().iter(),
inputs_and_leafs.leafs.iter(),
)?;
}
Err(_) => {
info!("load inputs and leafs from parameters");
+ handle_rollups = inputs.is_some();
+ compute_data::insert_handle_rollups(&connection, handle_rollups)?;
compute_data::insert_compute_data(
&connection,
- inputs.iter(),
+ inputs.unwrap_or_default().iter(),
leafs.iter(),
)?;
}
@@ -95,6 +102,7 @@ impl ComputeStateAccess {
Ok(Self {
connection: Mutex::new(connection),
+ handle_rollups,
work_path,
})
}
@@ -106,6 +114,11 @@ impl ComputeStateAccess {
compute_data::input(&conn, id)
}
+ pub fn inputs(&self) -> Result>> {
+ let conn = self.connection.lock().unwrap();
+ compute_data::inputs(&conn)
+ }
+
pub fn insert_compute_leafs<'a>(
&self,
level: u64,
@@ -197,6 +210,8 @@ mod compute_state_access_tests {
fn test_access_sequentially() {
test_compute_tree();
test_closest_snapshot();
+ test_compute_or_rollups_true();
+ test_compute_or_rollups_false();
}
fn test_closest_snapshot() {
@@ -205,7 +220,7 @@ mod compute_state_access_tests {
create_directory(&work_dir).unwrap();
{
let access =
- ComputeStateAccess::new(Vec::new(), Vec::new(), String::from("0x12345678"), "/tmp")
+ ComputeStateAccess::new(None, Vec::new(), String::from("0x12345678"), "/tmp")
.unwrap();
assert_eq!(access.closest_snapshot(0).unwrap(), None);
@@ -264,8 +279,7 @@ mod compute_state_access_tests {
remove_directory(&work_dir).unwrap();
create_directory(&work_dir).unwrap();
let access =
- ComputeStateAccess::new(Vec::new(), Vec::new(), String::from("0x12345678"), "/tmp")
- .unwrap();
+ ComputeStateAccess::new(None, Vec::new(), String::from("0x12345678"), "/tmp").unwrap();
let root = [
1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1,
@@ -284,11 +298,36 @@ mod compute_state_access_tests {
assert!(tree.0.subtrees().is_some());
}
+ fn test_compute_or_rollups_true() {
+ let work_dir = PathBuf::from("/tmp/0x12345678");
+ remove_directory(&work_dir).unwrap();
+ create_directory(&work_dir).unwrap();
+ let access = ComputeStateAccess::new(
+ Some(Vec::new()),
+ Vec::new(),
+ String::from("0x12345678"),
+ "/tmp",
+ )
+ .unwrap();
+
+ assert!(matches!(access.handle_rollups, true));
+ }
+
+ fn test_compute_or_rollups_false() {
+ let work_dir = PathBuf::from("/tmp/0x12345678");
+ remove_directory(&work_dir).unwrap();
+ create_directory(&work_dir).unwrap();
+ let access =
+ ComputeStateAccess::new(None, Vec::new(), String::from("0x12345678"), "/tmp").unwrap();
+
+ assert!(matches!(access.handle_rollups, false));
+ }
+
#[test]
fn test_deserialize() {
let json_str_1 = r#"{"leafs": [["0x01020304050607abcdef01020304050607abcdef01020304050607abcdef0102", 20], ["0x01020304050607fedcba01020304050607fedcba01020304050607fedcba0102", 13]]}"#;
let inputs_and_leafs_1: InputsAndLeafs = serde_json::from_str(json_str_1).unwrap();
- assert_eq!(inputs_and_leafs_1.inputs.len(), 0);
+ assert_eq!(inputs_and_leafs_1.inputs.unwrap_or_default().len(), 0);
assert_eq!(inputs_and_leafs_1.leafs.len(), 2);
assert_eq!(
inputs_and_leafs_1.leafs[0].0,
@@ -307,14 +346,15 @@ mod compute_state_access_tests {
let json_str_2 = r#"{"inputs": [], "leafs": [["0x01020304050607abcdef01020304050607abcdef01020304050607abcdef0102", 20], ["0x01020304050607fedcba01020304050607fedcba01020304050607fedcba0102", 13]]}"#;
let inputs_and_leafs_2: InputsAndLeafs = serde_json::from_str(json_str_2).unwrap();
- assert_eq!(inputs_and_leafs_2.inputs.len(), 0);
+ assert_eq!(inputs_and_leafs_2.inputs.unwrap_or_default().len(), 0);
assert_eq!(inputs_and_leafs_2.leafs.len(), 2);
let json_str_3 = r#"{"inputs": ["0x12345678", "0x22345678"], "leafs": [["0x01020304050607abcdef01020304050607abcdef01020304050607abcdef0102", 20], ["0x01020304050607fedcba01020304050607fedcba01020304050607fedcba0102", 13]]}"#;
let inputs_and_leafs_3: InputsAndLeafs = serde_json::from_str(json_str_3).unwrap();
- assert_eq!(inputs_and_leafs_3.inputs.len(), 2);
+ let inputs_3 = inputs_and_leafs_3.inputs.unwrap();
+ assert_eq!(inputs_3.len(), 2);
assert_eq!(inputs_and_leafs_3.leafs.len(), 2);
- assert_eq!(inputs_and_leafs_3.inputs[0].0, [18, 52, 86, 120]);
- assert_eq!(inputs_and_leafs_3.inputs[1].0, [34, 52, 86, 120]);
+ assert_eq!(inputs_3[0].0, [18, 52, 86, 120]);
+ assert_eq!(inputs_3[1].0, [34, 52, 86, 120]);
}
}
diff --git a/prt/client-rs/src/db/sql/compute_data.rs b/prt/client-rs/src/db/sql/compute_data.rs
index 57097438..1cc6a4b4 100644
--- a/prt/client-rs/src/db/sql/compute_data.rs
+++ b/prt/client-rs/src/db/sql/compute_data.rs
@@ -45,6 +45,24 @@ pub fn input(conn: &rusqlite::Connection, id: u64) -> Result