Skip to content

Commit

Permalink
feat(prt-client): improve snapshots
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenctw committed Dec 11, 2024
1 parent a426e52 commit 69f724c
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 59 deletions.
23 changes: 6 additions & 17 deletions prt/client-lua/computation/commitment.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ local consts = require "computation.constants"

local ulte = arithmetic.ulte

local save_snapshot = true
local handle_rollups = false


Expand Down Expand Up @@ -48,12 +47,7 @@ local function run_uarch_span(machine)
return builder:build(), machine_state
end

local function build_small_machine_commitment(base_cycle, log2_stride_count, machine, initial_state, snapshot_dir)
if save_snapshot then
-- taking snapshot for leafs to save time in next level
machine:take_snapshot(snapshot_dir, base_cycle, handle_rollups)
end

local function build_small_machine_commitment(log2_stride_count, machine, initial_state, snapshot_dir)
local builder = MerkleBuilder:new()
local instruction_count = arithmetic.max_uint(log2_stride_count - consts.log2_uarch_span)
local instruction = 0
Expand Down Expand Up @@ -81,13 +75,7 @@ local function build_small_machine_commitment(base_cycle, log2_stride_count, mac
return initial_state, builder:build(initial_state)
end

local function build_big_machine_commitment(base_cycle, log2_stride, log2_stride_count, machine, initial_state,
snapshot_dir)
if save_snapshot then
-- taking snapshot for leafs to save time in next level
machine:take_snapshot(snapshot_dir, base_cycle, handle_rollups)
end

local function build_big_machine_commitment(base_cycle, log2_stride, log2_stride_count, machine, initial_state)
local builder = MerkleBuilder:new()
local instruction_count = arithmetic.max_uint(log2_stride_count)
local instruction = 0
Expand Down Expand Up @@ -125,11 +113,12 @@ local function build_commitment(base_cycle, log2_stride, log2_stride_count, mach
-- the base_cycle may be the cycle to receive input,
-- we need to take the initial state before feeding input to the machine
handle_rollups = true
initial_state = machine:run_with_inputs(base_cycle, inputs).root_hash
initial_state = machine:run_with_inputs(base_cycle, inputs, snapshot_dir).root_hash
else
-- treat it as compute
handle_rollups = false
initial_state = machine:run(base_cycle).root_hash
initial_state = machine:run(base_cycle).root_hash -- taking snapshot for leafs to save time in next level
machine:take_snapshot(snapshot_dir, base_cycle, handle_rollups)
end

if log2_stride >= consts.log2_uarch_span then
Expand All @@ -141,7 +130,7 @@ local function build_commitment(base_cycle, log2_stride, log2_stride_count, mach
snapshot_dir)
else
assert(log2_stride == 0)
return build_small_machine_commitment(base_cycle, log2_stride_count, machine, initial_state, snapshot_dir)
return build_small_machine_commitment(log2_stride_count, machine, initial_state, snapshot_dir)
end
end

Expand Down
10 changes: 7 additions & 3 deletions prt/client-lua/computation/machine.lua
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ end

function Machine:take_snapshot(snapshot_dir, cycle, handle_rollups)
local input_mask = arithmetic.max_uint(consts.log2_emulator_span)
if handle_rollups and cycle & input_mask == 0 and not self.yielded then
if handle_rollups and cycle & input_mask == 0 then
-- dont snapshot a machine state that's freshly fed with input without advance
return
assert(not self.yielded, "don't snapshot a machine state that's freshly fed with input without advance")
end

if helper.exists(snapshot_dir) then
Expand Down Expand Up @@ -178,7 +178,7 @@ function Machine:run_uarch(ucycle)
self.ucycle = ucycle
end

function Machine:run_with_inputs(cycle, inputs)
function Machine:run_with_inputs(cycle, inputs, snapshot_dir)
local input_mask = arithmetic.max_uint(consts.log2_emulator_span)
local current_input_index = self.cycle >> consts.log2_emulator_span

Expand All @@ -194,6 +194,9 @@ function Machine:run_with_inputs(cycle, inputs)

while next_input_cycle <= cycle do
machine_state_without_input = self:run(next_input_cycle)
if next_input_cycle == cycle then
self:take_snapshot(snapshot_dir, next_input_cycle, true)
end
local input = inputs[next_input_index + 1]
if input then
local h = assert(input:match("0x(%x+)"), input)
Expand All @@ -209,6 +212,7 @@ function Machine:run_with_inputs(cycle, inputs)

if cycle > self.cycle then
machine_state_without_input = self:run(cycle)
self:take_snapshot(snapshot_dir, next_input_cycle, true)
end

return machine_state_without_input
Expand Down
20 changes: 0 additions & 20 deletions prt/client-rs/src/machine/commitment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,6 @@ pub fn build_big_machine_commitment(
initial_state: Digest,
db: &ComputeStateAccess,
) -> Result<MachineCommitment> {
snapshot_base_cycle(machine, base_cycle, db)?;

let mut builder = MerkleBuilder::default();
let mut leafs = Vec::new();
let instruction_count = arithmetic::max_uint(log2_stride_count);
Expand Down Expand Up @@ -159,8 +157,6 @@ pub fn build_small_machine_commitment(
initial_state: Digest,
db: &ComputeStateAccess,
) -> Result<MachineCommitment> {
snapshot_base_cycle(machine, base_cycle, db)?;

let mut builder = MerkleBuilder::default();
let mut leafs = Vec::new();
let mut uarch_span_and_leafs = Vec::new();
Expand Down Expand Up @@ -208,22 +204,6 @@ pub fn build_small_machine_commitment(
})
}

fn snapshot_base_cycle(
machine: &mut MachineInstance,
base_cycle: u64,
db: &ComputeStateAccess,
) -> Result<()> {
let mask = arithmetic::max_uint(constants::LOG2_EMULATOR_SPAN);
if db.handle_rollups && base_cycle & mask == 0 && !machine.machine_state()?.yielded {
// don't snapshot a machine state that's freshly fed with input without advance
return Ok(());
}

let snapshot_path = db.work_path.join(format!("{}", base_cycle));
machine.snapshot(&snapshot_path)?;
Ok(())
}

fn run_uarch_span(
machine: &mut MachineInstance,
) -> Result<(Arc<MerkleTree>, MachineState, Vec<Leaf>)> {
Expand Down
8 changes: 4 additions & 4 deletions prt/client-rs/src/machine/commitment_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ impl CachingMachineCommitmentBuilder {
let initial_state = {
if db.handle_rollups {
// treat it as rollups
machine
.run_with_inputs(base_cycle, &db.inputs()?)?
.root_hash
machine.run_with_inputs(base_cycle, &db)?.root_hash
} else {
// treat it as compute
machine.run(base_cycle)?.root_hash
let root_hash = machine.run(base_cycle)?.root_hash;
machine.take_snapshot(base_cycle, &db);
root_hash
}
};
trace!("initial state for commitment: {}", initial_state);
Expand Down
43 changes: 30 additions & 13 deletions prt/client-rs/src/machine/instance.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::db::compute_state_access::Input;
use crate::db::compute_state_access::{ComputeStateAccess, Input};
use crate::machine::constants;
use cartesi_dave_arithmetic as arithmetic;
use cartesi_dave_merkle::Digest;
Expand Down Expand Up @@ -65,6 +65,22 @@ impl MachineInstance {
ucycle: 0,
})
}
pub fn take_snapshot(&mut self, base_cycle: u64, db: &ComputeStateAccess) -> Result<()> {
let mask = arithmetic::max_uint(constants::LOG2_EMULATOR_SPAN);
if db.handle_rollups && base_cycle & mask == 0 {
// don't snapshot a machine state that's freshly fed with input without advance
assert!(
self.machine_state()?.yielded,
"don't snapshot a machine state that's freshly fed with input without advance",
);
}

let snapshot_path = db.work_path.join(format!("{}", base_cycle));
if !snapshot_path.exists() {
self.machine.store(&snapshot_path)?;
}
Ok(())
}

// load inner machine with snapshot, update cycle, keep everything else the same
pub fn load_snapshot(&mut self, snapshot_path: &Path, snapshot_cycle: u64) -> Result<()> {
Expand All @@ -86,13 +102,6 @@ impl MachineInstance {
Ok(())
}

pub fn snapshot(&self, snapshot_path: &Path) -> Result<()> {
if !snapshot_path.exists() {
self.machine.store(snapshot_path)?;
}
Ok(())
}

pub fn root_hash(&self) -> Digest {
self.root_hash
}
Expand All @@ -101,8 +110,7 @@ impl MachineInstance {
&mut self,
cycle: u64,
ucycle: u64,
inputs: Vec<Vec<u8>>,
handle_rollups: bool,
db: &ComputeStateAccess,
) -> Result<MachineProof> {
let log_type = AccessLogType {
annotations: true,
Expand All @@ -111,18 +119,19 @@ impl MachineInstance {
};

let mut logs = Vec::new();
if handle_rollups {
if db.handle_rollups {
// treat it as rollups
// the cycle may be the cycle to receive input,
// we need to include the process of feeding input to the machine in the log
if cycle == 0 {
self.run(cycle)?;
} else {
self.run_with_inputs(cycle - 1, &inputs)?;
self.run_with_inputs(cycle - 1, db)?;
self.run(cycle)?;
}

let mask = arithmetic::max_uint(constants::LOG2_EMULATOR_SPAN);
let inputs = &db.inputs()?;
let input = inputs.get((cycle >> constants::LOG2_EMULATOR_SPAN) as usize);
if cycle & mask == 0 && input.is_some() {
// need to process input
Expand Down Expand Up @@ -225,13 +234,14 @@ impl MachineInstance {
// One exception is that if `cycle` is supposed to receive an input, in this case
// the machine state would be `without` input included in the machine,
// this is useful when we need the initial state to compute the commitments
pub fn run_with_inputs(&mut self, cycle: u64, inputs: &Vec<Vec<u8>>) -> Result<MachineState> {
pub fn run_with_inputs(&mut self, cycle: u64, db: &ComputeStateAccess) -> Result<MachineState> {
trace!(
"run_with_inputs self cycle: {}, target cycle: {}",
self.cycle,
cycle
);

let inputs = &db.inputs()?;
let mut machine_state_without_input = self.machine_state()?;
let input_mask = arithmetic::max_uint(constants::LOG2_EMULATOR_SPAN);
let current_input_index = self.cycle >> constants::LOG2_EMULATOR_SPAN;
Expand All @@ -249,15 +259,21 @@ impl MachineInstance {
trace!("next input index: {}", next_input_index);
trace!("run to next input cycle: {}", next_input_cycle);
machine_state_without_input = self.run(next_input_cycle)?;
if next_input_cycle == cycle {
self.take_snapshot(next_input_cycle, &db)?;
}

let input = inputs.get(next_input_index as usize);
if let Some(data) = input {
trace!(
"before input, machine state: {}",
self.machine_state()?.root_hash
);
trace!("input: 0x{}", data.encode_hex());

self.machine
.send_cmio_response(htif::fromhost::ADVANCE_STATE, data)?;

trace!(
"after input, machine state: {}",
self.machine_state()?.root_hash
Expand All @@ -269,6 +285,7 @@ impl MachineInstance {
}
if cycle > self.cycle {
machine_state_without_input = self.run(cycle)?;
self.take_snapshot(next_input_cycle, &db)?;
}
Ok(machine_state_without_input)
}
Expand Down
3 changes: 1 addition & 2 deletions prt/client-rs/src/strategy/player.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,7 @@ impl Player {
if let Some(snapshot) = self.db.closest_snapshot(cycle)? {
machine.load_snapshot(&snapshot.1, snapshot.0)?;
};
let inputs = self.db.inputs()?;
machine.get_logs(cycle, ucycle, inputs, self.db.handle_rollups)?
machine.get_logs(cycle, ucycle, &self.db)?
};

info!(
Expand Down

0 comments on commit 69f724c

Please sign in to comment.