Skip to content

Commit

Permalink
[ENH] Support RI-4 and RI-5 hybrid read workloads.
Browse files Browse the repository at this point in the history
  • Loading branch information
rescrv committed Dec 2, 2024
1 parent 7591522 commit 03b4be4
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 11 deletions.
50 changes: 41 additions & 9 deletions rust/load/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,17 +190,32 @@ pub struct QueryQuery {
pub enum Workload {
#[serde(rename = "nop")]
Nop,
#[serde(rename = "by_name")]
ByName(String),
#[serde(rename = "get")]
Get(GetQuery),
#[serde(rename = "query")]
Query(QueryQuery),
#[serde(rename = "hybrid")]
Hybrid(Vec<(f64, Workload)>),
}

impl Workload {
pub fn description(&self) -> String {
serde_json::to_string_pretty(self).unwrap()
}

pub fn resolve_by_name(&mut self, workloads: &HashMap<String, Workload>) -> Result<(), Error> {
if let Workload::ByName(name) = self {
if let Some(workload) = workloads.get(name) {
*self = workload.clone();
} else {
return Err(Error::InvalidRequest(format!("workload not found: {name}")));
}
}
Ok(())
}

pub async fn step(
&self,
client: &ChromaClient,
Expand All @@ -212,6 +227,12 @@ impl Workload {
tracing::info!("nop");
Ok(())
}
Workload::ByName(_) => {
tracing::error!("cannot step by name; by_name should be resolved");
Err(Box::new(Error::InternalError(
"cannot step by name".to_string(),
)))
}
Workload::Get(get) => {
data_set
.get(client, get.clone(), guac)
Expand All @@ -224,6 +245,24 @@ impl Workload {
.instrument(tracing::info_span!("query"))
.await
}
Workload::Hybrid(hybrid) => {
let scale: f64 = any(guac);
let mut total = scale * hybrid.iter().map(|(p, _)| *p).sum::<f64>();
for (p, workload) in hybrid {
if *p < 0.0 {
return Err(Box::new(Error::InvalidRequest(
"hybrid probabilities must be positive".to_string(),
)));
}
if *p >= total {
return Box::pin(workload.step(client, data_set, guac)).await;
}
total -= *p;
}
Err(Box::new(Error::InternalError(
"miscalculation of total hybrid probabilities".to_string(),
)))
}
}
}
}
Expand Down Expand Up @@ -367,21 +406,14 @@ impl LoadService {
&self,
name: String,
data_set: String,
workload: String,
mut workload: Workload,
expires: chrono::DateTime<chrono::FixedOffset>,
throughput: f64,
) -> Result<Uuid, Error> {
let Some(data_set) = self.data_sets().iter().find(|ds| ds.name() == data_set) else {
return Err(Error::NotFound("data set not found".to_string()));
};
let Some(workload) = self
.workloads()
.iter()
.find(|(name, _)| **name == workload)
.map(|(_, wl)| wl)
else {
return Err(Error::NotFound("workload not found".to_string()));
};
workload.resolve_by_name(self.workloads())?;
// SAFETY(rescrv): Mutex poisoning.
let mut harness = self.harness.lock().unwrap();
Ok(harness.start(name, workload.clone(), data_set, expires, throughput))
Expand Down
4 changes: 2 additions & 2 deletions rust/load/src/rest.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use uuid::Uuid;

use crate::WorkloadSummary;
use crate::{Workload, WorkloadSummary};

#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
pub struct Description {
Expand Down Expand Up @@ -29,7 +29,7 @@ pub struct Status {
#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
pub struct StartRequest {
pub name: String,
pub workload: String,
pub workload: Workload,
pub data_set: String,
pub expires: String,
pub throughput: f64,
Expand Down
50 changes: 50 additions & 0 deletions rust/load/src/workloads.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,55 @@ pub fn all_workloads() -> HashMap<String, Workload> {
document: None,
}),
),
(
"hybrid-fts-vector".to_string(),
Workload::Hybrid(vec![
(
0.3,
Workload::Get(GetQuery {
limit: Distribution::Constant(10),
metadata: None,
document: Some(DocumentQuery::Raw(serde_json::json!({"$contains": "the"}))),
}),
),
(
0.7,
Workload::Query(QueryQuery {
limit: Distribution::Constant(10),
metadata: Some(MetadataQuery::Raw(serde_json::json!({"i1": 1000}))),
document: None,
}),
),
]),
),
(
"hybrid-fts-md-vector".to_string(),
Workload::Hybrid(vec![
(
0.5,
Workload::Get(GetQuery {
limit: Distribution::Constant(10),
metadata: None,
document: Some(DocumentQuery::Raw(serde_json::json!({"$contains": "the"}))),
}),
),
(
0.25,
Workload::Get(GetQuery {
limit: Distribution::Constant(10),
metadata: Some(MetadataQuery::Raw(serde_json::json!({"i1": 1000}))),
document: None,
}),
),
(
0.25,
Workload::Query(QueryQuery {
limit: Distribution::Constant(10),
metadata: Some(MetadataQuery::Raw(serde_json::json!({"i1": 1000}))),
document: None,
}),
),
]),
),
])
}

0 comments on commit 03b4be4

Please sign in to comment.