diff --git a/src/daft-core/src/datatypes/infer_datatype.rs b/src/daft-core/src/datatypes/infer_datatype.rs index c9a2b3f2ac..7885c7fd83 100644 --- a/src/daft-core/src/datatypes/infer_datatype.rs +++ b/src/daft-core/src/datatypes/infer_datatype.rs @@ -128,6 +128,11 @@ impl<'a> InferDataType<'a> { Ok((DataType::Boolean, Some(d_type.clone()), d_type)) } + + (DataType::Utf8, DataType::Date) | (DataType::Date, DataType::Utf8) => { + // Date is logical, so we cast to intermediate type (date), then compare on the physical type (i32) + Ok((DataType::Boolean, Some(DataType::Date), DataType::Int32)) + } (s, o) if s.is_physical() && o.is_physical() => { Ok((DataType::Boolean, None, try_physical_supertype(s, o)?)) } diff --git a/tests/dataframe/test_temporals.py b/tests/dataframe/test_temporals.py index 1a26b8cffc..52b70b4a46 100644 --- a/tests/dataframe/test_temporals.py +++ b/tests/dataframe/test_temporals.py @@ -2,7 +2,7 @@ import itertools import tempfile -from datetime import datetime, timedelta, timezone +from datetime import date, datetime, timedelta, timezone import pyarrow as pa import pytest @@ -465,3 +465,20 @@ def test_intervals(op, expected): expected = {"datetimes": expected} assert actual == expected + + +@pytest.mark.parametrize( + "value", + [ + date(2020, 1, 1), # explicit date + "2020-01-01", # implicit coercion + ], +) +def test_date_comparison(value): + date_df = daft.from_pydict({"date_str": ["2020-01-01", "2020-01-02", "2020-01-03"]}) + date_df = date_df.with_column("date", col("date_str").str.to_date("%Y-%m-%d")) + actual = date_df.filter(col("date") == value).select("date").to_pydict() + + expected = {"date": [date(2020, 1, 1)]} + + assert actual == expected diff --git a/tests/sql/test_temporal_exprs.py b/tests/sql/test_temporal_exprs.py index d475850839..9067e6b3d1 100644 --- a/tests/sql/test_temporal_exprs.py +++ b/tests/sql/test_temporal_exprs.py @@ -88,3 +88,11 @@ def test_extract(): """).collect() assert actual.to_pydict() == expected.to_pydict() + + +def test_date_comparison(): + date_df = daft.from_pydict({"date_str": ["2020-01-01", "2020-01-02", "2020-01-03"]}) + date_df = date_df.with_column("date", daft.col("date_str").str.to_date("%Y-%m-%d")) + expected = date_df.filter(daft.col("date") == "2020-01-01").select("date").to_pydict() + actual = daft.sql("select date from date_df where date == '2020-01-01'").to_pydict() + assert actual == expected