From fd74cafc24bc5de835112f12a02959e5051e1d32 Mon Sep 17 00:00:00 2001 From: Jay Chia <17691182+jaychia@users.noreply.github.com> Date: Fri, 17 Nov 2023 09:50:58 -0800 Subject: [PATCH] [BUG] Fix CSV roundtrip for decimals (actually an f64->decimal casting bug) (#1626) Co-authored-by: Jay Chia --- src/daft-core/src/array/ops/cast.rs | 10 +--------- tests/io/test_csv_roundtrip.py | 9 +++++++-- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/daft-core/src/array/ops/cast.rs b/src/daft-core/src/array/ops/cast.rs index 2f6babca16..8511ffdff1 100644 --- a/src/daft-core/src/array/ops/cast.rs +++ b/src/daft-core/src/array/ops/cast.rs @@ -186,21 +186,13 @@ where )? } else if can_cast_types(&self_arrow_type, &target_arrow_type) { // Cast from logical Arrow2 type to logical Arrow2 type. - let arrow_logical = cast( + cast( to_cast.data(), &target_arrow_type, CastOptions { wrapped: true, partial: false, }, - )?; - cast( - arrow_logical.as_ref(), - &target_arrow_physical_type, - CastOptions { - wrapped: true, - partial: false, - }, )? } else if can_cast_types(&self_physical_arrow_type, &target_arrow_physical_type) { // Cast from physical Arrow2 type to physical Arrow2 type. diff --git a/tests/io/test_csv_roundtrip.py b/tests/io/test_csv_roundtrip.py index 364ce40c91..1508376653 100644 --- a/tests/io/test_csv_roundtrip.py +++ b/tests/io/test_csv_roundtrip.py @@ -1,6 +1,7 @@ from __future__ import annotations import datetime +import decimal import pyarrow as pa import pytest @@ -23,8 +24,12 @@ ([b"a", b"b", b""], pa.large_binary(), DataType.binary(), DataType.string()), ([True, False, None], pa.bool_(), DataType.bool(), DataType.bool()), ([None, None, None], pa.null(), DataType.null(), DataType.null()), - # TODO: This is broken, needs more investigation into why - # ([decimal.Decimal("1.23"), decimal.Decimal("1.24"), None], pa.decimal128(16, 8), DataType.decimal128(16, 8), DataType.float64()), + ( + [decimal.Decimal("1.23"), decimal.Decimal("1.24"), None], + pa.decimal128(16, 8), + DataType.decimal128(16, 8), + DataType.float64(), + ), ([datetime.date(1994, 1, 1), datetime.date(1995, 1, 1), None], pa.date32(), DataType.date(), DataType.date()), ( [datetime.datetime(1994, 1, 1), datetime.datetime(1995, 1, 1), None],