Skip to content

Commit

Permalink
[FEAT]: allow for implicit coercion between str & date (#3337)
Browse files Browse the repository at this point in the history
  • Loading branch information
universalmind303 authored Nov 20, 2024
1 parent 066cde1 commit b89ee3d
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 1 deletion.
5 changes: 5 additions & 0 deletions src/daft-core/src/datatypes/infer_datatype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ impl<'a> InferDataType<'a> {

Ok((DataType::Boolean, Some(d_type.clone()), d_type))
}

(DataType::Utf8, DataType::Date) | (DataType::Date, DataType::Utf8) => {
// Date is logical, so we cast to intermediate type (date), then compare on the physical type (i32)
Ok((DataType::Boolean, Some(DataType::Date), DataType::Int32))
}
(s, o) if s.is_physical() && o.is_physical() => {
Ok((DataType::Boolean, None, try_physical_supertype(s, o)?))
}
Expand Down
19 changes: 18 additions & 1 deletion tests/dataframe/test_temporals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import itertools
import tempfile
from datetime import datetime, timedelta, timezone
from datetime import date, datetime, timedelta, timezone

import pyarrow as pa
import pytest
Expand Down Expand Up @@ -465,3 +465,20 @@ def test_intervals(op, expected):
expected = {"datetimes": expected}

assert actual == expected


@pytest.mark.parametrize(
"value",
[
date(2020, 1, 1), # explicit date
"2020-01-01", # implicit coercion
],
)
def test_date_comparison(value):
date_df = daft.from_pydict({"date_str": ["2020-01-01", "2020-01-02", "2020-01-03"]})
date_df = date_df.with_column("date", col("date_str").str.to_date("%Y-%m-%d"))
actual = date_df.filter(col("date") == value).select("date").to_pydict()

expected = {"date": [date(2020, 1, 1)]}

assert actual == expected
8 changes: 8 additions & 0 deletions tests/sql/test_temporal_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,11 @@ def test_extract():
""").collect()

assert actual.to_pydict() == expected.to_pydict()


def test_date_comparison():
date_df = daft.from_pydict({"date_str": ["2020-01-01", "2020-01-02", "2020-01-03"]})
date_df = date_df.with_column("date", daft.col("date_str").str.to_date("%Y-%m-%d"))
expected = date_df.filter(daft.col("date") == "2020-01-01").select("date").to_pydict()
actual = daft.sql("select date from date_df where date == '2020-01-01'").to_pydict()
assert actual == expected

0 comments on commit b89ee3d

Please sign in to comment.