Skip to content

Commit

Permalink
[ENH] Support full-text-search mixins. (#3323)
Browse files Browse the repository at this point in the history
  • Loading branch information
rescrv authored Dec 18, 2024
1 parent e512eed commit 1e0030c
Show file tree
Hide file tree
Showing 4 changed files with 1,025 additions and 14 deletions.
8 changes: 4 additions & 4 deletions rust/load/src/bit_difference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use guacamole::{FromGuacamole, Guacamole, Zipf};
use siphasher::sip::SipHasher24;
use tracing::Instrument;

use crate::words::WORDS;
use crate::words::MANY_WORDS;
use crate::{DataSet, GetQuery, KeySelector, QueryQuery, Skew, UpsertQuery};

const EMBEDDING_BYTES: usize = 128;
Expand Down Expand Up @@ -101,7 +101,7 @@ impl Document {
pub fn embedding(&self) -> Vec<f32> {
let mut result = vec![];
let words = self.content.split_whitespace().collect::<Vec<_>>();
for word in WORDS.iter() {
for word in MANY_WORDS.iter() {
if words.contains(word) {
result.push(1.0);
} else {
Expand All @@ -114,7 +114,7 @@ impl Document {

impl From<[u8; EMBEDDING_BYTES]> for Document {
fn from(embedding: [u8; EMBEDDING_BYTES]) -> Document {
let document = WORDS
let document = MANY_WORDS
.iter()
.enumerate()
.filter_map(|(idx, word)| {
Expand Down Expand Up @@ -388,7 +388,7 @@ mod tests {

#[test]
fn constants() {
assert_eq!(EMBEDDING_SIZE, WORDS.len());
assert_eq!(EMBEDDING_SIZE, MANY_WORDS.len());
}

mod synthethic {
Expand Down
14 changes: 14 additions & 0 deletions rust/load/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,9 @@ pub enum WhereMixin {
/// A raw metadata query simply copies the provided filter spec.
#[serde(rename = "query")]
Constant(serde_json::Value),
/// Search for a word from the provided set of words with skew.
#[serde(rename = "fts")]
FullTextSearch(Skew),
/// The tiny stories workload. The way these collections were setup, there are three fields
/// each of integer, float, and string. The integer fields are named i1, i2, and i3. The
/// float fields are named f1, f2, and f3. The string fields are named s1, s2, and s3.
Expand All @@ -325,6 +328,17 @@ impl WhereMixin {
pub fn to_json(&self, guac: &mut Guacamole) -> serde_json::Value {
match self {
Self::Constant(query) => query.clone(),
Self::FullTextSearch(skew) => {
const WORDS: &[&str] = words::FEW_WORDS;
let word = match skew {
Skew::Uniform => WORDS[uniform(0, WORDS.len() as u64)(guac) as usize],
Skew::Zipf { theta } => {
let z = Zipf::from_alpha(WORDS.len() as u64, *theta);
WORDS[z.next(guac) as usize]
}
};
serde_json::json!({"$contains": word.to_string()})
}
Self::TinyStories(mixin) => mixin.to_json(guac),
Self::Select(select) => {
let scale: f64 = any(guac);
Expand Down
Loading

0 comments on commit 1e0030c

Please sign in to comment.