Skip to content

Commit

Permalink
[FEAT]: sql IN operator (#3086)
Browse files Browse the repository at this point in the history
closes #3085
  • Loading branch information
universalmind303 authored Oct 23, 2024
1 parent c69ee3f commit 4ec76ce
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 25 deletions.
62 changes: 38 additions & 24 deletions src/daft-dsl/src/lit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,13 +442,18 @@ pub fn null_lit() -> ExprRef {
pub fn literals_to_series(values: &[LiteralValue]) -> DaftResult<Series> {
use daft_core::{datatypes::*, series::IntoSeries};

let dtype = values[0].get_type();
let first_non_null = values.iter().find(|v| !matches!(v, LiteralValue::Null));
let Some(first_non_null) = first_non_null else {
return Ok(Series::full_null("literal", &DataType::Null, values.len()));
};

let dtype = first_non_null.get_type();

// make sure all dtypes are the same
if !values
.windows(2)
.all(|w| w[0].get_type() == w[1].get_type())
{
// make sure all dtypes are the same, or null
if !values.windows(2).all(|w| {
w[0].get_type() == w[1].get_type()
|| matches!(w, [LiteralValue::Null, _] | [_, LiteralValue::Null])
}) {
return Err(DaftError::ValueError(format!(
"All literals must have the same data type. Found: {:?}",
values.iter().map(|lit| lit.get_type()).collect::<Vec<_>>()
Expand All @@ -458,15 +463,17 @@ pub fn literals_to_series(values: &[LiteralValue]) -> DaftResult<Series> {
macro_rules! unwrap_unchecked {
($expr:expr, $variant:ident) => {
match $expr {
LiteralValue::$variant(val, ..) => *val,
LiteralValue::$variant(val, ..) => Some(*val),
LiteralValue::Null => None,
_ => unreachable!("datatype is already checked"),
}
};
}
macro_rules! unwrap_unchecked_ref {
($expr:expr, $variant:ident) => {
match $expr {
LiteralValue::$variant(val) => val.clone(),
LiteralValue::$variant(val) => Some(val.clone()),
LiteralValue::Null => None,
_ => unreachable!("datatype is already checked"),
}
};
Expand All @@ -476,63 +483,64 @@ pub fn literals_to_series(values: &[LiteralValue]) -> DaftResult<Series> {
DataType::Null => NullArray::full_null("literal", &dtype, values.len()).into_series(),
DataType::Boolean => {
let data = values.iter().map(|lit| unwrap_unchecked!(lit, Boolean));
BooleanArray::from_values("literal", data).into_series()
BooleanArray::from_iter("literal", data).into_series()
}
DataType::Utf8 => {
let data = values.iter().map(|lit| unwrap_unchecked_ref!(lit, Utf8));
Utf8Array::from_values("literal", data).into_series()
Utf8Array::from_iter("literal", data).into_series()
}
DataType::Binary => {
let data = values.iter().map(|lit| unwrap_unchecked_ref!(lit, Binary));
BinaryArray::from_values("literal", data).into_series()
BinaryArray::from_iter("literal", data).into_series()
}
DataType::Int32 => {
let data = values.iter().map(|lit| unwrap_unchecked!(lit, Int32));
Int32Array::from_values("literal", data).into_series()
Int32Array::from_iter("literal", data).into_series()
}
DataType::UInt32 => {
let data = values.iter().map(|lit| unwrap_unchecked!(lit, UInt32));
UInt32Array::from_values("literal", data).into_series()
UInt32Array::from_iter("literal", data).into_series()
}
DataType::Int64 => {
let data = values.iter().map(|lit| unwrap_unchecked!(lit, Int64));
Int64Array::from_values("literal", data).into_series()
Int64Array::from_iter("literal", data).into_series()
}
DataType::UInt64 => {
let data = values.iter().map(|lit| unwrap_unchecked!(lit, UInt64));
UInt64Array::from_values("literal", data).into_series()
UInt64Array::from_iter("literal", data).into_series()
}
DataType::Interval => {
let data = values.iter().map(|lit| match lit {
LiteralValue::Interval(iv) => (iv.months, iv.days, iv.nanoseconds),
LiteralValue::Interval(iv) => Some((iv.months, iv.days, iv.nanoseconds)),
LiteralValue::Null => None,
_ => unreachable!("datatype is already checked"),
});
IntervalArray::from_values("literal", data).into_series()
IntervalArray::from_iter("literal", data).into_series()
}
dtype @ DataType::Timestamp(_, _) => {
let data = values.iter().map(|lit| unwrap_unchecked!(lit, Timestamp));
let physical = Int64Array::from_values("literal", data);
let physical = Int64Array::from_iter("literal", data);
TimestampArray::new(Field::new("literal", dtype), physical).into_series()
}
dtype @ DataType::Date => {
let data = values.iter().map(|lit| unwrap_unchecked!(lit, Date));
let physical = Int32Array::from_values("literal", data);
let physical = Int32Array::from_iter("literal", data);
DateArray::new(Field::new("literal", dtype), physical).into_series()
}
dtype @ DataType::Time(_) => {
let data = values.iter().map(|lit| unwrap_unchecked!(lit, Time));
let physical = Int64Array::from_values("literal", data);
let physical = Int64Array::from_iter("literal", data);
TimeArray::new(Field::new("literal", dtype), physical).into_series()
}
DataType::Float64 => {
let data = values.iter().map(|lit| unwrap_unchecked!(lit, Float64));

Float64Array::from_values("literal", data).into_series()
Float64Array::from_iter("literal", data).into_series()
}
dtype @ DataType::Decimal128 { .. } => {
let data = values.iter().map(|lit| unwrap_unchecked!(lit, Decimal));

let physical = Int128Array::from_values("literal", data);
let physical = Int128Array::from_iter("literal", data);
Decimal128Array::new(Field::new("literal", dtype), physical).into_series()
}
_ => {
Expand Down Expand Up @@ -571,8 +579,14 @@ mod test {
LiteralValue::UInt64(2),
LiteralValue::UInt64(3),
];
let actual = super::literals_to_series(&values);
assert!(actual.is_err());
let expected = vec![None, Some(2), Some(3)];
let expected = UInt64Array::from_iter("literal", expected.into_iter());
let expected = expected.into_series();
let actual = super::literals_to_series(&values).unwrap();
// Series.eq returns false for nulls
for (expected, actual) in expected.u64().iter().zip(actual.u64().iter()) {
assert_eq!(expected, actual);
}
}

#[test]
Expand Down
31 changes: 30 additions & 1 deletion src/daft-sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,14 @@ impl SQLPlanner {
}
})
}

fn plan_lit(&self, expr: &sqlparser::ast::Expr) -> SQLPlannerResult<LiteralValue> {
if let sqlparser::ast::Expr::Value(v) = expr {
self.value_to_lit(v)
} else {
invalid_operation_err!("Only string, number, boolean and null literals are supported");
}
}
pub(crate) fn plan_expr(&self, expr: &sqlparser::ast::Expr) -> SQLPlannerResult<ExprRef> {
use sqlparser::ast::Expr as SQLExpr;
match expr {
Expand Down Expand Up @@ -695,7 +703,28 @@ impl SQLPlanner {
SQLExpr::IsNotDistinctFrom(_, _) => {
unsupported_sql_err!("IS NOT DISTINCT FROM")
}
SQLExpr::InList { .. } => unsupported_sql_err!("IN LIST"),
SQLExpr::InList {
expr,
list,
negated,
} => {
let expr = self.plan_expr(expr)?;
let list = list
.iter()
.map(|e| self.plan_lit(e))
.collect::<SQLPlannerResult<Vec<_>>>()?;
// We should really have a better way to use `is_in` instead of all of this extra wrapping of the values
let series = literals_to_series(&list)?;
let series_lit = LiteralValue::Series(series);
let series_expr = Expr::Literal(series_lit);
let series_expr_arc = Arc::new(series_expr);
let expr = expr.is_in(series_expr_arc);
if *negated {
Ok(expr.not())
} else {
Ok(expr)
}
}
SQLExpr::InSubquery { .. } => {
unsupported_sql_err!("IN subquery")
}
Expand Down
43 changes: 43 additions & 0 deletions tests/sql/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,46 @@ def test_between():

expected = df.filter(col("integers").between(1, 4)).collect().to_pydict()
assert actual == expected


def test_is_in():
df = daft.from_pydict({"idx": [1, 2, 3], "val": ["foo", "bar", "baz"]})
expected = df.filter(col("val").is_in(["bar", "foo"])).collect().to_pydict()
actual = daft.sql("select * from df where val in ('bar','foo')").collect().to_pydict()
assert actual == expected

# test negated too
expected = df.filter(~col("val").is_in(["bar", "foo"])).collect().to_pydict()
actual = daft.sql("select * from df where val not in ('bar','foo')").collect().to_pydict()
assert actual == expected


def test_is_in_edge_cases():
df = daft.from_pydict(
{
"nums": [1, 2, 3, None, 4, 5],
"strs": ["a", "b", None, "c", "d", "e"],
}
)

# Test with NULL values in the column
actual = daft.sql("SELECT * FROM df WHERE strs IN ('a', 'b')").collect().to_pydict()
expected = df.filter(col("strs").is_in(["a", "b"])).collect().to_pydict()
assert actual == expected

# Test with empty IN list
with pytest.raises(Exception, match="Expected: an expression"):
daft.sql("SELECT * FROM df WHERE nums IN ()").collect()

# Test with numbers and NULL in IN list
actual = daft.sql("SELECT * FROM df WHERE nums IN (1, NULL, 3)").collect().to_pydict()
expected = df.filter(col("nums").is_in([1, None, 3])).collect().to_pydict()
assert actual == expected

# Test with single value
actual = daft.sql("SELECT * FROM df WHERE nums IN (1)").collect().to_pydict()
expected = df.filter(col("nums").is_in([1])).collect().to_pydict()

# Test with mixed types in the IN list
with pytest.raises(Exception, match="All literals must have the same data type"):
daft.sql("SELECT * FROM df WHERE nums IN (1, '2', 3.0)").collect().to_pydict()

0 comments on commit 4ec76ce

Please sign in to comment.