diff --git a/src/daft-core/src/array/ops/date.rs b/src/daft-core/src/array/ops/date.rs index eefb053393..8bba15ee91 100644 --- a/src/daft-core/src/array/ops/date.rs +++ b/src/daft-core/src/array/ops/date.rs @@ -6,7 +6,7 @@ use crate::{ DataType, }; use arrow2::compute::arithmetics::ArraySub; -use chrono::{NaiveDate, Timelike}; +use chrono::{NaiveDate, NaiveTime, Timelike}; use common_error::{DaftError, DaftResult}; use super::as_arrow::AsArrow; @@ -114,11 +114,12 @@ impl TimestampArray { unreachable!("Timestamp array must have Timestamp datatype") }; let tu = timeunit.to_arrow(); - let timeunit_for_cast = if timeunit_for_cast == &TimeUnit::Nanoseconds { - TimeUnit::Nanoseconds - } else { - TimeUnit::Microseconds // default to microseconds - }; + if !matches!( + timeunit_for_cast, + TimeUnit::Microseconds | TimeUnit::Nanoseconds + ) { + return Err(DaftError::ValueError(format!("Only microseconds and nanoseconds time units are supported for the Time dtype, but got {timeunit_for_cast}"))); + } let time_arrow = match tz { Some(tz) => match arrow2::temporal_conversions::parse_offset(tz) { Ok(tz) => Ok(arrow2::array::PrimitiveArray::::from_iter( @@ -126,22 +127,12 @@ impl TimestampArray { ts.map(|ts| { let dt = arrow2::temporal_conversions::timestamp_to_datetime(*ts, tu, &tz); - match timeunit_for_cast { - TimeUnit::Nanoseconds => { - let hour = dt.hour() as i64 * 3_600_000_000_000; - let minute = dt.minute() as i64 * 60_000_000_000; - let second = dt.second() as i64 * 1_000_000_000; - let nanosecond = dt.nanosecond() as i64; - hour + minute + second + nanosecond - } - _ => { - let hour = dt.hour() as i64 * 3_600_000_000; - let minute = dt.minute() as i64 * 60_000_000; - let second = dt.second() as i64 * 1_000_000; - let microsecond = dt.nanosecond() as i64 / 1_000; - hour + minute + second + microsecond + let time_delta = dt.time() - NaiveTime::from_hms_opt(0,0,0).unwrap(); + match timeunit_for_cast { + TimeUnit::Microseconds => time_delta.num_microseconds().unwrap(), + TimeUnit::Nanoseconds => time_delta.num_nanoseconds().unwrap(), + _ => unreachable!("Only microseconds and nanoseconds time units are supported for the Time dtype, but got {timeunit_for_cast}"), } - } }) }), )), @@ -154,28 +145,18 @@ impl TimestampArray { physical.iter().map(|ts| { ts.map(|ts| { let dt = arrow2::temporal_conversions::timestamp_to_naive_datetime(*ts, tu); + let time_delta = dt.time() - NaiveTime::from_hms_opt(0,0,0).unwrap(); match timeunit_for_cast { - TimeUnit::Nanoseconds => { - let hour = dt.hour() as i64 * 3_600_000_000_000; - let minute = dt.minute() as i64 * 60_000_000_000; - let second = dt.second() as i64 * 1_000_000_000; - let nanosecond = dt.nanosecond() as i64; - hour + minute + second + nanosecond - } - _ => { - let hour = dt.hour() as i64 * 3_600_000_000; - let minute = dt.minute() as i64 * 60_000_000; - let second = dt.second() as i64 * 1_000_000; - let microsecond = dt.nanosecond() as i64 / 1_000; - hour + minute + second + microsecond - } + TimeUnit::Microseconds => time_delta.num_microseconds().unwrap(), + TimeUnit::Nanoseconds => time_delta.num_nanoseconds().unwrap(), + _ => unreachable!("Only microseconds and nanoseconds time units are supported for the Time dtype, but got {timeunit_for_cast}"), } }) }), )), }?; Ok(TimeArray::new( - Field::new(self.name(), DataType::Time(timeunit_for_cast)), + Field::new(self.name(), DataType::Time(*timeunit_for_cast)), Int64Array::from((self.name(), Box::new(time_arrow))), )) } diff --git a/src/daft-core/src/python/datatype.rs b/src/daft-core/src/python/datatype.rs index 1317d7fd70..63131eb228 100644 --- a/src/daft-core/src/python/datatype.rs +++ b/src/daft-core/src/python/datatype.rs @@ -162,6 +162,12 @@ impl PyDataType { #[staticmethod] pub fn time(timeunit: PyTimeUnit) -> PyResult { + if timeunit.timeunit == TimeUnit::Seconds || timeunit.timeunit == TimeUnit::Milliseconds { + return Err(PyValueError::new_err(format!( + "The time unit for time types must be microseconds or nanoseconds, but got: {}", + timeunit.timeunit + ))); + } Ok(DataType::Time(timeunit.timeunit).into()) } diff --git a/tests/series/test_cast.py b/tests/series/test_cast.py index 996c53317c..eb664e0118 100644 --- a/tests/series/test_cast.py +++ b/tests/series/test_cast.py @@ -775,20 +775,19 @@ def test_cast_date_to_timestamp(): assert (input == back).to_pylist() == [True] -def test_cast_timestamp_to_time(): +@pytest.mark.parametrize("timeunit", ["us", "ns"]) +def test_cast_timestamp_to_time(timeunit): from datetime import datetime, time - """Microseconds""" input = Series.from_pylist([datetime(2022, 1, 6, 12, 34, 56, 78)]) - casted = input.cast(DataType.time("us")) + casted = input.cast(DataType.time(timeunit)) assert casted.to_pylist() == [time(12, 34, 56, 78)] - """Nanoseconds""" - input = Series.from_pylist([datetime(2022, 1, 6, 12, 34, 56, 78)]) - casted = input.cast(DataType.time("ns")) - assert casted.to_pylist() == [time(12, 34, 56, 78)] - """Seconds""" +@pytest.mark.parametrize("timeunit", ["s", "ms"]) +def test_cast_timestamp_to_time_unsupported_timeunit(timeunit): + from datetime import datetime + input = Series.from_pylist([datetime(2022, 1, 6, 12, 34, 56, 78)]) - casted = input.cast(DataType.time("s")) - assert casted.to_pylist() == [time(12, 34, 56, 78)] + with pytest.raises(ValueError): + input.cast(DataType.time(timeunit))