Skip to content

Commit

Permalink
stricter timeunit assertions
Browse files Browse the repository at this point in the history
  • Loading branch information
colin-ho committed Feb 23, 2024
1 parent bde1d35 commit 10dd638
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 46 deletions.
53 changes: 17 additions & 36 deletions src/daft-core/src/array/ops/date.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::{
DataType,
};
use arrow2::compute::arithmetics::ArraySub;
use chrono::{NaiveDate, Timelike};
use chrono::{NaiveDate, NaiveTime, Timelike};
use common_error::{DaftError, DaftResult};

use super::as_arrow::AsArrow;
Expand Down Expand Up @@ -114,34 +114,25 @@ impl TimestampArray {
unreachable!("Timestamp array must have Timestamp datatype")
};
let tu = timeunit.to_arrow();
let timeunit_for_cast = if timeunit_for_cast == &TimeUnit::Nanoseconds {
TimeUnit::Nanoseconds
} else {
TimeUnit::Microseconds // default to microseconds
};
if !matches!(
timeunit_for_cast,
TimeUnit::Microseconds | TimeUnit::Nanoseconds
) {
return Err(DaftError::ValueError(format!("Only microseconds and nanoseconds time units are supported for the Time dtype, but got {timeunit_for_cast}")));
}
let time_arrow = match tz {
Some(tz) => match arrow2::temporal_conversions::parse_offset(tz) {
Ok(tz) => Ok(arrow2::array::PrimitiveArray::<i64>::from_iter(
physical.iter().map(|ts| {
ts.map(|ts| {
let dt =
arrow2::temporal_conversions::timestamp_to_datetime(*ts, tu, &tz);
match timeunit_for_cast {
TimeUnit::Nanoseconds => {
let hour = dt.hour() as i64 * 3_600_000_000_000;
let minute = dt.minute() as i64 * 60_000_000_000;
let second = dt.second() as i64 * 1_000_000_000;
let nanosecond = dt.nanosecond() as i64;
hour + minute + second + nanosecond
}
_ => {
let hour = dt.hour() as i64 * 3_600_000_000;
let minute = dt.minute() as i64 * 60_000_000;
let second = dt.second() as i64 * 1_000_000;
let microsecond = dt.nanosecond() as i64 / 1_000;
hour + minute + second + microsecond
let time_delta = dt.time() - NaiveTime::from_hms_opt(0,0,0).unwrap();
match timeunit_for_cast {
TimeUnit::Microseconds => time_delta.num_microseconds().unwrap(),
TimeUnit::Nanoseconds => time_delta.num_nanoseconds().unwrap(),
_ => unreachable!("Only microseconds and nanoseconds time units are supported for the Time dtype, but got {timeunit_for_cast}"),
}
}
})
}),
)),
Expand All @@ -154,28 +145,18 @@ impl TimestampArray {
physical.iter().map(|ts| {
ts.map(|ts| {
let dt = arrow2::temporal_conversions::timestamp_to_naive_datetime(*ts, tu);
let time_delta = dt.time() - NaiveTime::from_hms_opt(0,0,0).unwrap();
match timeunit_for_cast {
TimeUnit::Nanoseconds => {
let hour = dt.hour() as i64 * 3_600_000_000_000;
let minute = dt.minute() as i64 * 60_000_000_000;
let second = dt.second() as i64 * 1_000_000_000;
let nanosecond = dt.nanosecond() as i64;
hour + minute + second + nanosecond
}
_ => {
let hour = dt.hour() as i64 * 3_600_000_000;
let minute = dt.minute() as i64 * 60_000_000;
let second = dt.second() as i64 * 1_000_000;
let microsecond = dt.nanosecond() as i64 / 1_000;
hour + minute + second + microsecond
}
TimeUnit::Microseconds => time_delta.num_microseconds().unwrap(),
TimeUnit::Nanoseconds => time_delta.num_nanoseconds().unwrap(),
_ => unreachable!("Only microseconds and nanoseconds time units are supported for the Time dtype, but got {timeunit_for_cast}"),
}
})
}),
)),
}?;
Ok(TimeArray::new(
Field::new(self.name(), DataType::Time(timeunit_for_cast)),
Field::new(self.name(), DataType::Time(*timeunit_for_cast)),
Int64Array::from((self.name(), Box::new(time_arrow))),
))
}
Expand Down
6 changes: 6 additions & 0 deletions src/daft-core/src/python/datatype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,12 @@ impl PyDataType {

#[staticmethod]
pub fn time(timeunit: PyTimeUnit) -> PyResult<Self> {
if timeunit.timeunit == TimeUnit::Seconds || timeunit.timeunit == TimeUnit::Milliseconds {
return Err(PyValueError::new_err(format!(
"The time unit for time types must be microseconds or nanoseconds, but got: {}",
timeunit.timeunit
)));
}
Ok(DataType::Time(timeunit.timeunit).into())
}

Expand Down
19 changes: 9 additions & 10 deletions tests/series/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,20 +775,19 @@ def test_cast_date_to_timestamp():
assert (input == back).to_pylist() == [True]


def test_cast_timestamp_to_time():
@pytest.mark.parametrize("timeunit", ["us", "ns"])
def test_cast_timestamp_to_time(timeunit):
from datetime import datetime, time

"""Microseconds"""
input = Series.from_pylist([datetime(2022, 1, 6, 12, 34, 56, 78)])
casted = input.cast(DataType.time("us"))
casted = input.cast(DataType.time(timeunit))
assert casted.to_pylist() == [time(12, 34, 56, 78)]

"""Nanoseconds"""
input = Series.from_pylist([datetime(2022, 1, 6, 12, 34, 56, 78)])
casted = input.cast(DataType.time("ns"))
assert casted.to_pylist() == [time(12, 34, 56, 78)]

"""Seconds"""
@pytest.mark.parametrize("timeunit", ["s", "ms"])
def test_cast_timestamp_to_time_unsupported_timeunit(timeunit):
from datetime import datetime

input = Series.from_pylist([datetime(2022, 1, 6, 12, 34, 56, 78)])
casted = input.cast(DataType.time("s"))
assert casted.to_pylist() == [time(12, 34, 56, 78)]
with pytest.raises(ValueError):
input.cast(DataType.time(timeunit))

0 comments on commit 10dd638

Please sign in to comment.