From cf9a09bb237aa2a0420af920375893f2fddf0c42 Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Tue, 9 Jul 2024 09:54:34 -0500 Subject: [PATCH] [FEAT]: dyn function registry (#2466) this adds a new variant to `Expr`. `Expr::ScalarFunction` which is pretty similar to `FunctionExpr`, but uses dynamic dispatch and a registry instead of the enum variants. The registry is inspired by datafusion's function registry. when an expr is serialized, it just serializes the name and the inputs. for example `col("text").hash(seed=42)` ```js { "name": "hash", "inputs": [ `col('text')`, // serialized repr of this `lit(42)` // serialized repr of this ] } ``` then when deserializing, it just fetches the appropriate function from the registry. _(errorring if no matches found)_. Also just to make sure everything works, I refactored the `hash` function to use the new paradigm. --- Cargo.lock | 17 +++ src/daft-dsl/Cargo.toml | 3 + src/daft-dsl/src/expr.rs | 17 ++- src/daft-dsl/src/functions/hash.rs | 51 ++++---- src/daft-dsl/src/functions/json/mod.rs | 2 +- src/daft-dsl/src/functions/mod.rs | 11 +- src/daft-dsl/src/functions/registry.rs | 43 ++++++ src/daft-dsl/src/functions/scalar.rs | 160 +++++++++++++++++++++++ src/daft-dsl/src/optimization.rs | 1 + src/daft-dsl/src/python.rs | 6 +- src/daft-plan/src/logical_ops/project.rs | 16 +++ src/daft-plan/src/partitioning.rs | 10 ++ src/daft-table/src/lib.rs | 8 ++ 13 files changed, 309 insertions(+), 36 deletions(-) create mode 100644 src/daft-dsl/src/functions/registry.rs create mode 100644 src/daft-dsl/src/functions/scalar.rs diff --git a/Cargo.lock b/Cargo.lock index c189ea1df7..f4e9e124d1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1672,8 +1672,11 @@ dependencies = [ "daft-core", "daft-io", "daft-sketch", + "dashmap", + "erased-serde", "indexmap 2.2.6", "itertools 0.11.0", + "lazy_static", "pyo3", "pyo3-log", "serde", @@ -2015,6 +2018,20 @@ dependencies = [ "serde", ] +[[package]] +name = "dashmap" +version = "6.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "804c8821570c3f8b70230c2ba75ffa5c0f9a4189b9a432b6656c536712acae28" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.3", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "deflate64" version = "0.1.6" diff --git a/src/daft-dsl/Cargo.toml b/src/daft-dsl/Cargo.toml index d745015f7f..1cd8750d6d 100644 --- a/src/daft-dsl/Cargo.toml +++ b/src/daft-dsl/Cargo.toml @@ -6,8 +6,11 @@ common-treenode = {path = "../common/treenode", default-features = false} daft-core = {path = "../daft-core", default-features = false} daft-io = {path = "../daft-io", default-features = false} daft-sketch = {path = "../daft-sketch", default-features = false} +dashmap = "6.0.1" +erased-serde = "0.4.5" indexmap = {workspace = true} itertools = {workspace = true} +lazy_static = {workspace = true} pyo3 = {workspace = true, optional = true} pyo3-log = {workspace = true, optional = true} serde = {workspace = true} diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr.rs index ae4af41591..73000c6466 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr.rs @@ -7,10 +7,10 @@ use daft_core::{ use crate::{ functions::{ - function_display, function_semantic_id, + function_display, function_semantic_id, scalar_function_semantic_id, sketch::{HashableVecPercentiles, SketchExpr}, struct_::StructExpr, - FunctionEvaluator, + FunctionEvaluator, ScalarFunction, }, lit, optimization::{get_required_columns, requires_computation}, @@ -58,6 +58,7 @@ pub enum Expr { if_false: ExprRef, predicate: ExprRef, }, + ScalarFunction(ScalarFunction), } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Hash, Eq)] @@ -576,6 +577,7 @@ impl Expr { // Agg: Separate path. Agg(agg_expr) => agg_expr.semantic_id(schema), + ScalarFunction(sf) => scalar_function_semantic_id(sf, schema), } } @@ -607,6 +609,7 @@ impl Expr { vec![if_true.clone(), if_false.clone(), predicate.clone()] } FillNull(expr, fill_value) => vec![expr.clone(), fill_value.clone()], + ScalarFunction(sf) => sf.inputs.clone(), } } @@ -658,6 +661,7 @@ impl Expr { func: func.clone(), inputs: children, }, + ScalarFunction(sf) => ScalarFunction(sf.clone()), } } @@ -710,6 +714,8 @@ impl Expr { } Literal(value) => Ok(Field::new("literal", value.get_type())), Function { func, inputs } => func.to_field(inputs.as_slice(), schema, func), + ScalarFunction(sf) => sf.to_field(schema), + BinaryOp { op, left, right } => { let left_field = left.to_field(schema)?; let right_field = right.to_field(schema)?; @@ -814,6 +820,7 @@ impl Expr { FunctionExpr::Struct(StructExpr::Get(name)) => name, _ => inputs.first().unwrap().name(), }, + ScalarFunction(func) => func.inputs.first().unwrap().name(), BinaryOp { op: _, left, @@ -903,7 +910,8 @@ impl Expr { | Expr::IsIn(..) | Expr::Between(..) | Expr::Function { .. } - | Expr::FillNull(..) => Err(io::Error::new( + | Expr::FillNull(..) + | Expr::ScalarFunction { .. } => Err(io::Error::new( io::ErrorKind::Other, "Unsupported expression for SQL translation", )), @@ -946,6 +954,8 @@ impl Display for Expr { Between(expr, lower, upper) => write!(f, "{expr} in [{lower},{upper}]"), Literal(val) => write!(f, "lit({val})"), Function { func, inputs } => function_display(f, func, inputs), + ScalarFunction(func) => write!(f, "{func}"), + IfElse { if_true, if_false, @@ -1130,6 +1140,7 @@ fn expr_has_agg(expr: &ExprRef) -> bool { Alias(e, _) | Cast(e, _) | Not(e) | IsNull(e) | NotNull(e) => expr_has_agg(e), BinaryOp { left, right, .. } => expr_has_agg(left) || expr_has_agg(right), Function { inputs, .. } => inputs.iter().any(expr_has_agg), + ScalarFunction(func) => func.inputs.iter().any(expr_has_agg), IsIn(l, r) | FillNull(l, r) => expr_has_agg(l) || expr_has_agg(r), Between(v, l, u) => expr_has_agg(v) || expr_has_agg(l) || expr_has_agg(u), IfElse { diff --git a/src/daft-dsl/src/functions/hash.rs b/src/daft-dsl/src/functions/hash.rs index 3dc02de788..7449aac195 100644 --- a/src/daft-dsl/src/functions/hash.rs +++ b/src/daft-dsl/src/functions/hash.rs @@ -4,33 +4,25 @@ use daft_core::{ schema::Schema, DataType, IntoSeries, Series, }; +use serde::{Deserialize, Serialize}; -use crate::{ - functions::{FunctionEvaluator, FunctionExpr}, - Expr, ExprRef, -}; +use crate::{Expr, ExprRef}; -pub(super) struct HashEvaluator {} +use super::{ScalarFunction, ScalarUDF}; -impl FunctionEvaluator for HashEvaluator { - fn fn_name(&self) -> &'static str { - "hash" +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub(super) struct HashFunction; + +impl ScalarUDF for HashFunction { + fn as_any(&self) -> &dyn std::any::Any { + self } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [input] | [input, _] => match input.to_field(schema) { - Ok(field) => Ok(Field::new(field.name, DataType::UInt64)), - e => e, - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 2 input arg, got {}", - inputs.len() - ))), - } + fn name(&self) -> &'static str { + "hash" } - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { [input] => input.hash(None).map(|s| s.into_series()), [input, seed] => { @@ -60,6 +52,19 @@ impl FunctionEvaluator for HashEvaluator { _ => Err(DaftError::ValueError("Expected 2 input arg".to_string())), } } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [input] | [input, _] => match input.to_field(schema) { + Ok(field) => Ok(Field::new(field.name, DataType::UInt64)), + e => e, + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 2 input arg, got {}", + inputs.len() + ))), + } + } } pub fn hash(input: ExprRef, seed: Option) -> ExprRef { @@ -68,9 +73,5 @@ pub fn hash(input: ExprRef, seed: Option) -> ExprRef { None => vec![input], }; - Expr::Function { - func: FunctionExpr::Hash, - inputs, - } - .into() + Expr::ScalarFunction(ScalarFunction::new(HashFunction {}, inputs)).into() } diff --git a/src/daft-dsl/src/functions/json/mod.rs b/src/daft-dsl/src/functions/json/mod.rs index 96e3187e4b..e929da00b4 100644 --- a/src/daft-dsl/src/functions/json/mod.rs +++ b/src/daft-dsl/src/functions/json/mod.rs @@ -14,7 +14,7 @@ pub enum JsonExpr { impl JsonExpr { #[inline] - pub fn query_evaluator(&self) -> &dyn FunctionEvaluator { + pub fn get_evaluator(&self) -> &dyn FunctionEvaluator { use JsonExpr::*; match self { Query(_) => &JsonQueryEvaluator {}, diff --git a/src/daft-dsl/src/functions/mod.rs b/src/daft-dsl/src/functions/mod.rs index 6cb522e593..376c646acb 100644 --- a/src/daft-dsl/src/functions/mod.rs +++ b/src/daft-dsl/src/functions/mod.rs @@ -7,6 +7,8 @@ pub mod map; pub mod minhash; pub mod numeric; pub mod partitioning; +pub mod registry; +pub mod scalar; pub mod sketch; pub mod struct_; pub mod temporal; @@ -14,6 +16,7 @@ pub mod uri; pub mod utf8; use std::fmt::{Display, Formatter, Result}; +use std::hash::Hash; use crate::ExprRef; @@ -28,10 +31,12 @@ use self::struct_::StructExpr; use self::temporal::TemporalExpr; use self::utf8::Utf8Expr; use self::{float::FloatExpr, uri::UriExpr}; +pub use scalar::*; + use common_error::DaftResult; use daft_core::datatypes::FieldID; use daft_core::{datatypes::Field, schema::Schema, series::Series}; -use hash::HashEvaluator; + use minhash::{MinHashEvaluator, MinHashExpr}; use serde::{Deserialize, Serialize}; @@ -56,7 +61,6 @@ pub enum FunctionExpr { Python(PythonUDF), Partitioning(PartitioningExpr), Uri(UriExpr), - Hash, MinHash(MinHashExpr), } @@ -84,13 +88,12 @@ impl FunctionExpr { Map(expr) => expr.get_evaluator(), Sketch(expr) => expr.get_evaluator(), Struct(expr) => expr.get_evaluator(), - Json(expr) => expr.query_evaluator(), + Json(expr) => expr.get_evaluator(), Image(expr) => expr.get_evaluator(), Uri(expr) => expr.get_evaluator(), #[cfg(feature = "python")] Python(expr) => expr, Partitioning(expr) => expr.get_evaluator(), - Hash => &HashEvaluator {}, MinHash(_) => &MinHashEvaluator {}, } } diff --git a/src/daft-dsl/src/functions/registry.rs b/src/daft-dsl/src/functions/registry.rs new file mode 100644 index 0000000000..74e42f2fbc --- /dev/null +++ b/src/daft-dsl/src/functions/registry.rs @@ -0,0 +1,43 @@ +use std::sync::Arc; + +use common_error::{DaftError, DaftResult}; +use dashmap::DashMap; + +use super::{hash::HashFunction, ScalarUDF}; + +lazy_static::lazy_static! { + pub static ref REGISTRY: Registry = Registry::new(); +} + +pub struct Registry { + functions: DashMap<&'static str, Arc>, +} + +impl Registry { + fn new() -> Self { + let iter: Vec> = vec![Arc::new(HashFunction {})]; + + let functions = iter.into_iter().map(|f| (f.name(), f)).collect(); + + Self { functions } + } + pub fn register(&mut self, function: Arc) -> DaftResult<()> { + if self.functions.contains_key(function.name()) { + Err(DaftError::ValueError(format!( + "function {} already exists", + function.name() + ))) + } else { + self.functions.insert(function.name(), function); + Ok(()) + } + } + + pub fn get(&self, name: &str) -> Option> { + self.functions.get(name).map(|f| f.value().clone()) + } + + pub fn names(&self) -> Vec<&'static str> { + self.functions.iter().map(|pair| pair.name()).collect() + } +} diff --git a/src/daft-dsl/src/functions/scalar.rs b/src/daft-dsl/src/functions/scalar.rs new file mode 100644 index 0000000000..d8d0c8e41a --- /dev/null +++ b/src/daft-dsl/src/functions/scalar.rs @@ -0,0 +1,160 @@ +use std::any::Any; +use std::fmt::{Display, Formatter}; +use std::sync::Arc; + +use common_error::DaftResult; +use daft_core::datatypes::FieldID; +use daft_core::{datatypes::Field, schema::Schema, series::Series}; + +use crate::ExprRef; + +use super::registry::REGISTRY; +use serde::ser::SerializeStruct; +use serde::{Deserialize, Serialize}; +#[derive(Debug, Clone)] +pub struct ScalarFunction { + pub udf: Arc, + pub inputs: Vec, +} + +impl ScalarFunction { + pub fn new(udf: UDF, inputs: Vec) -> Self { + Self { + udf: Arc::new(udf), + inputs, + } + } + + pub fn name(&self) -> &str { + self.udf.name() + } + + pub fn to_field(&self, schema: &Schema) -> DaftResult { + self.udf.to_field(&self.inputs, schema) + } +} + +pub trait ScalarUDF: Send + Sync + std::fmt::Debug + erased_serde::Serialize { + fn as_any(&self) -> &dyn Any; + fn name(&self) -> &'static str; + // TODO: evaluate should allow for &[Series | LiteralValue] inputs. + fn evaluate(&self, inputs: &[Series]) -> DaftResult; + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult; +} + +erased_serde::serialize_trait_object!(ScalarUDF); + +pub fn scalar_function_semantic_id(func: &ScalarFunction, schema: &Schema) -> FieldID { + let inputs = func + .inputs + .iter() + .map(|expr| expr.semantic_id(schema).id.to_string()) + .collect::>() + .join(", "); + // TODO: check for function idempotency here. + FieldID::new(format!("Function_{func:?}({inputs})")) +} + +impl PartialEq for ScalarFunction { + fn eq(&self, other: &Self) -> bool { + self.name() == other.name() && self.inputs == other.inputs + } +} + +impl Eq for ScalarFunction {} +impl std::hash::Hash for ScalarFunction { + fn hash(&self, state: &mut H) { + self.name().hash(state); + self.inputs.hash(state); + } +} + +impl Serialize for ScalarFunction { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let mut struct_ = serializer.serialize_struct("ScalarFunction", 2)?; + struct_.serialize_field("name", &self.name())?; + struct_.serialize_field("inputs", &self.inputs)?; + + struct_.end() + } +} + +impl<'de> Deserialize<'de> for ScalarFunction { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct ScalarFunctionVisitor; + + impl<'de> serde::de::Visitor<'de> for ScalarFunctionVisitor { + type Value = ScalarFunction; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("struct ScalarFunction") + } + fn visit_seq(self, mut seq: A) -> std::result::Result + where + A: serde::de::SeqAccess<'de>, + { + let name: String = seq + .next_element()? + .ok_or_else(|| serde::de::Error::invalid_length(0, &"2"))?; + + match REGISTRY.get(&name) { + None => Err(serde::de::Error::unknown_field(&name, &[])), + Some(udf) => { + let inputs = seq + .next_element::>()? + .ok_or_else(|| serde::de::Error::invalid_length(1, &"2"))?; + + Ok(ScalarFunction { + udf: udf.clone(), + inputs, + }) + } + } + } + fn visit_map(self, mut map: A) -> std::result::Result + where + A: serde::de::MapAccess<'de>, + { + let name = map + .next_key::()? + .ok_or_else(|| serde::de::Error::missing_field("name"))?; + + match REGISTRY.get(&name) { + None => Err(serde::de::Error::unknown_field(&name, &[])), + Some(udf) => { + let inputs = map.next_value::>()?; + Ok(ScalarFunction { + udf: udf.clone(), + inputs, + }) + } + } + } + } + deserializer.deserialize_struct( + "ScalarFunction", + &["name", "inputs"], + ScalarFunctionVisitor, + ) + } +} + +impl Display for ScalarFunction { + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { + write!(f, "{}(", self.name())?; + for (i, input) in self.inputs.iter().enumerate() { + if i != 0 { + write!(f, ", ")?; + } + write!(f, "{input}")?; + } + write!(f, ")")?; + Ok(()) + } +} diff --git a/src/daft-dsl/src/optimization.rs b/src/daft-dsl/src/optimization.rs index cbe76c167e..1fa295a1dd 100644 --- a/src/daft-dsl/src/optimization.rs +++ b/src/daft-dsl/src/optimization.rs @@ -27,6 +27,7 @@ pub fn requires_computation(e: &Expr) -> bool { | Expr::BinaryOp { .. } | Expr::Cast(..) | Expr::Function { .. } + | Expr::ScalarFunction { .. } | Expr::Not(..) | Expr::IsNull(..) | Expr::NotNull(..) diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index 2e82d2c436..716cc79602 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -908,13 +908,13 @@ impl PyExpr { } let cast_seed = seed as u32; use crate::functions::minhash::minhash; - Ok(minhash( + let expr = minhash( self.into(), num_hashes as usize, ngram_size as usize, cast_seed, - ) - .into()) + ); + Ok(expr.into()) } } diff --git a/src/daft-plan/src/logical_ops/project.rs b/src/daft-plan/src/logical_ops/project.rs index 07881ccf1c..1a5cca0712 100644 --- a/src/daft-plan/src/logical_ops/project.rs +++ b/src/daft-plan/src/logical_ops/project.rs @@ -352,6 +352,22 @@ fn replace_column_with_semantic_id( ) } } + Expr::ScalarFunction(func) => { + let mut func = func.clone(); + let transforms = func + .inputs + .iter() + .map(|e| { + replace_column_with_semantic_id(e.clone(), subexprs_to_replace, schema) + }) + .collect::>(); + if transforms.iter().all(|e| e.is_no()) { + Transformed::No(e) + } else { + func.inputs = transforms.iter().map(|t| t.unwrap()).cloned().collect(); + Transformed::Yes(Expr::ScalarFunction(func).into()) + } + } } } } diff --git a/src/daft-plan/src/partitioning.rs b/src/daft-plan/src/partitioning.rs index 41b9f458f6..21f59a98ad 100644 --- a/src/daft-plan/src/partitioning.rs +++ b/src/daft-plan/src/partitioning.rs @@ -256,6 +256,16 @@ fn translate_clustering_spec_expr( } .into()) } + Expr::ScalarFunction(func) => { + let mut func = func.clone(); + let new_inputs = func + .inputs + .iter() + .map(|e| translate_clustering_spec_expr(e, old_colname_to_new_colname)) + .collect::, _>>()?; + func.inputs = new_inputs; + Ok(Expr::ScalarFunction(func).into()) + } Expr::Not(child) => { let newchild = translate_clustering_spec_expr(child, old_colname_to_new_colname)?; Ok(newchild.not()) diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index 0c5dfb9ae6..531fc28090 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -400,6 +400,14 @@ impl Table { .collect::>>()?; func.evaluate(evaluated_inputs.as_slice(), func) } + ScalarFunction(func) => { + let evaluated_inputs = func + .inputs + .iter() + .map(|e| self.eval_expression(e)) + .collect::>>()?; + func.udf.evaluate(evaluated_inputs.as_slice()) + } Literal(lit_value) => Ok(lit_value.to_series()), IfElse { if_true,