From 8779c3cf95279b6c36298dd47075fba1535a5d9f Mon Sep 17 00:00:00 2001 From: Robert Escriva Date: Fri, 13 Dec 2024 11:43:11 -0800 Subject: [PATCH] [ENH] Parameterized queries. (#3299) This introduces parameterized queries for selecting from the tiny stories data set. --- rust/load/src/bit_difference.rs | 8 +-- rust/load/src/data_sets.rs | 8 +-- rust/load/src/lib.rs | 101 ++++++++++++++++++++++---------- rust/load/src/workloads.rs | 26 +++++--- 4 files changed, 97 insertions(+), 46 deletions(-) diff --git a/rust/load/src/bit_difference.rs b/rust/load/src/bit_difference.rs index 5f1cf0768a1..4fcbabc3cd0 100644 --- a/rust/load/src/bit_difference.rs +++ b/rust/load/src/bit_difference.rs @@ -316,8 +316,8 @@ impl DataSet for SyntheticDataSet { let collection = client.get_or_create_collection(&self.name(), None).await?; let limit = gq.limit.sample(guac); let mut ids = self.sample_ids(gq.skew, guac, limit); - let where_metadata = gq.metadata.map(|m| m.into_where_metadata(guac)); - let where_document = gq.document.map(|m| m.into_where_document(guac)); + let where_metadata = gq.metadata.map(|m| m.to_json(guac)); + let where_document = gq.document.map(|m| m.to_json(guac)); let results = collection .get(GetOptions { ids: ids.clone(), @@ -346,8 +346,8 @@ impl DataSet for SyntheticDataSet { ) -> Result<(), Box> { let collection = client.get_or_create_collection(&self.name(), None).await?; let cluster = self.cluster_by_skew(vq.skew, guac); - let where_metadata = vq.metadata.map(|m| m.into_where_metadata(guac)); - let where_document = vq.document.map(|m| m.into_where_document(guac)); + let where_metadata = vq.metadata.map(|m| m.to_json(guac)); + let where_document = vq.document.map(|m| m.to_json(guac)); let results = collection .query( QueryOptions { diff --git a/rust/load/src/data_sets.rs b/rust/load/src/data_sets.rs index 426493c8321..931b813379e 100644 --- a/rust/load/src/data_sets.rs +++ b/rust/load/src/data_sets.rs @@ -41,10 +41,10 @@ impl DataSet for NopDataSet { async fn query( &self, _: &ChromaClient, - _: QueryQuery, + qq: QueryQuery, _: &mut Guacamole, ) -> Result<(), Box> { - tracing::info!("nop query"); + tracing::info!("nop query {qq:?}", qq = qq); Ok(()) } @@ -113,8 +113,8 @@ impl DataSet for TinyStoriesDataSet { ) -> Result<(), Box> { let collection = client.get_collection(&self.name()).await?; let limit = gq.limit.sample(guac); - let where_metadata = gq.metadata.map(|m| m.into_where_metadata(guac)); - let where_document = gq.document.map(|m| m.into_where_document(guac)); + let where_metadata = gq.metadata.map(|m| m.to_json(guac)); + let where_document = gq.document.map(|m| m.to_json(guac)); let results = collection .get(GetOptions { ids: vec![], diff --git a/rust/load/src/lib.rs b/rust/load/src/lib.rs index 5161b71df24..388fa594c15 100644 --- a/rust/load/src/lib.rs +++ b/rust/load/src/lib.rs @@ -264,44 +264,88 @@ impl PartialEq for Skew { } } -/////////////////////////////////////////// MetadataQuery ////////////////////////////////////////// +///////////////////////////////////////// TinyStoriesMixin ///////////////////////////////////////// -/// A metadata query specifies a metadata filter in Chroma. -#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)] -pub enum MetadataQuery { - /// A raw metadata query simply copies the provided filter spec. - #[serde(rename = "raw")] - Raw(serde_json::Value), +#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)] +pub enum TinyStoriesMixin { + #[serde(rename = "numeric")] + Numeric { ratio_selected: f64 }, } -impl MetadataQuery { - /// Convert the metadata query into a JSON value suitable for use in a Chroma query. - pub fn into_where_metadata(self, _: &mut Guacamole) -> serde_json::Value { +impl TinyStoriesMixin { + pub fn to_json(&self, guac: &mut Guacamole) -> serde_json::Value { match self { - MetadataQuery::Raw(json) => json, + Self::Numeric { ratio_selected } => { + let field: &'static str = match uniform(0u8, 5u8)(guac) { + 0 => "i1", + 1 => "i2", + 2 => "i3", + 3 => "f1", + 4 => "f2", + 5 => "f3", + _ => unreachable!(), + }; + let mut center = uniform(0, 1_000_000)(guac); + let window = (1e6 * ratio_selected) as usize; + if window / 2 > center { + center = window / 2 + } + let min = center - window / 2; + let max = center + window / 2; + serde_json::json!({"$and": [{field: {"$gte": min}}, {field: {"$lt": max}}]}) + } } } } -/////////////////////////////////////////// DocumentQuery ////////////////////////////////////////// +//////////////////////////////////////////// WhereMixin //////////////////////////////////////////// -/// A document query specifies a document filter in Chroma. -#[derive(Clone, Debug, Eq, PartialEq, serde::Deserialize, serde::Serialize)] -pub enum DocumentQuery { - // A raw document query simply copies the provided filter spec. - #[serde(rename = "raw")] - Raw(serde_json::Value), +/// A metadata query specifies a metadata filter in Chroma. +#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)] +pub enum WhereMixin { + /// A raw metadata query simply copies the provided filter spec. + #[serde(rename = "query")] + Constant(serde_json::Value), + /// The tiny stories workload. The way these collections were setup, there are three fields + /// each of integer, float, and string. The integer fields are named i1, i2, and i3. The + /// float fields are named f1, f2, and f3. The string fields are named s1, s2, and s3. + /// + /// This mixin selects one of these 6 numeric fields at random and picks a metadata range query + /// to perform on it that will return data according to the mixin. + #[serde(rename = "tiny-stories")] + TinyStories(TinyStoriesMixin), + /// A constant operator with different comparison. + /// A mix of metadata queries selects one of the queries at random. + #[serde(rename = "select")] + Select(Vec<(f64, WhereMixin)>), } -impl DocumentQuery { - /// Convert the document query into a JSON value suitable for use in a Chroma query. - pub fn into_where_document(self, _: &mut Guacamole) -> serde_json::Value { +impl WhereMixin { + /// Convert the metadata query into a JSON value suitable for use in a Chroma query. + pub fn to_json(&self, guac: &mut Guacamole) -> serde_json::Value { match self { - DocumentQuery::Raw(json) => json, + Self::Constant(query) => query.clone(), + Self::TinyStories(mixin) => mixin.to_json(guac), + Self::Select(select) => { + let scale: f64 = any(guac); + let mut total = scale * select.iter().map(|(p, _)| *p).sum::(); + for (p, mixin) in select { + if *p < 0.0 { + return serde_json::Value::Null; + } + if *p >= total { + return mixin.to_json(guac); + } + total -= *p; + } + serde_json::Value::Null + } } } } +impl Eq for WhereMixin {} + ///////////////////////////////////////////// GetQuery ///////////////////////////////////////////// /// A get query specifies a get operation in Chroma. @@ -318,9 +362,9 @@ pub struct GetQuery { pub skew: Skew, pub limit: Distribution, #[serde(skip_serializing_if = "Option::is_none")] - pub metadata: Option, + pub metadata: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub document: Option, + pub document: Option, } //////////////////////////////////////////// QueryQuery //////////////////////////////////////////// @@ -339,9 +383,9 @@ pub struct QueryQuery { pub skew: Skew, pub limit: Distribution, #[serde(skip_serializing_if = "Option::is_none")] - pub metadata: Option, + pub metadata: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub document: Option, + pub document: Option, } //////////////////////////////////////////// KeySelector /////////////////////////////////////////// @@ -1505,13 +1549,11 @@ mod tests { #[test] fn workload_save_restore() { const TEST_PATH: &str = "workload_save_restore.test.json"; - println!("FINDME {}:{}", file!(), line!()); std::fs::remove_file(TEST_PATH).ok(); // First verse. let mut load = LoadService::default(); load.set_persistent_path_and_load(Some(TEST_PATH.to_string())) .unwrap(); - println!("FINDME {}:{}", file!(), line!()); load.start( "foo".to_string(), "nop".to_string(), @@ -1520,7 +1562,6 @@ mod tests { Throughput::Constant(1.0), ) .unwrap(); - println!("FINDME {}:{}", file!(), line!()); let expected = { // SAFETY(rescrv): Mutex poisoning. let harness = load.harness.lock().unwrap(); @@ -1528,7 +1569,6 @@ mod tests { harness.running[0].clone() }; drop(load); - println!("FINDME {}:{}", file!(), line!()); println!("expected: {:?}", expected); // Second verse. let mut load = LoadService::default(); @@ -1537,7 +1577,6 @@ mod tests { let harness = load.harness.lock().unwrap(); assert!(harness.running.is_empty()); } - println!("FINDME {}:{}", file!(), line!()); load.set_persistent_path_and_load(Some(TEST_PATH.to_string())) .unwrap(); { diff --git a/rust/load/src/workloads.rs b/rust/load/src/workloads.rs index ec52567f260..f51c199bcc7 100644 --- a/rust/load/src/workloads.rs +++ b/rust/load/src/workloads.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use crate::{ - Distribution, DocumentQuery, GetQuery, KeySelector, MetadataQuery, QueryQuery, Skew, Workload, + Distribution, GetQuery, KeySelector, QueryQuery, Skew, TinyStoriesMixin, WhereMixin, Workload, }; /// Return a map of all pre-configured workloads. @@ -22,7 +22,9 @@ pub fn all_workloads() -> HashMap { skew: Skew::Zipf { theta: 0.999 }, limit: Distribution::Constant(10), metadata: None, - document: Some(DocumentQuery::Raw(serde_json::json!({"$contains": "the"}))), + document: Some(WhereMixin::Constant( + serde_json::json!({"$contains": "the"}), + )), }), ), ( @@ -30,7 +32,9 @@ pub fn all_workloads() -> HashMap { Workload::Get(GetQuery { skew: Skew::Zipf { theta: 0.999 }, limit: Distribution::Constant(10), - metadata: Some(MetadataQuery::Raw(serde_json::json!({"i1": 1000}))), + metadata: Some(WhereMixin::TinyStories(TinyStoriesMixin::Numeric { + ratio_selected: 0.01, + })), document: None, }), ), @@ -52,7 +56,9 @@ pub fn all_workloads() -> HashMap { skew: Skew::Zipf { theta: 0.999 }, limit: Distribution::Constant(10), metadata: None, - document: Some(DocumentQuery::Raw(serde_json::json!({"$contains": "the"}))), + document: Some(WhereMixin::Constant( + serde_json::json!({"$contains": "the"}), + )), }), ), ( @@ -60,7 +66,9 @@ pub fn all_workloads() -> HashMap { Workload::Query(QueryQuery { skew: Skew::Zipf { theta: 0.999 }, limit: Distribution::Constant(10), - metadata: Some(MetadataQuery::Raw(serde_json::json!({"i1": 1000}))), + metadata: Some(WhereMixin::TinyStories(TinyStoriesMixin::Numeric { + ratio_selected: 0.01, + })), document: None, }), ), @@ -75,7 +83,9 @@ pub fn all_workloads() -> HashMap { skew: Skew::Zipf { theta: 0.999 }, limit: Distribution::Constant(10), metadata: None, - document: Some(DocumentQuery::Raw(serde_json::json!({"$contains": "the"}))), + document: Some(WhereMixin::Constant( + serde_json::json!({"$contains": "the"}), + )), }), ), ( @@ -83,7 +93,9 @@ pub fn all_workloads() -> HashMap { Workload::Get(GetQuery { skew: Skew::Zipf { theta: 0.999 }, limit: Distribution::Constant(10), - metadata: Some(MetadataQuery::Raw(serde_json::json!({"i1": 1000}))), + metadata: Some(WhereMixin::TinyStories(TinyStoriesMixin::Numeric { + ratio_selected: 0.01, + })), document: None, }), ),