From 3c85c752d68a2c1b8cf99f066ac7925eb552a380 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Thu, 9 Nov 2023 14:27:59 -0800 Subject: [PATCH 1/5] modified df.count --- daft/dataframe/dataframe.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index ddf73b291a..584c70a5b2 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -4,6 +4,7 @@ # in order to support runtime typechecking across different Python versions. # For technical details, see https://github.com/Eventual-Inc/Daft/pull/630 +import logging import pathlib from dataclasses import dataclass from functools import reduce @@ -46,6 +47,8 @@ from daft.logical.schema import Schema +logger = logging.getLogger(__name__) + UDFReturnType = TypeVar("UDFReturnType", covariant=True) ColumnInputType = Union[Expression, str] @@ -898,7 +901,9 @@ def count(self, *cols: ColumnInputType) -> "DataFrame": Returns: DataFrame: Globally aggregated count. Should be a single row. """ - assert len(cols) > 0, "no columns were passed in" + if len(cols) == 0: + logger.warning("No columns specified; performing count on all columns. Specify columns using df.count('col1', 'col2', ...) or use df.count_rows() for row counts.") + cols = tuple(self.columns) return self._agg([(c, "count") for c in cols]) @DataframePublicAPI From e6cb93f11104ab552b8d4dabe9e766345fe773de Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Thu, 9 Nov 2023 14:54:57 -0800 Subject: [PATCH 2/5] formatting --- daft/dataframe/dataframe.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 584c70a5b2..6d73429265 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -902,7 +902,9 @@ def count(self, *cols: ColumnInputType) -> "DataFrame": DataFrame: Globally aggregated count. Should be a single row. """ if len(cols) == 0: - logger.warning("No columns specified; performing count on all columns. Specify columns using df.count('col1', 'col2', ...) or use df.count_rows() for row counts.") + logger.warning( + "No columns specified; performing count on all columns. Specify columns using df.count('col1', 'col2', ...) or use df.count_rows() for row counts." + ) cols = tuple(self.columns) return self._agg([(c, "count") for c in cols]) From c038678bf784b49a1e3edb90ca670b245d4dbfbd Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Fri, 10 Nov 2023 12:53:05 -0800 Subject: [PATCH 3/5] modify other agg methods --- daft/dataframe/dataframe.py | 36 ++++++++++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 6d73429265..e701640b76 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -853,7 +853,11 @@ def sum(self, *cols: ColumnInputType) -> "DataFrame": Returns: DataFrame: Globally aggregated sums. Should be a single row. """ - assert len(cols) > 0, "no columns were passed in" + if len(cols) == 0: + logger.warning( + "No columns specified; performing sum on all columns. Specify columns using df.sum('col1', 'col2', ...)." + ) + cols = tuple(self.columns) return self._agg([(c, "sum") for c in cols]) @DataframePublicAPI @@ -865,7 +869,11 @@ def mean(self, *cols: ColumnInputType) -> "DataFrame": Returns: DataFrame: Globally aggregated mean. Should be a single row. """ - assert len(cols) > 0, "no columns were passed in" + if len(cols) == 0: + logger.warning( + "No columns specified; performing mean on all columns. Specify columns using df.mean('col1', 'col2', ...)." + ) + cols = tuple(self.columns) return self._agg([(c, "mean") for c in cols]) @DataframePublicAPI @@ -877,7 +885,11 @@ def min(self, *cols: ColumnInputType) -> "DataFrame": Returns: DataFrame: Globally aggregated min. Should be a single row. """ - assert len(cols) > 0, "no columns were passed in" + if len(cols) == 0: + logger.warning( + "No columns specified; performing min on all columns. Specify columns using df.min('col1', 'col2', ...)." + ) + cols = tuple(self.columns) return self._agg([(c, "min") for c in cols]) @DataframePublicAPI @@ -889,7 +901,11 @@ def max(self, *cols: ColumnInputType) -> "DataFrame": Returns: DataFrame: Globally aggregated max. Should be a single row. """ - assert len(cols) > 0, "no columns were passed in" + if len(cols) == 0: + logger.warning( + "No columns specified; performing max on all columns. Specify columns using df.max('col1', 'col2', ...)." + ) + cols = tuple(self.columns) return self._agg([(c, "max") for c in cols]) @DataframePublicAPI @@ -917,7 +933,11 @@ def agg_list(self, *cols: ColumnInputType) -> "DataFrame": Returns: DataFrame: Globally aggregated list. Should be a single row. """ - assert len(cols) > 0, "no columns were passed in" + if len(cols) == 0: + logger.warning( + "No columns specified; performing agg_list on all columns. Specify columns using df.agg_list('col1', 'col2', ...)." + ) + cols = tuple(self.columns) return self._agg([(c, "list") for c in cols]) @DataframePublicAPI @@ -929,7 +949,11 @@ def agg_concat(self, *cols: ColumnInputType) -> "DataFrame": Returns: DataFrame: Globally aggregated list. Should be a single row. """ - assert len(cols) > 0, "no columns were passed in" + if len(cols) == 0: + logger.warning( + "No columns specified; performing agg_concat on all columns. Specify columns using df.agg_concat('col1', 'col2', ...)." + ) + cols = tuple(self.columns) return self._agg([(c, "concat") for c in cols]) @DataframePublicAPI From 2e8683d602dff49520f8dba575db9e4dd1b351cd Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Mon, 13 Nov 2023 12:50:17 -0800 Subject: [PATCH 4/5] use warnings.warn --- daft/dataframe/dataframe.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index e701640b76..36b70c5774 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -4,7 +4,6 @@ # in order to support runtime typechecking across different Python versions. # For technical details, see https://github.com/Eventual-Inc/Daft/pull/630 -import logging import pathlib from dataclasses import dataclass from functools import reduce @@ -21,6 +20,7 @@ TypeVar, Union, ) +import warnings from daft.api_annotations import DataframePublicAPI from daft.context import get_context @@ -47,8 +47,6 @@ from daft.logical.schema import Schema -logger = logging.getLogger(__name__) - UDFReturnType = TypeVar("UDFReturnType", covariant=True) ColumnInputType = Union[Expression, str] @@ -854,7 +852,7 @@ def sum(self, *cols: ColumnInputType) -> "DataFrame": DataFrame: Globally aggregated sums. Should be a single row. """ if len(cols) == 0: - logger.warning( + warnings.warn( "No columns specified; performing sum on all columns. Specify columns using df.sum('col1', 'col2', ...)." ) cols = tuple(self.columns) @@ -870,7 +868,7 @@ def mean(self, *cols: ColumnInputType) -> "DataFrame": DataFrame: Globally aggregated mean. Should be a single row. """ if len(cols) == 0: - logger.warning( + warnings.warn( "No columns specified; performing mean on all columns. Specify columns using df.mean('col1', 'col2', ...)." ) cols = tuple(self.columns) @@ -886,7 +884,7 @@ def min(self, *cols: ColumnInputType) -> "DataFrame": DataFrame: Globally aggregated min. Should be a single row. """ if len(cols) == 0: - logger.warning( + warnings.warn( "No columns specified; performing min on all columns. Specify columns using df.min('col1', 'col2', ...)." ) cols = tuple(self.columns) @@ -902,7 +900,7 @@ def max(self, *cols: ColumnInputType) -> "DataFrame": DataFrame: Globally aggregated max. Should be a single row. """ if len(cols) == 0: - logger.warning( + warnings.warn( "No columns specified; performing max on all columns. Specify columns using df.max('col1', 'col2', ...)." ) cols = tuple(self.columns) @@ -918,7 +916,7 @@ def count(self, *cols: ColumnInputType) -> "DataFrame": DataFrame: Globally aggregated count. Should be a single row. """ if len(cols) == 0: - logger.warning( + warnings.warn( "No columns specified; performing count on all columns. Specify columns using df.count('col1', 'col2', ...) or use df.count_rows() for row counts." ) cols = tuple(self.columns) @@ -934,7 +932,7 @@ def agg_list(self, *cols: ColumnInputType) -> "DataFrame": DataFrame: Globally aggregated list. Should be a single row. """ if len(cols) == 0: - logger.warning( + warnings.warn( "No columns specified; performing agg_list on all columns. Specify columns using df.agg_list('col1', 'col2', ...)." ) cols = tuple(self.columns) @@ -950,7 +948,7 @@ def agg_concat(self, *cols: ColumnInputType) -> "DataFrame": DataFrame: Globally aggregated list. Should be a single row. """ if len(cols) == 0: - logger.warning( + warnings.warn( "No columns specified; performing agg_concat on all columns. Specify columns using df.agg_concat('col1', 'col2', ...)." ) cols = tuple(self.columns) From 5c787ae949513c7bb5c71247afcdcabfc24d7bab Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Mon, 13 Nov 2023 13:27:33 -0800 Subject: [PATCH 5/5] fix import order --- daft/dataframe/dataframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 36b70c5774..ba7e05ff23 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -5,6 +5,7 @@ # For technical details, see https://github.com/Eventual-Inc/Daft/pull/630 import pathlib +import warnings from dataclasses import dataclass from functools import reduce from typing import ( @@ -20,7 +21,6 @@ TypeVar, Union, ) -import warnings from daft.api_annotations import DataframePublicAPI from daft.context import get_context