diff --git a/daft/table/table.py b/daft/table/table.py index 7c9b9aedb3..5217f35f21 100644 --- a/daft/table/table.py +++ b/daft/table/table.py @@ -529,7 +529,7 @@ def read_parquet_into_pyarrow( coerce_int96_timestamp_unit=coerce_int96_timestamp_unit._timeunit, ) schema = pa.schema(fields, metadata=metadata) - columns = [pa.chunked_array(c) for c in columns] # type: ignore + columns = [pa.chunked_array(c, type=f.type) for f, c in zip(schema, columns)] # type: ignore return pa.table(columns, schema=schema) @@ -556,6 +556,6 @@ def read_parquet_into_pyarrow_bulk( coerce_int96_timestamp_unit=coerce_int96_timestamp_unit._timeunit, ) return [ - pa.table([pa.chunked_array(c) for c in columns], schema=pa.schema(fields, metadata=metadata)) # type: ignore + pa.table([pa.chunked_array(c, type=f.type) for f, c in zip(fields, columns)], schema=pa.schema(fields, metadata=metadata)) # type: ignore for fields, metadata, columns in bulk_result ] diff --git a/tests/table/table_io/test_parquet.py b/tests/table/table_io/test_parquet.py index 2becf7bd76..b617b5beae 100644 --- a/tests/table/table_io/test_parquet.py +++ b/tests/table/table_io/test_parquet.py @@ -15,7 +15,13 @@ from daft.datatype import DataType, TimeUnit from daft.logical.schema import Schema from daft.runners.partitioning import TableParseParquetOptions, TableReadOptions -from daft.table import Table, schema_inference, table_io +from daft.table import ( + Table, + read_parquet_into_pyarrow, + read_parquet_into_pyarrow_bulk, + schema_inference, + table_io, +) PYARROW_GE_11_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) >= (11, 0, 0) PYARROW_GE_13_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) >= (13, 0, 0) @@ -317,7 +323,7 @@ def test_parquet_read_int96_timestamps_schema_inference(coerce_to, store_schema) @pytest.mark.parametrize("n_bytes", [0, 1, 2, 7]) -def test_read_empty_parquet_file(tmpdir, n_bytes): +def test_read_too_small_parquet_file(tmpdir, n_bytes): tmpdir = pathlib.Path(tmpdir) file_path = tmpdir / "file.parquet" @@ -326,3 +332,36 @@ def test_read_empty_parquet_file(tmpdir, n_bytes): f.write(b"0") with pytest.raises(ValueError, match="smaller than the minimum size of 12 bytes"): Table.read_parquet(file_path.as_posix()) + + +def test_read_empty_parquet_file_with_table(tmpdir): + tmpdir = pathlib.Path(tmpdir) + file_path = tmpdir / "file.parquet" + tab = pa.table({"x": pa.array([], type=pa.int64())}) + with open(file_path, "wb") as f: + papq.write_table(tab, file_path.as_posix()) + read_back = Table.read_parquet(file_path.as_posix()).to_arrow() + assert tab == read_back + + +def test_read_empty_parquet_file_with_pyarrow(tmpdir): + + tmpdir = pathlib.Path(tmpdir) + file_path = tmpdir / "file.parquet" + tab = pa.table({"x": pa.array([], type=pa.int64())}) + with open(file_path, "wb") as f: + papq.write_table(tab, file_path.as_posix()) + read_back = read_parquet_into_pyarrow(file_path.as_posix()) + assert tab == read_back + + +def test_read_empty_parquet_file_with_pyarrow_bulk(tmpdir): + + tmpdir = pathlib.Path(tmpdir) + file_path = tmpdir / "file.parquet" + tab = pa.table({"x": pa.array([], type=pa.int64())}) + with open(file_path, "wb") as f: + papq.write_table(tab, file_path.as_posix()) + read_back = read_parquet_into_pyarrow_bulk([file_path.as_posix()]) + assert len(read_back) == 1 + assert tab == read_back[0]