Skip to content

Commit

Permalink
feat: Raise DuplicateError if given a pyarrow Table object with dup…
Browse files Browse the repository at this point in the history
…licate column names
  • Loading branch information
alexander-beedie committed Jan 8, 2025
1 parent 92fd75d commit 8995e81
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
12 changes: 9 additions & 3 deletions py-polars/polars/_utils/construction/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import contextlib
from collections import Counter
from collections.abc import Generator, Mapping, Sequence
from datetime import date, datetime, time, timedelta
from functools import singledispatch
Expand Down Expand Up @@ -52,7 +53,7 @@
from polars.dependencies import numpy as np
from polars.dependencies import pandas as pd
from polars.dependencies import pyarrow as pa
from polars.exceptions import DataOrientationWarning, ShapeError
from polars.exceptions import DataOrientationWarning, DuplicateError, ShapeError
from polars.meta import thread_pool_size

with contextlib.suppress(ImportError): # Module not available when building docs
Expand Down Expand Up @@ -209,7 +210,7 @@ def _parse_schema_overrides(

schema_overrides = _parse_schema_overrides(schema_overrides)

# Fast path for empty schema
# fast path for empty schema
if not schema:
columns = (
[f"column_{i}" for i in range(n_expected)] if n_expected is not None else []
Expand Down Expand Up @@ -1163,14 +1164,19 @@ def arrow_to_pydf(
column_names, schema_overrides = _unpack_schema(
(schema or data.schema.names), schema_overrides=schema_overrides
)

try:
if column_names != data.schema.names:
data = data.rename_columns(column_names)
except pa.lib.ArrowInvalid as e:
msg = "dimensions of columns arg must match data dimensions"
raise ValueError(msg) from e

# arrow tables allow duplicate names; we don't
if len(column_names) != len(set(column_names)):
col_name, col_count = Counter(column_names).most_common(1)[0]
msg = f"column {col_name!r} appears {col_count} times; names must be unique"
raise DuplicateError(msg)

batches: list[pa.RecordBatch]
if isinstance(data, pa.RecordBatch):
batches = [data]
Expand Down
17 changes: 16 additions & 1 deletion py-polars/tests/unit/constructors/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from polars._utils.construction.utils import try_get_type_hints
from polars.datatypes import numpy_char_code_to_dtype
from polars.dependencies import dataclasses, pydantic
from polars.exceptions import ShapeError
from polars.exceptions import DuplicateError, ShapeError
from polars.testing import assert_frame_equal, assert_series_equal

if TYPE_CHECKING:
Expand Down Expand Up @@ -723,6 +723,21 @@ def test_init_arrow() -> None:
pl.DataFrame(pa.table({"a": [1, 2, 3], "b": [4, 5, 6]}), schema=["c", "d", "e"])


def test_init_arrow_dupes() -> None:
tbl = pa.Table.from_arrays(
arrays=[
pa.array([1, 2, 3], type=pa.int32()),
pa.array([4, 5, 6], type=pa.int32()),
],
schema=pa.schema([("col", pa.int32()), ("col", pa.int32())]),
)
with pytest.raises(
DuplicateError,
match="column 'col' appears 2 times; names must be unique",
):
pl.DataFrame(tbl)


def test_init_from_frame() -> None:
df1 = pl.DataFrame({"id": [0, 1], "misc": ["a", "b"], "val": [-10, 10]})
assert_frame_equal(df1, pl.DataFrame(df1))
Expand Down

0 comments on commit 8995e81

Please sign in to comment.