Skip to content

Commit

Permalink
K-nearest neighbor (KNN) vector searches support (#165)
Browse files Browse the repository at this point in the history
* Reorder fields

* Initial FirestoreVector type support and low level API implementation

* Example and tests

* Fluent API support and example update
  • Loading branch information
abdolence authored Apr 12, 2024
1 parent 6aeb7dc commit 71cbe43
Show file tree
Hide file tree
Showing 12 changed files with 690 additions and 56 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Library provides a simple API for Google Firestore based on the official gRPC AP
- Transactions;
- Aggregated Queries;
- Streaming batch writes with automatic throttling to avoid time limits from Firestore;
- K-nearest neighbor (KNN) vector search;
- Explaining queries;
- Fluent high-level and strongly typed API;
- Full async based on Tokio runtime;
Expand All @@ -36,7 +37,7 @@ Cargo.toml:

```toml
[dependencies]
firestore = "0.40"
firestore = "0.41"
```

## Examples
Expand Down
90 changes: 90 additions & 0 deletions examples/nearest-vector-query.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
use firestore::*;
use serde::{Deserialize, Serialize};

pub fn config_env_var(name: &str) -> Result<String, String> {
std::env::var(name).map_err(|e| format!("{}: {}", name, e))
}

// Example structure to play with
#[derive(Debug, Clone, Deserialize, Serialize)]
struct MyTestStructure {
some_id: String,
some_string: String,
some_vec: FirestoreVector,
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// Logging with debug enabled
let subscriber = tracing_subscriber::fmt()
.with_env_filter("firestore=debug")
.finish();
tracing::subscriber::set_global_default(subscriber)?;

// Create an instance
let db = FirestoreDb::new(&config_env_var("PROJECT_ID")?).await?;

const TEST_COLLECTION_NAME: &'static str = "test-query-vec";

if db
.fluent()
.select()
.by_id_in(TEST_COLLECTION_NAME)
.one("test-0")
.await?
.is_none()
{
println!("Populating a test collection");
let batch_writer = db.create_simple_batch_writer().await?;
let mut current_batch = batch_writer.new_batch();

for i in 0..500 {
let my_struct = MyTestStructure {
some_id: format!("test-{}", i),
some_string: "Test".to_string(),
some_vec: vec![i as f64, (i * 10) as f64, (i * 20) as f64].into(),
};

// Let's insert some data
db.fluent()
.update()
.in_col(TEST_COLLECTION_NAME)
.document_id(&my_struct.some_id)
.object(&my_struct)
.add_to_batch(&mut current_batch)?;
}
current_batch.write().await?;
}

println!("Show sample documents in the test collection");
let as_vec: Vec<MyTestStructure> = db
.fluent()
.select()
.from(TEST_COLLECTION_NAME)
.limit(3)
.obj()
.query()
.await?;

println!("Examples: {:?}", as_vec);

println!("Search for a test collection with a vector closest");

let as_vec: Vec<MyTestStructure> = db
.fluent()
.select()
.from(TEST_COLLECTION_NAME)
.find_nearest(
path!(MyTestStructure::some_vec),
vec![0.0_f64, 0.0_f64, 0.0_f64].into(),
FirestoreFindNearestDistanceMeasure::Euclidean,
5,
)
.obj()
.query()
.await?;

println!("Found: {:?}", as_vec);

Ok(())
}
12 changes: 12 additions & 0 deletions examples/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,19 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {

println!("Querying a test collection as a stream using Fluent API");

// Simple query into vector
// Query as a stream our data
let as_vec: Vec<MyTestStructure> = db
.fluent()
.select()
.from(TEST_COLLECTION_NAME)
.obj()
.query()
.await?;

println!("{:?}", as_vec);

// Query as a stream our data with filters and ordering
let object_stream: BoxStream<FirestoreResult<MyTestStructure>> = db
.fluent()
.select()
Expand Down
2 changes: 1 addition & 1 deletion src/db/aggregated_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ impl FirestoreDb {
query_type: Some(run_aggregation_query_request::QueryType::StructuredAggregationQuery(
StructuredAggregationQuery {
aggregations: params.aggregations.iter().map(|agg| agg.into()).collect(),
query_type: Some(gcloud_sdk::google::firestore::v1::structured_aggregation_query::QueryType::StructuredQuery(params.query_params.into())),
query_type: Some(gcloud_sdk::google::firestore::v1::structured_aggregation_query::QueryType::StructuredQuery(params.query_params.try_into()?)),
}
)),
explain_options: None,
Expand Down
2 changes: 1 addition & 1 deletion src/db/listen_changes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ impl FirestoreDb {
.unwrap_or_else(|| self.get_documents_path())
.clone(),
query_type: Some(target::query_target::QueryType::StructuredQuery(
query_params.into(),
query_params.try_into()?,
)),
})
}
Expand Down
94 changes: 53 additions & 41 deletions src/db/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ impl FirestoreDb {
.as_ref()
.map(|eo| eo.try_into())
.transpose()?,
query_type: Some(run_query_request::QueryType::StructuredQuery(params.into())),
query_type: Some(run_query_request::QueryType::StructuredQuery(
params.try_into()?,
)),
}))
}

Expand Down Expand Up @@ -415,49 +417,59 @@ impl FirestoreQuerySupport for FirestoreDb {
Some((params, consistency_selector)),
move |maybe_params| async move {
if let Some((params, maybe_consistency_selector)) = maybe_params {
let request = gcloud_sdk::tonic::Request::new(PartitionQueryRequest {
page_size: params.page_size as i32,
partition_count: params.partition_count as i64,
parent: params
.query_params
.parent
.as_ref()
.unwrap_or_else(|| self.get_documents_path())
.clone(),
consistency_selector: maybe_consistency_selector.clone(),
query_type: Some(
partition_query_request::QueryType::StructuredQuery(
params.query_params.clone().into(),
),
),
page_token: params.page_token.clone().unwrap_or_default(),
});

match self.client().get().partition_query(request).await {
Ok(response) => {
let partition_response = response.into_inner();
let firestore_cursors: Vec<FirestoreQueryCursor> =
partition_response
.partitions
.into_iter()
.map(|e| e.into())
.collect();

if !partition_response.next_page_token.is_empty() {
Some((
Ok(firestore_cursors),
Some((
params.with_page_token(
partition_response.next_page_token,
match params.query_params.clone().try_into() {
Ok(query_params) => {
let request =
gcloud_sdk::tonic::Request::new(PartitionQueryRequest {
page_size: params.page_size as i32,
partition_count: params.partition_count as i64,
parent: params
.query_params
.parent
.as_ref()
.unwrap_or_else(|| self.get_documents_path())
.clone(),
consistency_selector: maybe_consistency_selector
.clone(),
query_type: Some(
partition_query_request::QueryType::StructuredQuery(
query_params,
),
maybe_consistency_selector,
)),
))
} else {
Some((Ok(firestore_cursors), None))
),
page_token: params
.page_token
.clone()
.unwrap_or_default(),
});

match self.client().get().partition_query(request).await {
Ok(response) => {
let partition_response = response.into_inner();
let firestore_cursors: Vec<FirestoreQueryCursor> =
partition_response
.partitions
.into_iter()
.map(|e| e.into())
.collect();

if !partition_response.next_page_token.is_empty() {
Some((
Ok(firestore_cursors),
Some((
params.with_page_token(
partition_response.next_page_token,
),
maybe_consistency_selector,
)),
))
} else {
Some((Ok(firestore_cursors), None))
}
}
Err(err) => Some((Err(FirestoreError::from(err)), None)),
}
}
Err(err) => Some((Err(FirestoreError::from(err)), None)),
Err(err) => Some((Err(err), None)),
}
} else {
None
Expand Down
87 changes: 80 additions & 7 deletions src/db/query_models.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#![allow(clippy::derive_partial_eq_without_eq)] // Since we may not be able to implement Eq for the changes coming from Firestore protos

use crate::errors::FirestoreError;
use crate::FirestoreValue;
use crate::errors::{
FirestoreError, FirestoreInvalidParametersError, FirestoreInvalidParametersPublicDetails,
};
use crate::{FirestoreValue, FirestoreVector};
use gcloud_sdk::google::firestore::v1::*;
use rsb_derive::Builder;

Expand Down Expand Up @@ -39,13 +41,16 @@ pub struct FirestoreQueryParams {
pub start_at: Option<FirestoreQueryCursor>,
pub end_at: Option<FirestoreQueryCursor>,
pub explain_options: Option<FirestoreExplainOptions>,
pub find_nearest: Option<FirestoreFindNearestOptions>,
}

impl From<FirestoreQueryParams> for StructuredQuery {
fn from(params: FirestoreQueryParams) -> Self {
impl TryFrom<FirestoreQueryParams> for StructuredQuery {
type Error = FirestoreError;

fn try_from(params: FirestoreQueryParams) -> Result<Self, Self::Error> {
let query_filter = params.filter.map(|f| f.into());

StructuredQuery {
Ok(StructuredQuery {
select: params.return_only_fields.map(|select_only_fields| {
structured_query::Projection {
fields: select_only_fields
Expand Down Expand Up @@ -79,9 +84,12 @@ impl From<FirestoreQueryParams> for StructuredQuery {
})
.collect(),
},
find_nearest: params
.find_nearest
.map(|find_nearest| find_nearest.try_into())
.transpose()?,
r#where: query_filter,
find_nearest: None,
}
})
}
}

Expand Down Expand Up @@ -425,3 +433,68 @@ impl TryFrom<&FirestoreExplainOptions> for gcloud_sdk::google::firestore::v1::Ex
})
}
}

#[derive(Debug, PartialEq, Clone, Builder)]
pub struct FirestoreFindNearestOptions {
pub field_name: String,
pub query_vector: FirestoreVector,
pub distance_measure: FirestoreFindNearestDistanceMeasure,
pub neighbors_limit: u32,
}

impl TryFrom<FirestoreFindNearestOptions>
for gcloud_sdk::google::firestore::v1::structured_query::FindNearest
{
type Error = FirestoreError;

fn try_from(options: FirestoreFindNearestOptions) -> Result<Self, Self::Error> {
Ok(structured_query::FindNearest {
vector_field: Some(structured_query::FieldReference {
field_path: options.field_name,
}),
query_vector: Some(Into::<FirestoreValue>::into(options.query_vector).value),
distance_measure: {
let distance_measure: structured_query::find_nearest::DistanceMeasure = options.distance_measure.try_into()?;
distance_measure.into()
},
limit: Some(options.neighbors_limit.try_into().map_err(|e| FirestoreError::InvalidParametersError(
FirestoreInvalidParametersError::new(FirestoreInvalidParametersPublicDetails::new(
"neighbors_limit".to_string(),
format!(
"Invalid value for neighbors_limit: {}. Maximum allowed value is {}. Error: {}",
options.neighbors_limit,
i32::MAX,
e
),
)))
)?),
})
}
}

#[derive(Debug, PartialEq, Clone)]
pub enum FirestoreFindNearestDistanceMeasure {
Euclidean,
Cosine,
DotProduct,
}

impl TryFrom<FirestoreFindNearestDistanceMeasure>
for structured_query::find_nearest::DistanceMeasure
{
type Error = FirestoreError;

fn try_from(measure: FirestoreFindNearestDistanceMeasure) -> Result<Self, Self::Error> {
match measure {
FirestoreFindNearestDistanceMeasure::Euclidean => {
Ok(structured_query::find_nearest::DistanceMeasure::Euclidean)
}
FirestoreFindNearestDistanceMeasure::Cosine => {
Ok(structured_query::find_nearest::DistanceMeasure::Cosine)
}
FirestoreFindNearestDistanceMeasure::DotProduct => {
Ok(structured_query::find_nearest::DistanceMeasure::DotProduct)
}
}
}
}
8 changes: 8 additions & 0 deletions src/firestore_serde/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,25 @@ mod deserializer;
mod serializer;

mod timestamp_serializers;

pub use timestamp_serializers::*;

mod null_serializers;

pub use null_serializers::*;

mod latlng_serializers;

pub use latlng_serializers::*;

mod reference_serializers;

pub use reference_serializers::*;

mod vector_serializers;

pub use vector_serializers::*;

use crate::FirestoreValue;
use gcloud_sdk::google::firestore::v1::Value;

Expand Down
Loading

0 comments on commit 71cbe43

Please sign in to comment.