Skip to content

Commit

Permalink
[ENH][chroma-load] Support delay on workloads.
Browse files Browse the repository at this point in the history
This supports delaying a workload.  The way it's implemented, workloads
advertise being "active" and just silently NOP when they are "inactive".
  • Loading branch information
rescrv committed Dec 2, 2024
1 parent 22807bb commit dcc5522
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 20 deletions.
2 changes: 1 addition & 1 deletion rust/load/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ edition = "2021"
async-trait = "0.1.83"
axum = "0.7"
chromadb = { git = "https://github.com/rescrv/chromadb-rs", rev = "e364e35c34c660d4e8e862436ea600ddc2f46a1e" }
chrono = "0.4.38"
chrono = { version = "0.4.38", features = ["serde"] }
figment = { version = "0.10.12", features = ["env", "yaml", "test"] }
guacamole = { version = "0.9", default-features = false }
serde.workspace = true
Expand Down
32 changes: 32 additions & 0 deletions rust/load/examples/workload-json.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use chroma_load::{Distribution, GetQuery, QueryQuery, Workload};

fn main() {
let w = Workload::Hybrid(vec![
(1.0, Workload::Nop),
(1.0, Workload::ByName("foo".to_string())),
(
1.0,
Workload::Get(GetQuery {
limit: Distribution::Constant(10),
document: None,
metadata: None,
}),
),
(
1.0,
Workload::Query(QueryQuery {
limit: Distribution::Constant(10),
document: None,
metadata: None,
}),
),
(
1.0,
Workload::Delay(
chrono::DateTime::parse_from_rfc3339("2021-01-01T00:00:00+00:00").unwrap(),
Box::new(Workload::Nop),
),
),
]);
println!("{}", serde_json::to_string_pretty(&w).unwrap());
}
141 changes: 122 additions & 19 deletions rust/load/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,22 +169,22 @@ impl DocumentQuery {

#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
pub struct GetQuery {
limit: Distribution,
pub limit: Distribution,
#[serde(skip_serializing_if = "Option::is_none")]
metadata: Option<MetadataQuery>,
pub metadata: Option<MetadataQuery>,
#[serde(skip_serializing_if = "Option::is_none")]
document: Option<DocumentQuery>,
pub document: Option<DocumentQuery>,
}

//////////////////////////////////////////// QueryQuery ////////////////////////////////////////////

#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
pub struct QueryQuery {
limit: Distribution,
pub limit: Distribution,
#[serde(skip_serializing_if = "Option::is_none")]
metadata: Option<MetadataQuery>,
pub metadata: Option<MetadataQuery>,
#[serde(skip_serializing_if = "Option::is_none")]
document: Option<DocumentQuery>,
pub document: Option<DocumentQuery>,
}

///////////////////////////////////////////// Workload /////////////////////////////////////////////
Expand All @@ -201,6 +201,8 @@ pub enum Workload {
Query(QueryQuery),
#[serde(rename = "hybrid")]
Hybrid(Vec<(f64, Workload)>),
#[serde(rename = "delay")]
Delay(chrono::DateTime<chrono::FixedOffset>, Box<Workload>),
}

impl Workload {
Expand All @@ -209,17 +211,23 @@ impl Workload {
}

pub fn resolve_by_name(&mut self, workloads: &HashMap<String, Workload>) -> Result<(), Error> {
if let Workload::ByName(name) = self {
if let Some(workload) = workloads.get(name) {
*self = workload.clone();
} else {
return Err(Error::InvalidRequest(format!("workload not found: {name}")));
match self {
Workload::Nop => {}
Workload::ByName(name) => {
if let Some(workload) = workloads.get(name) {
*self = workload.clone();
} else {
return Err(Error::InvalidRequest(format!("workload not found: {name}")));
}
}
}
if let Workload::Hybrid(hybrid) = self {
for (_, workload) in hybrid {
workload.resolve_by_name(workloads)?;
Workload::Get(_) => {}
Workload::Query(_) => {}
Workload::Hybrid(hybrid) => {
for (_, workload) in hybrid {
workload.resolve_by_name(workloads)?;
}
}
Workload::Delay(_, w) => w.resolve_by_name(workloads)?,
}
Ok(())
}
Expand Down Expand Up @@ -255,22 +263,40 @@ impl Workload {
}
Workload::Hybrid(hybrid) => {
let scale: f64 = any(guac);
let mut total = scale * hybrid.iter().map(|(p, _)| *p).sum::<f64>();
let mut total = scale
* hybrid
.iter()
.filter_map(|(p, w)| if w.is_active() { Some(*p) } else { None })
.sum::<f64>();
for (p, workload) in hybrid {
if *p < 0.0 {
return Err(Box::new(Error::InvalidRequest(
"hybrid probabilities must be positive".to_string(),
)));
}
if *p >= total {
return Box::pin(workload.step(client, data_set, guac)).await;
if workload.is_active() {
if *p >= total {
return Box::pin(workload.step(client, data_set, guac)).await;
}
total -= *p;
}
total -= *p;
}
Err(Box::new(Error::InternalError(
"miscalculation of total hybrid probabilities".to_string(),
)))
}
Workload::Delay(_, w) => Box::pin(w.step(client, data_set, guac)).await,
}
}

pub fn is_active(&self) -> bool {
match self {
Workload::Nop => true,
Workload::ByName(_) => true,
Workload::Get(_) => true,
Workload::Query(_) => true,
Workload::Hybrid(hybrid) => hybrid.iter().any(|(_, w)| w.is_active()),
Workload::Delay(after, w) => chrono::Utc::now() >= *after && w.is_active(),
}
}
}
Expand Down Expand Up @@ -511,6 +537,8 @@ impl LoadService {
}
if inhibit.load(std::sync::atomic::Ordering::Relaxed) {
tracing::info!("inhibited");
} else if !spec.workload.is_active() {
tracing::debug!("workload inactive");
} else if let Err(err) = spec
.workload
.step(&client, &*spec.data_set, &mut guac)
Expand Down Expand Up @@ -722,4 +750,79 @@ mod tests {
.unwrap();
tokio::time::sleep(std::time::Duration::from_secs(10)).await;
}

#[test]
fn workload_json() {
let json = r#"{
"hybrid": [
[
1.0,
"nop"
],
[
1.0,
{
"by_name": "foo"
}
],
[
1.0,
{
"get": {
"limit": {
"Constant": 10
}
}
}
],
[
1.0,
{
"query": {
"limit": {
"Constant": 10
}
}
}
],
[
1.0,
{
"delay": [
"2021-01-01T00:00:00Z",
"nop"
]
}
]
]
}"#;
let workload = Workload::Hybrid(vec![
(1.0, Workload::Nop),
(1.0, Workload::ByName("foo".to_string())),
(
1.0,
Workload::Get(GetQuery {
limit: Distribution::Constant(10),
document: None,
metadata: None,
}),
),
(
1.0,
Workload::Query(QueryQuery {
limit: Distribution::Constant(10),
document: None,
metadata: None,
}),
),
(
1.0,
Workload::Delay(
chrono::DateTime::parse_from_rfc3339("2021-01-01T00:00:00+00:00").unwrap(),
Box::new(Workload::Nop),
),
),
]);
assert_eq!(json, serde_json::to_string_pretty(&workload).unwrap());
}
}

0 comments on commit dcc5522

Please sign in to comment.