From b4467ad62d987f1a87d665ea383ef302e2b7d889 Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 13 Nov 2023 13:58:42 -0800 Subject: [PATCH] [CHORE] Improve error messages when calling aggregation methods on dataframe without input columns (#1587) Fixes #1583. When a user does not specify columns in df aggregation methods, e.g. `df.count()`: - Default to running aggregation on all columns - Log warning messages with an example to pass in columns. --- daft/dataframe/dataframe.py | 43 +++++++++++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 7 deletions(-) diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 467e9dfa8a..9c99652e78 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -5,6 +5,7 @@ # For technical details, see https://github.com/Eventual-Inc/Daft/pull/630 import pathlib +import warnings from dataclasses import dataclass from functools import reduce from typing import ( @@ -858,7 +859,11 @@ def sum(self, *cols: ColumnInputType) -> "DataFrame": Returns: DataFrame: Globally aggregated sums. Should be a single row. """ - assert len(cols) > 0, "no columns were passed in" + if len(cols) == 0: + warnings.warn( + "No columns specified; performing sum on all columns. Specify columns using df.sum('col1', 'col2', ...)." + ) + cols = tuple(self.columns) return self._agg([(c, "sum") for c in cols]) @DataframePublicAPI @@ -870,7 +875,11 @@ def mean(self, *cols: ColumnInputType) -> "DataFrame": Returns: DataFrame: Globally aggregated mean. Should be a single row. """ - assert len(cols) > 0, "no columns were passed in" + if len(cols) == 0: + warnings.warn( + "No columns specified; performing mean on all columns. Specify columns using df.mean('col1', 'col2', ...)." + ) + cols = tuple(self.columns) return self._agg([(c, "mean") for c in cols]) @DataframePublicAPI @@ -882,7 +891,11 @@ def min(self, *cols: ColumnInputType) -> "DataFrame": Returns: DataFrame: Globally aggregated min. Should be a single row. """ - assert len(cols) > 0, "no columns were passed in" + if len(cols) == 0: + warnings.warn( + "No columns specified; performing min on all columns. Specify columns using df.min('col1', 'col2', ...)." + ) + cols = tuple(self.columns) return self._agg([(c, "min") for c in cols]) @DataframePublicAPI @@ -894,7 +907,11 @@ def max(self, *cols: ColumnInputType) -> "DataFrame": Returns: DataFrame: Globally aggregated max. Should be a single row. """ - assert len(cols) > 0, "no columns were passed in" + if len(cols) == 0: + warnings.warn( + "No columns specified; performing max on all columns. Specify columns using df.max('col1', 'col2', ...)." + ) + cols = tuple(self.columns) return self._agg([(c, "max") for c in cols]) @DataframePublicAPI @@ -906,7 +923,11 @@ def count(self, *cols: ColumnInputType) -> "DataFrame": Returns: DataFrame: Globally aggregated count. Should be a single row. """ - assert len(cols) > 0, "no columns were passed in" + if len(cols) == 0: + warnings.warn( + "No columns specified; performing count on all columns. Specify columns using df.count('col1', 'col2', ...) or use df.count_rows() for row counts." + ) + cols = tuple(self.columns) return self._agg([(c, "count") for c in cols]) @DataframePublicAPI @@ -918,7 +939,11 @@ def agg_list(self, *cols: ColumnInputType) -> "DataFrame": Returns: DataFrame: Globally aggregated list. Should be a single row. """ - assert len(cols) > 0, "no columns were passed in" + if len(cols) == 0: + warnings.warn( + "No columns specified; performing agg_list on all columns. Specify columns using df.agg_list('col1', 'col2', ...)." + ) + cols = tuple(self.columns) return self._agg([(c, "list") for c in cols]) @DataframePublicAPI @@ -930,7 +955,11 @@ def agg_concat(self, *cols: ColumnInputType) -> "DataFrame": Returns: DataFrame: Globally aggregated list. Should be a single row. """ - assert len(cols) > 0, "no columns were passed in" + if len(cols) == 0: + warnings.warn( + "No columns specified; performing agg_concat on all columns. Specify columns using df.agg_concat('col1', 'col2', ...)." + ) + cols = tuple(self.columns) return self._agg([(c, "concat") for c in cols]) @DataframePublicAPI