Skip to content

Commit

Permalink
improve and test repr
Browse files Browse the repository at this point in the history
  • Loading branch information
Colin Ho authored and Colin Ho committed Oct 7, 2024
1 parent c79f043 commit 519e820
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 9 deletions.
1 change: 0 additions & 1 deletion daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def __get__( # type: ignore[override]


def lit(value: object) -> Expression:
print("lit: ", value)
"""Creates an Expression representing a column with every value set to the provided value
Example:
Expand Down
20 changes: 18 additions & 2 deletions src/daft-core/src/array/ops/repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ use crate::{
NullArray, UInt64Array, Utf8Array,
},
series::Series,
utils::display::{display_date32, display_decimal128, display_time64, display_timestamp},
utils::display::{
display_date32, display_decimal128, display_duration, display_time64, display_timestamp,
},
with_match_daft_types,
};

Expand All @@ -34,7 +36,6 @@ macro_rules! impl_array_str_value {

impl_array_str_value!(BooleanArray, "{}");
impl_array_str_value!(ExtensionArray, "{:?}");
impl_array_str_value!(DurationArray, "{}");

fn pretty_print_bytes(bytes: &[u8], max_len: usize) -> DaftResult<String> {
/// influenced by pythons bytes repr
Expand Down Expand Up @@ -192,6 +193,21 @@ impl TimestampArray {
}
}

impl DurationArray {
pub fn str_value(&self, idx: usize) -> DaftResult<String> {
let res = self.get(idx).map_or_else(
|| "None".to_string(),
|val| -> String {
let DataType::Duration(time_unit) = &self.field.dtype else {
panic!("Wrong dtype for DurationArray: {}", self.field.dtype)

Check warning on line 202 in src/daft-core/src/array/ops/repr.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/array/ops/repr.rs#L197-L202

Added lines #L197 - L202 were not covered by tests
};
display_duration(val, time_unit)
},
);
Ok(res)
}

Check warning on line 208 in src/daft-core/src/array/ops/repr.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/array/ops/repr.rs#L204-L208

Added lines #L204 - L208 were not covered by tests
}

impl Decimal128Array {
pub fn str_value(&self, idx: usize) -> DaftResult<String> {
let res = self.get(idx).map_or_else(
Expand Down
50 changes: 44 additions & 6 deletions src/daft-core/src/utils/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,52 @@ pub fn display_timestamp(val: i64, unit: &TimeUnit, timezone: &Option<String>) -
)
}

const UNITS: [&str; 4] = ["d", "h", "m", "s"];
const SIZES: [[i64; 4]; 4] = [
[
86_400_000_000_000,
3_600_000_000_000,
60_000_000_000,
1_000_000_000,
], // Nanoseconds
[86_400_000_000, 3_600_000_000, 60_000_000, 1_000_000], // Microseconds
[86_400_000, 3_600_000, 60_000, 1_000], // Milliseconds
[86_400, 3_600, 60, 1], // Seconds
];

pub fn display_duration(val: i64, unit: &TimeUnit) -> String {
let duration = match unit {
TimeUnit::Nanoseconds => chrono::Duration::nanoseconds(val),
TimeUnit::Microseconds => chrono::Duration::microseconds(val),
TimeUnit::Milliseconds => chrono::Duration::milliseconds(val),
TimeUnit::Seconds => chrono::Duration::seconds(val),
let mut output = String::new();
let (sizes, suffix, remainder_divisor) = match unit {
TimeUnit::Nanoseconds => (&SIZES[0], "ns", 1_000_000_000),

Check warning on line 94 in src/daft-core/src/utils/display.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/utils/display.rs#L94

Added line #L94 was not covered by tests
TimeUnit::Microseconds => (&SIZES[1], "µs", 1_000_000),
TimeUnit::Milliseconds => (&SIZES[2], "ms", 1_000),
TimeUnit::Seconds => (&SIZES[3], "s", 1),

Check warning on line 97 in src/daft-core/src/utils/display.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/utils/display.rs#L96-L97

Added lines #L96 - L97 were not covered by tests
};
format!("{duration}")

if val == 0 {
return format!("0{}", suffix);

Check warning on line 101 in src/daft-core/src/utils/display.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/utils/display.rs#L101

Added line #L101 was not covered by tests
}

for (i, &size) in sizes.iter().enumerate() {
let whole_num = if i == 0 {
val / size
} else {
(val % sizes[i - 1]) / size
};
if whole_num != 0 {
output.push_str(&format!("{}{}", whole_num, UNITS[i]));
if val % size != 0 {
output.push(' ');
}
}
}

let remainder = val % remainder_divisor;
if remainder != 0 && suffix != "s" {
output.push_str(&format!("{}{}", remainder, suffix));
}

output
}

pub fn display_decimal128(val: i128, _precision: u8, scale: i8) -> String {
Expand Down
27 changes: 27 additions & 0 deletions tests/expressions/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,33 @@ def test_datetime_lit_different_timeunits(timeunit, expected) -> None:
assert timestamp_repr == expected


@pytest.mark.parametrize(
"input, expected",
[
(
timedelta(days=1),
"lit(1d)",
),
(
timedelta(days=1, hours=12, minutes=30, seconds=59),
"lit(1d 12h 30m 59s)",
),
(
timedelta(days=1, hours=12, minutes=30, seconds=59, microseconds=123456),
"lit(1d 12h 30m 59s 123456µs)",
),
(
timedelta(weeks=1, days=1, hours=12, minutes=30, seconds=59, microseconds=123456),
"lit(8d 12h 30m 59s 123456µs)",
),
],
)
def test_duration_lit(input, expected) -> None:
d = lit(input)
output = repr(d)
assert output == expected


def test_repr_series_lit() -> None:
s = lit(Series.from_pylist([1, 2, 3]))
output = repr(s)
Expand Down

0 comments on commit 519e820

Please sign in to comment.