Skip to content

Commit

Permalink
[FEATURE] add min_hash alternate hashers (#3052)
Browse files Browse the repository at this point in the history
Likely also increases performance due to removing heap alloc in some
places.

---------

Co-authored-by: Colin Ho <[email protected]>
  • Loading branch information
andrewgazelka and Colin Ho authored Oct 30, 2024
1 parent 404a37e commit c803bc9
Show file tree
Hide file tree
Showing 26 changed files with 2,035 additions and 338 deletions.
301 changes: 294 additions & 7 deletions Cargo.lock

Large diffs are not rendered by default.

63 changes: 38 additions & 25 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ daft-csv = {path = "src/daft-csv", default-features = false}
daft-dsl = {path = "src/daft-dsl", default-features = false}
daft-functions = {path = "src/daft-functions", default-features = false}
daft-functions-json = {path = "src/daft-functions-json", default-features = false}
daft-hash = {path = "src/daft-hash", default-features = false}
daft-image = {path = "src/daft-image", default-features = false}
daft-io = {path = "src/daft-io", default-features = false}
daft-json = {path = "src/daft-json", default-features = false}
Expand All @@ -37,29 +38,29 @@ sysinfo = {workspace = true}
[features]
# maturin will turn this on
python = [
"dep:pyo3",
"dep:pyo3-log",
"common-daft-config/python",
"common-display/python",
"common-resource-request/python",
"common-system-info/python",
"daft-core/python",
"daft-csv/python",
"daft-dsl/python",
"daft-local-execution/python",
"daft-io/python",
"daft-functions-json/python",
"daft-functions/python",
"daft-image/python",
"daft-io/python",
"daft-json/python",
"daft-local-execution/python",
"daft-micropartition/python",
"daft-parquet/python",
"daft-plan/python",
"daft-scan/python",
"daft-scheduler/python",
"daft-stats/python",
"daft-sql/python",
"daft-stats/python",
"daft-table/python",
"daft-functions/python",
"daft-functions-json/python",
"common-daft-config/python",
"common-system-info/python",
"common-display/python",
"common-resource-request/python"
"dep:pyo3",
"dep:pyo3-log"
]

[lib]
Expand Down Expand Up @@ -113,35 +114,38 @@ tikv-jemallocator = {version = "0.5.4", features = [
[workspace]
members = [
"src/arrow2",
"src/parquet2",
"src/common/daft-config",
"src/common/display",
"src/common/error",
"src/common/io-config",
"src/common/treenode",
"src/common/daft-config",
"src/common/system-info",
"src/common/treenode",
"src/daft-core",
"src/daft-local-execution",
"src/daft-io",
"src/daft-image",
"src/daft-parquet",
"src/daft-csv",
"src/daft-json",
"src/daft-dsl",
"src/daft-table",
"src/daft-plan",
"src/daft-physical-plan",
"src/daft-functions",
"src/daft-functions-json",
"src/daft-hash",
"src/daft-image",
"src/daft-io",
"src/daft-json",
"src/daft-local-execution",
"src/daft-micropartition",
"src/daft-parquet",
"src/daft-physical-plan",
"src/daft-plan",
"src/daft-scan",
"src/daft-scheduler",
"src/daft-sketch",
"src/daft-functions",
"src/daft-functions-json",
"src/daft-sql",
"src/hyperloglog"
"src/daft-table",
"src/hyperloglog",
"src/parquet2"
]

[workspace.dependencies]
ahash = "0.8.11"
approx = "0.5.1"
async-compat = "0.2.3"
async-compression = {version = "0.4.12", features = [
"tokio",
Expand All @@ -154,7 +158,10 @@ bytes = "1.6.0"
chrono = "0.4.38"
chrono-tz = "0.8.4"
comfy-table = "7.1.1"
common-error = {path = "src/common/error", default-features = false}
daft-hash = {path = "src/daft-hash"}
derivative = "2.2.0"
divan = "0.1.14"
dyn-clone = "1"
futures = "0.3.30"
html-escape = "0.2.13"
Expand All @@ -164,20 +171,25 @@ jaq-core = "1.2.0"
jaq-interpret = "1.2.0"
jaq-parse = "1.0.0"
jaq-std = "1.2.0"
mur3 = "0.1.0"
num-derive = "0.3.3"
num-traits = "0.2"
once_cell = "1.19.0"
path_macro = "1.0.0"
pretty_assertions = "1.4.0"
proptest = "1.5.0"
rand = "^0.8"
rayon = "1.10.0"
regex = "1.10.4"
rstest = "0.18.2"
rustc-hash = "2.0.0"
serde_json = "1.0.116"
sha1 = "0.11.0-pre.4"
sketches-ddsketch = {version = "0.2.2", features = ["use_serde"]}
snafu = {version = "0.7.4", features = ["futures"]}
sqlparser = "0.51.0"
sysinfo = "0.30.12"
tango-bench = "0.6.0"
test-log = "0.2.16"
thiserror = "1.0.63"
tiktoken-rs = "0.5.9"
Expand All @@ -195,6 +207,7 @@ tokio-stream = {version = "0.1.14", features = ["fs", "io-util", "time"]}
tokio-util = "0.7.11"
tracing = "0.1"
url = "2.4.0"
xxhash-rust = "0.8.12"

[workspace.dependencies.arrow2]
path = "src/arrow2"
Expand Down
9 changes: 8 additions & 1 deletion daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1212,6 +1212,7 @@ def minhash(
num_hashes: int,
ngram_size: int,
seed: int = 1,
hash_function: Literal["murmurhash3", "xxhash", "sha1"] = "murmurhash3",
) -> PyExpr: ...

# -----
Expand Down Expand Up @@ -1347,7 +1348,13 @@ class PySeries:
def sort(self, descending: bool) -> PySeries: ...
def argsort(self, descending: bool) -> PySeries: ...
def hash(self, seed: PySeries | None = None) -> PySeries: ...
def minhash(self, num_hashes: int, ngram_size: int, seed: int = 1) -> PySeries: ...
def minhash(
self,
num_hashes: int,
ngram_size: int,
seed: int = 1,
hash_function: Literal["murmurhash3", "xxhash", "sha1"] = "murmurhash3",
) -> PySeries: ...
def __invert__(self) -> PySeries: ...
def count(self, mode: CountMode) -> PySeries: ...
def sum(self) -> PySeries: ...
Expand Down
9 changes: 7 additions & 2 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,6 +1196,7 @@ def minhash(
num_hashes: int,
ngram_size: int,
seed: int = 1,
hash_function: Literal["murmurhash3", "xxhash", "sha1"] = "murmurhash3",
) -> Expression:
"""
Runs the MinHash algorithm on the series.
Expand All @@ -1204,19 +1205,23 @@ def minhash(
repeating with `num_hashes` permutations. Returns as a list of 32-bit unsigned integers.
Tokens for the ngrams are delimited by spaces.
MurmurHash is used for the initial hash.
The strings are not normalized or pre-processed, so it is recommended
to normalize the strings yourself.
Args:
num_hashes: The number of hash permutations to compute.
ngram_size: The number of tokens in each shingle/ngram.
seed (optional): Seed used for generating permutations and the initial string hashes. Defaults to 1.
hash_function (optional): Hash function to use for initial string hashing. One of "murmurhash3", "xxhash", or "sha1". Defaults to "murmurhash3".
"""
assert isinstance(num_hashes, int)
assert isinstance(ngram_size, int)
assert isinstance(seed, int)
return Expression._from_pyexpr(native.minhash(self._expr, num_hashes, ngram_size, seed))
assert isinstance(hash_function, str)
assert hash_function in ["murmurhash3", "xxhash", "sha1"], f"Hash function {hash_function} not found"

return Expression._from_pyexpr(native.minhash(self._expr, num_hashes, ngram_size, seed, hash_function))

def name(self) -> builtins.str:
return self._expr.name()
Expand Down
13 changes: 11 additions & 2 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,7 @@ def minhash(
num_hashes: int,
ngram_size: int,
seed: int = 1,
hash_function: Literal["murmurhash3", "xxhash", "sha1"] = "murmurhash3",
) -> Series:
"""
Runs the MinHash algorithm on the series.
Expand All @@ -582,15 +583,23 @@ def minhash(
num_hashes: The number of hash permutations to compute.
ngram_size: The number of tokens in each shingle/ngram.
seed (optional): Seed used for generating permutations and the initial string hashes. Defaults to 1.
hash_function (optional): Hash function to use for initial string hashing. One of "murmur3", "xxhash", or "sha1". Defaults to "murmur3".
"""
if not isinstance(num_hashes, int):
raise ValueError(f"expected an integer for num_hashes but got {type(num_hashes)}")
if not isinstance(ngram_size, int):
raise ValueError(f"expected an integer for ngram_size but got {type(ngram_size)}")
if seed is not None and not isinstance(seed, int):
raise ValueError(f"expected an integer or None for seed but got {type(seed)}")

return Series._from_pyseries(self._series.minhash(num_hashes, ngram_size, seed))
if not isinstance(hash_function, str):
raise ValueError(f"expected str for hash_function but got {type(hash_function)}")
assert hash_function in [
"murmurhash3",
"xxhash",
"sha1",
], f"hash_function must be one of 'murmurhash3', 'xxhash', 'sha1', got {hash_function}"

return Series._from_pyseries(self._series.minhash(num_hashes, ngram_size, seed, hash_function))

def _to_str_values(self) -> Series:
return Series._from_pyseries(self._series.to_str_values())
Expand Down
11 changes: 6 additions & 5 deletions src/daft-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ common-display = {path = "../common/display", default-features = false}
common-error = {path = "../common/error", default-features = false}
common-hashable-float-wrapper = {path = "../common/hashable-float-wrapper"}
common-py-serde = {path = "../common/py-serde", default-features = false}
daft-hash = {workspace = true}
daft-minhash = {path = "../daft-minhash", default-features = false}
daft-schema = {path = "../daft-schema", default-features = false}
daft-sketch = {path = "../daft-sketch", default-features = false}
Expand All @@ -50,17 +51,17 @@ optional = true
version = "0.21.0"

[dependencies.xxhash-rust]
features = ["xxh3", "const_xxh3"]
features = ["xxh3", "const_xxh3", "xxh64"]
version = "0.8.5"

[features]
python = [
"dep:pyo3",
"dep:numpy",
"common-arrow-ffi/python",
"common-error/python",
"common-py-serde/python",
"common-arrow-ffi/python",
"daft-schema/python"
"daft-schema/python",
"dep:numpy",
"dep:pyo3"
]

[lints]
Expand Down
79 changes: 57 additions & 22 deletions src/daft-core/src/array/ops/minhash.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::iter::repeat_with;
use std::{collections::VecDeque, hash::BuildHasher, iter::repeat_with};

use arrow2::array::{MutableArray, MutablePrimitiveArray, PrimitiveArray};
use common_error::{DaftError, DaftResult};
Expand All @@ -14,7 +14,13 @@ use crate::{
impl DaftMinHash for Utf8Array {
type Output = DaftResult<FixedSizeListArray>;

fn minhash(&self, num_hashes: usize, ngram_size: usize, seed: u32) -> Self::Output {
fn minhash(
&self,
num_hashes: usize,
ngram_size: usize,
seed: u32,
hasher: &impl BuildHasher,
) -> Self::Output {
if num_hashes == 0 {
return Err(DaftError::ValueError(
"Number of hashes must be nonzero".into(),
Expand All @@ -24,42 +30,71 @@ impl DaftMinHash for Utf8Array {
return Err(DaftError::ValueError("Ngram size must be nonzero".into()));
}

// generate permutations
// Generate coefficients for MinHash permutation function: (a * x + b) % p
//
// The MinHash algorithm uses a hash function of the form (a * x + b) % p,
// where 'a' and 'b' are permutation coefficients, 'x' is the input hash,
// and 'p' is typically a large prime number.
//
// 1. perm_a (coefficient 'a'):
// - Starts from 1 to ensure 'a' is never zero
// - A non-zero 'a' is crucial for maintaining the bijective property of the permutation
//
// Example of how bijectivity fails if a = 0:
// Let p = 7 (prime number)
// If a = 0, b = 3, the function becomes: (0 * x + 3) % 7 = 3
// This always outputs 3, regardless of the input x, losing the bijective property
//
// 2. perm_b (coefficient 'b'):
// - Range: 0 to (i32::MAX as u64) - 1
// - Can start from 0 as 'b' can be any value without affecting the permutation property
//
// This approach ensures valid and uniformly distributed hash values, which is
// essential for accurate set similarity estimation in MinHash.
let mut rng = fastrand::Rng::with_seed(seed as u64);
let perm_a = repeat_with(|| rng.u64(1..(i32::MAX as u64))).take(num_hashes);
let perm_a_simd = load_simd(perm_a, num_hashes);
let perm_b = repeat_with(|| rng.u64(0..(i32::MAX as u64))).take(num_hashes);
let perm_b_simd = load_simd(perm_b, num_hashes);

let self_arrow = self.as_arrow();
let internal_arrow_representation = self.as_arrow();
let mut output: MutablePrimitiveArray<u32> =
MutablePrimitiveArray::with_capacity(num_hashes * self.len());
for maybe_s in self_arrow {
if let Some(s) = maybe_s {
let minhash_res = daft_minhash::minhash(
s,
(&perm_a_simd, &perm_b_simd),
num_hashes,
ngram_size,
seed,
)?;
output.extend(minhash_res.into_iter().map(Some));
} else {

let mut alloc = VecDeque::new();

for elem in internal_arrow_representation {
let Some(elem) = elem else {
for _ in 0..num_hashes {
output.push_null();
}
}
continue;
};

let minhash_res = daft_minhash::minhash_in(
elem,
(&perm_a_simd, &perm_b_simd),
num_hashes,
ngram_size,
hasher,
&mut alloc,
)?;

output.extend(minhash_res.into_iter().map(Some));
}
let output_immut: PrimitiveArray<u32> = output.into();

let immutable_output: PrimitiveArray<u32> = output.into();
let output_series = Series::from_arrow(
Field::new(self.name(), DataType::UInt32).into(),
Box::new(output_immut),
Box::new(immutable_output),
)?;
let field = Field::new(
self.name(),
DataType::FixedSizeList(Box::new(DataType::UInt32), num_hashes),
);

Ok(FixedSizeListArray::new(
Field::new(
self.name(),
DataType::FixedSizeList(Box::new(DataType::UInt32), num_hashes),
),
field,
output_series,
self.validity().cloned(),
))
Expand Down
Loading

0 comments on commit c803bc9

Please sign in to comment.