Skip to content

Commit

Permalink
feat(connect): add more unresolved functions (#3618)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka authored Dec 19, 2024
1 parent f6002f9 commit b87e0a3
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 15 deletions.
47 changes: 32 additions & 15 deletions src/daft-connect/src/translation/expr/unresolved_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,20 @@ pub fn unresolved_to_daft_expr(f: &UnresolvedFunction) -> eyre::Result<daft_dsl:
}

match function_name.as_str() {
"count" => handle_count(arguments).wrap_err("Failed to handle count function"),
"<" => handle_binary_op(arguments, daft_dsl::Operator::Lt)
.wrap_err("Failed to handle < function"),
">" => handle_binary_op(arguments, daft_dsl::Operator::Gt)
.wrap_err("Failed to handle > function"),
"<=" => handle_binary_op(arguments, daft_dsl::Operator::LtEq)
.wrap_err("Failed to handle <= function"),
">=" => handle_binary_op(arguments, daft_dsl::Operator::GtEq)
.wrap_err("Failed to handle >= function"),
"%" => handle_binary_op(arguments, daft_dsl::Operator::Modulus)
.wrap_err("Failed to handle % function"),
"sum" => handle_sum(arguments).wrap_err("Failed to handle sum function"),
"isnotnull" => handle_isnotnull(arguments).wrap_err("Failed to handle isnotnull function"),
"isnull" => handle_isnull(arguments).wrap_err("Failed to handle isnull function"),
n => bail!("Unresolved function {n} not yet supported"),
"%" => handle_binary_op(arguments, daft_dsl::Operator::Modulus),
"<" => handle_binary_op(arguments, daft_dsl::Operator::Lt),
"<=" => handle_binary_op(arguments, daft_dsl::Operator::LtEq),
"==" => handle_binary_op(arguments, daft_dsl::Operator::Eq),
">" => handle_binary_op(arguments, daft_dsl::Operator::Gt),
">=" => handle_binary_op(arguments, daft_dsl::Operator::GtEq),
"count" => handle_count(arguments),
"isnotnull" => handle_isnotnull(arguments),
"isnull" => handle_isnull(arguments),
"not" => not(arguments),
"sum" => handle_sum(arguments),
n => bail!("Unresolved function {n:?} not yet supported"),
}
.wrap_err_with(|| format!("Failed to handle function {function_name:?}"))
}

pub fn handle_sum(arguments: Vec<daft_dsl::ExprRef>) -> eyre::Result<daft_dsl::ExprRef> {
Expand All @@ -53,6 +51,25 @@ pub fn handle_sum(arguments: Vec<daft_dsl::ExprRef>) -> eyre::Result<daft_dsl::E
Ok(arg.sum())
}

/// If the arguments are exactly one, return it. Otherwise, return an error.
pub fn to_single(arguments: Vec<daft_dsl::ExprRef>) -> eyre::Result<daft_dsl::ExprRef> {
let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() {
Ok(arguments) => arguments,
Err(arguments) => {
bail!("requires exactly one argument; got {arguments:?}");
}
};

let [arg] = arguments;

Ok(arg)
}

pub fn not(arguments: Vec<daft_dsl::ExprRef>) -> eyre::Result<daft_dsl::ExprRef> {
let arg = to_single(arguments)?;
Ok(arg.not())
}

pub fn handle_binary_op(
arguments: Vec<daft_dsl::ExprRef>,
op: daft_dsl::Operator,
Expand Down
48 changes: 48 additions & 0 deletions tests/connect/test_unresolved.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pytest
from pyspark.sql import functions as F


def test_numeric_equals(spark_session):
"""Test numeric equality comparison with NULL handling."""
data = [(1, 10), (2, None)]
df = spark_session.createDataFrame(data, ["id", "value"])

result = df.withColumn("equals_20", F.col("value") == F.lit(20)).collect()

assert result[0].equals_20 is False # 10 == 20
assert result[1].equals_20 is None # NULL == 20


def test_string_equals(spark_session):
"""Test string equality comparison with NULL handling."""
data = [(1, "apple"), (2, None)]
df = spark_session.createDataFrame(data, ["id", "text"])

result = df.withColumn("equals_banana", F.col("text") == F.lit("banana")).collect()

assert result[0].equals_banana is False # apple == banana
assert result[1].equals_banana is None # NULL == banana


@pytest.mark.skip(reason="We believe null-safe equals are not yet implemented")
def test_null_safe_equals(spark_session):
"""Test null-safe equality comparison."""
data = [(1, 10), (2, None)]
df = spark_session.createDataFrame(data, ["id", "value"])

result = df.withColumn("null_safe_equals", F.col("value").eqNullSafe(F.lit(10))).collect()

assert result[0].null_safe_equals is True # 10 <=> 10
assert result[1].null_safe_equals is False # NULL <=> 10


def test_not(spark_session):
"""Test logical NOT operation with NULL handling."""
data = [(True,), (False,), (None,)]
df = spark_session.createDataFrame(data, ["value"])

result = df.withColumn("not_value", ~F.col("value")).collect()

assert result[0].not_value is False # NOT True
assert result[1].not_value is True # NOT False
assert result[2].not_value is None # NOT NULL

0 comments on commit b87e0a3

Please sign in to comment.