diff --git a/README.rst b/README.rst index f36f34b5..31bb94ff 100644 --- a/README.rst +++ b/README.rst @@ -1368,6 +1368,58 @@ This produces: DROP INDEX IF EXISTS my_index + +Chaining Functions +^^^^^^^^^^^^^^^^^^ + +The ``QueryBuilder.pipe`` method gives a more readable alternative while chaining functions. + +.. code-block:: python + + # This + ( + query + .pipe(func1, *args) + .pipe(func2, **kwargs) + .pipe(func3) + ) + + # Is equivalent to this + func3(func2(func1(query, *args), **kwargs)) + +Or for a more concrete example: + +.. code-block:: python + + from pypika import Field, Query, functions as fn + from pypika.queries import QueryBuilder + + def filter_days(query: QueryBuilder, col, num_days: int) -> QueryBuilder: + if isinstance(col, str): + col = Field(col) + + return query.where(col > fn.Now() - num_days) + + def count_groups(query: QueryBuilder, *groups) -> QueryBuilder: + return query.groupby(*groups).select(*groups, fn.Count("*").as_("n_rows")) + + base_query = Query.from_("table") + + query = ( + base_query + .pipe(filter_days, "date", num_days=7) + .pipe(count_groups, "col1", "col2") + ) + +This produces: + +.. code-block:: sql + + SELECT "col1","col2",COUNT(*) n_rows + FROM "table" + WHERE "date">NOW()-7 + GROUP BY "col1","col2" + .. _tutorial_end: .. _contributing_start: diff --git a/pypika/queries.py b/pypika/queries.py index 223c3c95..c51c6b2b 100644 --- a/pypika/queries.py +++ b/pypika/queries.py @@ -1560,6 +1560,54 @@ def _set_sql(self, **kwargs: Any) -> str: ) ) + def pipe(self, func, *args, **kwargs): + """Call a function on the current object and return the result. + + Example usage: + + .. code-block:: python + + from pypika import Query, functions as fn + from pypika.queries import QueryBuilder + + def rows_by_group(query: QueryBuilder, *groups) -> QueryBuilder: + return ( + query + .select(*groups, fn.Count("*").as_("n_rows")) + .groupby(*groups) + ) + + base_query = Query.from_("table") + + col1_agg = base_query.pipe(rows_by_group, "col1") + col2_agg = base_query.pipe(rows_by_group, "col2") + col1_col2_agg = base_query.pipe(rows_by_group, "col1", "col2") + + Makes chaining functions together easier, especially when the functions are + defined elsewhere. For example, you could define a function that filters + rows by a date range and then group by a set of columns: + + + .. code-block:: python + + from datetime import datetime, timedelta + + from pypika import Field + + def days_since(query: QueryBuilder, n_days: int) -> QueryBuilder: + return ( + query + .where("date" > fn.Date(datetime.now().date() - timedelta(days=n_days))) + ) + + ( + base_query + .pipe(days_since, n_days=7) + .pipe(rows_by_group, "col1", "col2") + ) + """ + return func(self, *args, **kwargs) + class Joiner: def __init__( diff --git a/pypika/tests/test_query.py b/pypika/tests/test_query.py index 460b7d77..ae8dc015 100644 --- a/pypika/tests/test_query.py +++ b/pypika/tests/test_query.py @@ -1,6 +1,6 @@ import unittest -from pypika import Case, Query, Tables, Tuple, functions +from pypika import Case, Query, Tables, Tuple, functions, Field from pypika.dialects import ( ClickHouseQuery, ClickHouseQueryBuilder, @@ -204,3 +204,39 @@ def test_query_builders_have_reference_to_correct_query_class(self): with self.subTest('OracleQueryBuilder'): self.assertEqual(OracleQuery, OracleQueryBuilder.QUERY_CLS) + + def test_pipe(self) -> None: + base_query = Query.from_("test") + + def select(query: QueryBuilder) -> QueryBuilder: + return query.select("test1", "test2") + + def count_group(query: QueryBuilder, *groups) -> QueryBuilder: + return query.groupby(*groups).select(*groups, functions.Count("*")) + + for func, args, kwargs, expected_str in [ + (select, [], {}, 'SELECT "test1","test2" FROM "test"'), + ( + count_group, + ["test1", "test2"], + {}, + 'SELECT "test1","test2",COUNT(*) FROM "test" GROUP BY "test1","test2"', + ), + (count_group, ["test1"], {}, 'SELECT "test1",COUNT(*) FROM "test" GROUP BY "test1"'), + ]: + result_str = str(base_query.pipe(func, *args, **kwargs)) + self.assertEqual(result_str, str(func(base_query, *args, **kwargs))) + self.assertEqual(result_str, expected_str) + + def where_clause(query: QueryBuilder, num_days: int) -> QueryBuilder: + return query.where(Field("date") > functions.Now() - num_days) + + result_str = str(base_query.pipe(select).pipe(where_clause, num_days=1)) + self.assertEqual( + result_str, + str(select(where_clause(base_query, num_days=1))), + ) + self.assertEqual( + result_str, + 'SELECT "test1","test2" FROM "test" WHERE "date">NOW()-1', + )