Skip to content

Commit

Permalink
fixes bug when comparing date and timestamps
Browse files Browse the repository at this point in the history
  • Loading branch information
samster25 committed Dec 16, 2023
1 parent fa782d8 commit b69c6df
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 12 deletions.
30 changes: 21 additions & 9 deletions src/daft-core/src/array/ops/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ use crate::{
FixedShapeTensorArray, ImageArray, LogicalArray, LogicalArrayImpl, TensorArray,
TimestampArray,
},
DaftArrowBackedType, DaftLogicalType, DataType, Field, ImageMode, TimeUnit, UInt64Array,
Utf8Array,
DaftArrowBackedType, DaftLogicalType, DataType, Field, ImageMode, Int64Array, TimeUnit,
UInt64Array, Utf8Array,
},
series::{IntoSeries, Series},
with_match_daft_logical_primitive_types,
Expand Down Expand Up @@ -255,14 +255,12 @@ where

impl DateArray {
pub fn cast(&self, dtype: &DataType) -> DaftResult<Series> {
// We need to handle casts that Arrow doesn't allow, but our type-system does

let date_array = self
.as_arrow()
.clone()
.to(arrow2::datatypes::DataType::Date32);

match dtype {
DataType::Date => Ok(self.clone().into_series()),
DataType::Utf8 => {
// TODO: we should move this into our own strftime kernel
let year_array = compute::temporal::year(&date_array)?;
Expand All @@ -279,11 +277,27 @@ impl DateArray {
.collect();
Ok(Utf8Array::from((self.name(), Box::new(date_str))).into_series())
}
DataType::Int32 => Ok(self.physical.clone().into_series()),
DataType::Float32 => self.cast(&DataType::Int32)?.cast(&DataType::Float32),
DataType::Float64 => self.cast(&DataType::Int32)?.cast(&DataType::Float64),
DataType::Timestamp(tu, _) => {
let days_to_unit: i64 = match tu {
TimeUnit::Nanoseconds => 24 * 3_600_000_000_000,
TimeUnit::Microseconds => 24 * 3_600_000_000,
TimeUnit::Milliseconds => 24 * 3_600_000,
TimeUnit::Seconds => 24 * 3_600,
};

let units_per_day = Int64Array::from(("units", vec![days_to_unit])).into_series();
let unit_since_epoch = ((&self.physical.clone().into_series()) * &units_per_day)?;
unit_since_epoch.cast(dtype)
}
#[cfg(feature = "python")]
DataType::Python => cast_logical_to_python_array(self, dtype),
_ => arrow_cast(&self.physical, dtype),
_ => Err(DaftError::TypeError(format!(
"Cannot cast Date to {}",
dtype
))),
}
}
}
Expand Down Expand Up @@ -348,9 +362,7 @@ impl TimestampArray {
pub fn cast(&self, dtype: &DataType) -> DaftResult<Series> {
match dtype {
DataType::Timestamp(..) => arrow_logical_cast(self, dtype),
DataType::Date => Err(DaftError::TypeError(
"Cannot cast Timestamp to Date, use .date() instead.".to_string(),
)),
DataType::Date => Ok(self.date()?.into_series()),
DataType::Utf8 => {
let DataType::Timestamp(unit, timezone) = self.data_type() else {
panic!("Wrong dtype for TimestampArray: {}", self.data_type())
Expand Down
7 changes: 6 additions & 1 deletion src/daft-core/src/datatypes/binary_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,16 @@ impl DataType {
(s, o) if s.is_physical() && o.is_physical() => {
Ok((Boolean, None, try_physical_supertype(s, o)?))
}
(Timestamp(..) | Date, Timestamp(..) | Date) => {
(Timestamp(..), Timestamp(..)) => {
let intermediate_type = try_get_supertype(self, other)?;
let pt = intermediate_type.to_physical();
Ok((Boolean, Some(intermediate_type), pt))
}
(Timestamp(..), Date) | (Date, Timestamp(..)) => {
let intermediate_type = Date;
let pt = intermediate_type.to_physical();
Ok((Boolean, Some(intermediate_type), pt))
}
_ => Err(DaftError::TypeError(format!(
"Cannot perform comparison on types: {}, {}",
self, other
Expand Down
10 changes: 10 additions & 0 deletions tests/series/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,3 +760,13 @@ def test_series_cast_struct_add_col() -> None:
casted = series.cast(cast_to)
assert casted.datatype() == cast_to
assert casted.to_pylist() == [{**x, "baz": None} for x in data.to_pylist()]


def test_cast_date_to_timestamp():
from datetime import date

input = Series.from_pylist([date(2022, 1, 6)])
casted = input.cast(DataType.timestamp("us", "UTC"))
# DO ASSERT AS TIMESTAMP
back = casted.dt.date()
assert (input == back).to_pylist() == [True]
10 changes: 8 additions & 2 deletions tests/series/test_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,8 +714,14 @@ def test_compare_timestamps_and_int():

def test_compare_timestamps_tz_date():
tz1 = Series.from_pylist([datetime(2022, 1, 1, tzinfo=pytz.utc)])
Series.from_pylist([date(2022, 1, 1)])
assert (tz1 == tz1).to_pylist() == [True]
tz2 = Series.from_pylist([date(2022, 1, 1)])
assert (tz1 == tz2).to_pylist() == [True]


def test_compare_lt_timestamps_tz_date():
tz1 = Series.from_pylist([datetime(2022, 1, 1, tzinfo=pytz.utc)])
tz2 = Series.from_pylist([date(2022, 1, 6)])
assert (tz1 < tz2).to_pylist() == [True]


@pytest.mark.parametrize("tu1, tu2", itertools.product(["ns", "us", "ms"], repeat=2))
Expand Down

0 comments on commit b69c6df

Please sign in to comment.