diff --git a/src/daft-core/src/array/ops/cast.rs b/src/daft-core/src/array/ops/cast.rs index a9b9b8c552..482f2ab168 100644 --- a/src/daft-core/src/array/ops/cast.rs +++ b/src/daft-core/src/array/ops/cast.rs @@ -14,8 +14,8 @@ use crate::{ FixedShapeTensorArray, ImageArray, LogicalArray, LogicalArrayImpl, TensorArray, TimestampArray, }, - DaftArrowBackedType, DaftLogicalType, DataType, Field, ImageMode, TimeUnit, UInt64Array, - Utf8Array, + DaftArrowBackedType, DaftLogicalType, DataType, Field, ImageMode, Int64Array, TimeUnit, + UInt64Array, Utf8Array, }, series::{IntoSeries, Series}, with_match_daft_logical_primitive_types, @@ -255,14 +255,12 @@ where impl DateArray { pub fn cast(&self, dtype: &DataType) -> DaftResult { - // We need to handle casts that Arrow doesn't allow, but our type-system does - let date_array = self .as_arrow() .clone() .to(arrow2::datatypes::DataType::Date32); - match dtype { + DataType::Date => Ok(self.clone().into_series()), DataType::Utf8 => { // TODO: we should move this into our own strftime kernel let year_array = compute::temporal::year(&date_array)?; @@ -279,11 +277,27 @@ impl DateArray { .collect(); Ok(Utf8Array::from((self.name(), Box::new(date_str))).into_series()) } + DataType::Int32 => Ok(self.physical.clone().into_series()), DataType::Float32 => self.cast(&DataType::Int32)?.cast(&DataType::Float32), DataType::Float64 => self.cast(&DataType::Int32)?.cast(&DataType::Float64), + DataType::Timestamp(tu, _) => { + let days_to_unit: i64 = match tu { + TimeUnit::Nanoseconds => 24 * 3_600_000_000_000, + TimeUnit::Microseconds => 24 * 3_600_000_000, + TimeUnit::Milliseconds => 24 * 3_600_000, + TimeUnit::Seconds => 24 * 3_600, + }; + + let units_per_day = Int64Array::from(("units", vec![days_to_unit])).into_series(); + let unit_since_epoch = ((&self.physical.clone().into_series()) * &units_per_day)?; + unit_since_epoch.cast(dtype) + } #[cfg(feature = "python")] DataType::Python => cast_logical_to_python_array(self, dtype), - _ => arrow_cast(&self.physical, dtype), + _ => Err(DaftError::TypeError(format!( + "Cannot cast Date to {}", + dtype + ))), } } } @@ -348,9 +362,7 @@ impl TimestampArray { pub fn cast(&self, dtype: &DataType) -> DaftResult { match dtype { DataType::Timestamp(..) => arrow_logical_cast(self, dtype), - DataType::Date => Err(DaftError::TypeError( - "Cannot cast Timestamp to Date, use .date() instead.".to_string(), - )), + DataType::Date => Ok(self.date()?.into_series()), DataType::Utf8 => { let DataType::Timestamp(unit, timezone) = self.data_type() else { panic!("Wrong dtype for TimestampArray: {}", self.data_type()) diff --git a/src/daft-core/src/datatypes/binary_ops.rs b/src/daft-core/src/datatypes/binary_ops.rs index 3353bde863..7ef1818c30 100644 --- a/src/daft-core/src/datatypes/binary_ops.rs +++ b/src/daft-core/src/datatypes/binary_ops.rs @@ -40,11 +40,16 @@ impl DataType { (s, o) if s.is_physical() && o.is_physical() => { Ok((Boolean, None, try_physical_supertype(s, o)?)) } - (Timestamp(..) | Date, Timestamp(..) | Date) => { + (Timestamp(..), Timestamp(..)) => { let intermediate_type = try_get_supertype(self, other)?; let pt = intermediate_type.to_physical(); Ok((Boolean, Some(intermediate_type), pt)) } + (Timestamp(..), Date) | (Date, Timestamp(..)) => { + let intermediate_type = Date; + let pt = intermediate_type.to_physical(); + Ok((Boolean, Some(intermediate_type), pt)) + } _ => Err(DaftError::TypeError(format!( "Cannot perform comparison on types: {}, {}", self, other diff --git a/tests/series/test_cast.py b/tests/series/test_cast.py index 62b4e6ebad..e07769d4ce 100644 --- a/tests/series/test_cast.py +++ b/tests/series/test_cast.py @@ -760,3 +760,13 @@ def test_series_cast_struct_add_col() -> None: casted = series.cast(cast_to) assert casted.datatype() == cast_to assert casted.to_pylist() == [{**x, "baz": None} for x in data.to_pylist()] + + +def test_cast_date_to_timestamp(): + from datetime import date + + input = Series.from_pylist([date(2022, 1, 6)]) + casted = input.cast(DataType.timestamp("us", "UTC")) + # DO ASSERT AS TIMESTAMP + back = casted.dt.date() + assert (input == back).to_pylist() == [True] diff --git a/tests/series/test_comparisons.py b/tests/series/test_comparisons.py index eca23a53cd..811670de40 100644 --- a/tests/series/test_comparisons.py +++ b/tests/series/test_comparisons.py @@ -714,8 +714,14 @@ def test_compare_timestamps_and_int(): def test_compare_timestamps_tz_date(): tz1 = Series.from_pylist([datetime(2022, 1, 1, tzinfo=pytz.utc)]) - Series.from_pylist([date(2022, 1, 1)]) - assert (tz1 == tz1).to_pylist() == [True] + tz2 = Series.from_pylist([date(2022, 1, 1)]) + assert (tz1 == tz2).to_pylist() == [True] + + +def test_compare_lt_timestamps_tz_date(): + tz1 = Series.from_pylist([datetime(2022, 1, 1, tzinfo=pytz.utc)]) + tz2 = Series.from_pylist([date(2022, 1, 6)]) + assert (tz1 < tz2).to_pylist() == [True] @pytest.mark.parametrize("tu1, tu2", itertools.product(["ns", "us", "ms"], repeat=2))