diff --git a/tests/trace_server/test_calls_query_builder.py b/tests/trace_server/test_calls_query_builder.py index f608fb13b20..d9a92c6201e 100644 --- a/tests/trace_server/test_calls_query_builder.py +++ b/tests/trace_server/test_calls_query_builder.py @@ -11,15 +11,14 @@ def assert_sql(cq: CallsQuery, exp_queries, exp_params): queries = cq.as_sql(pb) params = pb.get_params() - assert exp_params == params - assert len(queries) == len(exp_queries) - for qr, qe in zip(queries, exp_queries): exp_formatted = sqlparse.format(qe, reindent=True) found_formatted = sqlparse.format(qr, reindent=True) assert exp_formatted == found_formatted + assert exp_params == params + def test_query_baseline() -> None: cq = CallsQuery(project_id="project") diff --git a/weave/trace_server/calls_query_builder.py b/weave/trace_server/calls_query_builder.py index 1029b9e44c0..7d659433a61 100644 --- a/weave/trace_server/calls_query_builder.py +++ b/weave/trace_server/calls_query_builder.py @@ -464,10 +464,10 @@ def as_sql(self, pb: ParamBuilder, table_alias: str = "calls_merged") -> list[st ) ) - should_optimize = ( + should_optimize = self.should_do_two_step_query() or ( has_heavy_fields and do_predicate_pushdown - ) or self.include_costs - if not should_optimize: + ) + if not should_optimize and not self.include_costs: return [self._as_sql_base_format(pb, table_alias)] # Build the two queries