diff --git a/Cargo.lock b/Cargo.lock index b9fbe6b843b..5283f1e99b1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1304,11 +1304,13 @@ dependencies = [ "axum", "chromadb", "chrono", + "clap", "figment", "guacamole", "opentelemetry", "opentelemetry-otlp", "opentelemetry_sdk", + "reqwest 0.12.9", "serde", "serde_json", "tokio", diff --git a/rust/load/Cargo.toml b/rust/load/Cargo.toml index 3dadc39abda..4eebef01948 100644 --- a/rust/load/Cargo.toml +++ b/rust/load/Cargo.toml @@ -7,7 +7,8 @@ edition = "2021" async-trait = "0.1.83" axum = "0.7" chromadb = { git = "https://github.com/rescrv/chromadb-rs", rev = "e364e35c34c660d4e8e862436ea600ddc2f46a1e" } -chrono = "0.4.38" +chrono = { version = "0.4.38", features = ["serde"] } +clap = { version = "4", features = ["derive"] } figment = { version = "0.10.12", features = ["env", "yaml", "test"] } guacamole = { version = "0.9", default-features = false } serde.workspace = true @@ -27,3 +28,4 @@ opentelemetry-otlp = "0.27" opentelemetry_sdk = { version = "0.27", features = ["rt-tokio"] } tracing.workspace = true tower-http = { version = "0.6.2", features = ["trace"] } +reqwest = { version = "0.12", features = ["json"] } diff --git a/rust/load/examples/workload-json.rs b/rust/load/examples/workload-json.rs new file mode 100644 index 00000000000..d350bddadb6 --- /dev/null +++ b/rust/load/examples/workload-json.rs @@ -0,0 +1,32 @@ +use chroma_load::{Distribution, GetQuery, QueryQuery, Workload}; + +fn main() { + let w = Workload::Hybrid(vec![ + (1.0, Workload::Nop), + (1.0, Workload::ByName("foo".to_string())), + ( + 1.0, + Workload::Get(GetQuery { + limit: Distribution::Constant(10), + document: None, + metadata: None, + }), + ), + ( + 1.0, + Workload::Query(QueryQuery { + limit: Distribution::Constant(10), + document: None, + metadata: None, + }), + ), + ( + 1.0, + Workload::Delay { + after: chrono::DateTime::parse_from_rfc3339("2021-01-01T00:00:00+00:00").unwrap(), + wrap: Box::new(Workload::Nop), + }, + ), + ]); + println!("{}", serde_json::to_string_pretty(&w).unwrap()); +} diff --git a/rust/load/src/bin/chroma-load-inhibit.rs b/rust/load/src/bin/chroma-load-inhibit.rs new file mode 100644 index 00000000000..d5f0b43bbb5 --- /dev/null +++ b/rust/load/src/bin/chroma-load-inhibit.rs @@ -0,0 +1,18 @@ +//! Inhibit chroma-load on every host provided on the command line. + +#[tokio::main] +async fn main() { + for host in std::env::args().skip(1) { + let client = reqwest::Client::new(); + match client.post(format!("{}/inhibit", host)).send().await { + Ok(resp) => { + if resp.status().is_success() { + println!("Inhibited load on {}", host); + } else { + eprintln!("Failed to inhibit load on {}: {}", host, resp.status()); + } + } + Err(e) => eprintln!("Failed to inhibit load on {}: {}", host, e), + } + } +} diff --git a/rust/load/src/bin/chroma-load-start.rs b/rust/load/src/bin/chroma-load-start.rs new file mode 100644 index 00000000000..b06694719c1 --- /dev/null +++ b/rust/load/src/bin/chroma-load-start.rs @@ -0,0 +1,68 @@ +//! Start a workload on the chroma-load server. + +use clap::Parser; + +use chroma_load::rest::StartRequest; +use chroma_load::{humanize_expires, Workload}; + +#[derive(Parser, Debug)] +struct Args { + #[arg(long)] + host: String, + #[arg(long)] + name: String, + #[arg(long)] + expires: String, + #[arg(long)] + data_set: String, + #[arg(long)] + workload: String, + #[arg(long)] + throughput: f64, +} + +#[tokio::main] +async fn main() { + let args = Args::parse(); + let client = reqwest::Client::new(); + let req = StartRequest { + name: args.name, + expires: humanize_expires(&args.expires).unwrap_or(args.expires), + data_set: args.data_set, + workload: Workload::ByName(args.workload), + throughput: args.throughput, + }; + match client + .post(format!("{}/start", args.host)) + .header(reqwest::header::ACCEPT, "application/json") + .json(&req) + .send() + .await + { + Ok(resp) => { + if resp.status().is_success() { + let uuid = match resp.text().await { + Ok(uuid) => uuid, + Err(err) => { + eprintln!("Failed to start workload on {}: {}", args.host, err); + return; + } + }; + println!( + "Started workload on {}:\n{}", + args.host, + // SAFETY(rescrv): serde_json::to_string_pretty should always convert to JSON + // when it just parses as JSON. + uuid, + ); + } else { + eprintln!( + "Failed to start workload on {}: {}", + args.host, + resp.status() + ); + } + } + Err(e) => eprintln!("Failed to start workload on {}: {}", args.host, e), + } +} diff --git a/rust/load/src/bin/chroma-load-status.rs b/rust/load/src/bin/chroma-load-status.rs new file mode 100644 index 00000000000..fcf96ac4116 --- /dev/null +++ b/rust/load/src/bin/chroma-load-status.rs @@ -0,0 +1,55 @@ +//! Inspect chroma-load + +use clap::Parser; + +#[derive(Parser, Debug)] +struct Args { + #[arg(long)] + host: String, +} + +#[tokio::main] +async fn main() { + let args = Args::parse(); + let client = reqwest::Client::new(); + match client + .get(&args.host) + .header(reqwest::header::ACCEPT, "application/json") + .send() + .await + { + Ok(resp) => { + if resp.status().is_success() { + let status = match resp.json::().await { + Ok(status) => status, + Err(e) => { + eprintln!("Failed to fetch workload status on {}: {}", args.host, e); + return; + } + }; + if status.inhibited { + println!("inhibited"); + } else { + for running in status.running { + println!( + "{} {} {} {} {}", + running.uuid, + running.expires, + running.name, + running.data_set, + // SAFETY(rescrv): WorkloadSummary always converts to JSON. + serde_json::to_string(&running.workload).unwrap() + ); + } + } + } else { + eprintln!( + "Failed to get workload status on {}: {}", + args.host, + resp.status() + ); + } + } + Err(e) => eprintln!("Failed to get workload status on {}: {}", args.host, e), + } +} diff --git a/rust/load/src/bin/chroma-load-stop.rs b/rust/load/src/bin/chroma-load-stop.rs new file mode 100644 index 00000000000..97a227e112d --- /dev/null +++ b/rust/load/src/bin/chroma-load-stop.rs @@ -0,0 +1,44 @@ +//! Stop a single workload on the chroma-load server. +//! +//! If you are looking to stop traffic for a SEV, see chroma-load-inhibit. + +use clap::Parser; +use uuid::Uuid; + +use chroma_load::rest::StopRequest; + +#[derive(Parser, Debug)] +struct Args { + #[arg(long)] + host: String, + #[arg(long)] + uuid: String, +} + +#[tokio::main] +async fn main() { + let args = Args::parse(); + let client = reqwest::Client::new(); + let req = StopRequest { + uuid: Uuid::parse_str(&args.uuid).unwrap(), + }; + match client + .post(format!("{}/stop", args.host)) + .json(&req) + .send() + .await + { + Ok(resp) => { + if resp.status().is_success() { + println!("Stopped workload on {}", args.host); + } else { + eprintln!( + "Failed to stop workload on {}: {}", + args.host, + resp.status() + ); + } + } + Err(e) => eprintln!("Failed to stop workload on {}: {}", args.host, e), + } +} diff --git a/rust/load/src/bin/chroma-load-uninhibit.rs b/rust/load/src/bin/chroma-load-uninhibit.rs new file mode 100644 index 00000000000..546f8d85027 --- /dev/null +++ b/rust/load/src/bin/chroma-load-uninhibit.rs @@ -0,0 +1,18 @@ +//! Uninhibit chroma-load on every host provided on the command line. + +#[tokio::main] +async fn main() { + for host in std::env::args().skip(1) { + let client = reqwest::Client::new(); + match client.post(format!("{}/uninhibit", host)).send().await { + Ok(resp) => { + if resp.status().is_success() { + println!("Resumed load on {}", host); + } else { + eprintln!("Failed to uninhibit load on {}: {}", host, resp.status()); + } + } + Err(e) => eprintln!("Failed to uninhibit load on {}: {}", host, e), + } + } +} diff --git a/rust/load/src/lib.rs b/rust/load/src/lib.rs index 83a3a110c7e..27adfd04bd5 100644 --- a/rust/load/src/lib.rs +++ b/rust/load/src/lib.rs @@ -169,22 +169,22 @@ impl DocumentQuery { #[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] pub struct GetQuery { - limit: Distribution, + pub limit: Distribution, #[serde(skip_serializing_if = "Option::is_none")] - metadata: Option, + pub metadata: Option, #[serde(skip_serializing_if = "Option::is_none")] - document: Option, + pub document: Option, } //////////////////////////////////////////// QueryQuery //////////////////////////////////////////// #[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] pub struct QueryQuery { - limit: Distribution, + pub limit: Distribution, #[serde(skip_serializing_if = "Option::is_none")] - metadata: Option, + pub metadata: Option, #[serde(skip_serializing_if = "Option::is_none")] - document: Option, + pub document: Option, } ///////////////////////////////////////////// Workload ///////////////////////////////////////////// @@ -201,6 +201,11 @@ pub enum Workload { Query(QueryQuery), #[serde(rename = "hybrid")] Hybrid(Vec<(f64, Workload)>), + #[serde(rename = "delay")] + Delay { + after: chrono::DateTime, + wrap: Box, + }, } impl Workload { @@ -209,17 +214,23 @@ impl Workload { } pub fn resolve_by_name(&mut self, workloads: &HashMap) -> Result<(), Error> { - if let Workload::ByName(name) = self { - if let Some(workload) = workloads.get(name) { - *self = workload.clone(); - } else { - return Err(Error::InvalidRequest(format!("workload not found: {name}"))); + match self { + Workload::Nop => {} + Workload::ByName(name) => { + if let Some(workload) = workloads.get(name) { + *self = workload.clone(); + } else { + return Err(Error::InvalidRequest(format!("workload not found: {name}"))); + } } - } - if let Workload::Hybrid(hybrid) = self { - for (_, workload) in hybrid { - workload.resolve_by_name(workloads)?; + Workload::Get(_) => {} + Workload::Query(_) => {} + Workload::Hybrid(hybrid) => { + for (_, workload) in hybrid { + workload.resolve_by_name(workloads)?; + } } + Workload::Delay { after: _, wrap } => wrap.resolve_by_name(workloads)?, } Ok(()) } @@ -255,22 +266,40 @@ impl Workload { } Workload::Hybrid(hybrid) => { let scale: f64 = any(guac); - let mut total = scale * hybrid.iter().map(|(p, _)| *p).sum::(); + let mut total = scale + * hybrid + .iter() + .filter_map(|(p, w)| if w.is_active() { Some(*p) } else { None }) + .sum::(); for (p, workload) in hybrid { if *p < 0.0 { return Err(Box::new(Error::InvalidRequest( "hybrid probabilities must be positive".to_string(), ))); } - if *p >= total { - return Box::pin(workload.step(client, data_set, guac)).await; + if workload.is_active() { + if *p >= total { + return Box::pin(workload.step(client, data_set, guac)).await; + } + total -= *p; } - total -= *p; } Err(Box::new(Error::InternalError( "miscalculation of total hybrid probabilities".to_string(), ))) } + Workload::Delay { after: _, wrap } => Box::pin(wrap.step(client, data_set, guac)).await, + } + } + + pub fn is_active(&self) -> bool { + match self { + Workload::Nop => true, + Workload::ByName(_) => true, + Workload::Get(_) => true, + Workload::Query(_) => true, + Workload::Hybrid(hybrid) => hybrid.iter().any(|(_, w)| w.is_active()), + Workload::Delay { after, wrap } => chrono::Utc::now() >= *after && wrap.is_active(), } } } @@ -511,6 +540,8 @@ impl LoadService { } if inhibit.load(std::sync::atomic::Ordering::Relaxed) { tracing::info!("inhibited"); + } else if !spec.workload.is_active() { + tracing::debug!("workload inactive"); } else if let Err(err) = spec .workload .step(&client, &*spec.data_set, &mut guac) @@ -705,6 +736,21 @@ pub async fn entrypoint() { runner.abort(); } +pub fn humanize_expires(expires: &str) -> Option { + if let Ok(expires) = chrono::DateTime::parse_from_rfc3339(expires) { + Some(expires.to_rfc3339()) + } else if let Some(duration) = expires.strip_suffix("s") { + let expires = chrono::Utc::now() + chrono::Duration::seconds(duration.trim().parse().ok()?); + Some(expires.to_rfc3339()) + } else if let Some(duration) = expires.strip_suffix("min") { + let expires = chrono::Utc::now() + + chrono::Duration::seconds(duration.trim().parse::().ok()? * 60i64); + Some(expires.to_rfc3339()) + } else { + Some(expires.to_string()) + } +} + #[cfg(test)] mod tests { use super::*; @@ -722,4 +768,80 @@ mod tests { .unwrap(); tokio::time::sleep(std::time::Duration::from_secs(10)).await; } + + #[test] + fn workload_json() { + let json = r#"{ + "hybrid": [ + [ + 1.0, + "nop" + ], + [ + 1.0, + { + "by_name": "foo" + } + ], + [ + 1.0, + { + "get": { + "limit": { + "Constant": 10 + } + } + } + ], + [ + 1.0, + { + "query": { + "limit": { + "Constant": 10 + } + } + } + ], + [ + 1.0, + { + "delay": { + "after": "2021-01-01T00:00:00Z", + "wrap": "nop" + } + } + ] + ] +}"#; + let workload = Workload::Hybrid(vec![ + (1.0, Workload::Nop), + (1.0, Workload::ByName("foo".to_string())), + ( + 1.0, + Workload::Get(GetQuery { + limit: Distribution::Constant(10), + document: None, + metadata: None, + }), + ), + ( + 1.0, + Workload::Query(QueryQuery { + limit: Distribution::Constant(10), + document: None, + metadata: None, + }), + ), + ( + 1.0, + Workload::Delay { + after: chrono::DateTime::parse_from_rfc3339("2021-01-01T00:00:00+00:00") + .unwrap(), + wrap: Box::new(Workload::Nop), + }, + ), + ]); + assert_eq!(json, serde_json::to_string_pretty(&workload).unwrap()); + } } diff --git a/rust/load/src/rest.rs b/rust/load/src/rest.rs index 2d2db9c73ea..dfc19866691 100644 --- a/rust/load/src/rest.rs +++ b/rust/load/src/rest.rs @@ -21,9 +21,10 @@ impl From<&dyn crate::DataSet> for Description { #[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] pub struct Status { + pub inhibited: bool, pub running: Vec, pub data_sets: Vec, - pub workloads: Vec, + pub workloads: Vec, } #[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]