diff --git a/src/daft-dsl/src/lit.rs b/src/daft-dsl/src/lit.rs index 9cf99cd79e..47b4090be5 100644 --- a/src/daft-dsl/src/lit.rs +++ b/src/daft-dsl/src/lit.rs @@ -442,13 +442,18 @@ pub fn null_lit() -> ExprRef { pub fn literals_to_series(values: &[LiteralValue]) -> DaftResult { use daft_core::{datatypes::*, series::IntoSeries}; - let dtype = values[0].get_type(); + let first_non_null = values.iter().find(|v| !matches!(v, LiteralValue::Null)); + let Some(first_non_null) = first_non_null else { + return Ok(Series::full_null("literal", &DataType::Null, values.len())); + }; + + let dtype = first_non_null.get_type(); - // make sure all dtypes are the same - if !values - .windows(2) - .all(|w| w[0].get_type() == w[1].get_type()) - { + // make sure all dtypes are the same, or null + if !values.windows(2).all(|w| { + w[0].get_type() == w[1].get_type() + || matches!(w, [LiteralValue::Null, _] | [_, LiteralValue::Null]) + }) { return Err(DaftError::ValueError(format!( "All literals must have the same data type. Found: {:?}", values.iter().map(|lit| lit.get_type()).collect::>() @@ -458,7 +463,8 @@ pub fn literals_to_series(values: &[LiteralValue]) -> DaftResult { macro_rules! unwrap_unchecked { ($expr:expr, $variant:ident) => { match $expr { - LiteralValue::$variant(val, ..) => *val, + LiteralValue::$variant(val, ..) => Some(*val), + LiteralValue::Null => None, _ => unreachable!("datatype is already checked"), } }; @@ -466,7 +472,8 @@ pub fn literals_to_series(values: &[LiteralValue]) -> DaftResult { macro_rules! unwrap_unchecked_ref { ($expr:expr, $variant:ident) => { match $expr { - LiteralValue::$variant(val) => val.clone(), + LiteralValue::$variant(val) => Some(val.clone()), + LiteralValue::Null => None, _ => unreachable!("datatype is already checked"), } }; @@ -476,63 +483,64 @@ pub fn literals_to_series(values: &[LiteralValue]) -> DaftResult { DataType::Null => NullArray::full_null("literal", &dtype, values.len()).into_series(), DataType::Boolean => { let data = values.iter().map(|lit| unwrap_unchecked!(lit, Boolean)); - BooleanArray::from_values("literal", data).into_series() + BooleanArray::from_iter("literal", data).into_series() } DataType::Utf8 => { let data = values.iter().map(|lit| unwrap_unchecked_ref!(lit, Utf8)); - Utf8Array::from_values("literal", data).into_series() + Utf8Array::from_iter("literal", data).into_series() } DataType::Binary => { let data = values.iter().map(|lit| unwrap_unchecked_ref!(lit, Binary)); - BinaryArray::from_values("literal", data).into_series() + BinaryArray::from_iter("literal", data).into_series() } DataType::Int32 => { let data = values.iter().map(|lit| unwrap_unchecked!(lit, Int32)); - Int32Array::from_values("literal", data).into_series() + Int32Array::from_iter("literal", data).into_series() } DataType::UInt32 => { let data = values.iter().map(|lit| unwrap_unchecked!(lit, UInt32)); - UInt32Array::from_values("literal", data).into_series() + UInt32Array::from_iter("literal", data).into_series() } DataType::Int64 => { let data = values.iter().map(|lit| unwrap_unchecked!(lit, Int64)); - Int64Array::from_values("literal", data).into_series() + Int64Array::from_iter("literal", data).into_series() } DataType::UInt64 => { let data = values.iter().map(|lit| unwrap_unchecked!(lit, UInt64)); - UInt64Array::from_values("literal", data).into_series() + UInt64Array::from_iter("literal", data).into_series() } DataType::Interval => { let data = values.iter().map(|lit| match lit { - LiteralValue::Interval(iv) => (iv.months, iv.days, iv.nanoseconds), + LiteralValue::Interval(iv) => Some((iv.months, iv.days, iv.nanoseconds)), + LiteralValue::Null => None, _ => unreachable!("datatype is already checked"), }); - IntervalArray::from_values("literal", data).into_series() + IntervalArray::from_iter("literal", data).into_series() } dtype @ DataType::Timestamp(_, _) => { let data = values.iter().map(|lit| unwrap_unchecked!(lit, Timestamp)); - let physical = Int64Array::from_values("literal", data); + let physical = Int64Array::from_iter("literal", data); TimestampArray::new(Field::new("literal", dtype), physical).into_series() } dtype @ DataType::Date => { let data = values.iter().map(|lit| unwrap_unchecked!(lit, Date)); - let physical = Int32Array::from_values("literal", data); + let physical = Int32Array::from_iter("literal", data); DateArray::new(Field::new("literal", dtype), physical).into_series() } dtype @ DataType::Time(_) => { let data = values.iter().map(|lit| unwrap_unchecked!(lit, Time)); - let physical = Int64Array::from_values("literal", data); + let physical = Int64Array::from_iter("literal", data); TimeArray::new(Field::new("literal", dtype), physical).into_series() } DataType::Float64 => { let data = values.iter().map(|lit| unwrap_unchecked!(lit, Float64)); - Float64Array::from_values("literal", data).into_series() + Float64Array::from_iter("literal", data).into_series() } dtype @ DataType::Decimal128 { .. } => { let data = values.iter().map(|lit| unwrap_unchecked!(lit, Decimal)); - let physical = Int128Array::from_values("literal", data); + let physical = Int128Array::from_iter("literal", data); Decimal128Array::new(Field::new("literal", dtype), physical).into_series() } _ => { @@ -571,8 +579,14 @@ mod test { LiteralValue::UInt64(2), LiteralValue::UInt64(3), ]; - let actual = super::literals_to_series(&values); - assert!(actual.is_err()); + let expected = vec![None, Some(2), Some(3)]; + let expected = UInt64Array::from_iter("literal", expected.into_iter()); + let expected = expected.into_series(); + let actual = super::literals_to_series(&values).unwrap(); + // Series.eq returns false for nulls + for (expected, actual) in expected.u64().iter().zip(actual.u64().iter()) { + assert_eq!(expected, actual); + } } #[test] diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index 55823e5843..d14bc258da 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -650,6 +650,14 @@ impl SQLPlanner { } }) } + + fn plan_lit(&self, expr: &sqlparser::ast::Expr) -> SQLPlannerResult { + if let sqlparser::ast::Expr::Value(v) = expr { + self.value_to_lit(v) + } else { + invalid_operation_err!("Only string, number, boolean and null literals are supported"); + } + } pub(crate) fn plan_expr(&self, expr: &sqlparser::ast::Expr) -> SQLPlannerResult { use sqlparser::ast::Expr as SQLExpr; match expr { @@ -695,7 +703,28 @@ impl SQLPlanner { SQLExpr::IsNotDistinctFrom(_, _) => { unsupported_sql_err!("IS NOT DISTINCT FROM") } - SQLExpr::InList { .. } => unsupported_sql_err!("IN LIST"), + SQLExpr::InList { + expr, + list, + negated, + } => { + let expr = self.plan_expr(expr)?; + let list = list + .iter() + .map(|e| self.plan_lit(e)) + .collect::>>()?; + // We should really have a better way to use `is_in` instead of all of this extra wrapping of the values + let series = literals_to_series(&list)?; + let series_lit = LiteralValue::Series(series); + let series_expr = Expr::Literal(series_lit); + let series_expr_arc = Arc::new(series_expr); + let expr = expr.is_in(series_expr_arc); + if *negated { + Ok(expr.not()) + } else { + Ok(expr) + } + } SQLExpr::InSubquery { .. } => { unsupported_sql_err!("IN subquery") } diff --git a/tests/sql/test_exprs.py b/tests/sql/test_exprs.py index 9ff1fa066f..2adfb1db31 100644 --- a/tests/sql/test_exprs.py +++ b/tests/sql/test_exprs.py @@ -92,3 +92,46 @@ def test_between(): expected = df.filter(col("integers").between(1, 4)).collect().to_pydict() assert actual == expected + + +def test_is_in(): + df = daft.from_pydict({"idx": [1, 2, 3], "val": ["foo", "bar", "baz"]}) + expected = df.filter(col("val").is_in(["bar", "foo"])).collect().to_pydict() + actual = daft.sql("select * from df where val in ('bar','foo')").collect().to_pydict() + assert actual == expected + + # test negated too + expected = df.filter(~col("val").is_in(["bar", "foo"])).collect().to_pydict() + actual = daft.sql("select * from df where val not in ('bar','foo')").collect().to_pydict() + assert actual == expected + + +def test_is_in_edge_cases(): + df = daft.from_pydict( + { + "nums": [1, 2, 3, None, 4, 5], + "strs": ["a", "b", None, "c", "d", "e"], + } + ) + + # Test with NULL values in the column + actual = daft.sql("SELECT * FROM df WHERE strs IN ('a', 'b')").collect().to_pydict() + expected = df.filter(col("strs").is_in(["a", "b"])).collect().to_pydict() + assert actual == expected + + # Test with empty IN list + with pytest.raises(Exception, match="Expected: an expression"): + daft.sql("SELECT * FROM df WHERE nums IN ()").collect() + + # Test with numbers and NULL in IN list + actual = daft.sql("SELECT * FROM df WHERE nums IN (1, NULL, 3)").collect().to_pydict() + expected = df.filter(col("nums").is_in([1, None, 3])).collect().to_pydict() + assert actual == expected + + # Test with single value + actual = daft.sql("SELECT * FROM df WHERE nums IN (1)").collect().to_pydict() + expected = df.filter(col("nums").is_in([1])).collect().to_pydict() + + # Test with mixed types in the IN list + with pytest.raises(Exception, match="All literals must have the same data type"): + daft.sql("SELECT * FROM df WHERE nums IN (1, '2', 3.0)").collect().to_pydict()