Skip to content

Commit

Permalink
[BUG] pass in pyarrow dtype manually into parquet read (#1650)
Browse files Browse the repository at this point in the history
* Fixes bug when reading empty parquet files via the pyarrow reader
  • Loading branch information
samster25 authored Nov 21, 2023
1 parent 4ec44b2 commit 06c2ccf
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 4 deletions.
4 changes: 2 additions & 2 deletions daft/table/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def read_parquet_into_pyarrow(
coerce_int96_timestamp_unit=coerce_int96_timestamp_unit._timeunit,
)
schema = pa.schema(fields, metadata=metadata)
columns = [pa.chunked_array(c) for c in columns] # type: ignore
columns = [pa.chunked_array(c, type=f.type) for f, c in zip(schema, columns)] # type: ignore
return pa.table(columns, schema=schema)


Expand All @@ -556,6 +556,6 @@ def read_parquet_into_pyarrow_bulk(
coerce_int96_timestamp_unit=coerce_int96_timestamp_unit._timeunit,
)
return [
pa.table([pa.chunked_array(c) for c in columns], schema=pa.schema(fields, metadata=metadata)) # type: ignore
pa.table([pa.chunked_array(c, type=f.type) for f, c in zip(fields, columns)], schema=pa.schema(fields, metadata=metadata)) # type: ignore
for fields, metadata, columns in bulk_result
]
43 changes: 41 additions & 2 deletions tests/table/table_io/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
from daft.datatype import DataType, TimeUnit
from daft.logical.schema import Schema
from daft.runners.partitioning import TableParseParquetOptions, TableReadOptions
from daft.table import Table, schema_inference, table_io
from daft.table import (
Table,
read_parquet_into_pyarrow,
read_parquet_into_pyarrow_bulk,
schema_inference,
table_io,
)

PYARROW_GE_11_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) >= (11, 0, 0)
PYARROW_GE_13_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) >= (13, 0, 0)
Expand Down Expand Up @@ -317,7 +323,7 @@ def test_parquet_read_int96_timestamps_schema_inference(coerce_to, store_schema)


@pytest.mark.parametrize("n_bytes", [0, 1, 2, 7])
def test_read_empty_parquet_file(tmpdir, n_bytes):
def test_read_too_small_parquet_file(tmpdir, n_bytes):

tmpdir = pathlib.Path(tmpdir)
file_path = tmpdir / "file.parquet"
Expand All @@ -326,3 +332,36 @@ def test_read_empty_parquet_file(tmpdir, n_bytes):
f.write(b"0")
with pytest.raises(ValueError, match="smaller than the minimum size of 12 bytes"):
Table.read_parquet(file_path.as_posix())


def test_read_empty_parquet_file_with_table(tmpdir):
tmpdir = pathlib.Path(tmpdir)
file_path = tmpdir / "file.parquet"
tab = pa.table({"x": pa.array([], type=pa.int64())})
with open(file_path, "wb") as f:
papq.write_table(tab, file_path.as_posix())
read_back = Table.read_parquet(file_path.as_posix()).to_arrow()
assert tab == read_back


def test_read_empty_parquet_file_with_pyarrow(tmpdir):

tmpdir = pathlib.Path(tmpdir)
file_path = tmpdir / "file.parquet"
tab = pa.table({"x": pa.array([], type=pa.int64())})
with open(file_path, "wb") as f:
papq.write_table(tab, file_path.as_posix())
read_back = read_parquet_into_pyarrow(file_path.as_posix())
assert tab == read_back


def test_read_empty_parquet_file_with_pyarrow_bulk(tmpdir):

tmpdir = pathlib.Path(tmpdir)
file_path = tmpdir / "file.parquet"
tab = pa.table({"x": pa.array([], type=pa.int64())})
with open(file_path, "wb") as f:
papq.write_table(tab, file_path.as_posix())
read_back = read_parquet_into_pyarrow_bulk([file_path.as_posix()])
assert len(read_back) == 1
assert tab == read_back[0]

0 comments on commit 06c2ccf

Please sign in to comment.