Skip to content

Commit

Permalink
feat(python): Raise DuplicateError if given a pyarrow Table object …
Browse files Browse the repository at this point in the history
…with duplicate column names (#20624)
  • Loading branch information
alexander-beedie authored Jan 9, 2025
1 parent 323b88c commit 53a493f
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
12 changes: 9 additions & 3 deletions py-polars/polars/_utils/construction/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import contextlib
from collections import Counter
from collections.abc import Generator, Mapping, Sequence
from datetime import date, datetime, time, timedelta
from functools import singledispatch
Expand Down Expand Up @@ -52,7 +53,7 @@
from polars.dependencies import numpy as np
from polars.dependencies import pandas as pd
from polars.dependencies import pyarrow as pa
from polars.exceptions import DataOrientationWarning, ShapeError
from polars.exceptions import DataOrientationWarning, DuplicateError, ShapeError
from polars.meta import thread_pool_size

with contextlib.suppress(ImportError): # Module not available when building docs
Expand Down Expand Up @@ -209,7 +210,7 @@ def _parse_schema_overrides(

schema_overrides = _parse_schema_overrides(schema_overrides)

# Fast path for empty schema
# fast path for empty schema
if not schema:
columns = (
[f"column_{i}" for i in range(n_expected)] if n_expected is not None else []
Expand Down Expand Up @@ -1163,7 +1164,6 @@ def arrow_to_pydf(
column_names, schema_overrides = _unpack_schema(
(schema or data.schema.names), schema_overrides=schema_overrides
)

try:
if column_names != data.schema.names:
data = data.rename_columns(column_names)
Expand All @@ -1180,6 +1180,12 @@ def arrow_to_pydf(
# supply the arrow schema so the metadata is intact
pydf = PyDataFrame.from_arrow_record_batches(batches, data.schema)

# arrow tables allow duplicate names; we don't
if len(data.columns) != pydf.width():
col_name, _ = Counter(column_names).most_common(1)[0]
msg = f"column {col_name!r} appears more than once; names must be unique"
raise DuplicateError(msg)

if rechunk:
pydf = pydf.rechunk()

Expand Down
17 changes: 16 additions & 1 deletion py-polars/tests/unit/constructors/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from polars._utils.construction.utils import try_get_type_hints
from polars.datatypes import numpy_char_code_to_dtype
from polars.dependencies import dataclasses, pydantic
from polars.exceptions import ShapeError
from polars.exceptions import DuplicateError, ShapeError
from polars.testing import assert_frame_equal, assert_series_equal

if TYPE_CHECKING:
Expand Down Expand Up @@ -723,6 +723,21 @@ def test_init_arrow() -> None:
pl.DataFrame(pa.table({"a": [1, 2, 3], "b": [4, 5, 6]}), schema=["c", "d", "e"])


def test_init_arrow_dupes() -> None:
tbl = pa.Table.from_arrays(
arrays=[
pa.array([1, 2, 3], type=pa.int32()),
pa.array([4, 5, 6], type=pa.int32()),
],
schema=pa.schema([("col", pa.int32()), ("col", pa.int32())]),
)
with pytest.raises(
DuplicateError,
match="column 'col' appears more than once; names must be unique",
):
pl.DataFrame(tbl)


def test_init_from_frame() -> None:
df1 = pl.DataFrame({"id": [0, 1], "misc": ["a", "b"], "val": [-10, 10]})
assert_frame_equal(df1, pl.DataFrame(df1))
Expand Down

0 comments on commit 53a493f

Please sign in to comment.