Skip to content

Commit

Permalink
[BUG] Fix timestamp timezone parsing bug in CSVs (#1530)
Browse files Browse the repository at this point in the history
Co-authored-by: Jay Chia <[email protected]@users.noreply.github.com>
  • Loading branch information
jaychia and Jay Chia authored Oct 25, 2023
1 parent a122029 commit 959d93b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
10 changes: 7 additions & 3 deletions src/daft-csv/src/deserialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,13 @@ fn deserialize_datetime<T: chrono::TimeZone>(
fmt_idx: &mut usize,
) -> Option<chrono::DateTime<T>> {
// TODO(Clark): Parse as all candidate formats in a single pass.
for i in 0..ALL_NAIVE_TIMESTAMP_FMTS.len() {
let idx = (i + *fmt_idx) % ALL_NAIVE_TIMESTAMP_FMTS.len();
let fmt = ALL_NAIVE_TIMESTAMP_FMTS[idx];
for i in 0..ALL_TIMESTAMP_FMTS.len() {
let idx = (i + *fmt_idx) % ALL_TIMESTAMP_FMTS.len();
let fmt = ALL_TIMESTAMP_FMTS[idx];
println!(
"Deserializing dt: {string} with {fmt} as {:?}",
chrono::DateTime::parse_from_str(string, fmt)
);
if let Ok(dt) = chrono::DateTime::parse_from_str(string, fmt) {
*fmt_idx = idx;
return Some(dt.with_timezone(tz));
Expand Down
16 changes: 10 additions & 6 deletions tests/dataframe/test_temporals.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def test_temporal_arithmetic() -> None:


@pytest.mark.parametrize("format", ["csv", "parquet"])
def test_temporal_file_roundtrip(format) -> None:
@pytest.mark.parametrize("use_native_downloader", [True, False])
def test_temporal_file_roundtrip(format, use_native_downloader) -> None:
data = {
"date32": pa.array([1], pa.date32()),
"date64": pa.array([1], pa.date64()),
Expand All @@ -69,9 +70,12 @@ def test_temporal_file_roundtrip(format) -> None:
"timestamp_s": pa.array([1], pa.timestamp("s")),
"timestamp_ms": pa.array([1], pa.timestamp("ms")),
"timestamp_us": pa.array([1], pa.timestamp("us")),
"timestamp_s_tz": pa.array([1], pa.timestamp("s", tz="UTC")),
"timestamp_ms_tz": pa.array([1], pa.timestamp("ms", tz="UTC")),
"timestamp_us_tz": pa.array([1], pa.timestamp("us", tz="UTC")),
"timestamp_s_utc_tz": pa.array([1], pa.timestamp("s", tz="UTC")),
"timestamp_ms_utc_tz": pa.array([1], pa.timestamp("ms", tz="UTC")),
"timestamp_us_utc_tz": pa.array([1], pa.timestamp("us", tz="UTC")),
"timestamp_s_tz": pa.array([1], pa.timestamp("s", tz="Asia/Singapore")),
"timestamp_ms_tz": pa.array([1], pa.timestamp("ms", tz="Asia/Singapore")),
"timestamp_us_tz": pa.array([1], pa.timestamp("us", tz="Asia/Singapore")),
}

pa_table = pa.Table.from_pydict(data)
Expand All @@ -81,10 +85,10 @@ def test_temporal_file_roundtrip(format) -> None:
with tempfile.TemporaryDirectory() as dirname:
if format == "csv":
df.write_csv(dirname)
df_readback = daft.read_csv(dirname).collect()
df_readback = daft.read_csv(dirname, use_native_downloader=use_native_downloader).collect()
elif format == "parquet":
df.write_parquet(dirname)
df_readback = daft.read_parquet(dirname).collect()
df_readback = daft.read_parquet(dirname, use_native_downloader=use_native_downloader).collect()

assert df.to_pydict() == df_readback.to_pydict()

Expand Down

0 comments on commit 959d93b

Please sign in to comment.