Skip to content

Commit

Permalink
stash
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Oct 23, 2024
1 parent d5eccab commit 10d7438
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 25 deletions.
9 changes: 9 additions & 0 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,15 @@ class ResourceRequest:
def __eq__(self, other: ResourceRequest) -> bool: ... # type: ignore[override]
def __ne__(self, other: ResourceRequest) -> bool: ... # type: ignore[override]

class HashFunctionKind(Enum):
"""
Kind of hash function to use for minhash.
"""

MurmurHash3: int
XxHash: int
Sha1: int

class FileFormat(Enum):
"""
Format of a file, e.g. Parquet, CSV, and JSON.
Expand Down
10 changes: 9 additions & 1 deletion src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ use pyo3::{
types::{PyBytes, PyList},
};

fn x(x: HashFunctionKind) -> PyResult<HashFunctionKind> {}

use crate::{
array::{
ops::{
Expand Down Expand Up @@ -316,7 +318,13 @@ impl PySeries {
Ok(self.series.hash(seed_array)?.into_series().into())
}

pub fn minhash(&self, num_hashes: i64, ngram_size: i64, seed: i64) -> PyResult<Self> {
pub fn minhash(
&self,
num_hashes: i64,
ngram_size: i64,
seed: i64,
hash_function: HashFunctionKind,
) -> PyResult<Self> {
if num_hashes <= 0 {
return Err(PyValueError::new_err(format!(
"num_hashes must be positive: {num_hashes}"
Expand Down
47 changes: 25 additions & 22 deletions src/daft-functions/src/minhash.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::hash::{BuildHasher, BuildHasherDefault};
use std::{
hash::{BuildHasher, BuildHasherDefault},
str::FromStr,
};

use common_error::{DaftError, DaftResult};
use daft_core::prelude::*;
Expand All @@ -7,15 +10,16 @@ use daft_dsl::{
ExprRef,
};
use daft_hash::{MurBuildHasher, Sha1Hasher};
use pyo3::{pyclass, pymethods, types::PyType, PyErr, PyResult};
#[cfg(feature = "python")]
use pyo3::pyclass;
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct MinHashFunction {
pub num_hashes: usize,
pub ngram_size: usize,
pub seed: u32,
pub hash_function: HashFunctionLiteral,
pub hash_function: HashFunctionKind,
}

#[typetag::serde]
Expand All @@ -37,15 +41,15 @@ impl ScalarUDF for MinHashFunction {
};

match self.hash_function {
HashFunctionLiteral::MurmurHash3 => {
HashFunctionKind::MurmurHash3 => {
let hasher = MurBuildHasher::new(self.seed);
input.minhash(self.num_hashes, self.ngram_size, self.seed, &hasher)
}
HashFunctionLiteral::XxHash => {
HashFunctionKind::XxHash => {
let hasher = xxhash_rust::xxh64::Xxh64Builder::new(self.seed as u64);
input.minhash(self.num_hashes, self.ngram_size, self.seed, &hasher)
}
HashFunctionLiteral::Sha1 => {
HashFunctionKind::Sha1 => {
let hasher = BuildHasherDefault::<Sha1Hasher>::default();
input.minhash(self.num_hashes, self.ngram_size, self.seed, &hasher)
}
Expand Down Expand Up @@ -80,7 +84,7 @@ pub fn minhash(
num_hashes: usize,
ngram_size: usize,
seed: u32,
hash_function: HashFunctionLiteral,
hash_function: HashFunctionKind,
) -> ExprRef {
ScalarFunction::new(
MinHashFunction {
Expand All @@ -94,26 +98,25 @@ pub fn minhash(
.into()
}

// todo: what
#[pyclass]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum HashFunctionLiteral {
/// Format of a file, e.g. Parquet, CSV, JSON.
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Copy)]
#[cfg_attr(feature = "python", pyclass(module = "daft.daft"))]
pub enum HashFunctionKind {
MurmurHash3,
XxHash,
Sha1,
}

#[pymethods]
impl HashFunctionLiteral {
// todo: is there an updated way to do this? it says it is using a deprecated method
#[classmethod]
fn from_str(_cls: &PyType, s: &str) -> PyResult<Self> {
match s.to_lowercase().as_str() {
"murmurhash3" => Ok(Self::MurmurHash3),
impl FromStr for HashFunctionKind {
type Err = DaftError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"murmur3" => Ok(Self::MurmurHash3),
"xxhash" => Ok(Self::XxHash),
"sha1" => Ok(Self::Sha1),
_ => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Invalid hash function: {}",
_ => Err(DaftError::ValueError(format!(
"Hash function {} not found",
s
))),
}
Expand All @@ -125,15 +128,15 @@ pub mod python {
use daft_dsl::python::PyExpr;
use pyo3::{exceptions::PyValueError, pyfunction, PyResult};

use crate::minhash::HashFunctionLiteral;
use crate::minhash::HashFunctionKind;

#[pyfunction]
pub fn minhash(
expr: PyExpr,
num_hashes: i64,
ngram_size: i64,
seed: i64,
hash_function: HashFunctionLiteral,
hash_function: HashFunctionKind,
) -> PyResult<PyExpr> {
if num_hashes <= 0 {
return Err(PyValueError::new_err(format!(
Expand Down
25 changes: 23 additions & 2 deletions src/daft-sql/src/modules/hashing.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use daft_dsl::ExprRef;
use daft_functions::{
hash::hash,
minhash::{minhash, MinHashFunction},
minhash::{minhash, HashFunctionKind, MinHashFunction},
};
use sqlparser::ast::FunctionArg;

Expand Down Expand Up @@ -74,6 +74,7 @@ impl TryFrom<SQLFunctionArguments> for MinHashFunction {
.and_then(daft_dsl::LiteralValue::as_i64)
.ok_or_else(|| PlannerError::invalid_operation("ngram_size must be an integer"))?
as usize;

let seed = args
.get_named("seed")
.map(|arg| {
Expand All @@ -83,10 +84,24 @@ impl TryFrom<SQLFunctionArguments> for MinHashFunction {
})
.transpose()?
.unwrap_or(1) as u32;

let hash_function = args
.get_named("hash_function")
.map(|arg| {
arg.as_literal()
.and_then(daft_dsl::LiteralValue::as_str)
.ok_or_else(|| {
PlannerError::invalid_operation("hash_function must be a string")
})
})
.transpose()?
.unwrap_or("murmur3");

Ok(Self {
num_hashes,
ngram_size,
seed,
hash_function: hash_function.parse()?,
})
}
}
Expand All @@ -103,7 +118,13 @@ impl SQLFunction for SQLMinhash {
let args: MinHashFunction =
planner.plan_function_args(args, &["num_hashes", "ngram_size", "seed"], 0)?;

Ok(minhash(input, args.num_hashes, args.ngram_size, args.seed))
Ok(minhash(
input,
args.num_hashes,
args.ngram_size,
args.seed,
args.hash_function,
))
}
_ => unsupported_sql_err!("Invalid arguments for minhash: '{inputs:?}'"),
}
Expand Down

0 comments on commit 10d7438

Please sign in to comment.