diff --git a/bin/cairo-prove/src/prove.rs b/bin/cairo-prove/src/prove.rs index 4e0d70f..00dfd49 100644 --- a/bin/cairo-prove/src/prove.rs +++ b/bin/cairo-prove/src/prove.rs @@ -2,10 +2,7 @@ use crate::errors::ProveErrors; use crate::validate_input; use crate::Args; use crate::CairoVersion; -use common::{ - cairo0_prover_input::{Cairo0CompiledProgram, Cairo0ProverInput}, - cairo_prover_input::{CairoCompiledProgram, CairoProverInput}, -}; +use common::prover_input::*; use prover_sdk::sdk::ProverSDK; use serde_json::Value; diff --git a/common/src/lib.rs b/common/src/lib.rs index 57bedcf..6e7ee43 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -1,7 +1,3 @@ -pub mod cairo0_prover_input; -pub mod cairo_prover_input; pub mod models; +pub mod prover_input; pub mod requests; -pub trait ProverInput { - fn serialize(self) -> serde_json::Value; -} diff --git a/common/src/cairo_prover_input.rs b/common/src/prover_input/cairo.rs similarity index 77% rename from common/src/cairo_prover_input.rs rename to common/src/prover_input/cairo.rs index 1ff89fe..6c7c370 100644 --- a/common/src/cairo_prover_input.rs +++ b/common/src/prover_input/cairo.rs @@ -1,19 +1,13 @@ use serde::{Deserialize, Serialize}; use starknet_types_core::felt::Felt; -use crate::ProverInput; - #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct CairoProverInput { pub program: CairoCompiledProgram, pub program_input: Vec, pub layout: String, } -impl ProverInput for CairoProverInput { - fn serialize(self) -> serde_json::Value { - serde_json::to_value(self).unwrap() - } -} + #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct CairoCompiledProgram { //pub version: u64, diff --git a/common/src/cairo0_prover_input.rs b/common/src/prover_input/cairo0.rs similarity index 80% rename from common/src/cairo0_prover_input.rs rename to common/src/prover_input/cairo0.rs index ae01e9d..1ca399e 100644 --- a/common/src/cairo0_prover_input.rs +++ b/common/src/prover_input/cairo0.rs @@ -1,7 +1,5 @@ use serde::{Deserialize, Serialize}; -use crate::ProverInput; - #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct Cairo0ProverInput { pub program: Cairo0CompiledProgram, @@ -9,12 +7,6 @@ pub struct Cairo0ProverInput { pub layout: String, } -impl ProverInput for Cairo0ProverInput { - fn serialize(self) -> serde_json::Value { - serde_json::to_value(self).unwrap() - } -} - #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct Cairo0CompiledProgram { pub attributes: Vec, diff --git a/common/src/prover_input/mod.rs b/common/src/prover_input/mod.rs new file mode 100644 index 0000000..8c925ba --- /dev/null +++ b/common/src/prover_input/mod.rs @@ -0,0 +1,20 @@ +mod cairo; +mod cairo0; + +pub use cairo::{CairoCompiledProgram, CairoProverInput}; +pub use cairo0::{Cairo0CompiledProgram, Cairo0ProverInput}; + +#[derive(Debug)] +pub enum ProverInput { + Cairo0(Cairo0ProverInput), + Cairo(CairoProverInput), +} + +impl ProverInput { + pub fn to_json_value(&self) -> serde_json::Value { + match self { + ProverInput::Cairo0(input) => serde_json::to_value(input).unwrap(), + ProverInput::Cairo(input) => serde_json::to_value(input).unwrap(), + } + } +} diff --git a/prover-sdk/src/lib.rs b/prover-sdk/src/lib.rs index 42680d1..8a1d449 100644 --- a/prover-sdk/src/lib.rs +++ b/prover-sdk/src/lib.rs @@ -3,5 +3,4 @@ pub mod errors; pub mod sdk; pub mod sdk_builder; -pub use common::cairo0_prover_input::Cairo0ProverInput; -pub use common::cairo_prover_input::CairoProverInput; +pub use common::prover_input::ProverInput; diff --git a/prover-sdk/src/sdk.rs b/prover-sdk/src/sdk.rs index 01fcf98..cd2a992 100644 --- a/prover-sdk/src/sdk.rs +++ b/prover-sdk/src/sdk.rs @@ -1,5 +1,8 @@ use crate::{access_key::ProverAccessKey, errors::SdkErrors, sdk_builder::ProverSDKBuilder}; -use common::{requests::AddKeyRequest, ProverInput}; +use common::{ + prover_input::{Cairo0ProverInput, CairoProverInput, ProverInput}, + requests::AddKeyRequest, +}; use ed25519_dalek::{ed25519::signature::SignerMut, VerifyingKey}; use futures::StreamExt; use reqwest::{Client, Response}; @@ -39,28 +42,21 @@ impl ProverSDK { .build() } - pub async fn prove_cairo0(&self, data: T) -> Result - where - T: ProverInput + Send + 'static, - { - self.prove(data, self.prover_cairo0.clone()).await + pub async fn prove_cairo0(&self, data: Cairo0ProverInput) -> Result { + self.prove(ProverInput::Cairo0(data), self.prover_cairo0.clone()) + .await } - pub async fn prove_cairo(&self, data: T) -> Result - where - T: ProverInput + Send + 'static, - { - self.prove(data, self.prover_cairo.clone()).await + pub async fn prove_cairo(&self, data: CairoProverInput) -> Result { + self.prove(ProverInput::Cairo(data), self.prover_cairo.clone()) + .await } - async fn prove(&self, data: T, url: Url) -> Result - where - T: ProverInput + Send + 'static, - { + async fn prove(&self, data: ProverInput, url: Url) -> Result { let response = self .client .post(url.clone()) - .json(&data.serialize()) + .json(&data.to_json_value()) .send() .await?; diff --git a/prover-sdk/tests/prove_test.rs b/prover-sdk/tests/prove_test.rs index 82e67ef..fb0c172 100644 --- a/prover-sdk/tests/prove_test.rs +++ b/prover-sdk/tests/prove_test.rs @@ -1,7 +1,4 @@ -use common::{ - cairo0_prover_input::{Cairo0CompiledProgram, Cairo0ProverInput}, - cairo_prover_input::{CairoCompiledProgram, CairoProverInput}, -}; +use common::prover_input::*; use helpers::fetch_job; use prover_sdk::{access_key::ProverAccessKey, sdk::ProverSDK}; use serde_json::Value; diff --git a/prover-sdk/tests/verify_test.rs b/prover-sdk/tests/verify_test.rs index 2b28eb9..72d7d2a 100644 --- a/prover-sdk/tests/verify_test.rs +++ b/prover-sdk/tests/verify_test.rs @@ -1,4 +1,4 @@ -use common::cairo_prover_input::{CairoCompiledProgram, CairoProverInput}; +use common::prover_input::*; use helpers::fetch_job; use prover_sdk::{access_key::ProverAccessKey, sdk::ProverSDK}; use starknet_types_core::felt::Felt; diff --git a/prover/src/prove/cairo.rs b/prover/src/prove/cairo.rs index 073c8cf..8a2b145 100644 --- a/prover/src/prove/cairo.rs +++ b/prover/src/prove/cairo.rs @@ -4,7 +4,7 @@ use crate::server::AppState; use crate::threadpool::CairoVersionedInput; use axum::Json; use axum::{extract::State, http::StatusCode, response::IntoResponse}; -use common::cairo_prover_input::CairoProverInput; +use common::prover_input::CairoProverInput; use serde_json::json; pub async fn root( diff --git a/prover/src/prove/cairo0.rs b/prover/src/prove/cairo0.rs index b836e31..b25becb 100644 --- a/prover/src/prove/cairo0.rs +++ b/prover/src/prove/cairo0.rs @@ -4,7 +4,7 @@ use crate::server::AppState; use crate::threadpool::CairoVersionedInput; use axum::Json; use axum::{extract::State, http::StatusCode, response::IntoResponse}; -use common::cairo0_prover_input::Cairo0ProverInput; +use common::prover_input::Cairo0ProverInput; use serde_json::json; pub async fn root( diff --git a/prover/src/threadpool/run.rs b/prover/src/threadpool/run.rs index 73eae8c..96f3a28 100644 --- a/prover/src/threadpool/run.rs +++ b/prover/src/threadpool/run.rs @@ -1,6 +1,6 @@ use std::{fs, path::PathBuf}; -use common::{cairo0_prover_input::Cairo0ProverInput, cairo_prover_input::CairoProverInput}; +use common::prover_input::{Cairo0ProverInput, CairoProverInput}; use starknet_types_core::felt::Felt; use tokio::process::Command; use tracing::trace;