Skip to content

Commit

Permalink
rng_query_operator
Browse files Browse the repository at this point in the history
  • Loading branch information
sanketkedia committed Nov 28, 2024
1 parent 9de114c commit d1eec30
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 60 deletions.
70 changes: 11 additions & 59 deletions rust/index/src/spann/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use crate::{
Index, IndexUuid,
};

use super::utils::KMeansAlgorithm;
use super::utils::{rng_query, KMeansAlgorithm};

pub struct VersionsMapInner {
pub versions_map: HashMap<u32, u32>,
Expand Down Expand Up @@ -327,64 +327,16 @@ impl SpannIndexWriter {
if self.distance_function == DistanceFunction::Cosine {
normalized_query = normalize(query)
}
let mut nearby_ids: Vec<usize> = vec![];
let mut nearby_distances: Vec<f32> = vec![];
let mut embeddings: Vec<Vec<f32>> = vec![];
{
let read_guard = self.hnsw_index.inner.read();
let allowed_ids = vec![];
let disallowed_ids = vec![];
let (ids, distances) = read_guard
.query(
&normalized_query,
NUM_CENTROIDS_TO_SEARCH as usize,
&allowed_ids,
&disallowed_ids,
)
.map_err(|_| SpannIndexWriterConstructionError::HnswIndexSearchError)?;
for (id, distance) in ids.iter().zip(distances.iter()) {
if *distance <= (1_f32 + QUERY_EPSILON) * distances[0] {
nearby_ids.push(*id);
nearby_distances.push(*distance);
}
}
// Get the embeddings also for distance computation.
for id in nearby_ids.iter() {
let emb = read_guard
.get(*id)
.map_err(|_| SpannIndexWriterConstructionError::HnswIndexSearchError)?
.ok_or(SpannIndexWriterConstructionError::HnswIndexSearchError)?;
embeddings.push(emb);
}
}
// Apply the RNG rule to prune.
let mut res_ids = vec![];
let mut res_distances = vec![];
let mut res_embeddings: Vec<Vec<f32>> = vec![];
// Embeddings that were obtained are already normalized.
for (id, (distance, embedding)) in nearby_ids
.iter()
.zip(nearby_distances.iter().zip(embeddings))
{
let mut rng_accepted = true;
for nbr_embedding in res_embeddings.iter() {
let dist = self
.distance_function
.distance(&embedding[..], &nbr_embedding[..]);
if RNG_FACTOR * dist <= *distance {
rng_accepted = false;
break;
}
}
if !rng_accepted {
continue;
}
res_ids.push(*id);
res_distances.push(*distance);
res_embeddings.push(embedding);
}

Ok((res_ids, res_distances, res_embeddings))
rng_query(
&normalized_query,
self.hnsw_index.clone(),
NUM_CENTROIDS_TO_SEARCH as usize,
QUERY_EPSILON,
RNG_FACTOR,
self.distance_function.clone(),
)
.await
.map_err(|_| SpannIndexWriterConstructionError::HnswIndexSearchError)
}

async fn is_outdated(
Expand Down
80 changes: 80 additions & 0 deletions rust/index/src/spann/utils.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use std::{cmp::min, collections::HashMap};

use chroma_distance::DistanceFunction;
use chroma_error::{ChromaError, ErrorCodes};
use rand::{seq::SliceRandom, Rng};
use thiserror::Error;

use crate::{hnsw_provider::HnswIndexRef, Index};

// TODO(Sanket): I don't understand why the reference implementation defined
// max_distance this way.
Expand Down Expand Up @@ -428,6 +432,82 @@ impl<'referred_data> KMeansAlgorithm<'referred_data> {
}
}

#[derive(Error, Debug)]
pub enum RngQueryError {
#[error("Error searching Hnsw graph")]
HnswSearchError,
}

impl ChromaError for RngQueryError {
fn code(&self) -> ErrorCodes {
match self {
Self::HnswSearchError => ErrorCodes::Internal,
}
}
}

// Assumes that query is already normalized.
pub async fn rng_query(
normalized_query: &[f32],
hnsw_index: HnswIndexRef,
k: usize,
rng_epsilon: f32,
rng_factor: f32,
distance_function: DistanceFunction,
) -> Result<(Vec<usize>, Vec<f32>, Vec<Vec<f32>>), RngQueryError> {
let mut nearby_ids: Vec<usize> = vec![];
let mut nearby_distances: Vec<f32> = vec![];
let mut embeddings: Vec<Vec<f32>> = vec![];
{
let read_guard = hnsw_index.inner.read();
let allowed_ids = vec![];
let disallowed_ids = vec![];
let (ids, distances) = read_guard
.query(normalized_query, k, &allowed_ids, &disallowed_ids)
.map_err(|_| RngQueryError::HnswSearchError)?;
for (id, distance) in ids.iter().zip(distances.iter()) {
if *distance <= (1_f32 + rng_epsilon) * distances[0] {
nearby_ids.push(*id);
nearby_distances.push(*distance);
}
}
// Get the embeddings also for distance computation.
for id in nearby_ids.iter() {
let emb = read_guard
.get(*id)
.map_err(|_| RngQueryError::HnswSearchError)?
.ok_or(RngQueryError::HnswSearchError)?;
embeddings.push(emb);
}
}
// Apply the RNG rule to prune.
let mut res_ids = vec![];
let mut res_distances = vec![];
let mut res_embeddings: Vec<Vec<f32>> = vec![];
// Embeddings that were obtained are already normalized.
for (id, (distance, embedding)) in nearby_ids
.iter()
.zip(nearby_distances.iter().zip(embeddings))
{
let mut rng_accepted = true;
for nbr_embedding in res_embeddings.iter() {
let dist = distance_function.distance(&embedding[..], &nbr_embedding[..]);
if rng_factor * dist <= *distance {
rng_accepted = false;
break;
}
}
if !rng_accepted {
continue;
}
res_ids.push(*id);
res_distances.push(*distance);
res_embeddings.push(embedding);
}

Ok((res_ids, res_distances, res_embeddings))
}

#[cfg(test)]
mod tests {
use std::collections::HashMap;
Expand Down
1 change: 1 addition & 0 deletions rust/worker/src/execution/operators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub(super) mod partition;
pub(super) mod pull_log;
pub(super) mod record_segment_prefetch;
pub(super) mod register;
pub(super) mod spann_centers_search;
pub(super) mod write_segments;

// Required for benchmark
Expand Down
83 changes: 83 additions & 0 deletions rust/worker/src/execution/operators/spann_centers_search.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
use chroma_distance::DistanceFunction;
use chroma_error::{ChromaError, ErrorCodes};
use chroma_index::spann::utils::rng_query;
use thiserror::Error;
use tonic::async_trait;

use crate::{
execution::operator::Operator,
segment::spann_segment::{SpannSegmentReader, SpannSegmentReaderContext},
};

#[derive(Debug)]
pub struct SpannCentersSearchInput {
reader_context: SpannSegmentReaderContext,
// Assumes that query is already normalized in case of cosine.
query: Vec<f32>,
k: usize,
rng_epsilon: f32,
rng_factor: f32,
distance_function: DistanceFunction,
}

#[derive(Debug)]
pub struct SpannCentersSearchOutput {
center_ids: Vec<usize>,
}

#[derive(Error, Debug)]
pub enum SpannCentersSearchError {
#[error("Error creating spann segment reader")]
SpannSegmentReaderCreationError,
#[error("Error querying RNG")]
RngQueryError,
}

impl ChromaError for SpannCentersSearchError {
fn code(&self) -> ErrorCodes {
match self {
Self::SpannSegmentReaderCreationError => ErrorCodes::Internal,
Self::RngQueryError => ErrorCodes::Internal,
}
}
}

#[derive(Debug)]
pub struct SpannCentersSearchOperator {}

impl SpannCentersSearchOperator {
pub fn new() -> Box<Self> {
Box::new(SpannCentersSearchOperator {})
}
}

#[async_trait]
impl Operator<SpannCentersSearchInput, SpannCentersSearchOutput> for SpannCentersSearchOperator {
type Error = SpannCentersSearchError;

async fn run(
&self,
input: &SpannCentersSearchInput,
) -> Result<SpannCentersSearchOutput, SpannCentersSearchError> {
let spann_reader = SpannSegmentReader::from_segment(
&input.reader_context.segment,
&input.reader_context.blockfile_provider,
&input.reader_context.hnsw_provider,
input.reader_context.dimension,
)
.await
.map_err(|_| SpannCentersSearchError::SpannSegmentReaderCreationError)?;
// RNG Query.
let res = rng_query(
&input.query,
spann_reader.index_reader.hnsw_index.clone(),
input.k,
input.rng_epsilon,
input.rng_factor,
input.distance_function.clone(),
)
.await
.map_err(|_| SpannCentersSearchError::RngQueryError)?;
Ok(SpannCentersSearchOutput { center_ids: res.0 })
}
}
10 changes: 9 additions & 1 deletion rust/worker/src/segment/spann_segment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,9 +340,17 @@ impl ChromaError for SpannSegmentReaderError {
}
}

#[derive(Debug)]
pub struct SpannSegmentReaderContext {
pub segment: Segment,
pub blockfile_provider: BlockfileProvider,
pub hnsw_provider: HnswIndexProvider,
pub dimension: usize,
}

#[derive(Clone)]
pub(crate) struct SpannSegmentReader<'me> {
index_reader: SpannIndexReader<'me>,
pub index_reader: SpannIndexReader<'me>,
id: SegmentUuid,
}

Expand Down

0 comments on commit d1eec30

Please sign in to comment.