From af6d7c5770cfea935c19adce472b6081c78fe105 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niccol=C3=B2=20Petti?= Date: Wed, 14 Feb 2024 15:41:59 +0100 Subject: [PATCH] Add stream.iter_polars #1503 (#1504) --- docs/unreleased.md | 3 ++- poetry.lock | 40 +++++++++++++++++++++++++++++- pyproject.toml | 1 + river/stream/__init__.py | 7 ++++++ river/stream/iter_polars.py | 49 +++++++++++++++++++++++++++++++++++++ 5 files changed, 98 insertions(+), 2 deletions(-) create mode 100644 river/stream/iter_polars.py diff --git a/docs/unreleased.md b/docs/unreleased.md index 90b51a20ac..b2c83d0050 100644 --- a/docs/unreleased.md +++ b/docs/unreleased.md @@ -1,3 +1,4 @@ ## drift -- Added `FHDDM` drift detector. \ No newline at end of file +- Added `FHDDM` drift detector. +- Added a `iter_polars` function to iterate over the rows of a polars DataFrame. \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index ec8dbe14ff..9ca4786943 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2730,6 +2730,43 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "polars" +version = "0.20.8" +description = "Blazingly fast DataFrame library" +optional = false +python-versions = ">=3.8" +files = [ + {file = "polars-0.20.8-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:73f1d369aeddda5f11411b6497f697f2471bbe6ae55fd936677a10a40995c83c"}, + {file = "polars-0.20.8-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:dc3a446fe606095b3ad6df3cf3dddd8ad54be7745f255fedb29f8bdf71a60760"}, + {file = "polars-0.20.8-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d3d58ebc7a24d26930535d06b8772e125038a87a6abab4c5dfd87ea19bba61f3"}, + {file = "polars-0.20.8-cp38-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:5b733816ac61156c12bd0edd6d7c1a5e63859830ce0e425b6450b335024f0cd5"}, + {file = "polars-0.20.8-cp38-abi3-win_amd64.whl", hash = "sha256:2300f48ff7120eefe2cac2113990d0b0b5beedad93266b9fedfc8df133e7b13b"}, + {file = "polars-0.20.8.tar.gz", hash = "sha256:a34f6ce1c5469872b291aaf90467e632e81f92dec6c2e18136bc40cd92877411"}, +] + +[package.extras] +adbc = ["adbc_driver_sqlite"] +all = ["polars[adbc,cloudpickle,connectorx,deltalake,fsspec,gevent,numpy,pandas,plot,pyarrow,pydantic,pyiceberg,sqlalchemy,timezone,xlsx2csv,xlsxwriter]"] +cloudpickle = ["cloudpickle"] +connectorx = ["connectorx (>=0.3.2)"] +deltalake = ["deltalake (>=0.14.0)"] +fsspec = ["fsspec"] +gevent = ["gevent"] +matplotlib = ["matplotlib"] +numpy = ["numpy (>=1.16.0)"] +openpyxl = ["openpyxl (>=3.0.0)"] +pandas = ["pandas", "pyarrow (>=7.0.0)"] +plot = ["hvplot (>=0.9.1)"] +pyarrow = ["pyarrow (>=7.0.0)"] +pydantic = ["pydantic"] +pyiceberg = ["pyiceberg (>=0.5.0)"] +pyxlsb = ["pyxlsb (>=1.0)"] +sqlalchemy = ["pandas", "sqlalchemy"] +timezone = ["backports.zoneinfo", "tzdata"] +xlsx2csv = ["xlsx2csv (>=0.8.0)"] +xlsxwriter = ["xlsxwriter"] + [[package]] name = "pre-commit" version = "3.5.0" @@ -3228,6 +3265,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -4910,4 +4948,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "6176d641278998249cf285510d927bcb970d0940d43b00de6f48c28a794d7330" +content-hash = "aeaff58cc4d447bbb5c86a09998c1bf89291d5179805cde2785ac2471f907c02" diff --git a/pyproject.toml b/pyproject.toml index 560377e64a..bb6c7f07b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ python = ">=3.9,<3.13" numpy = "^1.23.0" scipy = "^1.8.1" pandas = "^2.1" +polars = "^0.20.8" [tool.poetry.group.dev.dependencies] graphviz = "^0.20.1" diff --git a/river/stream/__init__.py b/river/stream/__init__.py index 0433a06105..2a70f518f8 100644 --- a/river/stream/__init__.py +++ b/river/stream/__init__.py @@ -27,6 +27,13 @@ "TwitchChatStream", ] +try: + from .iter_polars import iter_polars + + __all__ += ["iter_polars"] +except ImportError: + pass + try: from .iter_pandas import iter_pandas diff --git a/river/stream/iter_polars.py b/river/stream/iter_polars.py new file mode 100644 index 0000000000..d36e628abb --- /dev/null +++ b/river/stream/iter_polars.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import polars as pl + +from river import base, stream + + +def iter_polars( + X: pl.DataFrame, y: pl.Series | pl.DataFrame | None = None, **kwargs +) -> base.typing.Stream: + """Iterates over the rows of a `polars.DataFrame`. + + Parameters + ---------- + X + A dataframe of features. + y + A series or a dataframe with one column per target. + kwargs + Extra keyword arguments are passed to the underlying call to `stream.iter_array`. + + Examples + -------- + + >>> import polars as pl + >>> from river import stream + + >>> X = pl.DataFrame({ + ... 'x1': [1, 2, 3, 4], + ... 'x2': ['blue', 'yellow', 'yellow', 'blue'], + ... 'y': [True, False, False, True] + ... }) + >>> y = X.get_column('y') + >>> X=X.drop("y") + + >>> for xi, yi in stream.iter_polars(X, y): + ... print(xi, yi) + {'x1': 1, 'x2': 'blue'} True + {'x1': 2, 'x2': 'yellow'} False + {'x1': 3, 'x2': 'yellow'} False + {'x1': 4, 'x2': 'blue'} True + + """ + + kwargs["feature_names"] = X.columns + if isinstance(y, pl.DataFrame): + kwargs["target_names"] = y.columns + + yield from stream.iter_array(X=X.to_numpy(), y=y if y is None else y.to_numpy(), **kwargs)