diff --git a/src/daft-core/src/datatypes/binary_ops.rs b/src/daft-core/src/datatypes/binary_ops.rs index b0e0834a5a..3353bde863 100644 --- a/src/daft-core/src/datatypes/binary_ops.rs +++ b/src/daft-core/src/datatypes/binary_ops.rs @@ -40,12 +40,6 @@ impl DataType { (s, o) if s.is_physical() && o.is_physical() => { Ok((Boolean, None, try_physical_supertype(s, o)?)) } - // To maintain existing behaviour. TODO: cleanup - (Date, o) | (o, Date) if o.is_physical() && o.clone() != Boolean => Ok(( - Boolean, - None, - try_physical_supertype(&Date.to_physical(), o)?, - )), (Timestamp(..) | Date, Timestamp(..) | Date) => { let intermediate_type = try_get_supertype(self, other)?; let pt = intermediate_type.to_physical(); diff --git a/tests/expressions/typing/conftest.py b/tests/expressions/typing/conftest.py index 268d115925..7347182c32 100644 --- a/tests/expressions/typing/conftest.py +++ b/tests/expressions/typing/conftest.py @@ -4,6 +4,8 @@ import itertools import sys +import pytz + if sys.version_info < (3, 8): pass else: @@ -33,12 +35,6 @@ (DataType.bool(), pa.array([True, False, None], type=pa.bool_())), (DataType.null(), pa.array([None, None, None], type=pa.null())), (DataType.binary(), pa.array([b"1", b"2", None], type=pa.binary())), - (DataType.date(), pa.array([datetime.date(2021, 1, 1), datetime.date(2021, 1, 2), None], type=pa.date32())), - # TODO(jay): Some of the fixtures are broken/become very complicated when testing against timestamps - # ( - # DataType.timestamp(TimeUnit.ms()), - # pa.array([datetime.datetime(2021, 1, 1), datetime.datetime(2021, 1, 2), None], type=pa.timestamp("ms")), - # ), ] ALL_DATATYPES_BINARY_PAIRS = list(itertools.product(ALL_DTYPES, repeat=2)) @@ -59,10 +55,75 @@ def binary_data_fixture(request) -> tuple[Series, Series]: return (s1, s2) +ALL_TEMPORAL_DTYPES = [ + (DataType.date(), pa.array([datetime.date(2021, 1, 1), datetime.date(2021, 1, 2), None], type=pa.date32())), + *[ + ( + DataType.timestamp(unit), + pa.array([datetime.datetime(2021, 1, 1), datetime.datetime(2021, 1, 2), None], type=pa.timestamp(unit)), + ) + for unit in ["ns", "us", "ms"] + ], + *[ + ( + DataType.timestamp(unit, "US/Eastern"), + pa.array( + [ + datetime.datetime(2021, 1, 1).astimezone(pytz.timezone("US/Eastern")), + datetime.datetime(2021, 1, 2).astimezone(pytz.timezone("US/Eastern")), + None, + ], + type=pa.timestamp(unit, "US/Eastern"), + ), + ) + for unit in ["ns", "us", "ms"] + ], + *[ + ( + DataType.timestamp(unit, "Africa/Accra"), + pa.array( + [ + datetime.datetime(2021, 1, 1).astimezone(pytz.timezone("Africa/Accra")), + datetime.datetime(2021, 1, 2).astimezone(pytz.timezone("Africa/Accra")), + None, + ], + type=pa.timestamp(unit, "Africa/Accra"), + ), + ) + for unit in ["ns", "us", "ms"] + ], +] + +ALL_TEMPORAL_DATATYPES_BINARY_PAIRS = [ + ((dt1, data1), (dt2, data2)) + for (dt1, data1), (dt2, data2) in itertools.product(ALL_TEMPORAL_DTYPES, repeat=2) + if not ( + pa.types.is_timestamp(data1.type) + and pa.types.is_timestamp(data2.type) + and (data1.type.tz is None) ^ (data2.type.tz is None) + ) +] + + +@pytest.fixture( + scope="module", + params=ALL_TEMPORAL_DATATYPES_BINARY_PAIRS, + ids=[f"{dt1}-{dt2}" for (dt1, _), (dt2, _) in ALL_TEMPORAL_DATATYPES_BINARY_PAIRS], +) +def binary_temporal_data_fixture(request) -> tuple[Series, Series]: + """Returns binary permutation of Series' of all DataType pairs""" + (dt1, data1), (dt2, data2) = request.param + s1 = Series.from_arrow(data1, name="lhs") + assert s1.datatype() == dt1 + s2 = Series.from_arrow(data2, name="rhs") + assert s2.datatype() == dt2 + return (s1, s2) + + @pytest.fixture( scope="module", - params=ALL_DTYPES, - ids=[f"{dt}" for (dt, _) in ALL_DTYPES], + params=ALL_DTYPES + ALL_TEMPORAL_DTYPES, + ids=[f"{dt}" for (dt, _) in ALL_DTYPES + ALL_TEMPORAL_DTYPES], ) def unary_data_fixture(request) -> Series: """Returns unary permutation of Series' of all DataType pairs""" diff --git a/tests/expressions/typing/test_compare.py b/tests/expressions/typing/test_compare.py index 9ab988e40c..8c174cd0bb 100644 --- a/tests/expressions/typing/test_compare.py +++ b/tests/expressions/typing/test_compare.py @@ -26,3 +26,14 @@ def test_comparable(binary_data_fixture, op): run_kernel=lambda: op(lhs, rhs), resolvable=comparable_type_validation(lhs.datatype(), rhs.datatype()), ) + + +@pytest.mark.parametrize("op", [ops.eq, ops.ne, ops.lt, ops.le, ops.gt, ops.ge]) +def test_temporal_comparable(binary_temporal_data_fixture, op): + lhs, rhs = binary_temporal_data_fixture + assert_typing_resolve_vs_runtime_behavior( + data=binary_temporal_data_fixture, + expr=op(col(lhs.name()), col(rhs.name())), + run_kernel=lambda: op(lhs, rhs), + resolvable=comparable_type_validation(lhs.datatype(), rhs.datatype()), + )