Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CHORE] Improve error messages when calling aggregation methods on dataframe without input columns #1587

Merged
merged 5 commits into from
Nov 13, 2023
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 38 additions & 7 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# in order to support runtime typechecking across different Python versions.
# For technical details, see https://github.com/Eventual-Inc/Daft/pull/630

import logging
import pathlib
from dataclasses import dataclass
from functools import reduce
Expand Down Expand Up @@ -46,6 +47,8 @@

from daft.logical.schema import Schema

logger = logging.getLogger(__name__)

UDFReturnType = TypeVar("UDFReturnType", covariant=True)

ColumnInputType = Union[Expression, str]
Expand Down Expand Up @@ -850,7 +853,11 @@
Returns:
DataFrame: Globally aggregated sums. Should be a single row.
"""
assert len(cols) > 0, "no columns were passed in"
if len(cols) == 0:
logger.warning(

Check warning on line 857 in daft/dataframe/dataframe.py

View check run for this annotation

Codecov / codecov/patch

daft/dataframe/dataframe.py#L857

Added line #L857 was not covered by tests
"No columns specified; performing sum on all columns. Specify columns using df.sum('col1', 'col2', ...)."
)
cols = tuple(self.columns)

Check warning on line 860 in daft/dataframe/dataframe.py

View check run for this annotation

Codecov / codecov/patch

daft/dataframe/dataframe.py#L860

Added line #L860 was not covered by tests
return self._agg([(c, "sum") for c in cols])

@DataframePublicAPI
Expand All @@ -862,7 +869,11 @@
Returns:
DataFrame: Globally aggregated mean. Should be a single row.
"""
assert len(cols) > 0, "no columns were passed in"
if len(cols) == 0:
logger.warning(

Check warning on line 873 in daft/dataframe/dataframe.py

View check run for this annotation

Codecov / codecov/patch

daft/dataframe/dataframe.py#L873

Added line #L873 was not covered by tests
"No columns specified; performing mean on all columns. Specify columns using df.mean('col1', 'col2', ...)."
)
cols = tuple(self.columns)

Check warning on line 876 in daft/dataframe/dataframe.py

View check run for this annotation

Codecov / codecov/patch

daft/dataframe/dataframe.py#L876

Added line #L876 was not covered by tests
return self._agg([(c, "mean") for c in cols])

@DataframePublicAPI
Expand All @@ -874,7 +885,11 @@
Returns:
DataFrame: Globally aggregated min. Should be a single row.
"""
assert len(cols) > 0, "no columns were passed in"
if len(cols) == 0:
logger.warning(

Check warning on line 889 in daft/dataframe/dataframe.py

View check run for this annotation

Codecov / codecov/patch

daft/dataframe/dataframe.py#L889

Added line #L889 was not covered by tests
"No columns specified; performing min on all columns. Specify columns using df.min('col1', 'col2', ...)."
)
cols = tuple(self.columns)

Check warning on line 892 in daft/dataframe/dataframe.py

View check run for this annotation

Codecov / codecov/patch

daft/dataframe/dataframe.py#L892

Added line #L892 was not covered by tests
return self._agg([(c, "min") for c in cols])

@DataframePublicAPI
Expand All @@ -886,7 +901,11 @@
Returns:
DataFrame: Globally aggregated max. Should be a single row.
"""
assert len(cols) > 0, "no columns were passed in"
if len(cols) == 0:
logger.warning(

Check warning on line 905 in daft/dataframe/dataframe.py

View check run for this annotation

Codecov / codecov/patch

daft/dataframe/dataframe.py#L905

Added line #L905 was not covered by tests
"No columns specified; performing max on all columns. Specify columns using df.max('col1', 'col2', ...)."
)
cols = tuple(self.columns)

Check warning on line 908 in daft/dataframe/dataframe.py

View check run for this annotation

Codecov / codecov/patch

daft/dataframe/dataframe.py#L908

Added line #L908 was not covered by tests
return self._agg([(c, "max") for c in cols])

@DataframePublicAPI
Expand All @@ -898,7 +917,11 @@
Returns:
DataFrame: Globally aggregated count. Should be a single row.
"""
assert len(cols) > 0, "no columns were passed in"
if len(cols) == 0:
logger.warning(

Check warning on line 921 in daft/dataframe/dataframe.py

View check run for this annotation

Codecov / codecov/patch

daft/dataframe/dataframe.py#L921

Added line #L921 was not covered by tests
"No columns specified; performing count on all columns. Specify columns using df.count('col1', 'col2', ...) or use df.count_rows() for row counts."
)
cols = tuple(self.columns)

Check warning on line 924 in daft/dataframe/dataframe.py

View check run for this annotation

Codecov / codecov/patch

daft/dataframe/dataframe.py#L924

Added line #L924 was not covered by tests
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @samster25 ,

Could you take a look at this and lmk if this aligns with what you expected? Once it's all good with you I'll modify the rest of the agg methods.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me! Though I wonder if we should corner-case the .count() aggregation method to just forward the call to .count_rows() instead 😛 . I believe that in PostgreSQL for example, COUNT(*) is actually a row count of the entire result set vs COUNT(c) which does a null-aware count of a column.

These semantics sound good for all the other aggregation-type methods though!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, though .count_rows() returns an int whereas .count() should return a dataframe, so instead of simply forwarding the call we could do return DataFrame(self._builder.count())? which would result in a dataframe that looks like:

+--------+
| count  |
| UInt64 |
+--------+
| 3      |
+--------+

Copy link
Contributor

@jaychia jaychia Nov 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes good point. The current semantics of this PR (broadcasting the count operation on all columns) makes perfect sense then 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sweet, thanks! just made changes for the rest of the aggregation methods in latest commit

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, some of the integration tests are failing with Error: Credentials could not be loaded, please check your action inputs: Could not load credentials from any providers, example. Do you know what's the cause 😅 ?

the release drafter label test too: [Error: Resource not accessible by integration](https://github.com/Eventual-Inc/Daft/actions/runs/6829508921/job/18575800562?pr=1587#step:2:25)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes that's security related because these CI tests are actually running on your fork, not on the main repo 😛

We need to figure out a (secure) way to run our integration tests on incoming PRs, or maybe a better policy here around accepting external contributions.

return self._agg([(c, "count") for c in cols])

@DataframePublicAPI
Expand All @@ -910,7 +933,11 @@
Returns:
DataFrame: Globally aggregated list. Should be a single row.
"""
assert len(cols) > 0, "no columns were passed in"
if len(cols) == 0:
logger.warning(

Check warning on line 937 in daft/dataframe/dataframe.py

View check run for this annotation

Codecov / codecov/patch

daft/dataframe/dataframe.py#L937

Added line #L937 was not covered by tests
"No columns specified; performing agg_list on all columns. Specify columns using df.agg_list('col1', 'col2', ...)."
)
cols = tuple(self.columns)

Check warning on line 940 in daft/dataframe/dataframe.py

View check run for this annotation

Codecov / codecov/patch

daft/dataframe/dataframe.py#L940

Added line #L940 was not covered by tests
return self._agg([(c, "list") for c in cols])

@DataframePublicAPI
Expand All @@ -922,7 +949,11 @@
Returns:
DataFrame: Globally aggregated list. Should be a single row.
"""
assert len(cols) > 0, "no columns were passed in"
if len(cols) == 0:
logger.warning(

Check warning on line 953 in daft/dataframe/dataframe.py

View check run for this annotation

Codecov / codecov/patch

daft/dataframe/dataframe.py#L952-L953

Added lines #L952 - L953 were not covered by tests
"No columns specified; performing agg_concat on all columns. Specify columns using df.agg_concat('col1', 'col2', ...)."
)
cols = tuple(self.columns)

Check warning on line 956 in daft/dataframe/dataframe.py

View check run for this annotation

Codecov / codecov/patch

daft/dataframe/dataframe.py#L956

Added line #L956 was not covered by tests
return self._agg([(c, "concat") for c in cols])

@DataframePublicAPI
Expand Down
Loading