Skip to content

Commit

Permalink
Fix: replace ProverInput generic for an enum (#57)
Browse files Browse the repository at this point in the history
* replaced ProverInput generic for an enum

* fix serialization

* cargo-fmt
  • Loading branch information
piniom authored Sep 12, 2024
1 parent 20f712b commit 63509b4
Show file tree
Hide file tree
Showing 12 changed files with 41 additions and 50 deletions.
5 changes: 1 addition & 4 deletions bin/cairo-prove/src/prove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@ use crate::errors::ProveErrors;
use crate::validate_input;
use crate::Args;
use crate::CairoVersion;
use common::{
cairo0_prover_input::{Cairo0CompiledProgram, Cairo0ProverInput},
cairo_prover_input::{CairoCompiledProgram, CairoProverInput},
};
use common::prover_input::*;
use prover_sdk::sdk::ProverSDK;
use serde_json::Value;

Expand Down
6 changes: 1 addition & 5 deletions common/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
pub mod cairo0_prover_input;
pub mod cairo_prover_input;
pub mod models;
pub mod prover_input;
pub mod requests;
pub trait ProverInput {
fn serialize(self) -> serde_json::Value;
}
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
use serde::{Deserialize, Serialize};
use starknet_types_core::felt::Felt;

use crate::ProverInput;

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct CairoProverInput {
pub program: CairoCompiledProgram,
pub program_input: Vec<Felt>,
pub layout: String,
}
impl ProverInput for CairoProverInput {
fn serialize(self) -> serde_json::Value {
serde_json::to_value(self).unwrap()
}
}

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct CairoCompiledProgram {
//pub version: u64,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
use serde::{Deserialize, Serialize};

use crate::ProverInput;

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Cairo0ProverInput {
pub program: Cairo0CompiledProgram,
pub program_input: serde_json::Value,
pub layout: String,
}

impl ProverInput for Cairo0ProverInput {
fn serialize(self) -> serde_json::Value {
serde_json::to_value(self).unwrap()
}
}

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Cairo0CompiledProgram {
pub attributes: Vec<String>,
Expand Down
20 changes: 20 additions & 0 deletions common/src/prover_input/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
mod cairo;
mod cairo0;

pub use cairo::{CairoCompiledProgram, CairoProverInput};
pub use cairo0::{Cairo0CompiledProgram, Cairo0ProverInput};

#[derive(Debug)]
pub enum ProverInput {
Cairo0(Cairo0ProverInput),
Cairo(CairoProverInput),
}

impl ProverInput {
pub fn to_json_value(&self) -> serde_json::Value {
match self {
ProverInput::Cairo0(input) => serde_json::to_value(input).unwrap(),
ProverInput::Cairo(input) => serde_json::to_value(input).unwrap(),
}
}
}
3 changes: 1 addition & 2 deletions prover-sdk/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,4 @@ pub mod errors;
pub mod sdk;
pub mod sdk_builder;

pub use common::cairo0_prover_input::Cairo0ProverInput;
pub use common::cairo_prover_input::CairoProverInput;
pub use common::prover_input::ProverInput;
28 changes: 12 additions & 16 deletions prover-sdk/src/sdk.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use crate::{access_key::ProverAccessKey, errors::SdkErrors, sdk_builder::ProverSDKBuilder};
use common::{requests::AddKeyRequest, ProverInput};
use common::{
prover_input::{Cairo0ProverInput, CairoProverInput, ProverInput},
requests::AddKeyRequest,
};
use ed25519_dalek::{ed25519::signature::SignerMut, VerifyingKey};
use futures::StreamExt;
use reqwest::{Client, Response};
Expand Down Expand Up @@ -39,28 +42,21 @@ impl ProverSDK {
.build()
}

pub async fn prove_cairo0<T>(&self, data: T) -> Result<u64, SdkErrors>
where
T: ProverInput + Send + 'static,
{
self.prove(data, self.prover_cairo0.clone()).await
pub async fn prove_cairo0(&self, data: Cairo0ProverInput) -> Result<u64, SdkErrors> {
self.prove(ProverInput::Cairo0(data), self.prover_cairo0.clone())
.await
}

pub async fn prove_cairo<T>(&self, data: T) -> Result<u64, SdkErrors>
where
T: ProverInput + Send + 'static,
{
self.prove(data, self.prover_cairo.clone()).await
pub async fn prove_cairo(&self, data: CairoProverInput) -> Result<u64, SdkErrors> {
self.prove(ProverInput::Cairo(data), self.prover_cairo.clone())
.await
}

async fn prove<T>(&self, data: T, url: Url) -> Result<u64, SdkErrors>
where
T: ProverInput + Send + 'static,
{
async fn prove(&self, data: ProverInput, url: Url) -> Result<u64, SdkErrors> {
let response = self
.client
.post(url.clone())
.json(&data.serialize())
.json(&data.to_json_value())
.send()
.await?;

Expand Down
5 changes: 1 addition & 4 deletions prover-sdk/tests/prove_test.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
use common::{
cairo0_prover_input::{Cairo0CompiledProgram, Cairo0ProverInput},
cairo_prover_input::{CairoCompiledProgram, CairoProverInput},
};
use common::prover_input::*;
use helpers::fetch_job;
use prover_sdk::{access_key::ProverAccessKey, sdk::ProverSDK};
use serde_json::Value;
Expand Down
2 changes: 1 addition & 1 deletion prover-sdk/tests/verify_test.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use common::cairo_prover_input::{CairoCompiledProgram, CairoProverInput};
use common::prover_input::*;
use helpers::fetch_job;
use prover_sdk::{access_key::ProverAccessKey, sdk::ProverSDK};
use starknet_types_core::felt::Felt;
Expand Down
2 changes: 1 addition & 1 deletion prover/src/prove/cairo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::server::AppState;
use crate::threadpool::CairoVersionedInput;
use axum::Json;
use axum::{extract::State, http::StatusCode, response::IntoResponse};
use common::cairo_prover_input::CairoProverInput;
use common::prover_input::CairoProverInput;
use serde_json::json;

pub async fn root(
Expand Down
2 changes: 1 addition & 1 deletion prover/src/prove/cairo0.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::server::AppState;
use crate::threadpool::CairoVersionedInput;
use axum::Json;
use axum::{extract::State, http::StatusCode, response::IntoResponse};
use common::cairo0_prover_input::Cairo0ProverInput;
use common::prover_input::Cairo0ProverInput;
use serde_json::json;

pub async fn root(
Expand Down
2 changes: 1 addition & 1 deletion prover/src/threadpool/run.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{fs, path::PathBuf};

use common::{cairo0_prover_input::Cairo0ProverInput, cairo_prover_input::CairoProverInput};
use common::prover_input::{Cairo0ProverInput, CairoProverInput};
use starknet_types_core::felt::Felt;
use tokio::process::Command;
use tracing::trace;
Expand Down

0 comments on commit 63509b4

Please sign in to comment.