Skip to content

Commit

Permalink
[ENH] Blockstore-based full-text search (#1759)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - New functionality
	 - Full-text search based on blockstore

## Test plan
*How are these changes tested?*

- [ ] Tests pass locally with `pytest` for python, `yarn test` for js

## Documentation Changes
*Are all docstrings for user-facing APIs updated if required? Do we need
to make documentation changes in the [docs
repository](https://github.com/chroma-core/docs)?*
  • Loading branch information
beggers authored Feb 23, 2024
1 parent 248796d commit 61da5f4
Show file tree
Hide file tree
Showing 8 changed files with 1,324 additions and 458 deletions.
1,406 changes: 948 additions & 458 deletions Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions rust/worker/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ aws-smithy-types = "1.1.0"
aws-config = { version = "1.1.2", features = ["behavior-version-latest"] }
arrow = "50.0.0"
roaring = "0.10.3"
tantivy = "0.21.1"

[build-dependencies]
tonic-build = "0.10"
Expand Down
4 changes: 4 additions & 0 deletions rust/worker/src/blockstore/positional_posting_list_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ impl PositionalPostingListBuilder {
self.positions.insert(doc_id, positions);
Ok(())
}

pub(crate) fn contains_doc_id(&self, doc_id: i32) -> bool {
self.doc_ids.contains(&doc_id)
}

pub(crate) fn add_positions_for_doc_id(
&mut self,
Expand Down
1 change: 1 addition & 0 deletions rust/worker/src/blockstore/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ impl Ord for BlockfileKey {

#[derive(Debug, Clone)]
pub(crate) enum Value {
Int32Value(i32),
Int32ArrayValue(Int32Array),
PositionalPostingListValue(PositionalPostingList),
StringValue(String),
Expand Down
4 changes: 4 additions & 0 deletions rust/worker/src/index/fulltext/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
pub mod tokenizer;
mod types;

pub use types::*;
88 changes: 88 additions & 0 deletions rust/worker/src/index/fulltext/tokenizer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
use crate::errors::{ChromaError, ErrorCodes};

use tantivy::tokenizer::{NgramTokenizer, Token, Tokenizer, TokenStream};

pub(crate) trait ChromaTokenStream {
fn process(&mut self, sink: &mut dyn FnMut(&Token));
fn get_tokens(&self) -> &Vec<Token>;
}

pub(crate) struct TantivyChromaTokenStream {
tokens: Vec<Token>
}

impl TantivyChromaTokenStream {
pub fn new(tokens: Vec<Token>) -> Self {
TantivyChromaTokenStream {
tokens,
}
}
}

impl ChromaTokenStream for TantivyChromaTokenStream {
fn process(&mut self, sink: &mut dyn FnMut(&Token)) {
for token in &self.tokens {
sink(token);
}
}

fn get_tokens(&self) -> &Vec<Token> {
&self.tokens
}
}

pub(crate) trait ChromaTokenizer {
fn encode(&mut self, text: &str) -> Box<dyn ChromaTokenStream>;
}

pub(crate) struct TantivyChromaTokenizer {
tokenizer: Box<NgramTokenizer>
}

impl TantivyChromaTokenizer {
pub fn new(tokenizer: Box<NgramTokenizer>) -> Self {
TantivyChromaTokenizer {
tokenizer,
}
}
}

impl ChromaTokenizer for TantivyChromaTokenizer {
fn encode(&mut self, text: &str) -> Box<dyn ChromaTokenStream> {
let mut token_stream = self.tokenizer.token_stream(text);
let mut tokens = Vec::new();
token_stream.process(&mut |token| {
tokens.push(token.clone());
});
Box::new(TantivyChromaTokenStream::new(tokens))
}
}

mod test {
use super::*;

#[test]
fn test_chroma_tokenizer() {
let tokenizer: Box<NgramTokenizer> = Box::new(NgramTokenizer::new(1, 1, false).unwrap());
let mut chroma_tokenizer = TantivyChromaTokenizer::new(tokenizer);
let mut token_stream = chroma_tokenizer.encode("hello world");
let mut tokens = Vec::new();
token_stream.process(&mut |token| {
tokens.push(token.clone());
});
assert_eq!(tokens.len(), 11);
assert_eq!(tokens[0].text, "h");
assert_eq!(tokens[1].text, "e");
}

#[test]
fn test_get_tokens() {
let tokenizer: Box<NgramTokenizer> = Box::new(NgramTokenizer::new(1, 1, false).unwrap());
let mut chroma_tokenizer = TantivyChromaTokenizer::new(tokenizer);
let token_stream = chroma_tokenizer.encode("hello world");
let tokens = token_stream.get_tokens();
assert_eq!(tokens.len(), 11);
assert_eq!(tokens[0].text, "h");
assert_eq!(tokens[1].text, "e");
}
}
276 changes: 276 additions & 0 deletions rust/worker/src/index/fulltext/types.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
use crate::errors::{ChromaError, ErrorCodes};
use thiserror::Error;

use std::collections::HashMap;
use crate::blockstore::{Blockfile, BlockfileKey, Key, PositionalPostingListBuilder, Value};
use crate::index::fulltext::tokenizer::{ChromaTokenizer, ChromaTokenStream};

#[derive(Error, Debug)]
pub enum FullTextIndexError {
#[error("Already in a transaction")]
AlreadyInTransaction,
#[error("Not in a transaction")]
NotInTransaction,
}

impl ChromaError for FullTextIndexError {
fn code(&self) -> ErrorCodes {
match self {
FullTextIndexError::AlreadyInTransaction => ErrorCodes::FailedPrecondition,
FullTextIndexError::NotInTransaction => ErrorCodes::FailedPrecondition,
}
}
}

pub(crate) trait FullTextIndex {
fn begin_transaction(&mut self) -> Result<(), Box<dyn ChromaError>>;
fn commit_transaction(&mut self) -> Result<(), Box<dyn ChromaError>>;

// Must be done inside a transaction.
fn add_document(&mut self, document: &str, offset_id: i32) -> Result<(), Box<dyn ChromaError>>;
// Only searches committed state.
fn search(&mut self, query: &str) -> Result<Vec<i32>, Box<dyn ChromaError>>;
}

pub(crate) struct BlockfileFullTextIndex {
posting_lists_blockfile: Box<dyn Blockfile>,
frequencies_blockfile: Box<dyn Blockfile>,
tokenizer: Box<dyn ChromaTokenizer>,
in_transaction: bool,

// term -> positional posting list builder for that term
uncommitted: HashMap<String, PositionalPostingListBuilder>,
uncommitted_frequencies: HashMap<String, i32>,
}

impl BlockfileFullTextIndex {
pub(crate) fn new(posting_lists_blockfile: Box<dyn Blockfile>, frequencies_blockfile: Box<dyn Blockfile>, tokenizer: Box<dyn ChromaTokenizer>) -> Self {
BlockfileFullTextIndex {
posting_lists_blockfile,
frequencies_blockfile,
tokenizer,
in_transaction: false,
uncommitted: HashMap::new(),
uncommitted_frequencies: HashMap::new(),
}
}
}

impl FullTextIndex for BlockfileFullTextIndex {
fn begin_transaction(&mut self) -> Result<(), Box<dyn ChromaError>> {
if self.in_transaction {
return Err(Box::new(FullTextIndexError::AlreadyInTransaction));
}
self.posting_lists_blockfile.begin_transaction()?;
self.frequencies_blockfile.begin_transaction()?;
self.in_transaction = true;
Ok(())
}

fn commit_transaction(&mut self) -> Result<(), Box<dyn ChromaError>> {
if !self.in_transaction {
return Err(Box::new(FullTextIndexError::NotInTransaction));
}
self.in_transaction = false;
for (key, mut value) in self.uncommitted.drain() {
let positional_posting_list = value.build();
let blockfilekey = BlockfileKey::new("".to_string(), Key::String(key.to_string()));
self.posting_lists_blockfile.set(blockfilekey, Value::PositionalPostingListValue(positional_posting_list));
}
for (key, value) in self.uncommitted_frequencies.drain() {
let blockfilekey = BlockfileKey::new("".to_string(), Key::String(key.to_string()));
self.frequencies_blockfile.set(blockfilekey, Value::Int32Value(value));
}
self.posting_lists_blockfile.commit_transaction()?;
self.frequencies_blockfile.commit_transaction()?;
self.uncommitted.clear();
Ok(())
}

fn add_document(&mut self, document: &str, offset_id: i32) -> Result<(), Box<dyn ChromaError>> {
if !self.in_transaction {
return Err(Box::new(FullTextIndexError::NotInTransaction));
}
let tokens = self.tokenizer.encode(document);
for token in tokens.get_tokens() {
self.uncommitted_frequencies.entry(token.text.to_string()).and_modify(|e| *e += 1).or_insert(1);
let mut builder = self.uncommitted.entry(token.text.to_string()).or_insert(PositionalPostingListBuilder::new());

// Store starting positions of tokens. These are NOT affected by token filters.
// For search, we can use the start and end positions to compute offsets to
// check full string match.
//
// See https://docs.rs/tantivy/latest/tantivy/tokenizer/struct.Token.html
if !builder.contains_doc_id(offset_id) {
// Casting to i32 is safe since we limit the size of the document.
builder.add_doc_id_and_positions(offset_id, vec![token.offset_from as i32]);
} else {
builder.add_positions_for_doc_id(offset_id, vec![token.offset_from as i32]);
}
}
Ok(())
}

fn search(&mut self, query: &str) -> Result<Vec<i32>, Box<dyn ChromaError>> {
let binding = self.tokenizer.encode(query);
let tokens = binding.get_tokens();

// Get query tokens sorted by frequency.
let mut token_frequencies = vec![];
for token in tokens {
let blockfilekey = BlockfileKey::new("".to_string(), Key::String(token.text.to_string()));
let value = self.frequencies_blockfile.get(blockfilekey);
match value {
Ok(Value::Int32Value(frequency)) => {
token_frequencies.push((token.text.to_string(), frequency));
},
Ok(_) => {
return Ok(vec![]);
}
Err(_) => {
// TODO error handling from blockfile
return Ok(vec![]);
}
}
}
token_frequencies.sort_by(|a, b| a.1.cmp(&b.1));

// Populate initial candidates with the least-frequent token's posting list.
// doc ID -> possible starting locations for the query.
let mut candidates: HashMap<i32, Vec<i32>> = HashMap::new();
let blockfilekey = BlockfileKey::new("".to_string(), Key::String(tokens[0].text.to_string()));
let first_token_positional_posting_list = match self.posting_lists_blockfile.get(blockfilekey).unwrap() {
Value::PositionalPostingListValue(arr) => arr,
_ => panic!("Value is not an arrow struct array"),
};
let first_token_offset = tokens[0].offset_from as i32;
for doc_id in first_token_positional_posting_list.get_doc_ids().values() {
let positions = first_token_positional_posting_list.get_positions_for_doc_id(*doc_id).unwrap();
let positions_vec: Vec<i32> = positions.values().iter().map(|x| *x - first_token_offset).collect();
candidates.insert(*doc_id, positions_vec);
}

// Iterate through the rest of the tokens, intersecting the posting lists with the candidates.
for (token, _) in token_frequencies[1..].iter() {
let blockfilekey = BlockfileKey::new("".to_string(), Key::String(token.to_string()));
let positional_posting_list = match self.posting_lists_blockfile.get(blockfilekey).unwrap() {
Value::PositionalPostingListValue(arr) => arr,
_ => panic!("Value is not an arrow struct array"),
};
let token_offset = tokens.iter().find(|t| t.text == *token).unwrap().offset_from as i32;
let mut new_candidates: HashMap<i32, Vec<i32>> = HashMap::new();
for (doc_id, positions) in candidates.iter() {
let mut new_positions = vec![];
for position in positions {
if let Some(positions_for_doc_id) = positional_posting_list.get_positions_for_doc_id(*doc_id) {
for position_for_doc_id in positions_for_doc_id.values() {
if position_for_doc_id - token_offset == *position {
new_positions.push(*position);
}
}
}
}
if !new_positions.is_empty() {
new_candidates.insert(*doc_id, new_positions);
}
}
candidates = new_candidates;
}

let mut results = vec![];
for (doc_id, _) in candidates.drain() {
results.push(doc_id);
}

Ok(results)
}
}

mod test {
use super::*;
use tantivy::tokenizer::NgramTokenizer;
use crate::blockstore::HashMapBlockfile;
use crate::index::fulltext::tokenizer::TantivyChromaTokenizer;

#[test]
fn test_new() {
let pl_blockfile = Box::new(HashMapBlockfile::open(&"in-memory-pl").unwrap());
let freq_blockfile = Box::new(HashMapBlockfile::open(&"in-memory-freqs").unwrap());
let tokenizer = Box::new(TantivyChromaTokenizer::new(Box::new(NgramTokenizer::new(1, 1, false).unwrap())));
let _index = BlockfileFullTextIndex::new(pl_blockfile, freq_blockfile, tokenizer);
}

#[test]
fn test_index_single_document() {
let pl_blockfile = Box::new(HashMapBlockfile::open(&"in-memory-pl").unwrap());
let freq_blockfile = Box::new(HashMapBlockfile::open(&"in-memory-freqs").unwrap());
let tokenizer = Box::new(TantivyChromaTokenizer::new(Box::new(NgramTokenizer::new(1, 1, false).unwrap())));
let mut index = BlockfileFullTextIndex::new(pl_blockfile, freq_blockfile, tokenizer);
index.begin_transaction().unwrap();
index.add_document("hello world", 1).unwrap();
index.commit_transaction().unwrap();

let res = index.search("hello");
assert_eq!(res.unwrap(), vec![1]);
}

#[test]
fn test_search_absent_token() {
let pl_blockfile = Box::new(HashMapBlockfile::open(&"in-memory-pl").unwrap());
let freq_blockfile = Box::new(HashMapBlockfile::open(&"in-memory-freqs").unwrap());
let tokenizer = Box::new(TantivyChromaTokenizer::new(Box::new(NgramTokenizer::new(1, 1, false).unwrap())));
let mut index = BlockfileFullTextIndex::new(pl_blockfile, freq_blockfile, tokenizer);
index.begin_transaction().unwrap();
index.add_document("hello world", 1).unwrap();
index.commit_transaction().unwrap();

let res = index.search("chroma");
assert!(res.unwrap().is_empty());
}

#[test]
fn test_index_and_search_multiple_documents() {
let pl_blockfile = Box::new(HashMapBlockfile::open(&"in-memory-pl").unwrap());
let freq_blockfile = Box::new(HashMapBlockfile::open(&"in-memory-freqs").unwrap());
let tokenizer = Box::new(TantivyChromaTokenizer::new(Box::new(NgramTokenizer::new(1, 1, false).unwrap())));
let mut index = BlockfileFullTextIndex::new(pl_blockfile, freq_blockfile, tokenizer);
index.begin_transaction().unwrap();
index.add_document("hello world", 1).unwrap();
index.add_document("hello chroma", 2).unwrap();
index.add_document("chroma world", 3).unwrap();
index.commit_transaction().unwrap();

let res = index.search("hello").unwrap();
assert!(res.contains(&1));
assert!(res.contains(&2));

let res = index.search("world").unwrap();
assert!(res.contains(&1));
assert!(res.contains(&3));

let res = index.search("llo chro").unwrap();
assert!(res.contains(&2));
}

#[test]
fn test_special_characters_search() {
let pl_blockfile = Box::new(HashMapBlockfile::open(&"in-memory-pl").unwrap());
let freq_blockfile = Box::new(HashMapBlockfile::open(&"in-memory-freqs").unwrap());
let tokenizer = Box::new(TantivyChromaTokenizer::new(Box::new(NgramTokenizer::new(1, 1, false).unwrap())));
let mut index = BlockfileFullTextIndex::new(pl_blockfile, freq_blockfile, tokenizer);
index.begin_transaction().unwrap();
index.add_document("!!!!", 1).unwrap();
index.add_document(",,!!", 2).unwrap();
index.add_document(".!", 3).unwrap();
index.add_document("!.!.!.!", 4).unwrap();
index.commit_transaction().unwrap();

let res = index.search("!!").unwrap();
assert!(res.contains(&1));
assert!(res.contains(&2));

let res = index.search(".!").unwrap();
assert!(res.contains(&3));
assert!(res.contains(&4));
}
}
Loading

0 comments on commit 61da5f4

Please sign in to comment.