Skip to content

Commit

Permalink
[BUG] Fix CSV roundtrip for decimals (actually an f64->decimal castin…
Browse files Browse the repository at this point in the history
…g bug) (#1626)

Co-authored-by: Jay Chia <[email protected]@users.noreply.github.com>
  • Loading branch information
jaychia and Jay Chia authored Nov 17, 2023
1 parent 0097aa1 commit fd74caf
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 11 deletions.
10 changes: 1 addition & 9 deletions src/daft-core/src/array/ops/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,21 +186,13 @@ where
)?
} else if can_cast_types(&self_arrow_type, &target_arrow_type) {
// Cast from logical Arrow2 type to logical Arrow2 type.
let arrow_logical = cast(
cast(
to_cast.data(),
&target_arrow_type,
CastOptions {
wrapped: true,
partial: false,
},
)?;
cast(
arrow_logical.as_ref(),
&target_arrow_physical_type,
CastOptions {
wrapped: true,
partial: false,
},
)?
} else if can_cast_types(&self_physical_arrow_type, &target_arrow_physical_type) {
// Cast from physical Arrow2 type to physical Arrow2 type.
Expand Down
9 changes: 7 additions & 2 deletions tests/io/test_csv_roundtrip.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import datetime
import decimal

import pyarrow as pa
import pytest
Expand All @@ -23,8 +24,12 @@
([b"a", b"b", b""], pa.large_binary(), DataType.binary(), DataType.string()),
([True, False, None], pa.bool_(), DataType.bool(), DataType.bool()),
([None, None, None], pa.null(), DataType.null(), DataType.null()),
# TODO: This is broken, needs more investigation into why
# ([decimal.Decimal("1.23"), decimal.Decimal("1.24"), None], pa.decimal128(16, 8), DataType.decimal128(16, 8), DataType.float64()),
(
[decimal.Decimal("1.23"), decimal.Decimal("1.24"), None],
pa.decimal128(16, 8),
DataType.decimal128(16, 8),
DataType.float64(),
),
([datetime.date(1994, 1, 1), datetime.date(1995, 1, 1), None], pa.date32(), DataType.date(), DataType.date()),
(
[datetime.datetime(1994, 1, 1), datetime.datetime(1995, 1, 1), None],
Expand Down

0 comments on commit fd74caf

Please sign in to comment.