From 519e8200ff1376b73e8b44cbdeebb4cd74acffdd Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Mon, 7 Oct 2024 11:07:13 -0500 Subject: [PATCH] improve and test repr --- daft/expressions/expressions.py | 1 - src/daft-core/src/array/ops/repr.rs | 20 +++++++++-- src/daft-core/src/utils/display.rs | 50 +++++++++++++++++++++++---- tests/expressions/test_expressions.py | 27 +++++++++++++++ 4 files changed, 89 insertions(+), 9 deletions(-) diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 00da05ad51..ae3655b430 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -72,7 +72,6 @@ def __get__( # type: ignore[override] def lit(value: object) -> Expression: - print("lit: ", value) """Creates an Expression representing a column with every value set to the provided value Example: diff --git a/src/daft-core/src/array/ops/repr.rs b/src/daft-core/src/array/ops/repr.rs index 8d60f697c7..5fbb6bf2c1 100644 --- a/src/daft-core/src/array/ops/repr.rs +++ b/src/daft-core/src/array/ops/repr.rs @@ -13,7 +13,9 @@ use crate::{ NullArray, UInt64Array, Utf8Array, }, series::Series, - utils::display::{display_date32, display_decimal128, display_time64, display_timestamp}, + utils::display::{ + display_date32, display_decimal128, display_duration, display_time64, display_timestamp, + }, with_match_daft_types, }; @@ -34,7 +36,6 @@ macro_rules! impl_array_str_value { impl_array_str_value!(BooleanArray, "{}"); impl_array_str_value!(ExtensionArray, "{:?}"); -impl_array_str_value!(DurationArray, "{}"); fn pretty_print_bytes(bytes: &[u8], max_len: usize) -> DaftResult { /// influenced by pythons bytes repr @@ -192,6 +193,21 @@ impl TimestampArray { } } +impl DurationArray { + pub fn str_value(&self, idx: usize) -> DaftResult { + let res = self.get(idx).map_or_else( + || "None".to_string(), + |val| -> String { + let DataType::Duration(time_unit) = &self.field.dtype else { + panic!("Wrong dtype for DurationArray: {}", self.field.dtype) + }; + display_duration(val, time_unit) + }, + ); + Ok(res) + } +} + impl Decimal128Array { pub fn str_value(&self, idx: usize) -> DaftResult { let res = self.get(idx).map_or_else( diff --git a/src/daft-core/src/utils/display.rs b/src/daft-core/src/utils/display.rs index 3e98099f6f..a76d382555 100644 --- a/src/daft-core/src/utils/display.rs +++ b/src/daft-core/src/utils/display.rs @@ -75,14 +75,52 @@ pub fn display_timestamp(val: i64, unit: &TimeUnit, timezone: &Option) - ) } +const UNITS: [&str; 4] = ["d", "h", "m", "s"]; +const SIZES: [[i64; 4]; 4] = [ + [ + 86_400_000_000_000, + 3_600_000_000_000, + 60_000_000_000, + 1_000_000_000, + ], // Nanoseconds + [86_400_000_000, 3_600_000_000, 60_000_000, 1_000_000], // Microseconds + [86_400_000, 3_600_000, 60_000, 1_000], // Milliseconds + [86_400, 3_600, 60, 1], // Seconds +]; + pub fn display_duration(val: i64, unit: &TimeUnit) -> String { - let duration = match unit { - TimeUnit::Nanoseconds => chrono::Duration::nanoseconds(val), - TimeUnit::Microseconds => chrono::Duration::microseconds(val), - TimeUnit::Milliseconds => chrono::Duration::milliseconds(val), - TimeUnit::Seconds => chrono::Duration::seconds(val), + let mut output = String::new(); + let (sizes, suffix, remainder_divisor) = match unit { + TimeUnit::Nanoseconds => (&SIZES[0], "ns", 1_000_000_000), + TimeUnit::Microseconds => (&SIZES[1], "µs", 1_000_000), + TimeUnit::Milliseconds => (&SIZES[2], "ms", 1_000), + TimeUnit::Seconds => (&SIZES[3], "s", 1), }; - format!("{duration}") + + if val == 0 { + return format!("0{}", suffix); + } + + for (i, &size) in sizes.iter().enumerate() { + let whole_num = if i == 0 { + val / size + } else { + (val % sizes[i - 1]) / size + }; + if whole_num != 0 { + output.push_str(&format!("{}{}", whole_num, UNITS[i])); + if val % size != 0 { + output.push(' '); + } + } + } + + let remainder = val % remainder_divisor; + if remainder != 0 && suffix != "s" { + output.push_str(&format!("{}{}", remainder, suffix)); + } + + output } pub fn display_decimal128(val: i128, _precision: u8, scale: i8) -> String { diff --git a/tests/expressions/test_expressions.py b/tests/expressions/test_expressions.py index d3727c2ac3..abae55386b 100644 --- a/tests/expressions/test_expressions.py +++ b/tests/expressions/test_expressions.py @@ -504,6 +504,33 @@ def test_datetime_lit_different_timeunits(timeunit, expected) -> None: assert timestamp_repr == expected +@pytest.mark.parametrize( + "input, expected", + [ + ( + timedelta(days=1), + "lit(1d)", + ), + ( + timedelta(days=1, hours=12, minutes=30, seconds=59), + "lit(1d 12h 30m 59s)", + ), + ( + timedelta(days=1, hours=12, minutes=30, seconds=59, microseconds=123456), + "lit(1d 12h 30m 59s 123456µs)", + ), + ( + timedelta(weeks=1, days=1, hours=12, minutes=30, seconds=59, microseconds=123456), + "lit(8d 12h 30m 59s 123456µs)", + ), + ], +) +def test_duration_lit(input, expected) -> None: + d = lit(input) + output = repr(d) + assert output == expected + + def test_repr_series_lit() -> None: s = lit(Series.from_pylist([1, 2, 3])) output = repr(s)