Skip to content

Commit

Permalink
fix: prove req pattern, fix async stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
nhtyy committed Dec 12, 2024
1 parent dd4b7ae commit b1bfa9d
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 124 deletions.
1 change: 1 addition & 0 deletions crates/sdk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ network-v2 = [
"dep:backoff",
]
cuda = ["sp1-cuda"]
blocking = []

[build-dependencies]
vergen = { version = "8", default-features = false, features = [
Expand Down
6 changes: 3 additions & 3 deletions crates/sdk/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ impl ProverClient {
}
}

pub async fn setup(&self, elf: Arc<[u8]>) -> (Arc<SP1ProvingKey>) {
pub async fn setup(&self, elf: Arc<[u8]>) -> Arc<SP1ProvingKey> {
self.inner.setup(elf).await
}

Expand Down Expand Up @@ -97,12 +97,12 @@ impl<T: BuildableProver> ProverClientBuilder<T> {
}

impl ProverClientBuilder<NetworkProverBuilder> {
pub fn with_rpc_url(mut self, url: String) -> Self {
pub fn rpc_url(mut self, url: String) -> Self {
self.inner_builder = self.inner_builder.rpc_url(url);
self
}

pub fn with_private_key(mut self, key: String) -> Self {
pub fn private_key(mut self, key: String) -> Self {
self.inner_builder = self.inner_builder.private_key(key);
self
}
Expand Down
127 changes: 82 additions & 45 deletions crates/sdk/src/local/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ impl LocalProver {
fn sp1_prover(&self) -> &SP1Prover {
&self.prover
}

pub fn prove(&self, pk: Arc<SP1ProvingKey>, stdin: SP1Stdin) -> LocalProofRequest {
LocalProofRequest::new(self, &pk, stdin)
}
}

pub struct LocalProverBuilder {}
Expand All @@ -61,7 +65,7 @@ impl LocalProverBuilder {

pub struct LocalProofRequest<'a> {
pub prover: &'a LocalProver,
pub pk: Arc<SP1ProvingKey>,
pub pk: &'a Arc<SP1ProvingKey>,
pub stdin: SP1Stdin,
pub mode: Mode,
pub timeout: u64,
Expand All @@ -70,7 +74,7 @@ pub struct LocalProofRequest<'a> {
}

impl<'a> LocalProofRequest<'a> {
pub fn new(prover: &'a LocalProver, pk: Arc<SP1ProvingKey>, stdin: SP1Stdin) -> Self {
pub fn new(prover: &'a LocalProver, pk: &'a Arc<SP1ProvingKey>, stdin: SP1Stdin) -> Self {
Self {
prover,
pk,
Expand Down Expand Up @@ -117,9 +121,24 @@ impl<'a> LocalProofRequest<'a> {
self
}

pub fn run(self) -> Result<SP1ProofWithPublicValues> {
let context = SP1Context::default();
Self::run_inner(
&self.prover.prover,
&**self.pk,
self.stdin,
self.mode,
self.timeout,
self.version,
self.sp1_prover_opts,
)
}
}

impl LocalProver {
fn run_inner(
prover: Arc<SP1Prover<DefaultProverComponents>>,
pk: Arc<SP1ProvingKey>,
prover: &SP1Prover<DefaultProverComponents>,
pk: &SP1ProvingKey,
stdin: SP1Stdin,
mode: Mode,
timeout: u64,
Expand Down Expand Up @@ -202,19 +221,6 @@ impl<'a> LocalProofRequest<'a> {

unreachable!()
}

pub fn run(self) -> Result<SP1ProofWithPublicValues> {
let context = SP1Context::default();
Self::run_inner(
Arc::clone(&self.prover.prover),
self.pk,
self.stdin,
self.mode,
self.timeout,
self.version,
self.sp1_prover_opts,
)
}
}

#[async_trait]
Expand All @@ -227,12 +233,14 @@ impl Prover for LocalProver {
})
.await
.unwrap();

result
}

#[cfg(feature = "blocking")]
fn setup_sync(&self, elf: &[u8]) -> Arc<SP1ProvingKey> {
let (pk, _vk) = self.prover.setup(elf);

Arc::new(pk)
}

Expand Down Expand Up @@ -262,23 +270,25 @@ impl Prover for LocalProver {
stdin: SP1Stdin,
opts: ProofOpts,
) -> Result<SP1ProofWithPublicValues> {
let prover = Arc::clone(&self.prover);
let mut req = self.prove(pk, stdin);

task::spawn_blocking(move || {
let context = SP1Context::default();
if let Some(mode) = opts.mode {
req.mode = mode;
}

LocalProofRequest::run_inner(
prover,
pk,
stdin,
opts.mode,
opts.timeout,
SP1_CIRCUIT_VERSION.to_string(),
SP1ProverOpts::default(),
)
})
.await
.unwrap()
if let Some(timeout) = opts.timeout {
req.timeout = timeout;
}

if let Some(version) = opts.version {
req.version = version;
}

if let Some(sp1_prover_opts) = opts.sp1_prover_opts {
req.sp1_prover_opts = sp1_prover_opts;
}

req.await
}

#[cfg(feature = "blocking")]
Expand All @@ -288,18 +298,25 @@ impl Prover for LocalProver {
stdin: SP1Stdin,
opts: ProofOpts,
) -> Result<SP1ProofWithPublicValues> {
let context = SP1Context::default();
let mut req = self.prove(pk, stdin);

LocalProofRequest::run_inner(
Arc::clone(&self.prover),
pk,
stdin,
opts.mode,
opts.timeout,
SP1_CIRCUIT_VERSION.to_string(),
SP1ProverOpts::default(),
context,
)
if let Some(mode) = opts.mode {
req.mode = mode;
}

if let Some(timeout) = opts.timeout {
req.timeout = timeout;
}

if let Some(version) = opts.version {
req.version = version;
}

if let Some(sp1_prover_opts) = opts.sp1_prover_opts {
req.sp1_prover_opts = sp1_prover_opts;
}

req.run()
}

async fn verify(
Expand All @@ -308,6 +325,7 @@ impl Prover for LocalProver {
vk: Arc<SP1VerifyingKey>,
) -> Result<(), SP1VerificationError> {
let prover = Arc::clone(&self.prover);

task::spawn_blocking(move || verify::verify(&prover, SP1_CIRCUIT_VERSION, &proof, &vk))
.await
.unwrap()
Expand All @@ -331,9 +349,28 @@ impl Default for LocalProver {

impl<'a> IntoFuture for LocalProofRequest<'a> {
type Output = Result<SP1ProofWithPublicValues>;
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>>;
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output>>>;

fn into_future(self) -> Self::IntoFuture {
Box::pin(async move { self.run() })
let LocalProofRequest { prover, pk, stdin, mode, timeout, version, sp1_prover_opts } = self;

let pk = Arc::clone(pk);
let prover = prover.prover.clone();

Box::pin(async move {
task::spawn_blocking(move || {
LocalProofRequest::run_inner(
&prover,
&**pk,
stdin,
mode,
timeout,
version,
sp1_prover_opts,
)
})
.await
.expect("To be able to join prove handle")
})
}
}
95 changes: 21 additions & 74 deletions crates/sdk/src/network-v2/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,63 +62,6 @@ impl NetworkProver {
}
}

/// Sets the proof mode to core.
pub fn core(mut self) -> Self {
self.network_client.mode = Mode::Core;
self
}

/// Sets the proof mode to compressed.
pub fn compressed(mut self) -> Self {
self.network_client.mode = Mode::Compressed;
self
}

/// Sets the proof mode to plonk.
pub fn plonk(mut self) -> Self {
self.network_client.mode = Mode::Plonk;
self
}

/// Sets the proof mode to groth16.
pub fn groth16(mut self) -> Self {
self.network_client.mode = Mode::Groth16;
self
}

/// Sets the RPC URL for the prover network.
///
/// This configures the endpoint that will be used for all network operations.
/// If not set, the default RPC URL will be used.
pub fn timeout(mut self, timeout: u64) -> Self {
self.network_client.timeout = Some(timeout);
self
}

/// Sets the cycle limit for the prover network.
///
/// See `get_cycle_limit` for more details the final cycle limit is determined.
pub fn cycle_limit(mut self, limit: u64) -> Self {
self.network_client.cycle_limit = Some(limit);
self
}

/// Skips simulation when determining the cycle limit.
///
/// See `get_cycle_limit` for more details the final cycle limit is determined.
pub fn skip_simulation(mut self, skip: bool) -> Self {
self.network_client.skip_simulation = skip;
self
}

/// Sets the fulfillment strategy for the prover network.
///
/// See `request_proof` for more details the final cycle limit is determined.
pub fn strategy(mut self, strategy: FulfillmentStrategy) -> Self {
self.network_client.strategy = Some(strategy);
self
}

/// Get the cycle limit to used for a proof request.
///
/// The cycle limit is determined according to the following priority:
Expand Down Expand Up @@ -152,17 +95,13 @@ impl NetworkProver {
}

/// Registers a program if it is not already registered.
pub async fn register_program(
&self,
vk: &SP1VerifyingKey,
elf: &[u8],
) -> Result<VerifyingKeyHash> {
async fn register_program(&self, vk: &SP1VerifyingKey, elf: &[u8]) -> Result<VerifyingKeyHash> {
self.network_client.register_program(vk, elf).await
}

/// Requests a proof from the prover network, returning the request ID.
#[allow(clippy::too_many_arguments)]
pub async fn request_proof(
async fn request_proof(
&self,
vk_hash: &VerifyingKeyHash,
stdin: &SP1Stdin,
Expand Down Expand Up @@ -197,7 +136,7 @@ impl NetworkProver {
///
/// The proof request must have already been submitted. This function will return a
/// `RequestTimedOut` error if the request does not received a response within the timeout.
pub async fn wait_proof<P: DeserializeOwned>(
async fn wait_proof<P: DeserializeOwned>(
&self,
request_id: &RequestId,
timeout_secs: u64,
Expand Down Expand Up @@ -254,14 +193,6 @@ impl NetworkProver {
}
}

pub fn prove_with_options<'a>(
&'a self,
pk: Arc<SP1ProvingKey>,
stdin: SP1Stdin,
) -> NetworkProofRequest<'a> {
NetworkProofRequest::new(self, pk, stdin)
}

/// Creates a new network prover builder. See [`NetworkProverBuilder`] for more details.
pub fn builder() -> NetworkProverBuilder {
NetworkProverBuilder::new()
Expand Down Expand Up @@ -327,8 +258,23 @@ impl<'a> NetworkProofRequest<'a> {
}
}

pub fn with_mode(mut self, mode: Mode) -> Self {
self.mode = mode.into();
pub fn groth16(mut self) -> Self {
self.mode = ProofMode::Groth16;
self
}

pub fn plonk(mut self) -> Self {
self.mode = ProofMode::Plonk;
self
}

pub fn core(mut self) -> Self {
self.mode = ProofMode::Core;
self
}

pub fn compressed(mut self) -> Self {
self.mode = ProofMode::Compressed;
self
}

Expand Down Expand Up @@ -445,6 +391,7 @@ impl Prover for NetworkProver {
.with_mode(opts.mode)
.with_timeout(opts.timeout)
.with_cycle_limit(opts.cycle_limit);

request.run_inner().await
}

Expand Down
2 changes: 1 addition & 1 deletion crates/sdk/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,6 @@ impl<'a> IntoFuture for DynProofRequest<'a> {
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>>;

fn into_future(self) -> Self::IntoFuture {
Box::pin(self.prover.prove_with_options(self.pk, self.stdin, self.opts))
self.prover.prove_with_options(self.pk, self.stdin, self.opts)
}
}
2 changes: 1 addition & 1 deletion crates/sdk/src/verify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ use strum_macros::EnumString;
use thiserror::Error;

use crate::install::try_install_circuit_artifacts;
use crate::opts::ProofOpts;
use crate::local::SP1VerificationError;
use crate::opts::ProofOpts;
use crate::{proof::SP1Proof, proof::SP1ProofKind, proof::SP1ProofWithPublicValues};

/// Verify that an SP1 proof is valid given its vkey and metadata.
Expand Down

0 comments on commit b1bfa9d

Please sign in to comment.