From ce3d798e9f36ea35d29ab05bddc2b1e460e5e4d5 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Thu, 24 Oct 2024 13:11:25 -0700 Subject: [PATCH 1/7] perf(weave): split calls query when using a heavy condition --- weave/trace_server/calls_query_builder.py | 73 +++++++++++++------ .../clickhouse_trace_server_batched.py | 48 +++++++++--- 2 files changed, 88 insertions(+), 33 deletions(-) diff --git a/weave/trace_server/calls_query_builder.py b/weave/trace_server/calls_query_builder.py index 76e94d27d61..3fec0c3c8fb 100644 --- a/weave/trace_server/calls_query_builder.py +++ b/weave/trace_server/calls_query_builder.py @@ -326,7 +326,7 @@ def set_include_costs(self, include_costs: bool) -> "CallsQuery": self.include_costs = include_costs return self - def as_sql(self, pb: ParamBuilder, table_alias: str = "calls_merged") -> str: + def as_sql(self, pb: ParamBuilder, table_alias: str = "calls_merged") -> list[str]: """ This is the main entry point for building the query. This method will determine the optimal query to build based on the fields and conditions @@ -456,7 +456,7 @@ def as_sql(self, pb: ParamBuilder, table_alias: str = "calls_merged") -> str: # If we should not optimize, then just build the base query if not should_optimize and not self.include_costs: - return self._as_sql_base_format(pb, table_alias) + return [self._as_sql_base_format(pb, table_alias)] # If so, build the two queries filter_query = CallsQuery(project_id=self.project_id) @@ -489,35 +489,60 @@ def as_sql(self, pb: ParamBuilder, table_alias: str = "calls_merged") -> str: outer_query.limit = self.limit outer_query.offset = self.offset - raw_sql = f""" - WITH filtered_calls AS ({filter_query._as_sql_base_format(pb, table_alias)}) - """ - - if self.include_costs: - # TODO: We should unify the calls query order by fields to be orm sort by fields - order_by_fields = [ - tsi.SortBy( - field=sort_by.field.field, direction=sort_by.direction.lower() - ) - for sort_by in self.order_fields - ] - raw_sql += f""", - all_calls AS ({outer_query._as_sql_base_format(pb, table_alias, id_subquery_name="filtered_calls")}), - {cost_query(pb, "all_calls", self.project_id, [field.field for field in self.select_fields], order_by_fields)} - """ + raw_sql_queries = [] + if has_heavy_fields: + # Make two part query + raw_sql_queries += [filter_query._as_sql_base_format(pb, table_alias)] + # we will now expect to receive the filtered_calls as a parameter + outer_raw_sql = outer_query._as_sql_base_format( + pb, + table_alias, + ids_param_slot=_param_slot("filtered_calls", "Array(String)"), + ) + if self.include_costs: + order_by_fields = [ + tsi.SortBy( + field=sort_by.field.field, direction=sort_by.direction.lower() + ) + for sort_by in self.order_fields + ] + outer_raw_sql = f""" + WITH all_calls AS ({outer_raw_sql}), + {cost_query(pb, "all_calls", self.project_id, [field.field for field in self.select_fields], order_by_fields)} + """ + raw_sql_queries.append(outer_raw_sql) else: - raw_sql += f""" - {outer_query._as_sql_base_format(pb, table_alias, id_subquery_name="filtered_calls")} + raw_sql = f""" + WITH filtered_calls as ({filter_query._as_sql_base_format(pb, table_alias)}) """ + if self.include_costs: + # TODO: We should unify the calls query order by fields to be orm sort by fields + order_by_fields = [ + tsi.SortBy( + field=sort_by.field.field, direction=sort_by.direction.lower() + ) + for sort_by in self.order_fields + ] - return _safely_format_sql(raw_sql) + raw_sql += f""" + , + all_calls AS ({outer_query._as_sql_base_format(pb, table_alias, ids_param_slot="filtered_calls")}), + {cost_query(pb, "all_calls", self.project_id, [field.field for field in self.select_fields], order_by_fields)} + """ + else: + raw_sql += f""" + {outer_query._as_sql_base_format(pb, table_alias, ids_param_slot="filtered_calls")} + """ + raw_sql_queries.append(raw_sql) + + return [_safely_format_sql(raw_sql) for raw_sql in raw_sql_queries] def _as_sql_base_format( self, pb: ParamBuilder, table_alias: str, - id_subquery_name: typing.Optional[str] = None, + ids_param_slot: typing.Optional[str] = None, ) -> str: select_fields_sql = ", ".join( field.as_select_sql(pb, table_alias) for field in self.select_fields @@ -555,8 +580,8 @@ def _as_sql_base_format( offset_sql = f"OFFSET {self.offset}" id_subquery_sql = "" - if id_subquery_name is not None: - id_subquery_sql = f"AND (id IN {id_subquery_name})" + if ids_param_slot is not None: + id_subquery_sql = f"AND (id IN {ids_param_slot})" project_param = pb.add_param(self.project_id) diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index bc6f834f662..dad1e99f0ff 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -289,11 +289,26 @@ def calls_query_stats(self, req: tsi.CallsQueryStatsReq) -> tsi.CallsQueryStatsR cq.add_condition(req.query.expr_) pb = ParamBuilder() - inner_query = cq.as_sql(pb) - raw_res = self._query( - f"SELECT count() FROM ({inner_query})", - pb.get_params(), - ) + queries = cq.as_sql(pb) + if len(queries) == 1: + raw_res = self._query( + f"SELECT count() FROM ({queries[0]})", + pb.get_params(), + ) + else: + filtered_calls_res = self._query( + queries[0], + pb.get_params(), + ) + filtered_calls = [x[0] for x in filtered_calls_res.result_rows] + pb.add( + filtered_calls, param_name="filtered_calls", param_type="Array(String)" + ) + raw_res = self._query( + f"SELECT count() FROM ({queries[1]})", + pb.get_params(), + ) + rows = raw_res.result_rows count = 0 if rows and len(rows) == 1 and len(rows[0]) == 1: @@ -341,10 +356,25 @@ def calls_query_stream(self, req: tsi.CallsQueryReq) -> Iterator[tsi.CallSchema] cq.set_offset(req.offset) pb = ParamBuilder() - raw_res = self._query_stream( - cq.as_sql(pb), - pb.get_params(), - ) + queries = cq.as_sql(pb) + if len(queries) == 1: + raw_res = self._query_stream( + queries[0], + pb.get_params(), + ) + else: + filtered_calls_res = self._query( + queries[0], + pb.get_params(), + ) + filtered_calls = [x[0] for x in filtered_calls_res.result_rows] + pb.add( + filtered_calls, param_name="filtered_calls", param_type="Array(String)" + ) + raw_res = self._query_stream( + queries[1], + pb.get_params(), + ) select_columns = [c.field for c in cq.select_fields] From c3e0ee1de714f032266add276c0bcf80a7eda24c Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Fri, 25 Oct 2024 09:50:49 -0700 Subject: [PATCH 2/7] cleaner --- weave/trace_server/calls_query_builder.py | 164 +++++++++--------- .../clickhouse_trace_server_batched.py | 69 ++++---- 2 files changed, 114 insertions(+), 119 deletions(-) diff --git a/weave/trace_server/calls_query_builder.py b/weave/trace_server/calls_query_builder.py index 3fec0c3c8fb..2928490533c 100644 --- a/weave/trace_server/calls_query_builder.py +++ b/weave/trace_server/calls_query_builder.py @@ -266,6 +266,23 @@ class CallsQuery(BaseModel): offset: typing.Optional[int] = None include_costs: bool = False + # Optional param name for the output of a filtered query. + # This is used for two-step queries. If CallsQuery is used + # to build a two-step query, this value must be set. + _filtered_output_param: typing.Optional[str] = None + + @property + def _has_heavy_select(self) -> bool: + return any(field.is_heavy() for field in self.select_fields) + + @property + def _has_heavy_filter(self) -> bool: + return any(condition.is_heavy() for condition in self.query_conditions) + + @property + def _has_heavy_order(self) -> bool: + return any(order_field.field.is_heavy() for order_field in self.order_fields) + def add_field(self, field: str) -> "CallsQuery": self.select_fields.append(get_field_by_name(field)) return self @@ -326,6 +343,18 @@ def set_include_costs(self, include_costs: bool) -> "CallsQuery": self.include_costs = include_costs return self + def should_do_two_step_query(self) -> bool: + """Returns True if the query should be a forced two step query. + When heavy fields are filtered or ordered, we can't push down to subquery. + Two-step query ensures that the subquery to return the filtered call IDs + is executed first, and then the main query can be executed with the + filtered IDs. + """ + return self._has_heavy_filter or self._has_heavy_order + + def set_filtered_output_param(self, param_name: str) -> None: + self._filtered_output_param = param_name + def as_sql(self, pb: ParamBuilder, table_alias: str = "calls_merged") -> list[str]: """ This is the main entry point for building the query. This method will @@ -401,42 +430,6 @@ def as_sql(self, pb: ParamBuilder, table_alias: str = "calls_merged") -> list[st if not self.select_fields: raise ValueError("Missing select columns") - # Determine if the query `has_heavy_fields` by checking - # if it `has_heavy_select or has_heavy_filter or has_heavy_order` - has_heavy_select = any(field.is_heavy() for field in self.select_fields) - - has_heavy_filter = any( - condition.is_heavy() for condition in self.query_conditions - ) - - has_heavy_order = any( - order_field.field.is_heavy() for order_field in self.order_fields - ) - - has_heavy_fields = has_heavy_select or has_heavy_filter or has_heavy_order - - # Determine if `predicate_pushdown_possible` which is - # if it `has_light_filter or has_light_query or has_light_order_filter` - has_light_filter = self.hardcoded_filter and self.hardcoded_filter.is_useful() - - has_light_query = any( - not condition.is_heavy() for condition in self.query_conditions - ) - - has_light_order_filter = ( - self.order_fields - and self.limit - and not has_heavy_filter - and not has_heavy_order - ) - - predicate_pushdown_possible = ( - has_light_filter or has_light_query or has_light_order_filter - ) - - # Determine if we should optimize! - should_optimize = has_heavy_fields and predicate_pushdown_possible - # Important: Always inject deleted_at into the query. # Note: it might be better to make this configurable. self.add_condition( @@ -454,11 +447,28 @@ def as_sql(self, pb: ParamBuilder, table_alias: str = "calls_merged") -> list[st ) ) - # If we should not optimize, then just build the base query - if not should_optimize and not self.include_costs: + has_heavy_fields = ( + self._has_heavy_select or self._has_heavy_filter or self._has_heavy_order + ) + has_light_filter = self.hardcoded_filter and self.hardcoded_filter.is_useful() + has_light_query = any( + not condition.is_heavy() for condition in self.query_conditions + ) + has_light_order_filter = ( + self.order_fields + and self.limit + and not self._has_heavy_filter + and not self._has_heavy_order + ) + do_predicate_pushdown = ( + has_light_filter or has_light_query or has_light_order_filter + ) + + # If we don't need to optimize, return the base query + if not (has_heavy_fields and do_predicate_pushdown and not self.include_costs): return [self._as_sql_base_format(pb, table_alias)] - # If so, build the two queries + # Build the two queries filter_query = CallsQuery(project_id=self.project_id) outer_query = CallsQuery(project_id=self.project_id) @@ -489,54 +499,46 @@ def as_sql(self, pb: ParamBuilder, table_alias: str = "calls_merged") -> list[st outer_query.limit = self.limit outer_query.offset = self.offset - raw_sql_queries = [] - if has_heavy_fields: - # Make two part query - raw_sql_queries += [filter_query._as_sql_base_format(pb, table_alias)] - - # we will now expect to receive the filtered_calls as a parameter - outer_raw_sql = outer_query._as_sql_base_format( - pb, - table_alias, - ids_param_slot=_param_slot("filtered_calls", "Array(String)"), + # If we have heavy fields in filter/order, and we have a filtered output param, + # we should do a two-step query. + two_step_query = ( + self.should_do_two_step_query() and self._filtered_output_param is not None + ) + if two_step_query: + ids_param_slot = _param_slot( + pb.add_param(self._filtered_output_param), "Array(String)" ) - if self.include_costs: - order_by_fields = [ - tsi.SortBy( - field=sort_by.field.field, direction=sort_by.direction.lower() - ) - for sort_by in self.order_fields - ] - outer_raw_sql = f""" + filter_query_sql = filter_query._as_sql_base_format(pb, table_alias) + else: + ids_param_slot = "filtered_calls" + filter_query_sql = f""" + WITH {ids_param_slot} as ({filter_query._as_sql_base_format(pb, table_alias)}) + """ + + outer_raw_sql = outer_query._as_sql_base_format( + pb, + table_alias, + ids_param_slot=ids_param_slot, + ) + + if self.include_costs: + order_by_fields = [ + tsi.SortBy( + field=sort_by.field.field, direction=sort_by.direction.lower() + ) + for sort_by in self.order_fields + ] + prefix = "" if self.requires_two_step_query() else "," + outer_raw_sql = f"""{outer_raw_sql}{prefix} WITH all_calls AS ({outer_raw_sql}), {cost_query(pb, "all_calls", self.project_id, [field.field for field in self.select_fields], order_by_fields)} - """ - raw_sql_queries.append(outer_raw_sql) - else: - raw_sql = f""" - WITH filtered_calls as ({filter_query._as_sql_base_format(pb, table_alias)}) """ - if self.include_costs: - # TODO: We should unify the calls query order by fields to be orm sort by fields - order_by_fields = [ - tsi.SortBy( - field=sort_by.field.field, direction=sort_by.direction.lower() - ) - for sort_by in self.order_fields - ] - raw_sql += f""" - , - all_calls AS ({outer_query._as_sql_base_format(pb, table_alias, ids_param_slot="filtered_calls")}), - {cost_query(pb, "all_calls", self.project_id, [field.field for field in self.select_fields], order_by_fields)} - """ - else: - raw_sql += f""" - {outer_query._as_sql_base_format(pb, table_alias, ids_param_slot="filtered_calls")} - """ - raw_sql_queries.append(raw_sql) + if not two_step_query: + # Join the two queries together, return a single query + return [_safely_format_sql("".join([filter_query_sql, outer_raw_sql]))] - return [_safely_format_sql(raw_sql) for raw_sql in raw_sql_queries] + return [_safely_format_sql(filter_query_sql), _safely_format_sql(outer_raw_sql)] def _as_sql_base_format( self, diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index dad1e99f0ff..93382fb3385 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -276,6 +276,32 @@ def calls_query(self, req: tsi.CallsQueryReq) -> tsi.CallsQueryRes: stream = self.calls_query_stream(req) return tsi.CallsQueryRes(calls=list(stream)) + def _construct_query_str_maybe_do_pre_query( + self, calls_query: CallsQuery, pb: ParamBuilder + ) -> str: + """Helper to query the calls table if heavy. Immediately returns a formatted sql string + if not heavy, otherwise does a pre-query to get the filtered call ids, injects them + into the main query, and returns the main query str. + """ + if not calls_query.should_do_two_step_query(): + query = calls_query.as_sql(pb)[0] + return query + + # We have to do a two-step query. Set the param for the output + # of the first query + output_param_name = "call_ids" + calls_query.set_filtered_output_param(output_param_name) + ids_query, filtered_query = calls_query.as_sql(pb) + + # actually make the ids query + ids_res = self._query(ids_query, pb.get_params()) + + ids = [x[0] for x in ids_res.result_rows] + # add the ids param to the param builder + pb.add(ids, param_name=output_param_name, param_type="Array(String)") + + return filtered_query + def calls_query_stats(self, req: tsi.CallsQueryStatsReq) -> tsi.CallsQueryStatsRes: """Returns a stats object for the given query. This is useful for counts or other aggregate statistics that are not directly queryable from the calls themselves. @@ -289,25 +315,9 @@ def calls_query_stats(self, req: tsi.CallsQueryStatsReq) -> tsi.CallsQueryStatsR cq.add_condition(req.query.expr_) pb = ParamBuilder() - queries = cq.as_sql(pb) - if len(queries) == 1: - raw_res = self._query( - f"SELECT count() FROM ({queries[0]})", - pb.get_params(), - ) - else: - filtered_calls_res = self._query( - queries[0], - pb.get_params(), - ) - filtered_calls = [x[0] for x in filtered_calls_res.result_rows] - pb.add( - filtered_calls, param_name="filtered_calls", param_type="Array(String)" - ) - raw_res = self._query( - f"SELECT count() FROM ({queries[1]})", - pb.get_params(), - ) + query = self._construct_query_str_maybe_do_pre_query(cq, pb) + count_query_str = f"SELECT count() FROM ({query})" + raw_res = self._query(count_query_str, pb.get_params()) rows = raw_res.result_rows count = 0 @@ -356,25 +366,8 @@ def calls_query_stream(self, req: tsi.CallsQueryReq) -> Iterator[tsi.CallSchema] cq.set_offset(req.offset) pb = ParamBuilder() - queries = cq.as_sql(pb) - if len(queries) == 1: - raw_res = self._query_stream( - queries[0], - pb.get_params(), - ) - else: - filtered_calls_res = self._query( - queries[0], - pb.get_params(), - ) - filtered_calls = [x[0] for x in filtered_calls_res.result_rows] - pb.add( - filtered_calls, param_name="filtered_calls", param_type="Array(String)" - ) - raw_res = self._query_stream( - queries[1], - pb.get_params(), - ) + query = self._construct_query_str_maybe_do_pre_query(cq, pb) + raw_res = self._query_stream(query, pb.get_params()) select_columns = [c.field for c in cq.select_fields] From 555d85e9c2e7210fe2d0c42474a474dcd971ecc4 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Fri, 25 Oct 2024 10:46:45 -0700 Subject: [PATCH 3/7] bug --- weave/trace_server/calls_query_builder.py | 11 +++++------ weave/trace_server/clickhouse_trace_server_batched.py | 2 +- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/weave/trace_server/calls_query_builder.py b/weave/trace_server/calls_query_builder.py index 2928490533c..7a6c13e4b88 100644 --- a/weave/trace_server/calls_query_builder.py +++ b/weave/trace_server/calls_query_builder.py @@ -465,7 +465,7 @@ def as_sql(self, pb: ParamBuilder, table_alias: str = "calls_merged") -> list[st ) # If we don't need to optimize, return the base query - if not (has_heavy_fields and do_predicate_pushdown and not self.include_costs): + if not (has_heavy_fields or do_predicate_pushdown or self.include_costs): return [self._as_sql_base_format(pb, table_alias)] # Build the two queries @@ -505,9 +505,8 @@ def as_sql(self, pb: ParamBuilder, table_alias: str = "calls_merged") -> list[st self.should_do_two_step_query() and self._filtered_output_param is not None ) if two_step_query: - ids_param_slot = _param_slot( - pb.add_param(self._filtered_output_param), "Array(String)" - ) + assert self._filtered_output_param is not None + ids_param_slot = _param_slot(self._filtered_output_param, "Array(String)") filter_query_sql = filter_query._as_sql_base_format(pb, table_alias) else: ids_param_slot = "filtered_calls" @@ -528,8 +527,8 @@ def as_sql(self, pb: ParamBuilder, table_alias: str = "calls_merged") -> list[st ) for sort_by in self.order_fields ] - prefix = "" if self.requires_two_step_query() else "," - outer_raw_sql = f"""{outer_raw_sql}{prefix} + prefix = "" if two_step_query else "," + outer_raw_sql = f"""{prefix} WITH all_calls AS ({outer_raw_sql}), {cost_query(pb, "all_calls", self.project_id, [field.field for field in self.select_fields], order_by_fields)} """ diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index 93382fb3385..fd7afc0e2d2 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -293,7 +293,7 @@ def _construct_query_str_maybe_do_pre_query( calls_query.set_filtered_output_param(output_param_name) ids_query, filtered_query = calls_query.as_sql(pb) - # actually make the ids query + # Hit the db to get the ids ids_res = self._query(ids_query, pb.get_params()) ids = [x[0] for x in ids_res.result_rows] From 67709f1a07d12142e490e4d1e1c2453b92462651 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Fri, 25 Oct 2024 11:01:15 -0700 Subject: [PATCH 4/7] bug --- weave/trace_server/calls_query_builder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/weave/trace_server/calls_query_builder.py b/weave/trace_server/calls_query_builder.py index 7a6c13e4b88..aa6ee0fd543 100644 --- a/weave/trace_server/calls_query_builder.py +++ b/weave/trace_server/calls_query_builder.py @@ -527,9 +527,9 @@ def as_sql(self, pb: ParamBuilder, table_alias: str = "calls_merged") -> list[st ) for sort_by in self.order_fields ] - prefix = "" if two_step_query else "," - outer_raw_sql = f"""{prefix} - WITH all_calls AS ({outer_raw_sql}), + prefix = "WITH" if two_step_query else "," + outer_raw_sql = f""" + {prefix} all_calls AS ({outer_raw_sql}), {cost_query(pb, "all_calls", self.project_id, [field.field for field in self.select_fields], order_by_fields)} """ From 331484d1ab44aa8c9b25e058ad253ecadaa29970 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Fri, 25 Oct 2024 11:44:13 -0700 Subject: [PATCH 5/7] fix trace tests --- .../trace_server/test_calls_query_builder.py | 153 ++++++++++++++---- weave/trace_server/calls_query_builder.py | 42 ++--- 2 files changed, 148 insertions(+), 47 deletions(-) diff --git a/tests/trace_server/test_calls_query_builder.py b/tests/trace_server/test_calls_query_builder.py index 23716a13a68..22933e2d83f 100644 --- a/tests/trace_server/test_calls_query_builder.py +++ b/tests/trace_server/test_calls_query_builder.py @@ -6,12 +6,30 @@ from weave.trace_server.orm import ParamBuilder +def assert_sql(cq: CallsQuery, exp_queries, exp_params): + pb = ParamBuilder("pb") + queries = cq.as_sql(pb) + params = pb.get_params() + + print("PARAMS", params) + + assert exp_params == params + assert len(queries) == len(exp_queries) + + for qr, qe in zip(queries, exp_queries): + exp_formatted = sqlparse.format(qe, reindent=True) + found_formatted = sqlparse.format(qr, reindent=True) + + assert exp_formatted == found_formatted + + def test_query_baseline() -> None: cq = CallsQuery(project_id="project") cq.add_field("id") assert_sql( cq, - """ + [ + """ SELECT calls_merged.id AS id FROM calls_merged WHERE project_id = {pb_0:String} @@ -27,7 +45,8 @@ def test_query_baseline() -> None: )) )) ) - """, + """ + ], {"pb_0": "project"}, ) @@ -38,7 +57,8 @@ def test_query_light_column() -> None: cq.add_field("started_at") assert_sql( cq, - """ + [ + """ SELECT calls_merged.id AS id, any(calls_merged.started_at) AS started_at @@ -56,7 +76,8 @@ def test_query_light_column() -> None: )) )) ) - """, + """ + ], {"pb_0": "project"}, ) @@ -67,7 +88,8 @@ def test_query_heavy_column() -> None: cq.add_field("inputs") assert_sql( cq, - """ + [ + """ SELECT calls_merged.id AS id, any(calls_merged.inputs_dump) AS inputs_dump @@ -85,7 +107,8 @@ def test_query_heavy_column() -> None: )) )) ) - """, + """ + ], {"pb_0": "project"}, ) @@ -103,7 +126,8 @@ def test_query_heavy_column_simple_filter() -> None: ) assert_sql( cq, - """ + [ + """ WITH filtered_calls AS ( SELECT calls_merged.id AS id @@ -125,7 +149,8 @@ def test_query_heavy_column_simple_filter() -> None: AND (id IN filtered_calls) GROUP BY (project_id,id) - """, + """ + ], {"pb_0": ["a", "b"], "pb_1": "project", "pb_2": "project"}, ) @@ -144,7 +169,8 @@ def test_query_heavy_column_simple_filter_with_order() -> None: ) assert_sql( cq, - """ + [ + """ WITH filtered_calls AS ( SELECT calls_merged.id AS id @@ -167,7 +193,8 @@ def test_query_heavy_column_simple_filter_with_order() -> None: (id IN filtered_calls) GROUP BY (project_id,id) ORDER BY any(calls_merged.started_at) DESC - """, + """ + ], {"pb_0": ["a", "b"], "pb_1": "project", "pb_2": "project"}, ) @@ -187,7 +214,8 @@ def test_query_heavy_column_simple_filter_with_order_and_limit() -> None: ) assert_sql( cq, - """ + [ + """ WITH filtered_calls AS ( SELECT calls_merged.id AS id @@ -214,7 +242,8 @@ def test_query_heavy_column_simple_filter_with_order_and_limit() -> None: (id IN filtered_calls) GROUP BY (project_id,id) ORDER BY any(calls_merged.started_at) DESC - """, + """ + ], {"pb_0": ["a", "b"], "pb_1": "project", "pb_2": "project"}, ) @@ -253,7 +282,8 @@ def test_query_heavy_column_simple_filter_with_order_and_limit_and_mixed_query_c ) assert_sql( cq, - """ + [ + """ WITH filtered_calls AS ( SELECT calls_merged.id AS id @@ -284,7 +314,8 @@ def test_query_heavy_column_simple_filter_with_order_and_limit_and_mixed_query_c ) ORDER BY any(calls_merged.started_at) DESC LIMIT 10 - """, + """ + ], { "pb_0": "my_user_id", "pb_1": ["a", "b"], @@ -296,17 +327,83 @@ def test_query_heavy_column_simple_filter_with_order_and_limit_and_mixed_query_c ) -def assert_sql(cq: CallsQuery, exp_query, exp_params): - pb = ParamBuilder("pb") - query = cq.as_sql(pb) - params = pb.get_params() - - assert exp_params == params - - exp_formatted = sqlparse.format(exp_query, reindent=True) - found_formatted = sqlparse.format(query, reindent=True) - - assert exp_formatted == found_formatted +def test_query_heavy_column_simple_filter_with_order_and_limit_and_mixed_query_conditions_two_step() -> ( + None +): + cq = CallsQuery(project_id="project") + cq.add_field("id") + cq.add_field("inputs") + cq.add_order("started_at", "desc") + cq.set_limit(10) + cq.set_hardcoded_filter( + HardCodedFilter( + filter=tsi.CallsFilter( + op_names=["a", "b"], + ) + ) + ) + cq.add_condition( + tsi_query.AndOperation.model_validate( + { + "$and": [ + { + "$eq": [ + {"$getField": "inputs.param.val"}, + {"$literal": "hello"}, + ] + }, # <-- heavy condition + { + "$eq": [{"$getField": "wb_user_id"}, {"$literal": "my_user_id"}] + }, # <-- light condition + ] + } + ) + ) + cq.set_filtered_output_param("filtered_calls") + assert_sql( + cq, + [ + """ + SELECT + calls_merged.id AS id + FROM calls_merged + WHERE project_id = {pb_2:String} + GROUP BY (project_id,id) + HAVING ( + ((any(calls_merged.wb_user_id) = {pb_0:String})) + AND + ((any(calls_merged.deleted_at) IS NULL)) + AND + ((NOT ((any(calls_merged.started_at) IS NULL)))) + AND + (any(calls_merged.op_name) IN {pb_1:Array(String)}) + )""", + """ + SELECT + calls_merged.id AS id, + any(calls_merged.inputs_dump) AS inputs_dump + FROM calls_merged + WHERE + project_id = {pb_5:String} + AND + (id IN {filtered_calls:Array(String)}) + GROUP BY (project_id,id) + HAVING ( + JSON_VALUE(any(calls_merged.inputs_dump), {pb_3:String}) = {pb_4:String} + ) + ORDER BY any(calls_merged.started_at) DESC + LIMIT 10 + """, + ], + { + "pb_0": "my_user_id", + "pb_1": ["a", "b"], + "pb_2": "project", + "pb_3": '$."param"."val"', + "pb_4": "hello", + "pb_5": "project", + }, + ) def test_query_light_column_with_costs() -> None: @@ -324,7 +421,8 @@ def test_query_light_column_with_costs() -> None: ) assert_sql( cq, - """ + [ + """ WITH filtered_calls AS ( SELECT calls_merged.id AS id @@ -436,7 +534,8 @@ def test_query_light_column_with_costs() -> None: FROM ranked_prices WHERE (rank = {pb_3:UInt64}) GROUP BY id, started_at - """, + """ + ], { "pb_0": ["a", "b"], "pb_1": "UHJvamVjdEludGVybmFsSWQ6Mzk1NDg2Mjc=", diff --git a/weave/trace_server/calls_query_builder.py b/weave/trace_server/calls_query_builder.py index aa6ee0fd543..1029b9e44c0 100644 --- a/weave/trace_server/calls_query_builder.py +++ b/weave/trace_server/calls_query_builder.py @@ -430,23 +430,6 @@ def as_sql(self, pb: ParamBuilder, table_alias: str = "calls_merged") -> list[st if not self.select_fields: raise ValueError("Missing select columns") - # Important: Always inject deleted_at into the query. - # Note: it might be better to make this configurable. - self.add_condition( - tsi_query.EqOperation.model_validate( - {"$eq": [{"$getField": "deleted_at"}, {"$literal": None}]} - ) - ) - - # Important: We must always filter out calls that have not been started - # This can occur when there is an out of order call part insertion or worse, - # when such occurance happens and the client terminates early. - self.add_condition( - tsi_query.NotOperation.model_validate( - {"$not": [{"$eq": [{"$getField": "started_at"}, {"$literal": None}]}]} - ) - ) - has_heavy_fields = ( self._has_heavy_select or self._has_heavy_filter or self._has_heavy_order ) @@ -464,8 +447,27 @@ def as_sql(self, pb: ParamBuilder, table_alias: str = "calls_merged") -> list[st has_light_filter or has_light_query or has_light_order_filter ) - # If we don't need to optimize, return the base query - if not (has_heavy_fields or do_predicate_pushdown or self.include_costs): + # Important: Always inject deleted_at into the query. + # Note: it might be better to make this configurable. + self.add_condition( + tsi_query.EqOperation.model_validate( + {"$eq": [{"$getField": "deleted_at"}, {"$literal": None}]} + ) + ) + + # Important: We must always filter out calls that have not been started + # This can occur when there is an out of order call part insertion or worse, + # when such occurance happens and the client terminates early. + self.add_condition( + tsi_query.NotOperation.model_validate( + {"$not": [{"$eq": [{"$getField": "started_at"}, {"$literal": None}]}]} + ) + ) + + should_optimize = ( + has_heavy_fields and do_predicate_pushdown + ) or self.include_costs + if not should_optimize: return [self._as_sql_base_format(pb, table_alias)] # Build the two queries @@ -511,7 +513,7 @@ def as_sql(self, pb: ParamBuilder, table_alias: str = "calls_merged") -> list[st else: ids_param_slot = "filtered_calls" filter_query_sql = f""" - WITH {ids_param_slot} as ({filter_query._as_sql_base_format(pb, table_alias)}) + WITH {ids_param_slot} AS ({filter_query._as_sql_base_format(pb, table_alias)}) """ outer_raw_sql = outer_query._as_sql_base_format( From 3cfa0ec8ebe19d9311ebecd956a6e998e42b0115 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Fri, 25 Oct 2024 11:48:17 -0700 Subject: [PATCH 6/7] trax --- tests/trace_server/test_calls_query_builder.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/trace_server/test_calls_query_builder.py b/tests/trace_server/test_calls_query_builder.py index 22933e2d83f..f608fb13b20 100644 --- a/tests/trace_server/test_calls_query_builder.py +++ b/tests/trace_server/test_calls_query_builder.py @@ -11,8 +11,6 @@ def assert_sql(cq: CallsQuery, exp_queries, exp_params): queries = cq.as_sql(pb) params = pb.get_params() - print("PARAMS", params) - assert exp_params == params assert len(queries) == len(exp_queries) From 11e6aa5616cea00f7013fc9cb78a8623a6a2e524 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Fri, 25 Oct 2024 12:36:48 -0700 Subject: [PATCH 7/7] maybe --- tests/trace_server/test_calls_query_builder.py | 5 ++--- weave/trace_server/calls_query_builder.py | 6 +++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/trace_server/test_calls_query_builder.py b/tests/trace_server/test_calls_query_builder.py index f608fb13b20..d9a92c6201e 100644 --- a/tests/trace_server/test_calls_query_builder.py +++ b/tests/trace_server/test_calls_query_builder.py @@ -11,15 +11,14 @@ def assert_sql(cq: CallsQuery, exp_queries, exp_params): queries = cq.as_sql(pb) params = pb.get_params() - assert exp_params == params - assert len(queries) == len(exp_queries) - for qr, qe in zip(queries, exp_queries): exp_formatted = sqlparse.format(qe, reindent=True) found_formatted = sqlparse.format(qr, reindent=True) assert exp_formatted == found_formatted + assert exp_params == params + def test_query_baseline() -> None: cq = CallsQuery(project_id="project") diff --git a/weave/trace_server/calls_query_builder.py b/weave/trace_server/calls_query_builder.py index 1029b9e44c0..7d659433a61 100644 --- a/weave/trace_server/calls_query_builder.py +++ b/weave/trace_server/calls_query_builder.py @@ -464,10 +464,10 @@ def as_sql(self, pb: ParamBuilder, table_alias: str = "calls_merged") -> list[st ) ) - should_optimize = ( + should_optimize = self.should_do_two_step_query() or ( has_heavy_fields and do_predicate_pushdown - ) or self.include_costs - if not should_optimize: + ) + if not should_optimize and not self.include_costs: return [self._as_sql_base_format(pb, table_alias)] # Build the two queries