From e2decbfc92dd047c15a9f875b3d7bf7dee1d56fd Mon Sep 17 00:00:00 2001 From: Chris T Date: Mon, 13 May 2024 18:02:38 -0700 Subject: [PATCH] feat: reduce network prover (#687) --- core/src/runtime/mod.rs | 17 ++- core/src/runtime/syscall.rs | 62 +++++----- prover/src/lib.rs | 140 ++++++++++++++-------- prover/src/types.rs | 6 +- prover/src/utils.rs | 15 +++ recursion/program/src/machine/compress.rs | 3 +- sdk/src/lib.rs | 4 +- sdk/src/provers/network.rs | 16 +-- 8 files changed, 158 insertions(+), 105 deletions(-) diff --git a/core/src/runtime/mod.rs b/core/src/runtime/mod.rs index 282ce03656..09b21a88aa 100644 --- a/core/src/runtime/mod.rs +++ b/core/src/runtime/mod.rs @@ -25,7 +25,6 @@ use std::collections::HashMap; use std::fs::File; use std::io::BufWriter; use std::io::Write; -use std::rc::Rc; use std::sync::Arc; use thiserror::Error; @@ -76,7 +75,7 @@ pub struct Runtime { pub(crate) unconstrained_state: ForkState, - pub syscall_map: HashMap>, + pub syscall_map: HashMap>, pub max_syscall_cycles: u32, @@ -132,7 +131,7 @@ impl Runtime { program, memory_accesses: MemoryAccessRecord::default(), shard_size: shard_size * 4, - shard_batch_size: env::shard_batch_size() as u32 * shard_size, + shard_batch_size: env::shard_batch_size() as u32, cycle_tracker: HashMap::new(), io_buf: HashMap::new(), trace_buf, @@ -937,6 +936,12 @@ impl Runtime { tracing::info!("starting execution"); } + pub fn run_untraced(&mut self) -> Result<(), ExecutionError> { + self.emit_events = false; + while !self.execute()? {} + Ok(()) + } + pub fn run(&mut self) -> Result<(), ExecutionError> { self.emit_events = true; while !self.execute()? {} @@ -965,10 +970,10 @@ impl Runtime { break; } - if env::shard_batch_size() > 0 && current_shard != self.state.current_shard { + if self.shard_batch_size > 0 && current_shard != self.state.current_shard { num_shards_executed += 1; current_shard = self.state.current_shard; - if num_shards_executed == env::shard_batch_size() { + if num_shards_executed == self.shard_batch_size { break; } } @@ -1040,7 +1045,7 @@ impl Runtime { } } - fn get_syscall(&mut self, code: SyscallCode) -> Option<&Rc> { + fn get_syscall(&mut self, code: SyscallCode) -> Option<&Arc> { self.syscall_map.get(&code) } } diff --git a/core/src/runtime/syscall.rs b/core/src/runtime/syscall.rs index 4e6304b429..934c94738f 100644 --- a/core/src/runtime/syscall.rs +++ b/core/src/runtime/syscall.rs @@ -1,5 +1,5 @@ use std::collections::HashMap; -use std::rc::Rc; +use std::sync::Arc; use strum_macros::EnumIter; @@ -149,7 +149,7 @@ impl SyscallCode { } } -pub trait Syscall { +pub trait Syscall: Send + Sync { /// Execute the syscall and return the resulting value of register a0. `arg1` and `arg2` are the /// values in registers X10 and X11, respectively. While not a hard requirement, the convention /// is that the return value is only for system calls such as `HALT`. Most precompiles use `arg1` @@ -256,86 +256,86 @@ impl<'a> SyscallContext<'a> { } } -pub fn default_syscall_map() -> HashMap> { - let mut syscall_map = HashMap::>::default(); - syscall_map.insert(SyscallCode::HALT, Rc::new(SyscallHalt {})); - syscall_map.insert(SyscallCode::SHA_EXTEND, Rc::new(ShaExtendChip::new())); - syscall_map.insert(SyscallCode::SHA_COMPRESS, Rc::new(ShaCompressChip::new())); +pub fn default_syscall_map() -> HashMap> { + let mut syscall_map = HashMap::>::default(); + syscall_map.insert(SyscallCode::HALT, Arc::new(SyscallHalt {})); + syscall_map.insert(SyscallCode::SHA_EXTEND, Arc::new(ShaExtendChip::new())); + syscall_map.insert(SyscallCode::SHA_COMPRESS, Arc::new(ShaCompressChip::new())); syscall_map.insert( SyscallCode::ED_ADD, - Rc::new(EdAddAssignChip::::new()), + Arc::new(EdAddAssignChip::::new()), ); syscall_map.insert( SyscallCode::ED_DECOMPRESS, - Rc::new(EdDecompressChip::::new()), + Arc::new(EdDecompressChip::::new()), ); syscall_map.insert( SyscallCode::KECCAK_PERMUTE, - Rc::new(KeccakPermuteChip::new()), + Arc::new(KeccakPermuteChip::new()), ); syscall_map.insert( SyscallCode::SECP256K1_ADD, - Rc::new(WeierstrassAddAssignChip::::new()), + Arc::new(WeierstrassAddAssignChip::::new()), ); syscall_map.insert( SyscallCode::SECP256K1_DOUBLE, - Rc::new(WeierstrassDoubleAssignChip::::new()), + Arc::new(WeierstrassDoubleAssignChip::::new()), ); - syscall_map.insert(SyscallCode::SHA_COMPRESS, Rc::new(ShaCompressChip::new())); + syscall_map.insert(SyscallCode::SHA_COMPRESS, Arc::new(ShaCompressChip::new())); syscall_map.insert( SyscallCode::SECP256K1_DECOMPRESS, - Rc::new(WeierstrassDecompressChip::::new()), + Arc::new(WeierstrassDecompressChip::::new()), ); syscall_map.insert( SyscallCode::BN254_ADD, - Rc::new(WeierstrassAddAssignChip::::new()), + Arc::new(WeierstrassAddAssignChip::::new()), ); syscall_map.insert( SyscallCode::BN254_DOUBLE, - Rc::new(WeierstrassDoubleAssignChip::::new()), + Arc::new(WeierstrassDoubleAssignChip::::new()), ); syscall_map.insert( SyscallCode::BLAKE3_COMPRESS_INNER, - Rc::new(Blake3CompressInnerChip::new()), + Arc::new(Blake3CompressInnerChip::new()), ); syscall_map.insert( SyscallCode::BLS12381_ADD, - Rc::new(WeierstrassAddAssignChip::::new()), + Arc::new(WeierstrassAddAssignChip::::new()), ); syscall_map.insert( SyscallCode::BLS12381_DOUBLE, - Rc::new(WeierstrassDoubleAssignChip::::new()), + Arc::new(WeierstrassDoubleAssignChip::::new()), ); syscall_map.insert( SyscallCode::BLAKE3_COMPRESS_INNER, - Rc::new(Blake3CompressInnerChip::new()), + Arc::new(Blake3CompressInnerChip::new()), ); - syscall_map.insert(SyscallCode::UINT256_MUL, Rc::new(Uint256MulChip::new())); + syscall_map.insert(SyscallCode::UINT256_MUL, Arc::new(Uint256MulChip::new())); syscall_map.insert( SyscallCode::ENTER_UNCONSTRAINED, - Rc::new(SyscallEnterUnconstrained::new()), + Arc::new(SyscallEnterUnconstrained::new()), ); syscall_map.insert( SyscallCode::EXIT_UNCONSTRAINED, - Rc::new(SyscallExitUnconstrained::new()), + Arc::new(SyscallExitUnconstrained::new()), ); - syscall_map.insert(SyscallCode::WRITE, Rc::new(SyscallWrite::new())); - syscall_map.insert(SyscallCode::COMMIT, Rc::new(SyscallCommit::new())); + syscall_map.insert(SyscallCode::WRITE, Arc::new(SyscallWrite::new())); + syscall_map.insert(SyscallCode::COMMIT, Arc::new(SyscallCommit::new())); syscall_map.insert( SyscallCode::COMMIT_DEFERRED_PROOFS, - Rc::new(SyscallCommitDeferred::new()), + Arc::new(SyscallCommitDeferred::new()), ); syscall_map.insert( SyscallCode::VERIFY_SP1_PROOF, - Rc::new(SyscallVerifySP1Proof::new()), + Arc::new(SyscallVerifySP1Proof::new()), ); - syscall_map.insert(SyscallCode::HINT_LEN, Rc::new(SyscallHintLen::new())); - syscall_map.insert(SyscallCode::HINT_READ, Rc::new(SyscallHintRead::new())); + syscall_map.insert(SyscallCode::HINT_LEN, Arc::new(SyscallHintLen::new())); + syscall_map.insert(SyscallCode::HINT_READ, Arc::new(SyscallHintRead::new())); syscall_map.insert( SyscallCode::BLS12381_DECOMPRESS, - Rc::new(WeierstrassDecompressChip::::new()), + Arc::new(WeierstrassDecompressChip::::new()), ); - syscall_map.insert(SyscallCode::UINT256_MUL, Rc::new(Uint256MulChip::new())); + syscall_map.insert(SyscallCode::UINT256_MUL, Arc::new(Uint256MulChip::new())); syscall_map } diff --git a/prover/src/lib.rs b/prover/src/lib.rs index 415c5e4d04..dfbed9acbc 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -31,13 +31,15 @@ use rayon::prelude::*; use sp1_core::air::PublicValues; pub use sp1_core::io::{SP1PublicValues, SP1Stdin}; use sp1_core::runtime::{ExecutionError, Runtime}; -use sp1_core::stark::{Challenge, MachineVerificationError, StarkProvingKey}; +use sp1_core::stark::{Challenge, StarkProvingKey}; +use sp1_core::stark::{Challenger, MachineVerificationError}; +use sp1_core::utils::DIGEST_SIZE; use sp1_core::{ runtime::Program, stark::{ LocalProver, RiscvAir, ShardProof, StarkGenericConfig, StarkMachine, StarkVerifyingKey, Val, }, - utils::{BabyBearPoseidon2, SP1CoreProverError, DIGEST_SIZE}, + utils::{BabyBearPoseidon2, SP1CoreProverError}, }; use sp1_primitives::hash_deferred_proof; use sp1_recursion_circuit::witness::Witnessable; @@ -53,10 +55,12 @@ use sp1_recursion_gnark_ffi::plonk_bn254::PlonkBn254Prover; pub use sp1_recursion_gnark_ffi::Groth16Proof; use sp1_recursion_gnark_ffi::Groth16Prover; use sp1_recursion_program::hints::Hintable; +pub use sp1_recursion_program::machine::ReduceProgramType; use sp1_recursion_program::machine::{ - ReduceProgramType, SP1CompressVerifier, SP1DeferredMemoryLayout, SP1DeferredVerifier, - SP1RecursionMemoryLayout, SP1RecursiveVerifier, SP1ReduceMemoryLayout, SP1RootMemoryLayout, - SP1RootVerifier, + SP1CompressVerifier, SP1DeferredVerifier, SP1RecursiveVerifier, SP1RootVerifier, +}; +pub use sp1_recursion_program::machine::{ + SP1DeferredMemoryLayout, SP1RecursionMemoryLayout, SP1ReduceMemoryLayout, SP1RootMemoryLayout, }; use tracing::instrument; pub use types::*; @@ -220,7 +224,7 @@ impl SP1Prover { for (proof, vkey) in stdin.proofs.iter() { runtime.write_proof(proof.clone(), vkey.clone()); } - runtime.run()?; + runtime.run_untraced()?; Ok(SP1PublicValues::from(&runtime.state.public_values_stream)) } @@ -243,31 +247,19 @@ impl SP1Prover { }) } - /// Reduce shards proofs to a single shard proof using the recursion prover. - #[instrument(name = "compress", level = "info", skip_all)] - pub fn compress( - &self, - vk: &SP1VerifyingKey, - proof: SP1CoreProof, - deferred_proofs: Vec>, - ) -> Result, SP1RecursionProverError> { - let batch_size = 2; - let shard_proofs = &proof.proof.0; - - // Setup the reconstruct commitments flags to false and save its state. - let rc = env::var(RECONSTRUCT_COMMITMENTS_ENV_VAR).unwrap_or_default(); - env::set_var(RECONSTRUCT_COMMITMENTS_ENV_VAR, "false"); - - // Get the leaf challenger. - let mut leaf_challenger = self.core_machine.config().challenger(); - vk.vk.observe_into(&mut leaf_challenger); - shard_proofs.iter().for_each(|proof| { - leaf_challenger.observe(proof.commitment.main_commit); - leaf_challenger.observe_slice(&proof.public_values[0..self.core_machine.num_pv_elts()]); - }); - - // Make sure leaf challenger is not mutable anymore. - let leaf_challenger = leaf_challenger; + /// Generate the inputs for the first layer of recursive proofs. + #[allow(clippy::type_complexity)] + pub fn get_first_layer_inputs<'a>( + &'a self, + vk: &'a SP1VerifyingKey, + leaf_challenger: &'a Challenger, + shard_proofs: &[ShardProof], + deferred_proofs: &[ShardProof], + batch_size: usize, + ) -> ( + Vec>>, + Vec>>, + ) { let mut core_inputs = Vec::new(); let mut reconstruct_challenger = self.core_machine.config().challenger(); vk.vk.observe_into(&mut reconstruct_challenger); @@ -281,7 +273,7 @@ impl SP1Prover { vk: &vk.vk, machine: &self.core_machine, shard_proofs: proofs, - leaf_challenger: &leaf_challenger, + leaf_challenger, initial_reconstruct_challenger: reconstruct_challenger.clone(), is_complete, }); @@ -337,11 +329,46 @@ impl SP1Prover { deferred_digest = Self::hash_deferred_proofs(deferred_digest, batch); } + (core_inputs, deferred_inputs) + } + + /// Reduce shards proofs to a single shard proof using the recursion prover. + #[instrument(name = "compress", level = "info", skip_all)] + pub fn compress( + &self, + vk: &SP1VerifyingKey, + proof: SP1CoreProof, + deferred_proofs: Vec>, + ) -> Result, SP1RecursionProverError> { + // Set the batch size for the reduction tree. + let batch_size = 2; + + let shard_proofs = &proof.proof.0; + // Get the leaf challenger. + let mut leaf_challenger = self.core_machine.config().challenger(); + vk.vk.observe_into(&mut leaf_challenger); + shard_proofs.iter().for_each(|proof| { + leaf_challenger.observe(proof.commitment.main_commit); + leaf_challenger.observe_slice(&proof.public_values[0..self.core_machine.num_pv_elts()]); + }); + + // Setup the reconstruct commitments flags to false and save its state. + let rc = env::var(RECONSTRUCT_COMMITMENTS_ENV_VAR).unwrap_or_default(); + env::set_var(RECONSTRUCT_COMMITMENTS_ENV_VAR, "false"); + // Run the recursion and reduce programs. // Run the recursion programs. let mut records = Vec::new(); + let (core_inputs, deferred_inputs) = self.get_first_layer_inputs( + vk, + &leaf_challenger, + shard_proofs, + &deferred_proofs, + batch_size, + ); + for input in core_inputs { let mut runtime = RecursionRuntime::, Challenge, _>::new( &self.recursion_program, @@ -423,27 +450,12 @@ impl SP1Prover { is_complete, }; - let mut runtime = RecursionRuntime::, Challenge, _>::new( + let proof = self.compress_machine_proof( + input, &self.compress_program, - self.compress_machine.config().perm.clone(), - ); - - let mut witness_stream = Vec::new(); - witness_stream.extend(input.write()); - - runtime.witness_stream = witness_stream.into(); - runtime.run(); - runtime.print_stats(); - - let mut recursive_challenger = self.compress_machine.config().challenger(); - let mut proof = self.compress_machine.prove::>( &self.compress_pk, - runtime.record, - &mut recursive_challenger, ); - - debug_assert_eq!(proof.shard_proofs.len(), 1); - (proof.shard_proofs.pop().unwrap(), ReduceProgramType::Reduce) + (proof, ReduceProgramType::Reduce) }) .collect(); @@ -462,6 +474,32 @@ impl SP1Prover { }) } + pub fn compress_machine_proof( + &self, + input: impl Hintable, + program: &RecursionProgram, + pk: &StarkProvingKey, + ) -> ShardProof { + let mut runtime = RecursionRuntime::, Challenge, _>::new( + program, + self.compress_machine.config().perm.clone(), + ); + + let mut witness_stream = Vec::new(); + witness_stream.extend(input.write()); + + runtime.witness_stream = witness_stream.into(); + runtime.run(); + runtime.print_stats(); + + let mut recursive_challenger = self.compress_machine.config().challenger(); + self.compress_machine + .prove::>(pk, runtime.record, &mut recursive_challenger) + .shard_proofs + .pop() + .unwrap() + } + /// Wrap a reduce proof into a STARK proven over a SNARK-friendly field. #[instrument(name = "shrink", level = "info", skip_all)] pub fn shrink( @@ -603,7 +641,7 @@ impl SP1Prover { } /// Accumulate deferred proofs into a single digest. - fn hash_deferred_proofs( + pub fn hash_deferred_proofs( prev_digest: [Val; DIGEST_SIZE], deferred_proofs: &[ShardProof], ) -> [Val; 8] { diff --git a/prover/src/types.rs b/prover/src/types.rs index 94ffcc3cee..e3aa597bd9 100644 --- a/prover/src/types.rs +++ b/prover/src/types.rs @@ -31,7 +31,7 @@ pub struct SP1ProvingKey { } /// The information necessary to verify a proof for a given RISC-V program. -#[derive(Clone)] +#[derive(Clone, Serialize, Deserialize)] pub struct SP1VerifyingKey { pub vk: StarkVerifyingKey, } @@ -160,7 +160,7 @@ pub struct SP1Groth16ProofData(pub Groth16Proof); pub struct SP1PlonkProofData(pub PlonkBn254Proof); /// An intermediate proof which proves the execution over a range of shards. -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] #[serde(bound(serialize = "ShardProof: Serialize"))] #[serde(bound(deserialize = "ShardProof: Deserialize<'de>"))] pub struct SP1ReduceProof { @@ -190,7 +190,7 @@ impl SP1ReduceProof { } /// A proof that can be reduced along with other proofs into one proof. -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] pub enum SP1ReduceProofWrapper { Core(SP1ReduceProof), Recursive(SP1ReduceProof), diff --git a/prover/src/utils.rs b/prover/src/utils.rs index da32b12ab2..785b376079 100644 --- a/prover/src/utils.rs +++ b/prover/src/utils.rs @@ -3,6 +3,7 @@ use std::{ io::Read, }; +use futures::Future; use p3_baby_bear::BabyBear; use p3_bn254_fr::Bn254Fr; use p3_field::AbstractField; @@ -12,6 +13,7 @@ use sp1_core::{ io::SP1Stdin, runtime::{Program, Runtime}, }; +use tokio::{runtime, task::block_in_place}; use crate::SP1CoreProofData; @@ -84,3 +86,16 @@ pub fn words_to_bytes_be(words: &[u32; 8]) -> [u8; 32] { } bytes } + +/// Utility method for blocking on an async function. If we're already in a tokio runtime, we'll +/// block in place. Otherwise, we'll create a new runtime. +pub fn block_on(fut: impl Future) -> T { + // Handle case if we're already in an tokio runtime. + if let Ok(handle) = runtime::Handle::try_current() { + block_in_place(|| handle.block_on(fut)) + } else { + // Otherwise create a new runtime. + let rt = runtime::Runtime::new().expect("Failed to create a new runtime"); + rt.block_on(fut) + } +} diff --git a/recursion/program/src/machine/compress.rs b/recursion/program/src/machine/compress.rs index 441f648944..f260747d64 100644 --- a/recursion/program/src/machine/compress.rs +++ b/recursion/program/src/machine/compress.rs @@ -8,6 +8,7 @@ use p3_air::Air; use p3_baby_bear::BabyBear; use p3_commit::TwoAdicMultiplicativeCoset; use p3_field::{AbstractField, PrimeField32, TwoAdicField}; +use serde::{Deserialize, Serialize}; use sp1_core::air::MachineAir; use sp1_core::air::{Word, POSEIDON_NUM_WORDS, PV_DIGEST_NUM_WORDS}; use sp1_core::stark::StarkMachine; @@ -41,7 +42,7 @@ pub struct SP1CompressVerifier { } /// The different types of programs that can be verified by the `SP1ReduceVerifier`. -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] pub enum ReduceProgramType { /// A batch of proofs that are all SP1 Core proofs. Core = 0, diff --git a/sdk/src/lib.rs b/sdk/src/lib.rs index 7a53636ec1..58ae33b989 100644 --- a/sdk/src/lib.rs +++ b/sdk/src/lib.rs @@ -26,8 +26,8 @@ pub use provers::{LocalProver, MockProver, NetworkProver, Prover}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use sp1_core::stark::{MachineVerificationError, ShardProof}; pub use sp1_prover::{ - CoreSC, Groth16Proof, HashableKey, InnerSC, PlonkBn254Proof, SP1CoreProof, SP1Prover, - SP1ProvingKey, SP1PublicValues, SP1Stdin, SP1VerifyingKey, + CoreSC, Groth16Proof, HashableKey, InnerSC, OuterSC, PlonkBn254Proof, SP1Prover, SP1ProvingKey, + SP1PublicValues, SP1Stdin, SP1VerifyingKey, }; /// A client for interacting with SP1. diff --git a/sdk/src/provers/network.rs b/sdk/src/provers/network.rs index 17e9c5b301..a5b2e58e3c 100644 --- a/sdk/src/provers/network.rs +++ b/sdk/src/provers/network.rs @@ -11,6 +11,7 @@ use crate::{ }; use anyhow::{Context, Result}; use serde::de::DeserializeOwned; +use sp1_prover::utils::block_on; use sp1_prover::{SP1Prover, SP1Stdin}; use tokio::{runtime, time::sleep}; @@ -153,26 +154,19 @@ impl Prover for NetworkProver { } fn prove(&self, pk: &SP1ProvingKey, stdin: SP1Stdin) -> Result { - let rt = tokio::runtime::Runtime::new()?; - rt.block_on(async { self.prove_async(&pk.elf, stdin, ProofMode::Core).await }) + block_on(self.prove_async(&pk.elf, stdin, ProofMode::Core)) } fn prove_compressed(&self, pk: &SP1ProvingKey, stdin: SP1Stdin) -> Result { - let rt = tokio::runtime::Runtime::new()?; - rt.block_on(async { - self.prove_async(&pk.elf, stdin, ProofMode::Compressed) - .await - }) + block_on(self.prove_async(&pk.elf, stdin, ProofMode::Compressed)) } fn prove_groth16(&self, pk: &SP1ProvingKey, stdin: SP1Stdin) -> Result { - let rt = tokio::runtime::Runtime::new()?; - rt.block_on(async { self.prove_async(&pk.elf, stdin, ProofMode::Groth16).await }) + block_on(self.prove_async(&pk.elf, stdin, ProofMode::Groth16)) } fn prove_plonk(&self, pk: &SP1ProvingKey, stdin: SP1Stdin) -> Result { - let rt = tokio::runtime::Runtime::new()?; - rt.block_on(async { self.prove_async(&pk.elf, stdin, ProofMode::Plonk).await }) + block_on(self.prove_async(&pk.elf, stdin, ProofMode::Plonk)) } }