diff --git a/daft/datatype.py b/daft/datatype.py index 2285a96f55..d4e1456399 100644 --- a/daft/datatype.py +++ b/daft/datatype.py @@ -194,13 +194,17 @@ def date(cls) -> DataType: return cls._from_pydatatype(PyDataType.date()) @classmethod - def timestamp(cls, timeunit: TimeUnit, timezone: str | None = None) -> DataType: + def timestamp(cls, timeunit: TimeUnit | str, timezone: str | None = None) -> DataType: """Timestamp DataType.""" + if isinstance(timeunit, str): + timeunit = TimeUnit.from_str(timeunit) return cls._from_pydatatype(PyDataType.timestamp(timeunit._timeunit, timezone)) @classmethod - def duration(cls, timeunit: TimeUnit) -> DataType: + def duration(cls, timeunit: TimeUnit | str) -> DataType: """Duration DataType.""" + if isinstance(timeunit, str): + timeunit = TimeUnit.from_str(timeunit) return cls._from_pydatatype(PyDataType.duration(timeunit._timeunit)) @classmethod diff --git a/src/daft-core/src/datatypes/binary_ops.rs b/src/daft-core/src/datatypes/binary_ops.rs index f9cc967bf7..3353bde863 100644 --- a/src/daft-core/src/datatypes/binary_ops.rs +++ b/src/daft-core/src/datatypes/binary_ops.rs @@ -2,7 +2,7 @@ use std::ops::{Add, Div, Mul, Rem, Sub}; use common_error::{DaftError, DaftResult}; -use crate::impl_binary_trait_by_reference; +use crate::{impl_binary_trait_by_reference, utils::supertype::try_get_supertype}; use super::DataType; @@ -24,27 +24,37 @@ impl DataType { )) }) } - pub fn comparison_op(&self, other: &Self) -> DaftResult<(DataType, DataType)> { + pub fn comparison_op( + &self, + other: &Self, + ) -> DaftResult<(DataType, Option, DataType)> { // Whether a comparison op is supported between the two types. // Returns: // - the output type, + // - an optional intermediate type // - the type at which the comparison should be performed. - use DataType::*; - match (self, other) { - (s, o) if s == o => Ok(s.to_physical()), - (s, o) if s.is_physical() && o.is_physical() => { - try_physical_supertype(s, o).map_err(|_| ()) - } - // To maintain existing behaviour. TODO: cleanup - (Date, o) | (o, Date) if o.is_physical() && o.clone() != Boolean => { - try_physical_supertype(&Date.to_physical(), o).map_err(|_| ()) + let evaluator = || { + use DataType::*; + match (self, other) { + (s, o) if s == o => Ok((Boolean, None, s.to_physical())), + (s, o) if s.is_physical() && o.is_physical() => { + Ok((Boolean, None, try_physical_supertype(s, o)?)) + } + (Timestamp(..) | Date, Timestamp(..) | Date) => { + let intermediate_type = try_get_supertype(self, other)?; + let pt = intermediate_type.to_physical(); + Ok((Boolean, Some(intermediate_type), pt)) + } + _ => Err(DaftError::TypeError(format!( + "Cannot perform comparison on types: {}, {}", + self, other + ))), } - _ => Err(()), - } - .map(|comp_type| (Boolean, comp_type)) - .map_err(|()| { + }; + + evaluator().map_err(|err| { DaftError::TypeError(format!( - "Cannot perform comparison on types: {}, {}", + "Cannot perform comparison on types: {}, {}\nDetails:\n{err}", self, other )) }) diff --git a/src/daft-core/src/series/array_impl/binary_ops.rs b/src/daft-core/src/series/array_impl/binary_ops.rs index ee398fc12e..b2a652d529 100644 --- a/src/daft-core/src/series/array_impl/binary_ops.rs +++ b/src/daft-core/src/series/array_impl/binary_ops.rs @@ -114,17 +114,24 @@ macro_rules! physical_logic_op { macro_rules! physical_compare_op { ($self:expr, $rhs:expr, $op:ident, $pyop:expr) => {{ - let (output_type, comp_type) = ($self.data_type().comparison_op($rhs.data_type()))?; + let (output_type, intermediate, comp_type) = + ($self.data_type().comparison_op($rhs.data_type()))?; let lhs = $self.into_series(); + let (lhs, rhs) = if let Some(ref it) = intermediate { + (lhs.cast(it)?, $rhs.cast(it)?) + } else { + (lhs, $rhs.clone()) + }; + use DataType::*; if let Boolean = output_type { match comp_type { #[cfg(feature = "python")] - Python => py_binary_op_bool!(lhs, $rhs, $pyop) + Python => py_binary_op_bool!(lhs, rhs, $pyop) .downcast::() .cloned(), _ => with_match_comparable_daft_types!(comp_type, |$T| { - cast_downcast_op!(lhs, $rhs, &comp_type, <$T as DaftDataType>::ArrayType, $op) + cast_downcast_op!(lhs, rhs, &comp_type, <$T as DaftDataType>::ArrayType, $op) }), } } else { diff --git a/src/daft-core/src/utils/supertype.rs b/src/daft-core/src/utils/supertype.rs index 6ee1db850a..4ffec69b2c 100644 --- a/src/daft-core/src/utils/supertype.rs +++ b/src/daft-core/src/utils/supertype.rs @@ -168,20 +168,20 @@ pub fn get_supertype(l: &DataType, r: &DataType) -> Option { (Duration(_), Date) | (Date, Duration(_)) => Some(Date), (Duration(lu), Duration(ru)) => Some(Duration(get_time_units(lu, ru))), - // None and Some("") timezones + // Some() timezones that are non equal // we cast from more precision to higher precision as that always fits with occasional loss of precision - (Timestamp(tu_l, tz_l), Timestamp(tu_r, tz_r)) - if (tz_l.is_none() || tz_l.as_deref() == Some("")) - && (tz_r.is_none() || tz_r.as_deref() == Some("")) => + (Timestamp(tu_l, Some(tz_l)), Timestamp(tu_r, Some(tz_r))) + if !tz_l.is_empty() + && !tz_r.is_empty() && tz_l != tz_r => { let tu = get_time_units(tu_l, tu_r); - Some(Timestamp(tu, None)) + Some(Timestamp(tu, Some("UTC".to_string()))) } // None and Some("") timezones // we cast from more precision to higher precision as that always fits with occasional loss of precision (Timestamp(tu_l, tz_l), Timestamp(tu_r, tz_r)) if // both are none - tz_l.is_none() && tz_r.is_some() + tz_l.is_none() && tz_r.is_none() // both have the same time zone || (tz_l.is_some() && (tz_l == tz_r)) => { let tu = get_time_units(tu_l, tu_r); diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr.rs index 0c38e86806..9f4036a495 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr.rs @@ -436,7 +436,7 @@ impl Expr { | Operator::NotEq | Operator::LtEq | Operator::GtEq => { - let (result_type, _comp_type) = + let (result_type, _intermediate, _comp_type) = left_field.dtype.comparison_op(&right_field.dtype)?; Ok(Field::new(left_field.name.as_str(), result_type)) } diff --git a/tests/expressions/typing/conftest.py b/tests/expressions/typing/conftest.py index 268d115925..e48923b095 100644 --- a/tests/expressions/typing/conftest.py +++ b/tests/expressions/typing/conftest.py @@ -4,6 +4,8 @@ import itertools import sys +import pytz + if sys.version_info < (3, 8): pass else: @@ -33,17 +35,65 @@ (DataType.bool(), pa.array([True, False, None], type=pa.bool_())), (DataType.null(), pa.array([None, None, None], type=pa.null())), (DataType.binary(), pa.array([b"1", b"2", None], type=pa.binary())), - (DataType.date(), pa.array([datetime.date(2021, 1, 1), datetime.date(2021, 1, 2), None], type=pa.date32())), - # TODO(jay): Some of the fixtures are broken/become very complicated when testing against timestamps - # ( - # DataType.timestamp(TimeUnit.ms()), - # pa.array([datetime.datetime(2021, 1, 1), datetime.datetime(2021, 1, 2), None], type=pa.timestamp("ms")), - # ), ] ALL_DATATYPES_BINARY_PAIRS = list(itertools.product(ALL_DTYPES, repeat=2)) +ALL_TEMPORAL_DTYPES = [ + (DataType.date(), pa.array([datetime.date(2021, 1, 1), datetime.date(2021, 1, 2), None], type=pa.date32())), + *[ + ( + DataType.timestamp(unit), + pa.array([datetime.datetime(2021, 1, 1), datetime.datetime(2021, 1, 2), None], type=pa.timestamp(unit)), + ) + for unit in ["ns", "us", "ms"] + ], + *[ + ( + DataType.timestamp(unit, "US/Eastern"), + pa.array( + [ + datetime.datetime(2021, 1, 1).astimezone(pytz.timezone("US/Eastern")), + datetime.datetime(2021, 1, 2).astimezone(pytz.timezone("US/Eastern")), + None, + ], + type=pa.timestamp(unit, "US/Eastern"), + ), + ) + for unit in ["ns", "us", "ms"] + ], + *[ + ( + DataType.timestamp(unit, "Africa/Accra"), + pa.array( + [ + datetime.datetime(2021, 1, 1).astimezone(pytz.timezone("Africa/Accra")), + datetime.datetime(2021, 1, 2).astimezone(pytz.timezone("Africa/Accra")), + None, + ], + type=pa.timestamp(unit, "Africa/Accra"), + ), + ) + for unit in ["ns", "us", "ms"] + ], +] + +ALL_DTYPES += ALL_TEMPORAL_DTYPES + +ALL_TEMPORAL_DATATYPES_BINARY_PAIRS = [ + ((dt1, data1), (dt2, data2)) + for (dt1, data1), (dt2, data2) in itertools.product(ALL_TEMPORAL_DTYPES, repeat=2) + if not ( + pa.types.is_timestamp(data1.type) + and pa.types.is_timestamp(data2.type) + and (data1.type.tz is None) ^ (data2.type.tz is None) + ) +] + +ALL_DATATYPES_BINARY_PAIRS += ALL_TEMPORAL_DATATYPES_BINARY_PAIRS + + @pytest.fixture( scope="module", params=ALL_DATATYPES_BINARY_PAIRS, diff --git a/tests/series/test_comparisons.py b/tests/series/test_comparisons.py index 561f8a6eb7..eca23a53cd 100644 --- a/tests/series/test_comparisons.py +++ b/tests/series/test_comparisons.py @@ -2,9 +2,11 @@ import itertools import operator +from datetime import date, datetime import pyarrow as pa import pytest +import pytz from daft import DataType, Series @@ -682,3 +684,51 @@ def test_logicalops_pyobjects(op, expected, expected_self) -> None: assert op(custom_falses, values).datatype() == DataType.bool() assert op(custom_falses, values).to_pylist() == expected assert op(custom_falses, custom_falses).to_pylist() == expected_self + + +@pytest.mark.parametrize("tu1, tu2", itertools.product(["ns", "us", "ms"], repeat=2)) +def test_compare_timestamps_no_tz(tu1, tu2): + tz1 = Series.from_pylist([datetime(2022, 1, 1)]) + assert (tz1.cast(DataType.timestamp(tu1)) == tz1.cast(DataType.timestamp(tu2))).to_pylist() == [True] + + +def test_compare_timestamps_no_tz_date(): + tz1 = Series.from_pylist([datetime(2022, 1, 1)]) + Series.from_pylist([date(2022, 1, 1)]) + assert (tz1 == tz1).to_pylist() == [True] + + +def test_compare_timestamps_one_tz(): + tz1 = Series.from_pylist([datetime(2022, 1, 1)]) + tz2 = Series.from_pylist([datetime(2022, 1, 1, tzinfo=pytz.utc)]) + with pytest.raises(ValueError, match="Cannot perform comparison on types"): + assert (tz1 == tz2).to_pylist() == [True] + + +def test_compare_timestamps_and_int(): + tz1 = Series.from_pylist([datetime(2022, 1, 1)]) + tz2 = Series.from_pylist([5]) + with pytest.raises(ValueError, match="Cannot perform comparison on types"): + assert (tz1 == tz2).to_pylist() == [True] + + +def test_compare_timestamps_tz_date(): + tz1 = Series.from_pylist([datetime(2022, 1, 1, tzinfo=pytz.utc)]) + Series.from_pylist([date(2022, 1, 1)]) + assert (tz1 == tz1).to_pylist() == [True] + + +@pytest.mark.parametrize("tu1, tu2", itertools.product(["ns", "us", "ms"], repeat=2)) +def test_compare_timestamps_same_tz(tu1, tu2): + tz1 = Series.from_pylist([datetime(2022, 1, 1, tzinfo=pytz.utc)]).cast(DataType.timestamp(tu1, "UTC")) + tz2 = Series.from_pylist([datetime(2022, 1, 1, tzinfo=pytz.utc)]).cast(DataType.timestamp(tu2, "UTC")) + assert (tz1 == tz2).to_pylist() == [True] + + +@pytest.mark.parametrize("tu1, tu2", itertools.product(["ns", "us", "ms"], repeat=2)) +def test_compare_timestamps_diff_tz(tu1, tu2): + utc = datetime(2022, 1, 1, tzinfo=pytz.utc) + eastern = utc.astimezone(pytz.timezone("US/Eastern")) + tz1 = Series.from_pylist([utc]).cast(DataType.timestamp(tu1, "UTC")) + tz2 = Series.from_pylist([eastern]).cast(DataType.timestamp(tu1, "US/Eastern")) + assert (tz1 == tz2).to_pylist() == [True]