Skip to content

Commit

Permalink
chore!: upgrade Ray pins and pyarrow pins (#3612)
Browse files Browse the repository at this point in the history
Updates the lower bound of pyarrow to `pyarrow>=8.0.0`.

This allows us to flatten some code checks.

However, it turns out that our tests aren't being properly skipped -- I
had to update the tests to just skip based on our lower bound (skip if
version < 9.0.0) which very loose, but otherwise searching for the
individual versions for each suite of tests was quite difficult.

---------

Co-authored-by: Jay Chia <[email protected]@users.noreply.github.com>
  • Loading branch information
jaychia and Jay Chia authored Dec 19, 2024
1 parent 063de4d commit a76f800
Show file tree
Hide file tree
Showing 11 changed files with 32 additions and 36 deletions.
12 changes: 6 additions & 6 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,26 @@ jobs:
matrix:
python-version: ['3.9', '3.10']
daft-runner: [py, ray, native]
pyarrow-version: [7.0.0, 16.0.0]
pyarrow-version: [8.0.0, 16.0.0]
os: [ubuntu-20.04, windows-latest]
exclude:
- daft-runner: ray
pyarrow-version: 7.0.0
pyarrow-version: 8.0.0
os: ubuntu-20.04
- daft-runner: py
python-version: '3.10'
pyarrow-version: 7.0.0
pyarrow-version: 8.0.0
os: ubuntu-20.04
- daft-runner: native
python-version: '3.10'
pyarrow-version: 7.0.0
pyarrow-version: 8.0.0
os: ubuntu-20.04
- python-version: '3.9'
pyarrow-version: 16.0.0
- os: windows-latest
python-version: '3.9'
- os: windows-latest
pyarrow-version: 7.0.0
pyarrow-version: 8.0.0
steps:
- uses: actions/checkout@v4
- uses: moonrepo/setup-rust@v1
Expand Down Expand Up @@ -93,7 +93,7 @@ jobs:
run: uv pip install pyarrow==${{ matrix.pyarrow-version }}

- name: Override deltalake for pyarrow
if: ${{ (matrix.pyarrow-version == '7.0.0') }}
if: ${{ (matrix.pyarrow-version == '8.0.0') }}
run: uv pip install deltalake==0.10.0

- name: Build library and Test with pytest (unix)
Expand Down
13 changes: 4 additions & 9 deletions daft/table/table_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,16 +554,11 @@ def _write_tabular_arrow_table(
):
kwargs = dict()

from daft.utils import get_arrow_version
kwargs["max_rows_per_file"] = rows_per_file
kwargs["min_rows_per_group"] = rows_per_row_group
kwargs["max_rows_per_group"] = rows_per_row_group

arrow_version = get_arrow_version()

if arrow_version >= (7, 0, 0):
kwargs["max_rows_per_file"] = rows_per_file
kwargs["min_rows_per_group"] = rows_per_row_group
kwargs["max_rows_per_group"] = rows_per_row_group

if arrow_version >= (8, 0, 0) and not create_dir:
if not create_dir:
kwargs["create_dir"] = False

basename_template = _generate_basename_template(format.default_extname, version)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ requires = ["maturin>=1.5.0,<2.0.0"]
[project]
authors = [{name = "Eventual Inc", email = "[email protected]"}]
dependencies = [
"pyarrow >= 7.0.0",
"pyarrow >= 8.0.0",
"fsspec",
"tqdm",
"typing-extensions >= 4.0.0; python_version < '3.10'"
Expand Down
6 changes: 4 additions & 2 deletions tests/integration/iceberg/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@

pyiceberg = pytest.importorskip("pyiceberg")

PYARROW_LE_8_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) < (8, 0, 0)
pytestmark = pytest.mark.skipif(PYARROW_LE_8_0_0, reason="iceberg writes only supported if pyarrow >= 8.0.0")
PYARROW_LOWER_BOUND_SKIP = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) < (9, 0, 0)
pytestmark = pytest.mark.skipif(
PYARROW_LOWER_BOUND_SKIP, reason="iceberg writes not supported on old versions of pyarrow"
)

import tenacity
from pyiceberg.catalog import Catalog, load_catalog
Expand Down
6 changes: 3 additions & 3 deletions tests/io/delta_lake/test_table_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from daft.logical.schema import Schema
from tests.utils import assert_pyarrow_tables_equal

PYARROW_LE_8_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) < (8, 0, 0)
PYARROW_LOWER_BOUND_SKIP = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) < (9, 0, 0)
pytestmark = pytest.mark.skipif(
PYARROW_LE_8_0_0,
reason="deltalake only supported if pyarrow >= 8.0.0",
PYARROW_LOWER_BOUND_SKIP,
reason="deltalake not supported on older versions of pyarrow",
)


Expand Down
6 changes: 3 additions & 3 deletions tests/io/delta_lake/test_table_read_pushdowns.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
from daft.logical.schema import Schema
from tests.utils import assert_pyarrow_tables_equal

PYARROW_LE_8_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) < (8, 0, 0)
PYARROW_LOWER_BOUND_SKIP = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) < (9, 0, 0)
pytestmark = pytest.mark.skipif(
PYARROW_LE_8_0_0,
reason="deltalake only supported if pyarrow >= 8.0.0",
PYARROW_LOWER_BOUND_SKIP,
reason="deltalake not supported on older versions of pyarrow",
)


Expand Down
6 changes: 3 additions & 3 deletions tests/io/delta_lake/test_table_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from daft.logical.schema import Schema
from tests.conftest import get_tests_daft_runner_name

PYARROW_LE_8_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) < (8, 0, 0)
PYARROW_LOWER_BOUND_SKIP = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) < (9, 0, 0)
pytestmark = pytest.mark.skipif(
PYARROW_LE_8_0_0,
reason="deltalake only supported if pyarrow >= 8.0.0",
PYARROW_LOWER_BOUND_SKIP,
reason="deltalake not supported on older versions of pyarrow",
)


Expand Down
4 changes: 2 additions & 2 deletions tests/io/hudi/test_table_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import daft

PYARROW_LE_8_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) < (8, 0, 0)
pytestmark = pytest.mark.skipif(PYARROW_LE_8_0_0, reason="hudi only supported if pyarrow >= 8.0.0")
PYARROW_LOWER_BOUND_SKIP = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) < (9, 0, 0)
pytestmark = pytest.mark.skipif(PYARROW_LOWER_BOUND_SKIP, reason="hudi not supported on old versions of pyarrow")


def test_read_table(get_testing_table_for_supported_cases):
Expand Down
4 changes: 2 additions & 2 deletions tests/io/iceberg/test_iceberg_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

pyiceberg = pytest.importorskip("pyiceberg")

PYARROW_LE_8_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) < (8, 0, 0)
pytestmark = pytest.mark.skipif(PYARROW_LE_8_0_0, reason="iceberg only supported if pyarrow >= 8.0.0")
PYARROW_LOWER_BOUND_SKIP = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) < (9, 0, 0)
pytestmark = pytest.mark.skipif(PYARROW_LOWER_BOUND_SKIP, reason="iceberg not supported on old versions of pyarrow")


from pyiceberg.catalog.sql import SqlCatalog
Expand Down
4 changes: 2 additions & 2 deletions tests/io/lancedb/test_lancedb_reads.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
"long": [-122.7, -74.1],
}

PYARROW_LE_8_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) < (8, 0, 0)
pytestmark = pytest.mark.skipif(PYARROW_LE_8_0_0, reason="lance only supported if pyarrow >= 8.0.0")
PYARROW_LOWER_BOUND_SKIP = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) < (9, 0, 0)
pytestmark = pytest.mark.skipif(PYARROW_LOWER_BOUND_SKIP, reason="lance not supported on old versions of pyarrow")


@pytest.fixture(scope="function")
Expand Down
5 changes: 2 additions & 3 deletions tests/io/lancedb/test_lancedb_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
"long": [-122.7, -74.1],
}

PYARROW_LE_8_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) < (8, 0, 0)

pytestmark = pytest.mark.skipif(PYARROW_LE_8_0_0, reason="lance only supported if pyarrow >= 8.0.0")
PYARROW_LOWER_BOUND_SKIP = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) < (9, 0, 0)
pytestmark = pytest.mark.skipif(PYARROW_LOWER_BOUND_SKIP, reason="lance not supported on old versions of pyarrow")


@pytest.fixture(scope="function")
Expand Down

0 comments on commit a76f800

Please sign in to comment.