Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Infer timedelta literal as duration #3011

Merged
merged 3 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,7 @@ def lit(item: Any) -> PyExpr: ...
def date_lit(item: int) -> PyExpr: ...
def time_lit(item: int, tu: PyTimeUnit) -> PyExpr: ...
def timestamp_lit(item: int, tu: PyTimeUnit, tz: str | None) -> PyExpr: ...
def duration_lit(item: int, tu: PyTimeUnit) -> PyExpr: ...
def decimal_lit(sign: bool, digits: tuple[int, ...], exp: int) -> PyExpr: ...
def series_lit(item: PySeries) -> PyExpr: ...
def stateless_udf(
Expand Down
10 changes: 9 additions & 1 deletion daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import math
import os
import warnings
from datetime import date, datetime, time
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from typing import (
TYPE_CHECKING,
Expand All @@ -23,6 +23,7 @@
from daft.daft import col as _col
from daft.daft import date_lit as _date_lit
from daft.daft import decimal_lit as _decimal_lit
from daft.daft import duration_lit as _duration_lit
from daft.daft import list_sort as _list_sort
from daft.daft import lit as _lit
from daft.daft import series_lit as _series_lit
Expand Down Expand Up @@ -71,6 +72,7 @@ def __get__( # type: ignore[override]


def lit(value: object) -> Expression:
print("lit: ", value)
"""Creates an Expression representing a column with every value set to the provided value

Example:
Expand Down Expand Up @@ -115,6 +117,12 @@ def lit(value: object) -> Expression:
i64_value = pa_time.cast(pa.int64()).as_py()
time_unit = TimeUnit.from_str(pa.type_for_alias(str(pa_time.type)).unit)._timeunit
lit_value = _time_lit(i64_value, time_unit)
elif isinstance(value, timedelta):
# pyo3 timedelta (PyDelta) is not available when running in abi3 mode, workaround
pa_duration = pa.scalar(value)
i64_value = pa_duration.cast(pa.int64()).as_py()
time_unit = TimeUnit.from_str(pa_duration.type.unit)._timeunit
lit_value = _duration_lit(i64_value, time_unit)
elif isinstance(value, Decimal):
sign, digits, exponent = value.as_tuple()
assert isinstance(exponent, int)
Expand Down
20 changes: 18 additions & 2 deletions src/daft-core/src/array/ops/repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
NullArray, UInt64Array, Utf8Array,
},
series::Series,
utils::display::{display_date32, display_decimal128, display_time64, display_timestamp},
utils::display::{
display_date32, display_decimal128, display_duration, display_time64, display_timestamp,
},
with_match_daft_types,
};

Expand All @@ -34,7 +36,6 @@

impl_array_str_value!(BooleanArray, "{}");
impl_array_str_value!(ExtensionArray, "{:?}");
impl_array_str_value!(DurationArray, "{}");

fn pretty_print_bytes(bytes: &[u8], max_len: usize) -> DaftResult<String> {
/// influenced by pythons bytes repr
Expand Down Expand Up @@ -192,6 +193,21 @@
}
}

impl DurationArray {
pub fn str_value(&self, idx: usize) -> DaftResult<String> {
let res = self.get(idx).map_or_else(
|| "None".to_string(),
|val| -> String {
let DataType::Duration(time_unit) = &self.field.dtype else {
panic!("Wrong dtype for DurationArray: {}", self.field.dtype)

Check warning on line 202 in src/daft-core/src/array/ops/repr.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/array/ops/repr.rs#L197-L202

Added lines #L197 - L202 were not covered by tests
};
display_duration(val, time_unit)
},
);
Ok(res)
}

Check warning on line 208 in src/daft-core/src/array/ops/repr.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/array/ops/repr.rs#L204-L208

Added lines #L204 - L208 were not covered by tests
}

impl Decimal128Array {
pub fn str_value(&self, idx: usize) -> DaftResult<String> {
let res = self.get(idx).map_or_else(
Expand Down
10 changes: 10 additions & 0 deletions src/daft-core/src/utils/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@
)
}

pub fn display_duration(val: i64, unit: &TimeUnit) -> String {
let duration = match unit {
TimeUnit::Nanoseconds => chrono::Duration::nanoseconds(val),
TimeUnit::Microseconds => chrono::Duration::microseconds(val),
TimeUnit::Milliseconds => chrono::Duration::milliseconds(val),
TimeUnit::Seconds => chrono::Duration::seconds(val),

Check warning on line 83 in src/daft-core/src/utils/display.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/utils/display.rs#L78-L83

Added lines #L78 - L83 were not covered by tests
};
format!("{duration}")
}

Check warning on line 86 in src/daft-core/src/utils/display.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/utils/display.rs#L85-L86

Added lines #L85 - L86 were not covered by tests

pub fn display_decimal128(val: i128, _precision: u8, scale: i8) -> String {
if scale < 0 {
unimplemented!();
Expand Down
1 change: 1 addition & 0 deletions src/daft-dsl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pub fn register_modules(parent: &Bound<PyModule>) -> PyResult<()> {
parent.add_function(wrap_pyfunction_bound!(python::date_lit, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(python::time_lit, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(python::timestamp_lit, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(python::duration_lit, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(python::decimal_lit, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(python::series_lit, parent)?)?;
parent.add_function(wrap_pyfunction_bound!(python::stateless_udf, parent)?)?;
Expand Down
18 changes: 15 additions & 3 deletions src/daft-dsl/src/lit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
use daft_core::{
prelude::*,
utils::display::{
display_date32, display_decimal128, display_series_literal, display_time64,
display_timestamp,
display_date32, display_decimal128, display_duration, display_series_literal,
display_time64, display_timestamp,
},
};
use indexmap::IndexMap;
Expand Down Expand Up @@ -60,6 +60,8 @@
Date(i32),
/// An [`i64`] representing a time in microseconds or nanoseconds since midnight.
Time(i64, TimeUnit),
/// An [`i64`] representing a measure of elapsed time. This elapsed time is a physical duration (i.e. 1s as defined in S.I.)
Duration(i64, TimeUnit),
/// A 64-bit floating point number.
Float64(f64),
/// An [`i128`] representing a decimal number with the provided precision and scale.
Expand Down Expand Up @@ -99,6 +101,10 @@
tu.hash(state);
tz.hash(state);
}
Duration(n, tu) => {
n.hash(state);
tu.hash(state);
}
// Wrap float64 in hashable newtype.
Float64(n) => FloatWrapper(*n).hash(state),
Decimal(n, precision, scale) => {
Expand Down Expand Up @@ -141,6 +147,7 @@
Date(val) => write!(f, "{}", display_date32(*val)),
Time(val, tu) => write!(f, "{}", display_time64(*val, tu)),
Timestamp(val, tu, tz) => write!(f, "{}", display_timestamp(*val, tu, tz)),
Duration(val, tu) => write!(f, "{}", display_duration(*val, tu)),

Check warning on line 150 in src/daft-dsl/src/lit.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/lit.rs#L150

Added line #L150 was not covered by tests
Float64(val) => write!(f, "{val:.1}"),
Decimal(val, precision, scale) => {
write!(f, "{}", display_decimal128(*val, *precision, *scale))
Expand Down Expand Up @@ -181,6 +188,7 @@
Date(_) => DataType::Date,
Time(_, tu) => DataType::Time(*tu),
Timestamp(_, tu, tz) => DataType::Timestamp(*tu, tz.clone()),
Duration(_, tu) => DataType::Duration(*tu),
Float64(_) => DataType::Float64,
Decimal(_, precision, scale) => {
DataType::Decimal128(*precision as usize, *scale as usize)
Expand Down Expand Up @@ -215,6 +223,10 @@
let physical = Int64Array::from(("literal", [*val].as_slice()));
TimestampArray::new(Field::new("literal", self.get_type()), physical).into_series()
}
Duration(val, ..) => {
let physical = Int64Array::from(("literal", [*val].as_slice()));
DurationArray::new(Field::new("literal", self.get_type()), physical).into_series()
}
Float64(val) => Float64Array::from(("literal", [*val].as_slice())).into_series(),
Decimal(val, ..) => {
let physical = Int128Array::from(("literal", [*val].as_slice()));
Expand Down Expand Up @@ -259,7 +271,7 @@
display_timestamp(*val, tu, tz).replace('T', " ")
),
// TODO(Colin): Implement the rest of the types in future work for SQL pushdowns.
Decimal(..) | Series(..) | Time(..) | Binary(..) => display_sql_err,
Decimal(..) | Series(..) | Time(..) | Binary(..) | Duration(..) => display_sql_err,

Check warning on line 274 in src/daft-dsl/src/lit.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/lit.rs#L274

Added line #L274 was not covered by tests
#[cfg(feature = "python")]
Python(..) => display_sql_err,
Struct(..) => display_sql_err,
Expand Down
6 changes: 6 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ pub fn timestamp_lit(val: i64, tu: PyTimeUnit, tz: Option<String>) -> PyResult<P
Ok(expr.into())
}

#[pyfunction]
pub fn duration_lit(val: i64, tu: PyTimeUnit) -> PyResult<PyExpr> {
let expr = Expr::Literal(LiteralValue::Duration(val, tu.timeunit));
Ok(expr.into())
}

fn decimal_from_digits(digits: Vec<u8>, exp: i32) -> Option<(i128, usize)> {
const MAX_ABS_DEC: i128 = 10_i128.pow(38) - 1;
let mut v = 0_i128;
Expand Down
27 changes: 27 additions & 0 deletions tests/dataframe/test_temporals.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,33 @@ def test_python_duration() -> None:
assert res == duration


def test_temporal_arithmetic_with_duration_lit() -> None:
df = daft.from_pydict(
{
"duration": [timedelta(days=1)],
"date": [datetime(2021, 1, 1)],
"timestamp": [datetime(2021, 1, 1)],
}
)

df = df.select(
(df["date"] + timedelta(days=1)).alias("add_date"),
(df["date"] - timedelta(days=1)).alias("sub_date"),
(df["timestamp"] + timedelta(days=1)).alias("add_timestamp"),
(df["timestamp"] - timedelta(days=1)).alias("sub_timestamp"),
(df["duration"] + timedelta(days=1)).alias("add_dur"),
(df["duration"] - timedelta(days=1)).alias("sub_dur"),
)

result = df.to_pydict()
assert result["add_date"] == [datetime(2021, 1, 2)]
assert result["sub_date"] == [datetime(2020, 12, 31)]
assert result["add_timestamp"] == [datetime(2021, 1, 2)]
assert result["sub_timestamp"] == [datetime(2020, 12, 31)]
assert result["add_dur"] == [timedelta(days=2)]
assert result["sub_dur"] == [timedelta(0)]


@pytest.mark.parametrize(
"timeunit",
["s", "ms", "us", "ns"],
Expand Down
Loading