From cd2595b4b4c25828ce00e3a064524101a4e8f050 Mon Sep 17 00:00:00 2001 From: pwwang <1188067+pwwang@users.noreply.github.com> Date: Thu, 25 Aug 2022 10:30:48 -0700 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=96=200.8.6=20(#134)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- datar/__init__.py | 2 +- datar/base/arithmetic.py | 45 +++++++++++++++++++------ datar/base/verbs.py | 2 +- datar/dplyr/distinct.py | 65 +++++++++++++++++++++++------------- docs/CHANGELOG.md | 6 ++++ docs/requirements.txt | 11 +++--- pyproject.toml | 2 +- tests/base/test_stats.py | 5 +++ tests/dplyr/test_distinct.py | 3 +- 9 files changed, 97 insertions(+), 44 deletions(-) diff --git a/datar/__init__.py b/datar/__init__.py index f5436804..ea10575f 100644 --- a/datar/__init__.py +++ b/datar/__init__.py @@ -13,7 +13,7 @@ ) __all__ = ("f", "get_versions") -__version__ = "0.8.5" +__version__ = "0.8.6" apply_init_callbacks() diff --git a/datar/base/arithmetic.py b/datar/base/arithmetic.py index 26bbe267..0549a9ae 100644 --- a/datar/base/arithmetic.py +++ b/datar/base/arithmetic.py @@ -1,5 +1,6 @@ """Arithmetic or math functions""" +from functools import singledispatch import inspect from typing import TYPE_CHECKING, Union @@ -883,18 +884,42 @@ def std( sd = std -@func_factory("transform", {"x", "w"}) -def weighted_mean( - x: Series, w: Series = 1, na_rm=True, __args_raw=None -) -> Series: - """Calculate weighted mean""" - if __args_raw["w"] is not None and np.nansum(w) == 0: +@singledispatch +def _weighted_mean( + df: DataFrame, + has_w: bool = True, + na_rm: bool = True, +) -> np.ndarray: + if not has_w: + return np.nanmean(df["x"]) if na_rm else np.mean(df["x"]) + + if np.nansum(df["w"]) == 0: return np.nan if na_rm: - na_mask = pd.isnull(x) - x = x[~na_mask.values] - w = w[~na_mask.values] + na_mask = pd.isnull(df["x"]) + x = df["x"][~na_mask.values] + w = df["w"][~na_mask.values] return np.average(x, weights=w) - return np.average(x, weights=w) + return np.average(df["x"], weights=df["w"]) + + +@_weighted_mean.register(TibbleGrouped) +def _( + df: TibbleGrouped, + has_w: bool = True, + na_rm: bool = True, +) -> Series: + return df._datar["grouped"].apply( + lambda subdf: _weighted_mean(subdf, has_w, na_rm) + ) + + +@func_factory(None, {"x", "w"}) +def weighted_mean( + x: Series, w: Series = 1, na_rm=True, __args_raw=None, __args_frame=None, +) -> Series: + """Calculate weighted mean""" + has_w = __args_raw["w"] is not None + return _weighted_mean(__args_frame, has_w, na_rm) diff --git a/datar/base/verbs.py b/datar/base/verbs.py index 731fb588..1dff52f6 100644 --- a/datar/base/verbs.py +++ b/datar/base/verbs.py @@ -234,7 +234,7 @@ def union(x, y): @register_verb(context=Context.EVAL) def unique(x): - """Union of two iterables""" + """Get unique elements from an iterable and keep their order""" # order not kept # return np.unique(x) if is_scalar(x): diff --git a/datar/dplyr/distinct.py b/datar/dplyr/distinct.py index 2095b38a..61728d65 100644 --- a/datar/dplyr/distinct.py +++ b/datar/dplyr/distinct.py @@ -3,6 +3,7 @@ See source https://github.com/tidyverse/dplyr/blob/master/R/distinct.R """ from pipda import register_verb +from pipda.symbolic import Reference from ..core.backends.pandas import DataFrame from ..core.backends.pandas.core.groupby import GroupBy @@ -11,7 +12,7 @@ from ..core.factory import func_factory from ..core.utils import regcall from ..core.tibble import Tibble, TibbleGrouped, reconstruct_tibble -from ..base import union, setdiff, intersect +from ..base import union, setdiff, intersect, unique from .mutate import mutate @@ -33,31 +34,49 @@ def distinct(_data, *args, _keep_all=False, **kwargs): A dataframe without duplicated rows in _data """ if not args and not kwargs: - uniq = _data.drop_duplicates() + out = _data.drop_duplicates() else: - # keep_none_prefers_new_order - uniq = ( - regcall( - mutate, - _data, - *args, - **kwargs, - _keep="none", + if ( + not kwargs + # optimize: + # iris >> distinct(f.Species, f.Sepal_Length) + # We don't need to do mutation + and all( + isinstance(expr, Reference) + and expr._pipda_level == 1 + and expr._pipda_ref in _data.columns + for expr in args ) - ).drop_duplicates() + ): + subset = [expr._pipda_ref for expr in args] + ucols = getattr(_data, "group_vars", []) + ucols.extend(subset) + ucols = regcall(unique, ucols) + uniq = _data.drop_duplicates(subset=subset)[ucols] + else: + # keep_none_prefers_new_order + uniq = ( + regcall( + mutate, + _data, + *args, + **kwargs, + _keep="none", + ) + ).drop_duplicates() - if not _keep_all: - # keep original order - out = uniq[ - regcall( - union, - regcall(intersect, _data.columns, uniq.columns), - regcall(setdiff, uniq.columns, _data.columns), - ) - ] - else: - out = _data.loc[uniq.index, :].copy() - out[uniq.columns.tolist()] = uniq + if not _keep_all: + # keep original order + out = uniq[ + regcall( + union, + regcall(intersect, _data.columns, uniq.columns), + regcall(setdiff, uniq.columns, _data.columns), + ) + ] + else: + out = _data.loc[uniq.index, :].copy() + out[uniq.columns.tolist()] = uniq return reconstruct_tibble(_data, Tibble(out, copy=False)) diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index d1d35c50..cb65b522 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -1,3 +1,9 @@ +## 0.8.6 + +- 🐛 Fix weighted_mean not working for grouped data (#133) +- ✅ Add tests for weighted_mean on grouped data +- ⚡️ Optimize distinct on existing columns (#128) + ## 0.8.5 - 🐛 Fix columns missing after Join by same columns using mapping (#122) diff --git a/docs/requirements.txt b/docs/requirements.txt index 7a408f8d..209f991c 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,12 +1,9 @@ # use_directory_urls doesn't work for newer versions -mkdocs==1.1.2 -# AttributeError: module 'jinja2' has no attribute 'contextfilter' -# jinja2==3.1.0 -jinja2==3.0.3 -mkdocs-material==7.2.3 -pymdown-extensions==8.2 +mkdocs +mkdocs-material +pymdown-extensions mkapi-fix -mkdocs-jupyter==0.17.3 +mkdocs-jupyter ipykernel ipython_genutils # to compile readme.ipynb diff --git a/pyproject.toml b/pyproject.toml index e153d97f..6daa2ff3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "datar" -version = "0.8.5" +version = "0.8.6" description = "Port of dplyr and other related R packages in python, using pipda." authors = ["pwwang "] readme = "README.md" diff --git a/tests/base/test_stats.py b/tests/base/test_stats.py index 03a27aa6..552f503a 100644 --- a/tests/base/test_stats.py +++ b/tests/base/test_stats.py @@ -16,6 +16,11 @@ def test_weighted_mean(): with pytest.raises(ValueError): weighted_mean([1,2], [1,2,3]) + df = tibble(g=[1, 1, 2, 2], x=[1, 2, 3, 4], w=[1, 3, 3, 3]).group_by('g') + assert weighted_mean(df.g.obj, w=None) == 1.5 + assert_iterable_equal(weighted_mean(df.g), [1, 2]) + assert_iterable_equal(weighted_mean(df.x, w=df.w), [1.75, 3.5]) + def test_quantile(): df = tibble(x=[1, 2, 3], g=[1, 2, 2]) diff --git a/tests/dplyr/test_distinct.py b/tests/dplyr/test_distinct.py index 10ba7563..bad90a39 100644 --- a/tests/dplyr/test_distinct.py +++ b/tests/dplyr/test_distinct.py @@ -23,6 +23,7 @@ ) from datar.tibble import tibble from datar.datasets import iris +from datar.testing import assert_frame_equal def test_single_column(): @@ -51,7 +52,7 @@ def test_keeps_only_specified_cols(): df = tibble(x=c(1, 1, 1), y=c(1, 1, 1)) expect = tibble(x=1) out = df >> distinct(f.x) - assert out.equals(expect) + assert_frame_equal(out, expect) def test_unless_keep_all_true():