Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(weave): split calls query when using a heavy condition #2779

Merged
merged 8 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 123 additions & 27 deletions tests/trace_server/test_calls_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,27 @@
from weave.trace_server.orm import ParamBuilder


def assert_sql(cq: CallsQuery, exp_queries, exp_params):
pb = ParamBuilder("pb")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

every time I see this I think protobuf 🙃

queries = cq.as_sql(pb)
params = pb.get_params()

for qr, qe in zip(queries, exp_queries):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are qr, qe?

Copy link
Member Author

@gtarpenning gtarpenning Oct 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

queryRequest, queryExpected

queryReal?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idk

exp_formatted = sqlparse.format(qe, reindent=True)
found_formatted = sqlparse.format(qr, reindent=True)

assert exp_formatted == found_formatted

assert exp_params == params


def test_query_baseline() -> None:
cq = CallsQuery(project_id="project")
cq.add_field("id")
assert_sql(
cq,
"""
[
"""
SELECT calls_merged.id AS id
FROM calls_merged
WHERE project_id = {pb_0:String}
Expand All @@ -27,7 +42,8 @@ def test_query_baseline() -> None:
))
))
)
""",
"""
],
{"pb_0": "project"},
)

Expand All @@ -38,7 +54,8 @@ def test_query_light_column() -> None:
cq.add_field("started_at")
assert_sql(
cq,
"""
[
"""
SELECT
calls_merged.id AS id,
any(calls_merged.started_at) AS started_at
Expand All @@ -56,7 +73,8 @@ def test_query_light_column() -> None:
))
))
)
""",
"""
],
{"pb_0": "project"},
)

Expand All @@ -67,7 +85,8 @@ def test_query_heavy_column() -> None:
cq.add_field("inputs")
assert_sql(
cq,
"""
[
"""
SELECT
calls_merged.id AS id,
any(calls_merged.inputs_dump) AS inputs_dump
Expand All @@ -85,7 +104,8 @@ def test_query_heavy_column() -> None:
))
))
)
""",
"""
],
{"pb_0": "project"},
)

Expand All @@ -103,7 +123,8 @@ def test_query_heavy_column_simple_filter() -> None:
)
assert_sql(
cq,
"""
[
"""
WITH filtered_calls AS (
SELECT
calls_merged.id AS id
Expand All @@ -125,7 +146,8 @@ def test_query_heavy_column_simple_filter() -> None:
AND
(id IN filtered_calls)
GROUP BY (project_id,id)
""",
"""
],
{"pb_0": ["a", "b"], "pb_1": "project", "pb_2": "project"},
)

Expand All @@ -144,7 +166,8 @@ def test_query_heavy_column_simple_filter_with_order() -> None:
)
assert_sql(
cq,
"""
[
"""
WITH filtered_calls AS (
SELECT
calls_merged.id AS id
Expand All @@ -167,7 +190,8 @@ def test_query_heavy_column_simple_filter_with_order() -> None:
(id IN filtered_calls)
GROUP BY (project_id,id)
ORDER BY any(calls_merged.started_at) DESC
""",
"""
],
{"pb_0": ["a", "b"], "pb_1": "project", "pb_2": "project"},
)

Expand All @@ -187,7 +211,8 @@ def test_query_heavy_column_simple_filter_with_order_and_limit() -> None:
)
assert_sql(
cq,
"""
[
"""
WITH filtered_calls AS (
SELECT
calls_merged.id AS id
Expand All @@ -214,7 +239,8 @@ def test_query_heavy_column_simple_filter_with_order_and_limit() -> None:
(id IN filtered_calls)
GROUP BY (project_id,id)
ORDER BY any(calls_merged.started_at) DESC
""",
"""
],
{"pb_0": ["a", "b"], "pb_1": "project", "pb_2": "project"},
)

Expand Down Expand Up @@ -253,7 +279,8 @@ def test_query_heavy_column_simple_filter_with_order_and_limit_and_mixed_query_c
)
assert_sql(
cq,
"""
[
"""
WITH filtered_calls AS (
SELECT
calls_merged.id AS id
Expand Down Expand Up @@ -284,7 +311,8 @@ def test_query_heavy_column_simple_filter_with_order_and_limit_and_mixed_query_c
)
ORDER BY any(calls_merged.started_at) DESC
LIMIT 10
""",
"""
],
{
"pb_0": "my_user_id",
"pb_1": ["a", "b"],
Expand All @@ -296,17 +324,83 @@ def test_query_heavy_column_simple_filter_with_order_and_limit_and_mixed_query_c
)


def assert_sql(cq: CallsQuery, exp_query, exp_params):
pb = ParamBuilder("pb")
query = cq.as_sql(pb)
params = pb.get_params()

assert exp_params == params

exp_formatted = sqlparse.format(exp_query, reindent=True)
found_formatted = sqlparse.format(query, reindent=True)

assert exp_formatted == found_formatted
def test_query_heavy_column_simple_filter_with_order_and_limit_and_mixed_query_conditions_two_step() -> (
None
):
cq = CallsQuery(project_id="project")
cq.add_field("id")
cq.add_field("inputs")
cq.add_order("started_at", "desc")
cq.set_limit(10)
cq.set_hardcoded_filter(
HardCodedFilter(
filter=tsi.CallsFilter(
op_names=["a", "b"],
)
)
)
cq.add_condition(
tsi_query.AndOperation.model_validate(
{
"$and": [
{
"$eq": [
{"$getField": "inputs.param.val"},
{"$literal": "hello"},
]
}, # <-- heavy condition
{
"$eq": [{"$getField": "wb_user_id"}, {"$literal": "my_user_id"}]
}, # <-- light condition
]
}
)
)
cq.set_filtered_output_param("filtered_calls")
assert_sql(
cq,
[
"""
SELECT
calls_merged.id AS id
FROM calls_merged
WHERE project_id = {pb_2:String}
GROUP BY (project_id,id)
HAVING (
((any(calls_merged.wb_user_id) = {pb_0:String}))
AND
((any(calls_merged.deleted_at) IS NULL))
AND
((NOT ((any(calls_merged.started_at) IS NULL))))
AND
(any(calls_merged.op_name) IN {pb_1:Array(String)})
)""",
"""
SELECT
calls_merged.id AS id,
any(calls_merged.inputs_dump) AS inputs_dump
FROM calls_merged
WHERE
project_id = {pb_5:String}
AND
(id IN {filtered_calls:Array(String)})
GROUP BY (project_id,id)
HAVING (
JSON_VALUE(any(calls_merged.inputs_dump), {pb_3:String}) = {pb_4:String}
)
ORDER BY any(calls_merged.started_at) DESC
LIMIT 10
""",
],
{
"pb_0": "my_user_id",
"pb_1": ["a", "b"],
"pb_2": "project",
"pb_3": '$."param"."val"',
"pb_4": "hello",
"pb_5": "project",
},
)


def test_query_light_column_with_costs() -> None:
Expand All @@ -324,7 +418,8 @@ def test_query_light_column_with_costs() -> None:
)
assert_sql(
cq,
"""
[
"""
WITH
filtered_calls AS (
SELECT calls_merged.id AS id
Expand Down Expand Up @@ -436,7 +531,8 @@ def test_query_light_column_with_costs() -> None:
FROM ranked_prices
WHERE (rank = {pb_3:UInt64})
GROUP BY id, started_at
""",
"""
],
{
"pb_0": ["a", "b"],
"pb_1": "UHJvamVjdEludGVybmFsSWQ6Mzk1NDg2Mjc=",
Expand Down
Loading
Loading