From 7e45d8ffda02ff43ac9ff1e82b6f67ddf3bf74d4 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Mon, 7 Oct 2024 10:22:29 -0500 Subject: [PATCH 1/3] timedelta to duration --- daft/daft/__init__.pyi | 1 + daft/expressions/expressions.py | 10 +++++++++- src/daft-core/src/array/ops/repr.rs | 20 ++++++++++++++++++-- src/daft-core/src/utils/display.rs | 10 ++++++++++ src/daft-dsl/src/lib.rs | 1 + src/daft-dsl/src/lit.rs | 18 +++++++++++++++--- src/daft-dsl/src/python.rs | 6 ++++++ tests/dataframe/test_temporals.py | 27 +++++++++++++++++++++++++++ 8 files changed, 87 insertions(+), 6 deletions(-) diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index c90817dfc2..f6a496b1d7 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1134,6 +1134,7 @@ def lit(item: Any) -> PyExpr: ... def date_lit(item: int) -> PyExpr: ... def time_lit(item: int, tu: PyTimeUnit) -> PyExpr: ... def timestamp_lit(item: int, tu: PyTimeUnit, tz: str | None) -> PyExpr: ... +def duration_lit(item: int, tu: PyTimeUnit) -> PyExpr: ... def decimal_lit(sign: bool, digits: tuple[int, ...], exp: int) -> PyExpr: ... def series_lit(item: PySeries) -> PyExpr: ... def stateless_udf( diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 2701aebc77..00da05ad51 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -3,7 +3,7 @@ import math import os import warnings -from datetime import date, datetime, time +from datetime import date, datetime, time, timedelta from decimal import Decimal from typing import ( TYPE_CHECKING, @@ -23,6 +23,7 @@ from daft.daft import col as _col from daft.daft import date_lit as _date_lit from daft.daft import decimal_lit as _decimal_lit +from daft.daft import duration_lit as _duration_lit from daft.daft import list_sort as _list_sort from daft.daft import lit as _lit from daft.daft import series_lit as _series_lit @@ -71,6 +72,7 @@ def __get__( # type: ignore[override] def lit(value: object) -> Expression: + print("lit: ", value) """Creates an Expression representing a column with every value set to the provided value Example: @@ -115,6 +117,12 @@ def lit(value: object) -> Expression: i64_value = pa_time.cast(pa.int64()).as_py() time_unit = TimeUnit.from_str(pa.type_for_alias(str(pa_time.type)).unit)._timeunit lit_value = _time_lit(i64_value, time_unit) + elif isinstance(value, timedelta): + # pyo3 timedelta (PyDelta) is not available when running in abi3 mode, workaround + pa_duration = pa.scalar(value) + i64_value = pa_duration.cast(pa.int64()).as_py() + time_unit = TimeUnit.from_str(pa_duration.type.unit)._timeunit + lit_value = _duration_lit(i64_value, time_unit) elif isinstance(value, Decimal): sign, digits, exponent = value.as_tuple() assert isinstance(exponent, int) diff --git a/src/daft-core/src/array/ops/repr.rs b/src/daft-core/src/array/ops/repr.rs index 8d60f697c7..5fbb6bf2c1 100644 --- a/src/daft-core/src/array/ops/repr.rs +++ b/src/daft-core/src/array/ops/repr.rs @@ -13,7 +13,9 @@ use crate::{ NullArray, UInt64Array, Utf8Array, }, series::Series, - utils::display::{display_date32, display_decimal128, display_time64, display_timestamp}, + utils::display::{ + display_date32, display_decimal128, display_duration, display_time64, display_timestamp, + }, with_match_daft_types, }; @@ -34,7 +36,6 @@ macro_rules! impl_array_str_value { impl_array_str_value!(BooleanArray, "{}"); impl_array_str_value!(ExtensionArray, "{:?}"); -impl_array_str_value!(DurationArray, "{}"); fn pretty_print_bytes(bytes: &[u8], max_len: usize) -> DaftResult { /// influenced by pythons bytes repr @@ -192,6 +193,21 @@ impl TimestampArray { } } +impl DurationArray { + pub fn str_value(&self, idx: usize) -> DaftResult { + let res = self.get(idx).map_or_else( + || "None".to_string(), + |val| -> String { + let DataType::Duration(time_unit) = &self.field.dtype else { + panic!("Wrong dtype for DurationArray: {}", self.field.dtype) + }; + display_duration(val, time_unit) + }, + ); + Ok(res) + } +} + impl Decimal128Array { pub fn str_value(&self, idx: usize) -> DaftResult { let res = self.get(idx).map_or_else( diff --git a/src/daft-core/src/utils/display.rs b/src/daft-core/src/utils/display.rs index 37593cda8a..3e98099f6f 100644 --- a/src/daft-core/src/utils/display.rs +++ b/src/daft-core/src/utils/display.rs @@ -75,6 +75,16 @@ pub fn display_timestamp(val: i64, unit: &TimeUnit, timezone: &Option) - ) } +pub fn display_duration(val: i64, unit: &TimeUnit) -> String { + let duration = match unit { + TimeUnit::Nanoseconds => chrono::Duration::nanoseconds(val), + TimeUnit::Microseconds => chrono::Duration::microseconds(val), + TimeUnit::Milliseconds => chrono::Duration::milliseconds(val), + TimeUnit::Seconds => chrono::Duration::seconds(val), + }; + format!("{duration}") +} + pub fn display_decimal128(val: i128, _precision: u8, scale: i8) -> String { if scale < 0 { unimplemented!(); diff --git a/src/daft-dsl/src/lib.rs b/src/daft-dsl/src/lib.rs index 2fa99115e3..c3f5d68594 100644 --- a/src/daft-dsl/src/lib.rs +++ b/src/daft-dsl/src/lib.rs @@ -35,6 +35,7 @@ pub fn register_modules(parent: &Bound) -> PyResult<()> { parent.add_function(wrap_pyfunction_bound!(python::date_lit, parent)?)?; parent.add_function(wrap_pyfunction_bound!(python::time_lit, parent)?)?; parent.add_function(wrap_pyfunction_bound!(python::timestamp_lit, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(python::duration_lit, parent)?)?; parent.add_function(wrap_pyfunction_bound!(python::decimal_lit, parent)?)?; parent.add_function(wrap_pyfunction_bound!(python::series_lit, parent)?)?; parent.add_function(wrap_pyfunction_bound!(python::stateless_udf, parent)?)?; diff --git a/src/daft-dsl/src/lit.rs b/src/daft-dsl/src/lit.rs index 55888d73f8..74c680b3ab 100644 --- a/src/daft-dsl/src/lit.rs +++ b/src/daft-dsl/src/lit.rs @@ -10,8 +10,8 @@ use common_hashable_float_wrapper::FloatWrapper; use daft_core::{ prelude::*, utils::display::{ - display_date32, display_decimal128, display_series_literal, display_time64, - display_timestamp, + display_date32, display_decimal128, display_duration, display_series_literal, + display_time64, display_timestamp, }, }; use indexmap::IndexMap; @@ -60,6 +60,8 @@ pub enum LiteralValue { Date(i32), /// An [`i64`] representing a time in microseconds or nanoseconds since midnight. Time(i64, TimeUnit), + /// An [`i64`] representing a measure of elapsed time. This elapsed time is a physical duration (i.e. 1s as defined in S.I.) + Duration(i64, TimeUnit), /// A 64-bit floating point number. Float64(f64), /// An [`i128`] representing a decimal number with the provided precision and scale. @@ -99,6 +101,10 @@ impl Hash for LiteralValue { tu.hash(state); tz.hash(state); } + Duration(n, tu) => { + n.hash(state); + tu.hash(state); + } // Wrap float64 in hashable newtype. Float64(n) => FloatWrapper(*n).hash(state), Decimal(n, precision, scale) => { @@ -141,6 +147,7 @@ impl Display for LiteralValue { Date(val) => write!(f, "{}", display_date32(*val)), Time(val, tu) => write!(f, "{}", display_time64(*val, tu)), Timestamp(val, tu, tz) => write!(f, "{}", display_timestamp(*val, tu, tz)), + Duration(val, tu) => write!(f, "{}", display_duration(*val, tu)), Float64(val) => write!(f, "{val:.1}"), Decimal(val, precision, scale) => { write!(f, "{}", display_decimal128(*val, *precision, *scale)) @@ -181,6 +188,7 @@ impl LiteralValue { Date(_) => DataType::Date, Time(_, tu) => DataType::Time(*tu), Timestamp(_, tu, tz) => DataType::Timestamp(*tu, tz.clone()), + Duration(_, tu) => DataType::Duration(*tu), Float64(_) => DataType::Float64, Decimal(_, precision, scale) => { DataType::Decimal128(*precision as usize, *scale as usize) @@ -215,6 +223,10 @@ impl LiteralValue { let physical = Int64Array::from(("literal", [*val].as_slice())); TimestampArray::new(Field::new("literal", self.get_type()), physical).into_series() } + Duration(val, ..) => { + let physical = Int64Array::from(("literal", [*val].as_slice())); + DurationArray::new(Field::new("literal", self.get_type()), physical).into_series() + } Float64(val) => Float64Array::from(("literal", [*val].as_slice())).into_series(), Decimal(val, ..) => { let physical = Int128Array::from(("literal", [*val].as_slice())); @@ -259,7 +271,7 @@ impl LiteralValue { display_timestamp(*val, tu, tz).replace('T', " ") ), // TODO(Colin): Implement the rest of the types in future work for SQL pushdowns. - Decimal(..) | Series(..) | Time(..) | Binary(..) => display_sql_err, + Decimal(..) | Series(..) | Time(..) | Binary(..) | Duration(..) => display_sql_err, #[cfg(feature = "python")] Python(..) => display_sql_err, Struct(..) => display_sql_err, diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index d67e522ec0..edd3f5bcb4 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -46,6 +46,12 @@ pub fn timestamp_lit(val: i64, tu: PyTimeUnit, tz: Option) -> PyResult

PyResult { + let expr = Expr::Literal(LiteralValue::Duration(val, tu.timeunit)); + Ok(expr.into()) +} + fn decimal_from_digits(digits: Vec, exp: i32) -> Option<(i128, usize)> { const MAX_ABS_DEC: i128 = 10_i128.pow(38) - 1; let mut v = 0_i128; diff --git a/tests/dataframe/test_temporals.py b/tests/dataframe/test_temporals.py index 8843028b01..599e63eaf9 100644 --- a/tests/dataframe/test_temporals.py +++ b/tests/dataframe/test_temporals.py @@ -152,6 +152,33 @@ def test_python_duration() -> None: assert res == duration +def test_temporal_arithmetic_with_duration_lit() -> None: + df = daft.from_pydict( + { + "duration": [timedelta(days=1)], + "date": [datetime(2021, 1, 1)], + "timestamp": [datetime(2021, 1, 1)], + } + ) + + df = df.select( + (df["date"] + timedelta(days=1)).alias("add_date"), + (df["date"] - timedelta(days=1)).alias("sub_date"), + (df["timestamp"] + timedelta(days=1)).alias("add_timestamp"), + (df["timestamp"] - timedelta(days=1)).alias("sub_timestamp"), + (df["duration"] + timedelta(days=1)).alias("add_dur"), + (df["duration"] - timedelta(days=1)).alias("sub_dur"), + ) + + result = df.to_pydict() + assert result["add_date"] == [datetime(2021, 1, 2)] + assert result["sub_date"] == [datetime(2020, 12, 31)] + assert result["add_timestamp"] == [datetime(2021, 1, 2)] + assert result["sub_timestamp"] == [datetime(2020, 12, 31)] + assert result["add_dur"] == [timedelta(days=2)] + assert result["sub_dur"] == [timedelta(0)] + + @pytest.mark.parametrize( "timeunit", ["s", "ms", "us", "ns"], From c79f0433c1d5efd66a989f8b69326338a754e76c Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Mon, 7 Oct 2024 10:30:10 -0500 Subject: [PATCH 2/3] repr is bad --- src/daft-core/src/array/ops/repr.rs | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/src/daft-core/src/array/ops/repr.rs b/src/daft-core/src/array/ops/repr.rs index 5fbb6bf2c1..8d60f697c7 100644 --- a/src/daft-core/src/array/ops/repr.rs +++ b/src/daft-core/src/array/ops/repr.rs @@ -13,9 +13,7 @@ use crate::{ NullArray, UInt64Array, Utf8Array, }, series::Series, - utils::display::{ - display_date32, display_decimal128, display_duration, display_time64, display_timestamp, - }, + utils::display::{display_date32, display_decimal128, display_time64, display_timestamp}, with_match_daft_types, }; @@ -36,6 +34,7 @@ macro_rules! impl_array_str_value { impl_array_str_value!(BooleanArray, "{}"); impl_array_str_value!(ExtensionArray, "{:?}"); +impl_array_str_value!(DurationArray, "{}"); fn pretty_print_bytes(bytes: &[u8], max_len: usize) -> DaftResult { /// influenced by pythons bytes repr @@ -193,21 +192,6 @@ impl TimestampArray { } } -impl DurationArray { - pub fn str_value(&self, idx: usize) -> DaftResult { - let res = self.get(idx).map_or_else( - || "None".to_string(), - |val| -> String { - let DataType::Duration(time_unit) = &self.field.dtype else { - panic!("Wrong dtype for DurationArray: {}", self.field.dtype) - }; - display_duration(val, time_unit) - }, - ); - Ok(res) - } -} - impl Decimal128Array { pub fn str_value(&self, idx: usize) -> DaftResult { let res = self.get(idx).map_or_else( From 519e8200ff1376b73e8b44cbdeebb4cd74acffdd Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Mon, 7 Oct 2024 11:07:13 -0500 Subject: [PATCH 3/3] improve and test repr --- daft/expressions/expressions.py | 1 - src/daft-core/src/array/ops/repr.rs | 20 +++++++++-- src/daft-core/src/utils/display.rs | 50 +++++++++++++++++++++++---- tests/expressions/test_expressions.py | 27 +++++++++++++++ 4 files changed, 89 insertions(+), 9 deletions(-) diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 00da05ad51..ae3655b430 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -72,7 +72,6 @@ def __get__( # type: ignore[override] def lit(value: object) -> Expression: - print("lit: ", value) """Creates an Expression representing a column with every value set to the provided value Example: diff --git a/src/daft-core/src/array/ops/repr.rs b/src/daft-core/src/array/ops/repr.rs index 8d60f697c7..5fbb6bf2c1 100644 --- a/src/daft-core/src/array/ops/repr.rs +++ b/src/daft-core/src/array/ops/repr.rs @@ -13,7 +13,9 @@ use crate::{ NullArray, UInt64Array, Utf8Array, }, series::Series, - utils::display::{display_date32, display_decimal128, display_time64, display_timestamp}, + utils::display::{ + display_date32, display_decimal128, display_duration, display_time64, display_timestamp, + }, with_match_daft_types, }; @@ -34,7 +36,6 @@ macro_rules! impl_array_str_value { impl_array_str_value!(BooleanArray, "{}"); impl_array_str_value!(ExtensionArray, "{:?}"); -impl_array_str_value!(DurationArray, "{}"); fn pretty_print_bytes(bytes: &[u8], max_len: usize) -> DaftResult { /// influenced by pythons bytes repr @@ -192,6 +193,21 @@ impl TimestampArray { } } +impl DurationArray { + pub fn str_value(&self, idx: usize) -> DaftResult { + let res = self.get(idx).map_or_else( + || "None".to_string(), + |val| -> String { + let DataType::Duration(time_unit) = &self.field.dtype else { + panic!("Wrong dtype for DurationArray: {}", self.field.dtype) + }; + display_duration(val, time_unit) + }, + ); + Ok(res) + } +} + impl Decimal128Array { pub fn str_value(&self, idx: usize) -> DaftResult { let res = self.get(idx).map_or_else( diff --git a/src/daft-core/src/utils/display.rs b/src/daft-core/src/utils/display.rs index 3e98099f6f..a76d382555 100644 --- a/src/daft-core/src/utils/display.rs +++ b/src/daft-core/src/utils/display.rs @@ -75,14 +75,52 @@ pub fn display_timestamp(val: i64, unit: &TimeUnit, timezone: &Option) - ) } +const UNITS: [&str; 4] = ["d", "h", "m", "s"]; +const SIZES: [[i64; 4]; 4] = [ + [ + 86_400_000_000_000, + 3_600_000_000_000, + 60_000_000_000, + 1_000_000_000, + ], // Nanoseconds + [86_400_000_000, 3_600_000_000, 60_000_000, 1_000_000], // Microseconds + [86_400_000, 3_600_000, 60_000, 1_000], // Milliseconds + [86_400, 3_600, 60, 1], // Seconds +]; + pub fn display_duration(val: i64, unit: &TimeUnit) -> String { - let duration = match unit { - TimeUnit::Nanoseconds => chrono::Duration::nanoseconds(val), - TimeUnit::Microseconds => chrono::Duration::microseconds(val), - TimeUnit::Milliseconds => chrono::Duration::milliseconds(val), - TimeUnit::Seconds => chrono::Duration::seconds(val), + let mut output = String::new(); + let (sizes, suffix, remainder_divisor) = match unit { + TimeUnit::Nanoseconds => (&SIZES[0], "ns", 1_000_000_000), + TimeUnit::Microseconds => (&SIZES[1], "µs", 1_000_000), + TimeUnit::Milliseconds => (&SIZES[2], "ms", 1_000), + TimeUnit::Seconds => (&SIZES[3], "s", 1), }; - format!("{duration}") + + if val == 0 { + return format!("0{}", suffix); + } + + for (i, &size) in sizes.iter().enumerate() { + let whole_num = if i == 0 { + val / size + } else { + (val % sizes[i - 1]) / size + }; + if whole_num != 0 { + output.push_str(&format!("{}{}", whole_num, UNITS[i])); + if val % size != 0 { + output.push(' '); + } + } + } + + let remainder = val % remainder_divisor; + if remainder != 0 && suffix != "s" { + output.push_str(&format!("{}{}", remainder, suffix)); + } + + output } pub fn display_decimal128(val: i128, _precision: u8, scale: i8) -> String { diff --git a/tests/expressions/test_expressions.py b/tests/expressions/test_expressions.py index d3727c2ac3..abae55386b 100644 --- a/tests/expressions/test_expressions.py +++ b/tests/expressions/test_expressions.py @@ -504,6 +504,33 @@ def test_datetime_lit_different_timeunits(timeunit, expected) -> None: assert timestamp_repr == expected +@pytest.mark.parametrize( + "input, expected", + [ + ( + timedelta(days=1), + "lit(1d)", + ), + ( + timedelta(days=1, hours=12, minutes=30, seconds=59), + "lit(1d 12h 30m 59s)", + ), + ( + timedelta(days=1, hours=12, minutes=30, seconds=59, microseconds=123456), + "lit(1d 12h 30m 59s 123456µs)", + ), + ( + timedelta(weeks=1, days=1, hours=12, minutes=30, seconds=59, microseconds=123456), + "lit(8d 12h 30m 59s 123456µs)", + ), + ], +) +def test_duration_lit(input, expected) -> None: + d = lit(input) + output = repr(d) + assert output == expected + + def test_repr_series_lit() -> None: s = lit(Series.from_pylist([1, 2, 3])) output = repr(s)