diff --git a/src/main.rs b/src/main.rs index 92af7f7..29f5a30 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,7 +6,7 @@ use axum::{ Json, Router, }; use clap::Parser; -use num::{BigInt, Num}; +use num::{BigInt, BigUint, Num}; use oracle::*; use serde::{Deserialize, Serialize}; use stark_vrf::{generate_public_key, BaseField, StarkVRF}; @@ -19,8 +19,12 @@ use tracing::debug; #[command(version, about, long_about = None)] struct Args { /// Secret key - #[arg(short, long, required = true)] - secret_key: u64, + #[arg(short, long, required = true, help = "Secret key as hex string")] + secret_key: String, + + /// Port + #[arg(short, long, default_value = "3003")] + port: u16, } fn format(v: T) -> String { @@ -34,14 +38,21 @@ struct InfoResult { public_key_y: String, } -fn get_secret_key() -> String { +fn get_secret_key() -> BigUint { let args = Args::parse(); - args.secret_key.to_string() + BigUint::from_str_radix( + args + .secret_key + .trim_start_matches("0x") + .trim_start_matches("0X"), + 16, + ) + .expect("unable to parse secret_key") } async fn vrf_info() -> Json { let secret_key = get_secret_key(); - let public_key = generate_public_key(secret_key.parse().unwrap()); + let public_key = generate_public_key(secret_key.into()); Json(InfoResult { public_key_x: format(public_key.x), @@ -57,7 +68,7 @@ struct JsonResult { async fn stark_vrf(extract::Json(payload): extract::Json) -> Json { debug!("received payload {payload:?}"); let secret_key = get_secret_key(); - let public_key = generate_public_key(secret_key.parse().unwrap()); + let public_key = generate_public_key(secret_key.clone().into()); println!("public key {public_key}"); @@ -72,9 +83,7 @@ async fn stark_vrf(extract::Json(payload): extract::Json) -> Js .collect(); let ecvrf = StarkVRF::new(public_key).unwrap(); - let proof = ecvrf - .prove(&secret_key.parse().unwrap(), seed.as_slice()) - .unwrap(); + let proof = ecvrf.prove(&secret_key.into(), seed.as_slice()).unwrap(); let sqrt_ratio_hint = ecvrf.hash_to_sqrt_ratio_hint(seed.as_slice()); let rnd = ecvrf.proof_to_hash(&proof).unwrap(); @@ -100,7 +109,7 @@ async fn stark_vrf(extract::Json(payload): extract::Json) -> Js #[tokio::main(flavor = "current_thread")] async fn main() { - Args::parse(); + let args = Args::parse(); tracing_subscriber::fmt() .with_max_level(tracing::Level::DEBUG) @@ -116,11 +125,11 @@ async fn main() { .route("/stark_vrf", post(stark_vrf)) .layer(TraceLayer::new_for_http()); - let listener = tokio::net::TcpListener::bind("0.0.0.0:3000") + let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", &args.port)) .await - .expect("Failed to bind to port 3000, port already in use by another process. Change the port or terminate the other process."); + .expect(&format!("Failed to bind to port {}, port already in use by another process. Change the port or terminate the other process.", &args.port)); - debug!("Server started on http://0.0.0.0:3000"); + debug!("Server started on http://0.0.0.0:{}", &args.port); axum::serve(listener, app) .with_graceful_shutdown(shutdown_signal())