From 901da712cad49545c43f4eaea61eba42c354cddc Mon Sep 17 00:00:00 2001 From: Stephen Chen <20940639+stephenctw@users.noreply.github.com> Date: Thu, 31 Oct 2024 13:53:48 +0800 Subject: [PATCH] feat(prt): complete rollups meta step --- .../node/compute-runner/src/lib.rs | 2 +- .../node/machine-runner/src/lib.rs | 17 +-- prt/client-lua/computation/commitment.lua | 21 ++- prt/client-lua/computation/machine.lua | 107 ++++++++++----- prt/client-lua/player/strategy.lua | 10 +- prt/client-rs/src/db/compute_state_access.rs | 70 +++++++--- prt/client-rs/src/db/sql/compute_data.rs | 72 ++++++++++ prt/client-rs/src/db/sql/migrations.sql | 5 + prt/client-rs/src/machine/commitment.rs | 4 - .../src/machine/commitment_builder.rs | 9 +- prt/client-rs/src/machine/instance.rs | 125 +++++++++++------ prt/client-rs/src/strategy/player.rs | 6 +- prt/contracts/src/IDataProvider.sol | 1 + .../tournament/abstracts/LeafTournament.sol | 126 +++++++----------- .../src/tournament/libs/Commitment.sol | 4 +- prt/contracts/test/Util.sol | 29 +++- prt/tests/compute-rs/src/main.rs | 4 +- prt/tests/compute/prt_compute.lua | 3 +- .../runners/helpers/fake_commitment.lua | 9 +- prt/tests/compute/runners/hero_runner.lua | 2 +- prt/tests/compute/runners/sybil_runner.lua | 2 +- 21 files changed, 418 insertions(+), 210 deletions(-) diff --git a/cartesi-rollups/node/compute-runner/src/lib.rs b/cartesi-rollups/node/compute-runner/src/lib.rs index 5f037ffe..b9a167ae 100644 --- a/cartesi-rollups/node/compute-runner/src/lib.rs +++ b/cartesi-rollups/node/compute-runner/src/lib.rs @@ -43,7 +43,7 @@ where .state_manager .machine_state_hashes(last_sealed_epoch.epoch_number)?; let mut player = Player::new( - inputs.into_iter().map(|i| Input(i)).collect(), + Some(inputs.into_iter().map(|i| Input(i)).collect()), leafs .into_iter() .map(|l| { diff --git a/cartesi-rollups/node/machine-runner/src/lib.rs b/cartesi-rollups/node/machine-runner/src/lib.rs index 8404deff..96024233 100644 --- a/cartesi-rollups/node/machine-runner/src/lib.rs +++ b/cartesi-rollups/node/machine-runner/src/lib.rs @@ -168,18 +168,11 @@ where fn process_input(&mut self, data: &[u8]) -> Result<(), SM> { // TODO: review caclulations - let big_steps_in_stride = max_uint(LOG2_STRIDE - LOG2_UARCH_SPAN); - let stride_count_in_input = max_uint(LOG2_EMULATOR_SPAN + LOG2_UARCH_SPAN - LOG2_STRIDE); + let big_steps_in_stride = 1 << (LOG2_STRIDE - LOG2_UARCH_SPAN); + let stride_count_in_input = 1 << (LOG2_EMULATOR_SPAN + LOG2_UARCH_SPAN - LOG2_STRIDE); - // take snapshot and make it available to the compute client - // the snapshot taken before input insersion is for log/proof generation - self.snapshot(0)?; self.feed_input(data)?; - self.run_machine(1)?; - // take snapshot and make it available to the compute client - // the snapshot taken after insersion and step is for commitment builder - self.snapshot(1)?; - self.run_machine(big_steps_in_stride - 1)?; + self.run_machine(big_steps_in_stride)?; let mut i: u64 = 0; while !self.machine.read_iflags_y()? { @@ -228,12 +221,12 @@ where Ok(()) } - fn snapshot(&self, offset: u64) -> Result<(), SM> { + fn take_snapshot(&self) -> Result<(), SM> { // TODO: make sure "/rollups_data/{epoch_number}" exists let snapshot_path = PathBuf::from(format!( "/rollups_data/{}/{}", self.epoch_number, - self.next_input_index_in_epoch << LOG2_EMULATOR_SPAN + offset + self.next_input_index_in_epoch << LOG2_EMULATOR_SPAN )); if !snapshot_path.exists() { self.machine.store(&snapshot_path)?; diff --git a/prt/client-lua/computation/commitment.lua b/prt/client-lua/computation/commitment.lua index 0ef7825e..8bd1574f 100644 --- a/prt/client-lua/computation/commitment.lua +++ b/prt/client-lua/computation/commitment.lua @@ -35,10 +35,10 @@ local function run_uarch_span(machine) end local function build_small_machine_commitment(base_cycle, log2_stride_count, machine, snapshot_dir) - local machine_state = machine:run(base_cycle) + local machine_state = machine:state() if save_snapshot then -- taking snapshot for leafs to save time in next level - machine:snapshot(snapshot_dir, base_cycle) + machine:take_snapshot(snapshot_dir, base_cycle) end local initial_state = machine_state.root_hash @@ -60,10 +60,10 @@ local function build_small_machine_commitment(base_cycle, log2_stride_count, mac end local function build_big_machine_commitment(base_cycle, log2_stride, log2_stride_count, machine, snapshot_dir) - local machine_state = machine:run(base_cycle) + local machine_state = machine:state() if save_snapshot then -- taking snapshot for leafs to save time in next level - machine:snapshot(snapshot_dir, base_cycle) + machine:take_snapshot(snapshot_dir, base_cycle) end local initial_state = machine_state.root_hash @@ -88,9 +88,16 @@ local function build_big_machine_commitment(base_cycle, log2_stride, log2_stride return initial_state, builder:build(initial_state) end -local function build_commitment(base_cycle, log2_stride, log2_stride_count, machine_path, snapshot_dir) +local function build_commitment(base_cycle, log2_stride, log2_stride_count, machine_path, snapshot_dir, inputs) local machine = Machine:new_from_path(machine_path) machine:load_snapshot(snapshot_dir, base_cycle) + if inputs then + -- treat it as rollups + machine:run_with_inputs(base_cycle, inputs) + else + -- treat it as compute + machine:run(base_cycle) + end if log2_stride >= consts.log2_uarch_span then assert( @@ -120,7 +127,7 @@ function CommitmentBuilder:new(machine_path, snapshot_dir, root_commitment) return c end -function CommitmentBuilder:build(base_cycle, level, log2_stride, log2_stride_count) +function CommitmentBuilder:build(base_cycle, level, log2_stride, log2_stride_count, inputs) if not self.commitments[level] then self.commitments[level] = {} elseif self.commitments[level][base_cycle] then @@ -128,7 +135,7 @@ function CommitmentBuilder:build(base_cycle, level, log2_stride, log2_stride_cou end local _, commitment = build_commitment(base_cycle, log2_stride, log2_stride_count, self.machine_path, - self.snapshot_dir) + self.snapshot_dir, inputs) self.commitments[level][base_cycle] = commitment return commitment end diff --git a/prt/client-lua/computation/machine.lua b/prt/client-lua/computation/machine.lua index 889f79ae..efc37a28 100644 --- a/prt/client-lua/computation/machine.lua +++ b/prt/client-lua/computation/machine.lua @@ -110,7 +110,7 @@ local function find_closest_snapshot(path, current_cycle, cycle) return closest_dir end -function Machine:snapshot(snapshot_dir, cycle) +function Machine:take_snapshot(snapshot_dir, cycle) if helper.exists(snapshot_dir) then local snapshot_path = snapshot_dir .. "/" .. tostring(cycle) @@ -163,6 +163,34 @@ function Machine:run_uarch(ucycle) self.ucycle = ucycle end +function Machine:run_with_inputs(cycle, inputs) + local input_mask = arithmetic.max_uint(consts.log2_emulator_span) + local current_input_index = self.cycle >> consts.log2_emulator_span + + local next_input_index + + if self.cycle & input_mask == 0 then + next_input_index = current_input_index + else + next_input_index = current_input_index + 1 + end + local next_input_cycle = next_input_index << consts.log2_emulator_span + + while next_input_cycle < cycle do + self:run(next_input_cycle) + local input = inputs[next_input_index] + if input then + self.machine:send_cmio_response(cartesi.machine.HTIF_YIELD_REASON_ADVANCE_STATE, input); + end + + next_input_index = next_input_index + 1 + next_input_cycle = next_input_index << consts.log2_emulator_span + end + self:run(cycle) + + return self:state() +end + function Machine:increment_uarch() self.machine:run_uarch(self.ucycle + 1) self.ucycle = self.ucycle + 1 @@ -200,18 +228,28 @@ local function ver(t, p, s) return t end -local function encode_access_log(logs) +local bint = require 'utils.bint' (256) -- use 256 bits integers + +local function encode_access_logs(logs, input) local encoded = {} - for _, a in ipairs(logs.accesses) do - if a.log2_size == 3 then - table.insert(encoded, a.read) - else - table.insert(encoded, a.read_hash) - end + if input then + -- TODO: check #input is encoded as uint256 + table.insert(encoded, bint(#input)) + table.insert(encoded, input) + end - for _, h in ipairs(a.sibling_hashes) do - table.insert(encoded, h) + for _, log in ipairs(logs) do + for _, a in ipairs(log.accesses) do + if a.log2_size == 3 then + table.insert(encoded, a.read) + else + table.insert(encoded, a.read_hash) + end + + for _, h in ipairs(a.sibling_hashes) do + table.insert(encoded, h) + end end end @@ -223,38 +261,45 @@ local function encode_access_log(logs) return '"' .. hex_data .. '"' end -function Machine.get_logs(path, snapshot_dir, cycle, ucycle, input) +function Machine.get_logs(path, snapshot_dir, cycle, ucycle, inputs) local machine = Machine:new_from_path(path) machine:load_snapshot(snapshot_dir, cycle) - local logs + local logs = {} local log_type = { annotations = true, proofs = true } - machine:run(cycle) - - local mask = 1 << consts.log2_emulator_span - 1; - if cycle & mask == 0 and input then - -- need to process input - if ucycle == 0 then - logs = machine.machine:log_send_cmio_response(cartesi.machine.HTIF_YIELD_REASON_ADVANCE_STATE, input, - log_type - ) - local step_logs = machine.machine:log_uarch_step(log_type) - -- append step logs to cmio logs - for _, log in ipairs(step_logs) do - table.insert(logs, log) + local input = Hash.zero + if inputs then + -- treat it as rollups + machine:run_with_inputs(cycle, inputs) + + local mask = arithmetic.max_uint(consts.log2_emulator_span); + local try_input = inputs[cycle >> consts.log2_emulator_span] + if cycle & mask == 0 and try_input then + input = try_input + -- need to process input + if ucycle == 0 then + -- need to log cmio + table.insert(logs, + machine.machine:log_send_cmio_response(cartesi.machine.HTIF_YIELD_REASON_ADVANCE_STATE, input, + log_type + )) + table.insert(logs, machine.machine:log_uarch_step(log_type)) + return encode_access_logs(logs, input) + else + machine.machine:send_cmio_response(cartesi.machine.HTIF_YIELD_REASON_ADVANCE_STATE, input) end - return encode_access_log(logs) - else - machine.machine:send_cmio_response(cartesi.machine.HTIF_YIELD_REASON_ADVANCE_STATE, input) end + else + -- treat it as compute + machine:run(cycle) end machine:run_uarch(ucycle) if ucycle == consts.uarch_span then - logs = machine.machine:log_uarch_reset(log_type) + table.insert(logs, machine.machine:log_uarch_reset(log_type)) else - logs = machine.machine:log_uarch_step(log_type) + table.insert(logs, machine.machine:log_uarch_step(log_type)) end - return encode_access_log(logs) + return encode_access_logs(logs, nil) end return Machine diff --git a/prt/client-lua/player/strategy.lua b/prt/client-lua/player/strategy.lua index 28bf4667..a86079bb 100644 --- a/prt/client-lua/player/strategy.lua +++ b/prt/client-lua/player/strategy.lua @@ -127,8 +127,8 @@ function HonestStrategy:_react_match(match, commitment, log) local cycle = match.base_big_cycle local ucycle = (match.leaf_cycle & constants.uarch_span):touinteger() - local input = self.inputs[cycle >> constants.log2_emulator_span] - local logs = Machine.get_logs(self.machine_path, self.commitment_builder.snapshot_dir, cycle, ucycle, input) + local logs = Machine.get_logs(self.machine_path, self.commitment_builder.snapshot_dir, cycle, ucycle, + self.inputs) helper.log_full(self.sender.index, string.format( "win leaf match in tournament %s of level %d for commitment %s", @@ -281,7 +281,8 @@ function HonestStrategy:_react_tournament(tournament, log) tournament.base_big_cycle, tournament.level, tournament.log2_stride, - tournament.log2_stride_count + tournament.log2_stride_count, + self.inputs ) table.insert(log.tournaments, tournament) @@ -299,7 +300,8 @@ function HonestStrategy:_react_tournament(tournament, log) tournament.parent.base_big_cycle, tournament.parent.level, tournament.parent.log2_stride, - tournament.parent.log2_stride_count + tournament.parent.log2_stride_count, + self.inputs ) if tournament_winner.commitment ~= old_commitment then helper.log_full(self.sender.index, "player lost tournament") diff --git a/prt/client-rs/src/db/compute_state_access.rs b/prt/client-rs/src/db/compute_state_access.rs index 3cfb81dc..2d4be861 100644 --- a/prt/client-rs/src/db/compute_state_access.rs +++ b/prt/client-rs/src/db/compute_state_access.rs @@ -16,12 +16,11 @@ use std::{ #[derive(Debug, Serialize, Deserialize)] pub struct InputsAndLeafs { - #[serde(default)] - inputs: Vec, + inputs: Option>, leafs: Vec, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Default)] pub struct Input(#[serde(with = "alloy_hex::serde")] pub Vec); #[derive(Debug, Serialize, Deserialize)] @@ -30,6 +29,7 @@ pub struct Leaf(#[serde(with = "alloy_hex::serde")] pub [u8; 32], pub u64); #[derive(Debug)] pub struct ComputeStateAccess { connection: Mutex, + pub handle_rollups: bool, pub work_path: PathBuf, } @@ -46,7 +46,7 @@ fn read_json_file(file_path: &Path) -> Result { impl ComputeStateAccess { pub fn new( - inputs: Vec, + inputs: Option>, leafs: Vec, root_tournament: String, compute_data_path: &str, @@ -59,13 +59,16 @@ impl ComputeStateAccess { let work_path = PathBuf::from(work_dir); let db_path = work_path.join("db"); let no_create_flags = OpenFlags::default() & !OpenFlags::SQLITE_OPEN_CREATE; + let handle_rollups; match Connection::open_with_flags(&db_path, no_create_flags) { // database already exists, return it Ok(connection) => { + handle_rollups = compute_data::handle_rollups(&connection)?; return Ok(Self { connection: Mutex::new(connection), + handle_rollups, work_path, - }) + }); } Err(_) => { // create new database @@ -77,17 +80,21 @@ impl ComputeStateAccess { // prioritize json file over parameters match read_json_file(&json_path) { Ok(inputs_and_leafs) => { + handle_rollups = inputs_and_leafs.inputs.is_some(); + compute_data::insert_handle_rollups(&connection, handle_rollups)?; compute_data::insert_compute_data( &connection, - inputs_and_leafs.inputs.iter(), + inputs_and_leafs.inputs.unwrap_or_default().iter(), inputs_and_leafs.leafs.iter(), )?; } Err(_) => { info!("load inputs and leafs from parameters"); + handle_rollups = inputs.is_some(); + compute_data::insert_handle_rollups(&connection, handle_rollups)?; compute_data::insert_compute_data( &connection, - inputs.iter(), + inputs.unwrap_or_default().iter(), leafs.iter(), )?; } @@ -95,6 +102,7 @@ impl ComputeStateAccess { Ok(Self { connection: Mutex::new(connection), + handle_rollups, work_path, }) } @@ -106,6 +114,11 @@ impl ComputeStateAccess { compute_data::input(&conn, id) } + pub fn inputs(&self) -> Result>> { + let conn = self.connection.lock().unwrap(); + compute_data::inputs(&conn) + } + pub fn insert_compute_leafs<'a>( &self, level: u64, @@ -197,6 +210,8 @@ mod compute_state_access_tests { fn test_access_sequentially() { test_compute_tree(); test_closest_snapshot(); + test_compute_or_rollups_true(); + test_compute_or_rollups_false(); } fn test_closest_snapshot() { @@ -205,7 +220,7 @@ mod compute_state_access_tests { create_directory(&work_dir).unwrap(); { let access = - ComputeStateAccess::new(Vec::new(), Vec::new(), String::from("0x12345678"), "/tmp") + ComputeStateAccess::new(None, Vec::new(), String::from("0x12345678"), "/tmp") .unwrap(); assert_eq!(access.closest_snapshot(0).unwrap(), None); @@ -264,8 +279,7 @@ mod compute_state_access_tests { remove_directory(&work_dir).unwrap(); create_directory(&work_dir).unwrap(); let access = - ComputeStateAccess::new(Vec::new(), Vec::new(), String::from("0x12345678"), "/tmp") - .unwrap(); + ComputeStateAccess::new(None, Vec::new(), String::from("0x12345678"), "/tmp").unwrap(); let root = [ 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, @@ -284,11 +298,36 @@ mod compute_state_access_tests { assert!(tree.0.subtrees().is_some()); } + fn test_compute_or_rollups_true() { + let work_dir = PathBuf::from("/tmp/0x12345678"); + remove_directory(&work_dir).unwrap(); + create_directory(&work_dir).unwrap(); + let access = ComputeStateAccess::new( + Some(Vec::new()), + Vec::new(), + String::from("0x12345678"), + "/tmp", + ) + .unwrap(); + + assert!(matches!(access.handle_rollups, true)); + } + + fn test_compute_or_rollups_false() { + let work_dir = PathBuf::from("/tmp/0x12345678"); + remove_directory(&work_dir).unwrap(); + create_directory(&work_dir).unwrap(); + let access = + ComputeStateAccess::new(None, Vec::new(), String::from("0x12345678"), "/tmp").unwrap(); + + assert!(matches!(access.handle_rollups, false)); + } + #[test] fn test_deserialize() { let json_str_1 = r#"{"leafs": [["0x01020304050607abcdef01020304050607abcdef01020304050607abcdef0102", 20], ["0x01020304050607fedcba01020304050607fedcba01020304050607fedcba0102", 13]]}"#; let inputs_and_leafs_1: InputsAndLeafs = serde_json::from_str(json_str_1).unwrap(); - assert_eq!(inputs_and_leafs_1.inputs.len(), 0); + assert_eq!(inputs_and_leafs_1.inputs.unwrap_or_default().len(), 0); assert_eq!(inputs_and_leafs_1.leafs.len(), 2); assert_eq!( inputs_and_leafs_1.leafs[0].0, @@ -307,14 +346,15 @@ mod compute_state_access_tests { let json_str_2 = r#"{"inputs": [], "leafs": [["0x01020304050607abcdef01020304050607abcdef01020304050607abcdef0102", 20], ["0x01020304050607fedcba01020304050607fedcba01020304050607fedcba0102", 13]]}"#; let inputs_and_leafs_2: InputsAndLeafs = serde_json::from_str(json_str_2).unwrap(); - assert_eq!(inputs_and_leafs_2.inputs.len(), 0); + assert_eq!(inputs_and_leafs_2.inputs.unwrap_or_default().len(), 0); assert_eq!(inputs_and_leafs_2.leafs.len(), 2); let json_str_3 = r#"{"inputs": ["0x12345678", "0x22345678"], "leafs": [["0x01020304050607abcdef01020304050607abcdef01020304050607abcdef0102", 20], ["0x01020304050607fedcba01020304050607fedcba01020304050607fedcba0102", 13]]}"#; let inputs_and_leafs_3: InputsAndLeafs = serde_json::from_str(json_str_3).unwrap(); - assert_eq!(inputs_and_leafs_3.inputs.len(), 2); + let inputs_3 = inputs_and_leafs_3.inputs.unwrap(); + assert_eq!(inputs_3.len(), 2); assert_eq!(inputs_and_leafs_3.leafs.len(), 2); - assert_eq!(inputs_and_leafs_3.inputs[0].0, [18, 52, 86, 120]); - assert_eq!(inputs_and_leafs_3.inputs[1].0, [34, 52, 86, 120]); + assert_eq!(inputs_3[0].0, [18, 52, 86, 120]); + assert_eq!(inputs_3[1].0, [34, 52, 86, 120]); } } diff --git a/prt/client-rs/src/db/sql/compute_data.rs b/prt/client-rs/src/db/sql/compute_data.rs index 57097438..1cc6a4b4 100644 --- a/prt/client-rs/src/db/sql/compute_data.rs +++ b/prt/client-rs/src/db/sql/compute_data.rs @@ -45,6 +45,24 @@ pub fn input(conn: &rusqlite::Connection, id: u64) -> Result>> { Ok(i) } +pub fn inputs(conn: &rusqlite::Connection) -> Result>> { + let mut stmt = conn.prepare( + "\ + SELECT * FROM inputs + ORDER BY input_index ASC + ", + )?; + + let query = stmt.query_map([], |r| Ok(r.get("input")?))?; + + let mut res = vec![]; + for row in query { + res.push(row?); + } + + Ok(res) +} + // // Compute leafs // @@ -161,6 +179,38 @@ pub fn compute_tree_count(conn: &rusqlite::Connection, tree_root: &[u8]) -> Resu )?) } +// +// Handle rollups +// + +fn insert_handle_rollups_statement(conn: &rusqlite::Connection) -> Result { + Ok(conn.prepare( + "\ + INSERT INTO compute_or_rollups (id, handle_rollups) VALUES (0, ?1) + ", + )?) +} + +pub fn insert_handle_rollups(conn: &rusqlite::Connection, handle_rollups: bool) -> Result<()> { + let tx = conn.unchecked_transaction()?; + let mut stmt = insert_handle_rollups_statement(&conn)?; + stmt.execute(params![handle_rollups])?; + tx.commit()?; + + Ok(()) +} + +pub fn handle_rollups(conn: &rusqlite::Connection) -> Result { + Ok(conn.query_row( + "\ + SELECT handle_rollups FROM compute_or_rollups + WHERE id = 0 + ", + [], + |row| row.get(0), + )?) +} + pub fn insert_compute_data<'a>( conn: &rusqlite::Connection, inputs: impl Iterator, @@ -318,3 +368,25 @@ mod trees_tests { assert!(matches!(compute_tree(&conn, &root).unwrap().len(), 2)); } } + +#[cfg(test)] +mod compute_or_rollups_tests { + use super::*; + + #[test] + fn test_empty() { + let conn = test_helper::setup_db(); + assert!(matches!(handle_rollups(&conn), Err(_))); + } + + #[test] + fn test_insert() { + let conn = test_helper::setup_db(); + + assert!(matches!(insert_handle_rollups(&conn, true), Ok(()))); + assert!(matches!(handle_rollups(&conn), Ok(true))); + // compute_or_rollups can only be set once + assert!(matches!(insert_handle_rollups(&conn, true), Err(_))); + assert!(matches!(handle_rollups(&conn), Ok(true))); + } +} diff --git a/prt/client-rs/src/db/sql/migrations.sql b/prt/client-rs/src/db/sql/migrations.sql index 6ef90062..eac1b48a 100644 --- a/prt/client-rs/src/db/sql/migrations.sql +++ b/prt/client-rs/src/db/sql/migrations.sql @@ -19,3 +19,8 @@ CREATE TABLE compute_trees ( tree_leaf BLOB NOT NULL, PRIMARY KEY (tree_root, tree_leaf_index) ); + +CREATE TABLE compute_or_rollups ( + id INTEGER NOT NULL PRIMARY KEY, + handle_rollups INTEGER NOT NULL +); diff --git a/prt/client-rs/src/machine/commitment.rs b/prt/client-rs/src/machine/commitment.rs index c94df7ed..9056e29d 100644 --- a/prt/client-rs/src/machine/commitment.rs +++ b/prt/client-rs/src/machine/commitment.rs @@ -22,13 +22,11 @@ pub struct MachineCommitment { /// Builds a [MachineCommitment] from a [MachineInstance] and a base cycle and leafs. pub fn build_machine_commitment_from_leafs( machine: &mut MachineInstance, - base_cycle: u64, leafs: Vec<(L, u64)>, ) -> Result where L: Into>, { - machine.run(base_cycle)?; let initial_state = machine.machine_state()?; let mut builder = MerkleBuilder::default(); for leaf in leafs { @@ -81,7 +79,6 @@ pub fn build_big_machine_commitment( log2_stride_count: u64, db: &ComputeStateAccess, ) -> Result { - machine.run(base_cycle)?; snapshot_base_cycle(machine, base_cycle, db)?; let initial_state = machine.machine_state()?; @@ -145,7 +142,6 @@ pub fn build_small_machine_commitment( log2_stride_count: u64, db: &ComputeStateAccess, ) -> Result { - machine.run(base_cycle)?; snapshot_base_cycle(machine, base_cycle, db)?; let initial_state = machine.machine_state()?; diff --git a/prt/client-rs/src/machine/commitment_builder.rs b/prt/client-rs/src/machine/commitment_builder.rs index a9f190d1..44e0b839 100644 --- a/prt/client-rs/src/machine/commitment_builder.rs +++ b/prt/client-rs/src/machine/commitment_builder.rs @@ -46,12 +46,19 @@ impl CachingMachineCommitmentBuilder { if let Some(snapshot_path) = db.closest_snapshot(base_cycle)? { machine.load_snapshot(&PathBuf::from(snapshot_path))?; }; + if db.handle_rollups { + // treat it as rollups + machine.run_with_inputs(base_cycle, &db.inputs()?)?; + } else { + // treat it as compute + machine.run(base_cycle)?; + } let commitment = { let leafs = db.compute_leafs(level, base_cycle)?; // leafs are cached in database, use it to calculate merkle if leafs.len() > 0 { - build_machine_commitment_from_leafs(&mut machine, base_cycle, leafs)? + build_machine_commitment_from_leafs(&mut machine, leafs)? } else { // leafs are not cached, build merkle by running the machine build_machine_commitment( diff --git a/prt/client-rs/src/machine/instance.rs b/prt/client-rs/src/machine/instance.rs index 775a1be8..c457ad7e 100644 --- a/prt/client-rs/src/machine/instance.rs +++ b/prt/client-rs/src/machine/instance.rs @@ -1,3 +1,4 @@ +use crate::db::compute_state_access::Input; use crate::machine::constants; use cartesi_dave_arithmetic as arithmetic; use cartesi_dave_merkle::Digest; @@ -9,6 +10,7 @@ use cartesi_machine::{ }; use anyhow::Result; +use ruint::aliases::U256; use std::path::Path; #[derive(Debug)] @@ -90,46 +92,56 @@ impl MachineInstance { &mut self, cycle: u64, ucycle: u64, - input: Option>, + inputs: Vec>, + handle_rollups: bool, ) -> Result { - self.run(cycle)?; - let log_type = AccessLogType { annotations: true, proofs: true, large_data: false, }; - let mask = 1 << constants::LOG2_EMULATOR_SPAN - 1; - if cycle & mask == 0 && input.is_some() { - // need to process input - let data = input.unwrap(); - if ucycle == 0 { - let cmio_logs = self.machine.log_send_cmio_response( - htif::fromhost::ADVANCE_STATE, - &data, - log_type, - false, - )?; - // append step logs to cmio logs - let step_logs = self.machine.log_uarch_step(log_type, false)?; - let mut logs_encoded = encode_access_log(&cmio_logs); - let mut step_logs_encoded = encode_access_log(&step_logs); - logs_encoded.append(&mut step_logs_encoded); - return Ok(logs_encoded); - } else { - self.machine - .send_cmio_response(htif::fromhost::ADVANCE_STATE, &data)?; + let mut logs = Vec::new(); + if handle_rollups { + // treat it as rollups + self.run_with_inputs(cycle, &inputs)?; + + let mask = arithmetic::max_uint(constants::LOG2_EMULATOR_SPAN); + let input = inputs.get((cycle >> constants::LOG2_EMULATOR_SPAN) as usize); + if cycle & mask == 0 && input.is_some() { + // need to process input + let data = input.unwrap(); + if ucycle == 0 { + let cmio_logs = self.machine.log_send_cmio_response( + htif::fromhost::ADVANCE_STATE, + &data, + log_type, + false, + )?; + // append step logs to cmio logs + let step_logs = self.machine.log_uarch_step(log_type, false)?; + logs.push(&cmio_logs); + logs.push(&step_logs); + return Ok(encode_access_logs(logs, Some(Input { 0: data.clone() }))); + } else { + self.machine + .send_cmio_response(htif::fromhost::ADVANCE_STATE, &data)?; + } } + } else { + // treat it as compute + self.run(cycle)?; } self.run_uarch(ucycle)?; if ucycle == constants::UARCH_SPAN { let reset_logs = self.machine.log_uarch_reset(log_type, false)?; - Ok(encode_access_log(&reset_logs)) + logs.push(&reset_logs); + Ok(encode_access_logs(logs, None)) } else { let step_logs = self.machine.log_uarch_step(log_type, false)?; - Ok(encode_access_log(&step_logs)) + logs.push(&step_logs); + Ok(encode_access_logs(logs, None)) } } @@ -170,6 +182,36 @@ impl MachineInstance { Ok(()) } + pub fn run_with_inputs(&mut self, cycle: u64, inputs: &Vec>) -> Result<()> { + let input_mask = arithmetic::max_uint(constants::LOG2_EMULATOR_SPAN); + let current_input_index = self.cycle >> constants::LOG2_EMULATOR_SPAN; + + let mut next_input_index; + + if self.cycle & input_mask == 0 { + next_input_index = current_input_index; + } else { + next_input_index = current_input_index + 1; + } + + let mut next_input_cycle = next_input_index << constants::LOG2_EMULATOR_SPAN; + + while next_input_cycle < cycle { + self.run(next_input_cycle)?; + let input = inputs.get(next_input_index as usize); + if let Some(data) = input { + self.machine + .send_cmio_response(htif::fromhost::ADVANCE_STATE, data)?; + } + + next_input_index += 1; + next_input_cycle = next_input_index << constants::LOG2_EMULATOR_SPAN; + } + self.run(cycle)?; + + Ok(()) + } + pub fn increment_uarch(&mut self) -> Result<()> { self.machine.run_uarch(self.ucycle + 1)?; self.ucycle += 1; @@ -206,22 +248,29 @@ impl MachineInstance { } } -fn encode_access_log(log: &AccessLog) -> Vec { +fn encode_access_logs(logs: Vec<&AccessLog>, input: Option) -> Vec { let mut encoded: Vec> = Vec::new(); - for a in log.accesses().iter() { - if a.log2_size() == 3 { - encoded.push(a.read_data().to_vec()); - } else { - encoded.push(a.read_hash().as_bytes().to_vec()); - } + if let Some(i) = input { + encoded.push(U256::from(i.0.len()).to_be_bytes_vec()); + encoded.push(i.0); + } + + for log in logs.iter() { + for a in log.accesses().iter() { + if a.log2_size() == 3 { + encoded.push(a.read_data().to_vec()); + } else { + encoded.push(a.read_hash().as_bytes().to_vec()); + } - let decoded_siblings: Vec> = a - .sibling_hashes() - .iter() - .map(|h| h.as_bytes().to_vec()) - .collect(); - encoded.extend_from_slice(&decoded_siblings); + let decoded_siblings: Vec> = a + .sibling_hashes() + .iter() + .map(|h| h.as_bytes().to_vec()) + .collect(); + encoded.extend_from_slice(&decoded_siblings); + } } encoded.iter().flatten().cloned().collect() diff --git a/prt/client-rs/src/strategy/player.rs b/prt/client-rs/src/strategy/player.rs index 3639857c..bd16aeec 100644 --- a/prt/client-rs/src/strategy/player.rs +++ b/prt/client-rs/src/strategy/player.rs @@ -35,7 +35,7 @@ pub struct Player { impl Player { pub fn new( - inputs: Vec, + inputs: Option>, leafs: Vec, blockchain_config: &BlockchainConfig, machine_path: String, @@ -362,8 +362,8 @@ impl Player { if let Some(snapshot_path) = self.db.closest_snapshot(cycle)? { machine.load_snapshot(&PathBuf::from(snapshot_path))?; }; - let input = self.db.input(cycle >> constants::LOG2_EMULATOR_SPAN)?; - machine.get_logs(cycle, ucycle, input)? + let inputs = self.db.inputs()?; + machine.get_logs(cycle, ucycle, inputs, self.db.handle_rollups)? }; info!( diff --git a/prt/contracts/src/IDataProvider.sol b/prt/contracts/src/IDataProvider.sol index 2eac6c03..367c0b12 100644 --- a/prt/contracts/src/IDataProvider.sol +++ b/prt/contracts/src/IDataProvider.sol @@ -12,5 +12,6 @@ interface IDataProvider { /// @return Size of the response (in bytes) function gio(uint16 namespace, bytes calldata id, bytes calldata extra) external + view returns (bytes32, uint256); } diff --git a/prt/contracts/src/tournament/abstracts/LeafTournament.sol b/prt/contracts/src/tournament/abstracts/LeafTournament.sol index a8f14079..98bff7c1 100644 --- a/prt/contracts/src/tournament/abstracts/LeafTournament.sol +++ b/prt/contracts/src/tournament/abstracts/LeafTournament.sol @@ -61,15 +61,15 @@ abstract contract LeafTournament is Tournament { Tree.Node _rightNode, bytes calldata proofs ) external tournamentNotFinished { - Match.State storage _matchState = matches[_matchId.hashFromId()]; - _matchState.requireExist(); - _matchState.requireIsFinished(); - Clock.State storage _clockOne = clocks[_matchId.commitmentOne]; Clock.State storage _clockTwo = clocks[_matchId.commitmentTwo]; _clockOne.requireInitialized(); _clockTwo.requireInitialized(); + Match.State storage _matchState = matches[_matchId.hashFromId()]; + _matchState.requireExist(); + _matchState.requireIsFinished(); + ( Machine.Hash _agreeHash, uint256 _agreeCycle, @@ -77,81 +77,38 @@ abstract contract LeafTournament is Tournament { Machine.Hash _finalStateTwo ) = _matchState.getDivergence(startCycle); - Machine.Hash _finalState = runMetaStep(_agreeHash, _agreeCycle, proofs); + Machine.Hash _finalState = Machine.Hash.wrap( + metaStep(Machine.Hash.unwrap(_agreeHash), _agreeCycle, proofs) + ); if (_leftNode.join(_rightNode).eq(_matchId.commitmentOne)) { - require( - _finalState.eq(_finalStateOne), "final state one doesn't match" - ); + require(_finalState.eq(_finalStateOne), "final state one mismatch"); _clockOne.setPaused(); pairCommitment( _matchId.commitmentOne, _clockOne, _leftNode, _rightNode ); } else if (_leftNode.join(_rightNode).eq(_matchId.commitmentTwo)) { - require( - _finalState.eq(_finalStateTwo), "final state two doesn't match" - ); + require(_finalState.eq(_finalStateTwo), "final state two mismatch"); _clockTwo.setPaused(); pairCommitment( _matchId.commitmentTwo, _clockTwo, _leftNode, _rightNode ); } else { - revert("wrong left/right nodes for step"); + revert("wrong nodes for step"); } // delete storage deleteMatch(_matchId.hashFromId()); } - function runMetaStep( - Machine.Hash machineState, - uint256 counter, - bytes memory proofs - ) internal view returns (Machine.Hash) { - if (address(provider) == address(0)) { - return Machine.Hash.wrap( - computeMetaStep( - Machine.Hash.unwrap(machineState), counter, proofs - ) - ); - } else { - return Machine.Hash.wrap( - rollupsMetaStep( - Machine.Hash.unwrap(machineState), counter, proofs - ) - ); - } - } - - // this is a inputless version of the meta step implementation primarily used for testing - function computeMetaStep( - bytes32 machineState, - uint256 counter, - bytes memory proofs - ) internal pure returns (bytes32 newMachineState) { - // TODO: create a more convinient constructor. - AccessLogs.Context memory accessLogs = - AccessLogs.Context(machineState, Buffer.Context(proofs, 0)); - - uint256 uarch_step_mask = - (1 << ArbitrationConstants.LOG2_UARCH_SPAN) - 1; - - if ((counter + 1) & uarch_step_mask == 0) { - UArchReset.reset(accessLogs); - } else { - UArchStep.step(accessLogs); - } - newMachineState = accessLogs.currentRootHash; - } - // TODO: move to step repo - function rollupsMetaStep( + function metaStep( bytes32 machineState, uint256 counter, - bytes memory proofs - ) internal pure returns (bytes32 newMachineState) { + bytes calldata proofs + ) internal view returns (bytes32 newMachineState) { // TODO: create a more convinient constructor. AccessLogs.Context memory accessLogs = AccessLogs.Context(machineState, Buffer.Context(proofs, 0)); @@ -160,27 +117,46 @@ abstract contract LeafTournament is Tournament { (1 << ArbitrationConstants.LOG2_UARCH_SPAN) - 1; uint256 big_step_mask = ( 1 - << ( - ArbitrationConstants.LOG2_EMULATOR_SPAN - + ArbitrationConstants.LOG2_UARCH_SPAN - ) - 1 - ); + << ArbitrationConstants.LOG2_EMULATOR_SPAN + + ArbitrationConstants.LOG2_UARCH_SPAN + ) - 1; - if (counter & big_step_mask == 0) { - // TODO: add inputs - (bytes32 inputMerkleRoot, uint64 inputLength) = - provider.gio(namespace, id, extra); - SendCmioResponse.sendCmioResponse( - EmulatorConstants.HTIF_YIELD_REASON_ADVANCE_STATE, - inputMerkleRoot, - inputLength, - accessLogs - ); - UArchStep.step(accessLogs); - } else if ((counter + 1) & uarch_step_mask == 0) { - UArchReset.reset(accessLogs); + if (address(provider) == address(0)) { + // this is a inputless version of the meta step implementation primarily used for testing + if ((counter + 1) & uarch_step_mask == 0) { + UArchReset.reset(accessLogs); + } else { + UArchStep.step(accessLogs); + } } else { - UArchStep.step(accessLogs); + // rollups meta step handles input + if (counter & big_step_mask == 0) { + (uint256 inputLength,) = abi.decode(proofs, (uint256, bytes)); + bytes calldata input = proofs[32:32 + inputLength]; + uint256 inputIndex = counter + >> ( + ArbitrationConstants.LOG2_EMULATOR_SPAN + + ArbitrationConstants.LOG2_UARCH_SPAN + ); // TODO: add input index offset of the epoch + + (bytes32 inputMerkleRoot,) = + provider.gio(0, abi.encode(inputIndex), input); + accessLogs = AccessLogs.Context( + machineState, Buffer.Context(proofs, 32 + inputLength) + ); + // TODO: contract size too big... + // SendCmioResponse.sendCmioResponse( + // accessLogs, + // EmulatorConstants.HTIF_YIELD_REASON_ADVANCE_STATE, + // inputMerkleRoot, + // uint32(inputLength) + // ); + UArchStep.step(accessLogs); + } else if ((counter + 1) & uarch_step_mask == 0) { + UArchReset.reset(accessLogs); + } else { + UArchStep.step(accessLogs); + } } newMachineState = accessLogs.currentRootHash; } diff --git a/prt/contracts/src/tournament/libs/Commitment.sol b/prt/contracts/src/tournament/libs/Commitment.sol index b9a31f4a..4a09bbae 100644 --- a/prt/contracts/src/tournament/libs/Commitment.sol +++ b/prt/contracts/src/tournament/libs/Commitment.sol @@ -24,9 +24,7 @@ library Commitment { Tree.Node expectedCommitment = getRoot(Machine.Hash.unwrap(state), treeHeight, position, hashProof); - require( - commitment.eq(expectedCommitment), "commitment state doesn't match" - ); + require(commitment.eq(expectedCommitment), "commitment state mismatch"); } function isEven(uint256 x) private pure returns (bool) { diff --git a/prt/contracts/test/Util.sol b/prt/contracts/test/Util.sol index 220a34f1..1b7765fa 100644 --- a/prt/contracts/test/Util.sol +++ b/prt/contracts/test/Util.sol @@ -66,7 +66,18 @@ contract Util { } } - function generateProof(uint256 _player, uint64 _height) + function generateDivergenceProof(uint256 _player, uint64 _height) + internal + view + returns (bytes32[] memory) + { + bytes32[] memory _proof = generateFinalStateProof(_player, _height); + _proof[0] = Tree.Node.unwrap(playerNodes[_player][0]); + + return _proof; + } + + function generateFinalStateProof(uint256 _player, uint64 _height) internal view returns (bytes32[] memory) @@ -171,21 +182,27 @@ contract Util { if (_player == 0) { _tournament.joinTournament( ONE_STATE, - generateProof(_player, ArbitrationConstants.height(_level)), + generateFinalStateProof( + _player, ArbitrationConstants.height(_level) + ), playerNodes[0][ArbitrationConstants.height(_level) - 1], playerNodes[0][ArbitrationConstants.height(_level) - 1] ); } else if (_player == 1) { _tournament.joinTournament( TWO_STATE, - generateProof(_player, ArbitrationConstants.height(_level)), + generateFinalStateProof( + _player, ArbitrationConstants.height(_level) + ), playerNodes[1][ArbitrationConstants.height(_level) - 1], playerNodes[1][ArbitrationConstants.height(_level) - 1] ); } else if (_player == 2) { _tournament.joinTournament( TWO_STATE, - generateProof(_player, ArbitrationConstants.height(_level)), + generateFinalStateProof( + _player, ArbitrationConstants.height(_level) + ), playerNodes[0][ArbitrationConstants.height(_level) - 1], playerNodes[2][ArbitrationConstants.height(_level) - 1] ); @@ -206,7 +223,7 @@ contract Util { _left, _right, ONE_STATE, - generateProof(_player, ArbitrationConstants.height(0)) + generateDivergenceProof(_player, ArbitrationConstants.height(0)) ); } @@ -236,7 +253,7 @@ contract Util { _left, _right, ONE_STATE, - generateProof(_player, ArbitrationConstants.height(0)) + generateDivergenceProof(_player, ArbitrationConstants.height(0)) ); } diff --git a/prt/tests/compute-rs/src/main.rs b/prt/tests/compute-rs/src/main.rs index 1a10def2..8412e7ae 100644 --- a/prt/tests/compute-rs/src/main.rs +++ b/prt/tests/compute-rs/src/main.rs @@ -23,10 +23,10 @@ async fn main() -> Result<()> { let config = ComputeConfig::parse(); let blockchain_config = config.blockchain_config; - let sender = EthArenaSender::new(&blockchain_config)?; + let mut player = Player::new( - Vec::new(), + None, Vec::new(), &blockchain_config, config.machine_path, diff --git a/prt/tests/compute/prt_compute.lua b/prt/tests/compute/prt_compute.lua index 41a9e369..6b21c1c5 100755 --- a/prt/tests/compute/prt_compute.lua +++ b/prt/tests/compute/prt_compute.lua @@ -49,7 +49,7 @@ local function setup_players(use_lua_node, extra_data, root_constants, root_tour print("Calculating root commitment...") local snapshot_dir = string.format("/compute_data/%s", root_tournament) local builder = CommitmentBuilder:new(machine_path, snapshot_dir) - local root_commitment = builder:build(0, 0, root_constants.log2_step, root_constants.height) + local root_commitment = builder:build(0, 0, root_constants.log2_step, root_constants.height, nil) if use_lua_node then -- use Lua node to defend @@ -59,7 +59,6 @@ local function setup_players(use_lua_node, extra_data, root_constants, root_tour extra_data) else -- use Rust node to defend - print("Setting up Rust honest player") local rust_hero_runner = require "runners.rust_hero_runner" player_coroutines[player_index] = rust_hero_runner.create_react_once_runner(player_index, machine_path) diff --git a/prt/tests/compute/runners/helpers/fake_commitment.lua b/prt/tests/compute/runners/helpers/fake_commitment.lua index 24b583a0..49eaeaf4 100644 --- a/prt/tests/compute/runners/helpers/fake_commitment.lua +++ b/prt/tests/compute/runners/helpers/fake_commitment.lua @@ -65,7 +65,7 @@ local function rebuild_nested_trees(leafs) end local function build_commitment(cached_commitments, machine_path, snapshot_dir, base_cycle, level, log2_stride, - log2_stride_count) + log2_stride_count, inputs) -- the honest commitment builder should be operated in an isolated env -- to avoid side effects to the strategy behavior @@ -80,7 +80,7 @@ local function build_commitment(cached_commitments, machine_path, snapshot_dir, local CommitmentBuilder = scoped_require "computation.commitment" local builder = CommitmentBuilder:new(machine_path, snapshot_dir) - local commitment = builder:build(base_cycle, level, log2_stride, log2_stride_count) + local commitment = builder:build(base_cycle, level, log2_stride, log2_stride_count, inputs) coroutine.yield(commitment) end) @@ -124,7 +124,7 @@ function FakeCommitmentBuilder:new(machine_path, root_commitment, snapshot_dir) return c end -function FakeCommitmentBuilder:build(base_cycle, level, log2_stride, log2_stride_count) +function FakeCommitmentBuilder:build(base_cycle, level, log2_stride, log2_stride_count, inputs) -- function caller should set `self.fake_index` properly before calling this function -- the fake commitments are not guaranteed to be unique if there are not many leafs (short computation) -- `self.fake_index` is reset and the end of a successful call to ensure the next caller must set it again. @@ -141,7 +141,8 @@ function FakeCommitmentBuilder:build(base_cycle, level, log2_stride, log2_stride local commitment = build_commitment(self.commitments, self.machine_path, self.snapshot_dir, base_cycle, level, log2_stride, - log2_stride_count) + log2_stride_count, + inputs) local fake_commitment = build_fake_commitment(commitment, self.fake_index, log2_stride) self.fake_commitments[level][base_cycle][self.fake_index] = fake_commitment diff --git a/prt/tests/compute/runners/hero_runner.lua b/prt/tests/compute/runners/hero_runner.lua index ae38eeb7..9b50c476 100755 --- a/prt/tests/compute/runners/hero_runner.lua +++ b/prt/tests/compute/runners/hero_runner.lua @@ -18,7 +18,7 @@ local function hero_runner(player_id, machine_path, root_commitment, root_tourna local snapshot_dir = string.format("/compute_data/%s", root_tournament) local strategy = HonestStrategy:new( CommitmentBuilder:new(machine_path, snapshot_dir, root_commitment), - {}, + nil, machine_path, Sender:new(blockchain_consts.pks[player_id], player_id, blockchain_consts.endpoint) ) diff --git a/prt/tests/compute/runners/sybil_runner.lua b/prt/tests/compute/runners/sybil_runner.lua index 69fd4680..e16fea9f 100755 --- a/prt/tests/compute/runners/sybil_runner.lua +++ b/prt/tests/compute/runners/sybil_runner.lua @@ -39,7 +39,7 @@ local function sybil_runner(player_id, machine_path, root_commitment, root_tourn local snapshot_dir = string.format("/compute_data/%s", root_tournament) local strategy = HonestStrategy:new( FakeCommitmentBuilder:new(machine_path, root_commitment, snapshot_dir), - {}, + nil, machine_path, Sender:new(blockchain_consts.pks[player_id], player_id, blockchain_consts.endpoint) )