Skip to content

Commit

Permalink
[ENH] Fetch posting list operator (#3214)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
   - Operator for fetching the posting list for a given head. Currently clones the posting list. Can revisit in future if it ends up becoming a performance bottleneck.
 - New functionality
   - ...

## Test plan
*How are these changes tested?*
- [x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust

## Documentation Changes
None
  • Loading branch information
sanketkedia authored Dec 5, 2024
1 parent adfd60f commit 37a6103
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 1 deletion.
54 changes: 54 additions & 0 deletions rust/index/src/spann/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1384,6 +1384,8 @@ pub enum SpannIndexReaderError {
BlockfileReaderConstructionError,
#[error("Spann index uninitialized")]
UninitializedIndex,
#[error("Error reading posting list")]
PostingListReadError,
}

impl ChromaError for SpannIndexReaderError {
Expand All @@ -1392,15 +1394,23 @@ impl ChromaError for SpannIndexReaderError {
Self::HnswIndexConstructionError => ErrorCodes::Internal,
Self::BlockfileReaderConstructionError => ErrorCodes::Internal,
Self::UninitializedIndex => ErrorCodes::Internal,
Self::PostingListReadError => ErrorCodes::Internal,
}
}
}

#[derive(Debug)]
pub struct SpannPosting {
pub doc_offset_id: u32,
pub doc_embedding: Vec<f32>,
}

#[derive(Clone)]
pub struct SpannIndexReader<'me> {
pub posting_lists: BlockfileReader<'me, u32, SpannPostingList<'me>>,
pub hnsw_index: HnswIndexRef,
pub versions_map: BlockfileReader<'me, u32, u32>,
pub dimensionality: usize,
}

impl<'me> SpannIndexReader<'me> {
Expand Down Expand Up @@ -1490,8 +1500,52 @@ impl<'me> SpannIndexReader<'me> {
posting_lists: postings_list_reader,
hnsw_index: hnsw_reader,
versions_map: versions_map_reader,
dimensionality,
})
}

async fn is_outdated(
&self,
doc_offset_id: u32,
doc_version: u32,
) -> Result<bool, SpannIndexReaderError> {
let actual_version = self
.versions_map
.get("", doc_offset_id)
.await
.map_err(|_| SpannIndexReaderError::PostingListReadError)?
.ok_or(SpannIndexReaderError::PostingListReadError)?;
Ok(actual_version == 0 || doc_version < actual_version)
}

pub async fn fetch_posting_list(
&self,
head_id: u32,
) -> Result<Vec<SpannPosting>, SpannIndexReaderError> {
let res = self
.posting_lists
.get("", head_id)
.await
.map_err(|_| SpannIndexReaderError::PostingListReadError)?
.ok_or(SpannIndexReaderError::PostingListReadError)?;

let mut posting_lists = Vec::with_capacity(res.doc_offset_ids.len());
for (index, doc_offset_id) in res.doc_offset_ids.iter().enumerate() {
if self
.is_outdated(*doc_offset_id, res.doc_versions[index])
.await?
{
continue;
}
posting_lists.push(SpannPosting {
doc_offset_id: *doc_offset_id,
doc_embedding: res.doc_embeddings
[index * self.dimensionality..(index + 1) * self.dimensionality]
.to_vec(),
});
}
Ok(posting_lists)
}
}

#[cfg(test)]
Expand Down
1 change: 1 addition & 0 deletions rust/worker/src/execution/operators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub(super) mod pull_log;
pub(super) mod record_segment_prefetch;
pub(super) mod register;
pub(super) mod spann_centers_search;
pub(super) mod spann_fetch_pl;
pub(super) mod write_segments;

// Required for benchmark
Expand Down
78 changes: 78 additions & 0 deletions rust/worker/src/execution/operators/spann_fetch_pl.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use chroma_error::{ChromaError, ErrorCodes};
use chroma_index::spann::types::SpannPosting;
use thiserror::Error;
use tonic::async_trait;

use crate::{
execution::operator::{Operator, OperatorType},
segment::spann_segment::{SpannSegmentReader, SpannSegmentReaderContext},
};

#[derive(Debug)]
pub struct SpannFetchPlInput {
// TODO(Sanket): Ship the reader instead of constructing here.
reader_context: SpannSegmentReaderContext,
head_id: u32,
}

#[allow(dead_code)]
#[derive(Debug)]
pub struct SpannFetchPlOutput {
posting_list: Vec<SpannPosting>,
}

#[derive(Error, Debug)]
pub enum SpannFetchPlError {
#[error("Error creating spann segment reader")]
SpannSegmentReaderCreationError,
#[error("Error querying reader")]
SpannSegmentReaderError,
}

impl ChromaError for SpannFetchPlError {
fn code(&self) -> ErrorCodes {
match self {
Self::SpannSegmentReaderCreationError => ErrorCodes::Internal,
Self::SpannSegmentReaderError => ErrorCodes::Internal,
}
}
}

#[derive(Debug)]
pub struct SpannFetchPlOperator {}

impl SpannFetchPlOperator {
#[allow(dead_code)]
pub fn new() -> Box<Self> {
Box::new(SpannFetchPlOperator {})
}
}

#[async_trait]
impl Operator<SpannFetchPlInput, SpannFetchPlOutput> for SpannFetchPlOperator {
type Error = SpannFetchPlError;

async fn run(
&self,
input: &SpannFetchPlInput,
) -> Result<SpannFetchPlOutput, SpannFetchPlError> {
let spann_reader = SpannSegmentReader::from_segment(
&input.reader_context.segment,
&input.reader_context.blockfile_provider,
&input.reader_context.hnsw_provider,
input.reader_context.dimension,
)
.await
.map_err(|_| SpannFetchPlError::SpannSegmentReaderCreationError)?;
let posting_list = spann_reader
.fetch_posting_list(input.head_id)
.await
.map_err(|_| SpannFetchPlError::SpannSegmentReaderError)?;
Ok(SpannFetchPlOutput { posting_list })
}

// This operator is IO bound.
fn get_type(&self) -> OperatorType {
OperatorType::IO
}
}
15 changes: 14 additions & 1 deletion rust/worker/src/segment/spann_segment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use chroma_blockstore::provider::BlockfileProvider;
use chroma_distance::DistanceFunctionError;
use chroma_error::{ChromaError, ErrorCodes};
use chroma_index::spann::types::{
SpannIndexFlusher, SpannIndexReader, SpannIndexReaderError, SpannIndexWriterError,
SpannIndexFlusher, SpannIndexReader, SpannIndexReaderError, SpannIndexWriterError, SpannPosting,
};
use chroma_index::IndexUuid;
use chroma_index::{hnsw_provider::HnswIndexProvider, spann::types::SpannIndexWriter};
Expand Down Expand Up @@ -327,6 +327,8 @@ pub enum SpannSegmentReaderError {
SpannSegmentReaderCreateError,
#[error("Spann segment is uninitialized")]
UninitializedSegment,
#[error("Error reading key")]
KeyReadError,
}

impl ChromaError for SpannSegmentReaderError {
Expand All @@ -340,6 +342,7 @@ impl ChromaError for SpannSegmentReaderError {
Self::PostingListInvalidFilePath => ErrorCodes::Internal,
Self::SpannSegmentReaderCreateError => ErrorCodes::Internal,
Self::UninitializedSegment => ErrorCodes::Internal,
Self::KeyReadError => ErrorCodes::Internal,
}
}
}
Expand Down Expand Up @@ -456,6 +459,16 @@ impl<'me> SpannSegmentReader<'me> {
id: segment.id,
})
}

pub async fn fetch_posting_list(
&self,
head_id: u32,
) -> Result<Vec<SpannPosting>, SpannSegmentReaderError> {
self.index_reader
.fetch_posting_list(head_id)
.await
.map_err(|_| SpannSegmentReaderError::KeyReadError)
}
}

#[cfg(test)]
Expand Down

0 comments on commit 37a6103

Please sign in to comment.