-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ENH] Blockstore-based full-text search (#1759)
## Description of changes *Summarize the changes made by this PR.* - New functionality - Full-text search based on blockstore ## Test plan *How are these changes tested?* - [ ] Tests pass locally with `pytest` for python, `yarn test` for js ## Documentation Changes *Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)?*
- Loading branch information
Showing
8 changed files
with
1,324 additions
and
458 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
pub mod tokenizer; | ||
mod types; | ||
|
||
pub use types::*; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
use crate::errors::{ChromaError, ErrorCodes}; | ||
|
||
use tantivy::tokenizer::{NgramTokenizer, Token, Tokenizer, TokenStream}; | ||
|
||
pub(crate) trait ChromaTokenStream { | ||
fn process(&mut self, sink: &mut dyn FnMut(&Token)); | ||
fn get_tokens(&self) -> &Vec<Token>; | ||
} | ||
|
||
pub(crate) struct TantivyChromaTokenStream { | ||
tokens: Vec<Token> | ||
} | ||
|
||
impl TantivyChromaTokenStream { | ||
pub fn new(tokens: Vec<Token>) -> Self { | ||
TantivyChromaTokenStream { | ||
tokens, | ||
} | ||
} | ||
} | ||
|
||
impl ChromaTokenStream for TantivyChromaTokenStream { | ||
fn process(&mut self, sink: &mut dyn FnMut(&Token)) { | ||
for token in &self.tokens { | ||
sink(token); | ||
} | ||
} | ||
|
||
fn get_tokens(&self) -> &Vec<Token> { | ||
&self.tokens | ||
} | ||
} | ||
|
||
pub(crate) trait ChromaTokenizer { | ||
fn encode(&mut self, text: &str) -> Box<dyn ChromaTokenStream>; | ||
} | ||
|
||
pub(crate) struct TantivyChromaTokenizer { | ||
tokenizer: Box<NgramTokenizer> | ||
} | ||
|
||
impl TantivyChromaTokenizer { | ||
pub fn new(tokenizer: Box<NgramTokenizer>) -> Self { | ||
TantivyChromaTokenizer { | ||
tokenizer, | ||
} | ||
} | ||
} | ||
|
||
impl ChromaTokenizer for TantivyChromaTokenizer { | ||
fn encode(&mut self, text: &str) -> Box<dyn ChromaTokenStream> { | ||
let mut token_stream = self.tokenizer.token_stream(text); | ||
let mut tokens = Vec::new(); | ||
token_stream.process(&mut |token| { | ||
tokens.push(token.clone()); | ||
}); | ||
Box::new(TantivyChromaTokenStream::new(tokens)) | ||
} | ||
} | ||
|
||
mod test { | ||
use super::*; | ||
|
||
#[test] | ||
fn test_chroma_tokenizer() { | ||
let tokenizer: Box<NgramTokenizer> = Box::new(NgramTokenizer::new(1, 1, false).unwrap()); | ||
let mut chroma_tokenizer = TantivyChromaTokenizer::new(tokenizer); | ||
let mut token_stream = chroma_tokenizer.encode("hello world"); | ||
let mut tokens = Vec::new(); | ||
token_stream.process(&mut |token| { | ||
tokens.push(token.clone()); | ||
}); | ||
assert_eq!(tokens.len(), 11); | ||
assert_eq!(tokens[0].text, "h"); | ||
assert_eq!(tokens[1].text, "e"); | ||
} | ||
|
||
#[test] | ||
fn test_get_tokens() { | ||
let tokenizer: Box<NgramTokenizer> = Box::new(NgramTokenizer::new(1, 1, false).unwrap()); | ||
let mut chroma_tokenizer = TantivyChromaTokenizer::new(tokenizer); | ||
let token_stream = chroma_tokenizer.encode("hello world"); | ||
let tokens = token_stream.get_tokens(); | ||
assert_eq!(tokens.len(), 11); | ||
assert_eq!(tokens[0].text, "h"); | ||
assert_eq!(tokens[1].text, "e"); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,276 @@ | ||
use crate::errors::{ChromaError, ErrorCodes}; | ||
use thiserror::Error; | ||
|
||
use std::collections::HashMap; | ||
use crate::blockstore::{Blockfile, BlockfileKey, Key, PositionalPostingListBuilder, Value}; | ||
use crate::index::fulltext::tokenizer::{ChromaTokenizer, ChromaTokenStream}; | ||
|
||
#[derive(Error, Debug)] | ||
pub enum FullTextIndexError { | ||
#[error("Already in a transaction")] | ||
AlreadyInTransaction, | ||
#[error("Not in a transaction")] | ||
NotInTransaction, | ||
} | ||
|
||
impl ChromaError for FullTextIndexError { | ||
fn code(&self) -> ErrorCodes { | ||
match self { | ||
FullTextIndexError::AlreadyInTransaction => ErrorCodes::FailedPrecondition, | ||
FullTextIndexError::NotInTransaction => ErrorCodes::FailedPrecondition, | ||
} | ||
} | ||
} | ||
|
||
pub(crate) trait FullTextIndex { | ||
fn begin_transaction(&mut self) -> Result<(), Box<dyn ChromaError>>; | ||
fn commit_transaction(&mut self) -> Result<(), Box<dyn ChromaError>>; | ||
|
||
// Must be done inside a transaction. | ||
fn add_document(&mut self, document: &str, offset_id: i32) -> Result<(), Box<dyn ChromaError>>; | ||
// Only searches committed state. | ||
fn search(&mut self, query: &str) -> Result<Vec<i32>, Box<dyn ChromaError>>; | ||
} | ||
|
||
pub(crate) struct BlockfileFullTextIndex { | ||
posting_lists_blockfile: Box<dyn Blockfile>, | ||
frequencies_blockfile: Box<dyn Blockfile>, | ||
tokenizer: Box<dyn ChromaTokenizer>, | ||
in_transaction: bool, | ||
|
||
// term -> positional posting list builder for that term | ||
uncommitted: HashMap<String, PositionalPostingListBuilder>, | ||
uncommitted_frequencies: HashMap<String, i32>, | ||
} | ||
|
||
impl BlockfileFullTextIndex { | ||
pub(crate) fn new(posting_lists_blockfile: Box<dyn Blockfile>, frequencies_blockfile: Box<dyn Blockfile>, tokenizer: Box<dyn ChromaTokenizer>) -> Self { | ||
BlockfileFullTextIndex { | ||
posting_lists_blockfile, | ||
frequencies_blockfile, | ||
tokenizer, | ||
in_transaction: false, | ||
uncommitted: HashMap::new(), | ||
uncommitted_frequencies: HashMap::new(), | ||
} | ||
} | ||
} | ||
|
||
impl FullTextIndex for BlockfileFullTextIndex { | ||
fn begin_transaction(&mut self) -> Result<(), Box<dyn ChromaError>> { | ||
if self.in_transaction { | ||
return Err(Box::new(FullTextIndexError::AlreadyInTransaction)); | ||
} | ||
self.posting_lists_blockfile.begin_transaction()?; | ||
self.frequencies_blockfile.begin_transaction()?; | ||
self.in_transaction = true; | ||
Ok(()) | ||
} | ||
|
||
fn commit_transaction(&mut self) -> Result<(), Box<dyn ChromaError>> { | ||
if !self.in_transaction { | ||
return Err(Box::new(FullTextIndexError::NotInTransaction)); | ||
} | ||
self.in_transaction = false; | ||
for (key, mut value) in self.uncommitted.drain() { | ||
let positional_posting_list = value.build(); | ||
let blockfilekey = BlockfileKey::new("".to_string(), Key::String(key.to_string())); | ||
self.posting_lists_blockfile.set(blockfilekey, Value::PositionalPostingListValue(positional_posting_list)); | ||
} | ||
for (key, value) in self.uncommitted_frequencies.drain() { | ||
let blockfilekey = BlockfileKey::new("".to_string(), Key::String(key.to_string())); | ||
self.frequencies_blockfile.set(blockfilekey, Value::Int32Value(value)); | ||
} | ||
self.posting_lists_blockfile.commit_transaction()?; | ||
self.frequencies_blockfile.commit_transaction()?; | ||
self.uncommitted.clear(); | ||
Ok(()) | ||
} | ||
|
||
fn add_document(&mut self, document: &str, offset_id: i32) -> Result<(), Box<dyn ChromaError>> { | ||
if !self.in_transaction { | ||
return Err(Box::new(FullTextIndexError::NotInTransaction)); | ||
} | ||
let tokens = self.tokenizer.encode(document); | ||
for token in tokens.get_tokens() { | ||
self.uncommitted_frequencies.entry(token.text.to_string()).and_modify(|e| *e += 1).or_insert(1); | ||
let mut builder = self.uncommitted.entry(token.text.to_string()).or_insert(PositionalPostingListBuilder::new()); | ||
|
||
// Store starting positions of tokens. These are NOT affected by token filters. | ||
// For search, we can use the start and end positions to compute offsets to | ||
// check full string match. | ||
// | ||
// See https://docs.rs/tantivy/latest/tantivy/tokenizer/struct.Token.html | ||
if !builder.contains_doc_id(offset_id) { | ||
// Casting to i32 is safe since we limit the size of the document. | ||
builder.add_doc_id_and_positions(offset_id, vec![token.offset_from as i32]); | ||
} else { | ||
builder.add_positions_for_doc_id(offset_id, vec![token.offset_from as i32]); | ||
} | ||
} | ||
Ok(()) | ||
} | ||
|
||
fn search(&mut self, query: &str) -> Result<Vec<i32>, Box<dyn ChromaError>> { | ||
let binding = self.tokenizer.encode(query); | ||
let tokens = binding.get_tokens(); | ||
|
||
// Get query tokens sorted by frequency. | ||
let mut token_frequencies = vec![]; | ||
for token in tokens { | ||
let blockfilekey = BlockfileKey::new("".to_string(), Key::String(token.text.to_string())); | ||
let value = self.frequencies_blockfile.get(blockfilekey); | ||
match value { | ||
Ok(Value::Int32Value(frequency)) => { | ||
token_frequencies.push((token.text.to_string(), frequency)); | ||
}, | ||
Ok(_) => { | ||
return Ok(vec![]); | ||
} | ||
Err(_) => { | ||
// TODO error handling from blockfile | ||
return Ok(vec![]); | ||
} | ||
} | ||
} | ||
token_frequencies.sort_by(|a, b| a.1.cmp(&b.1)); | ||
|
||
// Populate initial candidates with the least-frequent token's posting list. | ||
// doc ID -> possible starting locations for the query. | ||
let mut candidates: HashMap<i32, Vec<i32>> = HashMap::new(); | ||
let blockfilekey = BlockfileKey::new("".to_string(), Key::String(tokens[0].text.to_string())); | ||
let first_token_positional_posting_list = match self.posting_lists_blockfile.get(blockfilekey).unwrap() { | ||
Value::PositionalPostingListValue(arr) => arr, | ||
_ => panic!("Value is not an arrow struct array"), | ||
}; | ||
let first_token_offset = tokens[0].offset_from as i32; | ||
for doc_id in first_token_positional_posting_list.get_doc_ids().values() { | ||
let positions = first_token_positional_posting_list.get_positions_for_doc_id(*doc_id).unwrap(); | ||
let positions_vec: Vec<i32> = positions.values().iter().map(|x| *x - first_token_offset).collect(); | ||
candidates.insert(*doc_id, positions_vec); | ||
} | ||
|
||
// Iterate through the rest of the tokens, intersecting the posting lists with the candidates. | ||
for (token, _) in token_frequencies[1..].iter() { | ||
let blockfilekey = BlockfileKey::new("".to_string(), Key::String(token.to_string())); | ||
let positional_posting_list = match self.posting_lists_blockfile.get(blockfilekey).unwrap() { | ||
Value::PositionalPostingListValue(arr) => arr, | ||
_ => panic!("Value is not an arrow struct array"), | ||
}; | ||
let token_offset = tokens.iter().find(|t| t.text == *token).unwrap().offset_from as i32; | ||
let mut new_candidates: HashMap<i32, Vec<i32>> = HashMap::new(); | ||
for (doc_id, positions) in candidates.iter() { | ||
let mut new_positions = vec![]; | ||
for position in positions { | ||
if let Some(positions_for_doc_id) = positional_posting_list.get_positions_for_doc_id(*doc_id) { | ||
for position_for_doc_id in positions_for_doc_id.values() { | ||
if position_for_doc_id - token_offset == *position { | ||
new_positions.push(*position); | ||
} | ||
} | ||
} | ||
} | ||
if !new_positions.is_empty() { | ||
new_candidates.insert(*doc_id, new_positions); | ||
} | ||
} | ||
candidates = new_candidates; | ||
} | ||
|
||
let mut results = vec![]; | ||
for (doc_id, _) in candidates.drain() { | ||
results.push(doc_id); | ||
} | ||
|
||
Ok(results) | ||
} | ||
} | ||
|
||
mod test { | ||
use super::*; | ||
use tantivy::tokenizer::NgramTokenizer; | ||
use crate::blockstore::HashMapBlockfile; | ||
use crate::index::fulltext::tokenizer::TantivyChromaTokenizer; | ||
|
||
#[test] | ||
fn test_new() { | ||
let pl_blockfile = Box::new(HashMapBlockfile::open(&"in-memory-pl").unwrap()); | ||
let freq_blockfile = Box::new(HashMapBlockfile::open(&"in-memory-freqs").unwrap()); | ||
let tokenizer = Box::new(TantivyChromaTokenizer::new(Box::new(NgramTokenizer::new(1, 1, false).unwrap()))); | ||
let _index = BlockfileFullTextIndex::new(pl_blockfile, freq_blockfile, tokenizer); | ||
} | ||
|
||
#[test] | ||
fn test_index_single_document() { | ||
let pl_blockfile = Box::new(HashMapBlockfile::open(&"in-memory-pl").unwrap()); | ||
let freq_blockfile = Box::new(HashMapBlockfile::open(&"in-memory-freqs").unwrap()); | ||
let tokenizer = Box::new(TantivyChromaTokenizer::new(Box::new(NgramTokenizer::new(1, 1, false).unwrap()))); | ||
let mut index = BlockfileFullTextIndex::new(pl_blockfile, freq_blockfile, tokenizer); | ||
index.begin_transaction().unwrap(); | ||
index.add_document("hello world", 1).unwrap(); | ||
index.commit_transaction().unwrap(); | ||
|
||
let res = index.search("hello"); | ||
assert_eq!(res.unwrap(), vec![1]); | ||
} | ||
|
||
#[test] | ||
fn test_search_absent_token() { | ||
let pl_blockfile = Box::new(HashMapBlockfile::open(&"in-memory-pl").unwrap()); | ||
let freq_blockfile = Box::new(HashMapBlockfile::open(&"in-memory-freqs").unwrap()); | ||
let tokenizer = Box::new(TantivyChromaTokenizer::new(Box::new(NgramTokenizer::new(1, 1, false).unwrap()))); | ||
let mut index = BlockfileFullTextIndex::new(pl_blockfile, freq_blockfile, tokenizer); | ||
index.begin_transaction().unwrap(); | ||
index.add_document("hello world", 1).unwrap(); | ||
index.commit_transaction().unwrap(); | ||
|
||
let res = index.search("chroma"); | ||
assert!(res.unwrap().is_empty()); | ||
} | ||
|
||
#[test] | ||
fn test_index_and_search_multiple_documents() { | ||
let pl_blockfile = Box::new(HashMapBlockfile::open(&"in-memory-pl").unwrap()); | ||
let freq_blockfile = Box::new(HashMapBlockfile::open(&"in-memory-freqs").unwrap()); | ||
let tokenizer = Box::new(TantivyChromaTokenizer::new(Box::new(NgramTokenizer::new(1, 1, false).unwrap()))); | ||
let mut index = BlockfileFullTextIndex::new(pl_blockfile, freq_blockfile, tokenizer); | ||
index.begin_transaction().unwrap(); | ||
index.add_document("hello world", 1).unwrap(); | ||
index.add_document("hello chroma", 2).unwrap(); | ||
index.add_document("chroma world", 3).unwrap(); | ||
index.commit_transaction().unwrap(); | ||
|
||
let res = index.search("hello").unwrap(); | ||
assert!(res.contains(&1)); | ||
assert!(res.contains(&2)); | ||
|
||
let res = index.search("world").unwrap(); | ||
assert!(res.contains(&1)); | ||
assert!(res.contains(&3)); | ||
|
||
let res = index.search("llo chro").unwrap(); | ||
assert!(res.contains(&2)); | ||
} | ||
|
||
#[test] | ||
fn test_special_characters_search() { | ||
let pl_blockfile = Box::new(HashMapBlockfile::open(&"in-memory-pl").unwrap()); | ||
let freq_blockfile = Box::new(HashMapBlockfile::open(&"in-memory-freqs").unwrap()); | ||
let tokenizer = Box::new(TantivyChromaTokenizer::new(Box::new(NgramTokenizer::new(1, 1, false).unwrap()))); | ||
let mut index = BlockfileFullTextIndex::new(pl_blockfile, freq_blockfile, tokenizer); | ||
index.begin_transaction().unwrap(); | ||
index.add_document("!!!!", 1).unwrap(); | ||
index.add_document(",,!!", 2).unwrap(); | ||
index.add_document(".!", 3).unwrap(); | ||
index.add_document("!.!.!.!", 4).unwrap(); | ||
index.commit_transaction().unwrap(); | ||
|
||
let res = index.search("!!").unwrap(); | ||
assert!(res.contains(&1)); | ||
assert!(res.contains(&2)); | ||
|
||
let res = index.search(".!").unwrap(); | ||
assert!(res.contains(&3)); | ||
assert!(res.contains(&4)); | ||
} | ||
} |
Oops, something went wrong.