Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CHORE] Improve error messages when calling aggregation methods on dataframe without input columns #1587

Merged
merged 5 commits into from
Nov 13, 2023
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 36 additions & 7 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
TypeVar,
Union,
)
import warnings

from daft.api_annotations import DataframePublicAPI
from daft.context import get_context
Expand Down Expand Up @@ -850,7 +851,11 @@ def sum(self, *cols: ColumnInputType) -> "DataFrame":
Returns:
DataFrame: Globally aggregated sums. Should be a single row.
"""
assert len(cols) > 0, "no columns were passed in"
if len(cols) == 0:
warnings.warn(
"No columns specified; performing sum on all columns. Specify columns using df.sum('col1', 'col2', ...)."
)
cols = tuple(self.columns)
return self._agg([(c, "sum") for c in cols])

@DataframePublicAPI
Expand All @@ -862,7 +867,11 @@ def mean(self, *cols: ColumnInputType) -> "DataFrame":
Returns:
DataFrame: Globally aggregated mean. Should be a single row.
"""
assert len(cols) > 0, "no columns were passed in"
if len(cols) == 0:
warnings.warn(
"No columns specified; performing mean on all columns. Specify columns using df.mean('col1', 'col2', ...)."
)
cols = tuple(self.columns)
return self._agg([(c, "mean") for c in cols])

@DataframePublicAPI
Expand All @@ -874,7 +883,11 @@ def min(self, *cols: ColumnInputType) -> "DataFrame":
Returns:
DataFrame: Globally aggregated min. Should be a single row.
"""
assert len(cols) > 0, "no columns were passed in"
if len(cols) == 0:
warnings.warn(
"No columns specified; performing min on all columns. Specify columns using df.min('col1', 'col2', ...)."
)
cols = tuple(self.columns)
return self._agg([(c, "min") for c in cols])

@DataframePublicAPI
Expand All @@ -886,7 +899,11 @@ def max(self, *cols: ColumnInputType) -> "DataFrame":
Returns:
DataFrame: Globally aggregated max. Should be a single row.
"""
assert len(cols) > 0, "no columns were passed in"
if len(cols) == 0:
warnings.warn(
"No columns specified; performing max on all columns. Specify columns using df.max('col1', 'col2', ...)."
)
cols = tuple(self.columns)
return self._agg([(c, "max") for c in cols])

@DataframePublicAPI
Expand All @@ -898,7 +915,11 @@ def count(self, *cols: ColumnInputType) -> "DataFrame":
Returns:
DataFrame: Globally aggregated count. Should be a single row.
"""
assert len(cols) > 0, "no columns were passed in"
if len(cols) == 0:
warnings.warn(
"No columns specified; performing count on all columns. Specify columns using df.count('col1', 'col2', ...) or use df.count_rows() for row counts."
)
cols = tuple(self.columns)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @samster25 ,

Could you take a look at this and lmk if this aligns with what you expected? Once it's all good with you I'll modify the rest of the agg methods.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me! Though I wonder if we should corner-case the .count() aggregation method to just forward the call to .count_rows() instead 😛 . I believe that in PostgreSQL for example, COUNT(*) is actually a row count of the entire result set vs COUNT(c) which does a null-aware count of a column.

These semantics sound good for all the other aggregation-type methods though!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, though .count_rows() returns an int whereas .count() should return a dataframe, so instead of simply forwarding the call we could do return DataFrame(self._builder.count())? which would result in a dataframe that looks like:

+--------+
| count  |
| UInt64 |
+--------+
| 3      |
+--------+

Copy link
Contributor

@jaychia jaychia Nov 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes good point. The current semantics of this PR (broadcasting the count operation on all columns) makes perfect sense then 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sweet, thanks! just made changes for the rest of the aggregation methods in latest commit

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, some of the integration tests are failing with Error: Credentials could not be loaded, please check your action inputs: Could not load credentials from any providers, example. Do you know what's the cause 😅 ?

the release drafter label test too: [Error: Resource not accessible by integration](https://github.com/Eventual-Inc/Daft/actions/runs/6829508921/job/18575800562?pr=1587#step:2:25)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes that's security related because these CI tests are actually running on your fork, not on the main repo 😛

We need to figure out a (secure) way to run our integration tests on incoming PRs, or maybe a better policy here around accepting external contributions.

return self._agg([(c, "count") for c in cols])

@DataframePublicAPI
Expand All @@ -910,7 +931,11 @@ def agg_list(self, *cols: ColumnInputType) -> "DataFrame":
Returns:
DataFrame: Globally aggregated list. Should be a single row.
"""
assert len(cols) > 0, "no columns were passed in"
if len(cols) == 0:
warnings.warn(
"No columns specified; performing agg_list on all columns. Specify columns using df.agg_list('col1', 'col2', ...)."
)
cols = tuple(self.columns)
return self._agg([(c, "list") for c in cols])

@DataframePublicAPI
Expand All @@ -922,7 +947,11 @@ def agg_concat(self, *cols: ColumnInputType) -> "DataFrame":
Returns:
DataFrame: Globally aggregated list. Should be a single row.
"""
assert len(cols) > 0, "no columns were passed in"
if len(cols) == 0:
warnings.warn(
"No columns specified; performing agg_concat on all columns. Specify columns using df.agg_concat('col1', 'col2', ...)."
)
cols = tuple(self.columns)
return self._agg([(c, "concat") for c in cols])

@DataframePublicAPI
Expand Down
Loading