Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(weave): refactor objects query into a simple query builder #3223

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
289 changes: 289 additions & 0 deletions tests/trace/test_objects_query_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
import pytest

from weave.trace_server import trace_server_interface as tsi
from weave.trace_server.objects_query_builder import (
ObjectMetadataQueryBuilder,
_make_conditions_part,
_make_limit_part,
_make_object_id_conditions_part,
_make_offset_part,
_make_sort_part,
make_objects_val_query_and_parameters,
)


def test_make_limit_part():
gtarpenning marked this conversation as resolved.
Show resolved Hide resolved
assert _make_limit_part(None) == ""
assert _make_limit_part(10) == "LIMIT 10"
assert _make_limit_part(0) == "LIMIT 0"


def test_make_offset_part():
assert _make_offset_part(None) == ""
assert _make_offset_part(5) == "OFFSET 5"
assert _make_offset_part(0) == "OFFSET 0"


def test_make_sort_part():
assert _make_sort_part(None) == ""
assert _make_sort_part([]) == ""

sort_by = [tsi.SortBy(field="created_at", direction="asc")]
assert _make_sort_part(sort_by) == "ORDER BY created_at ASC"

sort_by = [
tsi.SortBy(field="created_at", direction="desc"),
tsi.SortBy(field="object_id", direction="asc"),
]
assert _make_sort_part(sort_by) == "ORDER BY created_at DESC, object_id ASC"

# Invalid sort fields should be ignored
sort_by = [tsi.SortBy(field="invalid_field", direction="asc")]
gtarpenning marked this conversation as resolved.
Show resolved Hide resolved
assert _make_sort_part(sort_by) == ""


def test_make_conditions_part():
assert _make_conditions_part(None) == ""
assert _make_conditions_part([]) == ""
assert _make_conditions_part(["condition1"]) == "WHERE condition1"
assert (
_make_conditions_part(["condition1", "condition2"])
== "WHERE ((condition1) AND (condition2))"
)


def test_make_object_id_conditions_part():
assert _make_object_id_conditions_part(None) == ""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like a weird function to me - it assumes that it always comes after a non-empty where clause?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its meant to be an internal fn but i see your point. i'll rename and add a bit of logic here

assert _make_object_id_conditions_part([]) == ""
assert _make_object_id_conditions_part(["id = 1"]) == " AND id = 1"
assert (
_make_object_id_conditions_part(["id = 1", "id = 2"])
== " AND ((id = 1) AND (id = 2))"
)


def test_object_query_builder_basic():
builder = ObjectMetadataQueryBuilder(project_id="test_project")
assert "project_id = {project_id: String}" in builder.make_metadata_query()
assert builder.parameters["project_id"] == "test_project"


def test_object_query_builder_add_digest_condition():
builder = ObjectMetadataQueryBuilder(project_id="test_project")

# Test latest digest
builder.add_digest_condition("latest")
assert "is_latest = 1" in builder.conditions_part

# Test specific digest
builder = ObjectMetadataQueryBuilder(project_id="test_project")
builder.add_digest_condition("abc123")
assert "digest = {version_digest: String}" in builder.conditions_part
assert builder.parameters["version_digest"] == "abc123"


def test_object_query_builder_add_object_ids_condition():
builder = ObjectMetadataQueryBuilder(project_id="test_project")

# Test single object ID
builder.add_object_ids_condition(["obj1"])
assert "object_id = {object_id: String}" in builder.object_id_conditions_part
assert builder.parameters["object_id"] == "obj1"

# Test multiple object IDs
builder = ObjectMetadataQueryBuilder(project_id="test_project")
builder.add_object_ids_condition(["obj1", "obj2"])
assert (
"object_id IN {object_ids: Array(String)}" in builder.object_id_conditions_part
)
assert builder.parameters["object_ids"] == ["obj1", "obj2"]


def test_object_query_builder_add_is_op_condition():
builder = ObjectMetadataQueryBuilder(project_id="test_project")
builder.add_is_op_condition(True)
assert "is_op = 1" in builder.conditions_part


def test_object_query_builder_limit_offset():
builder = ObjectMetadataQueryBuilder(project_id="test_project")
assert builder.limit_part == ""
assert builder.offset_part == ""

builder.set_limit(10)
builder.set_offset(5)
assert builder.limit_part == "LIMIT 10"
assert builder.offset_part == "OFFSET 5"

# Test invalid values
with pytest.raises(ValueError):
builder.set_limit(-1)
with pytest.raises(ValueError):
builder.set_offset(-1)
with pytest.raises(ValueError):
builder.set_limit(5) # Limit already set
with pytest.raises(ValueError):
builder.set_offset(10) # Offset already set


def test_object_query_builder_sort():
builder = ObjectMetadataQueryBuilder(project_id="test_project")
builder.add_order("created_at", "DESC")
assert builder.sort_part == "ORDER BY created_at DESC"

with pytest.raises(ValueError):
builder.add_order("created_at", "INVALID")


STATIC_METADATA_QUERY_PART = """
SELECT
project_id,
object_id,
created_at,
refs,
kind,
base_object_class,
digest,
version_index,
is_latest,
version_count,
is_op
FROM (
SELECT
project_id,
object_id,
created_at,
kind,
base_object_class,
refs,
digest,
is_op,
row_number() OVER (
PARTITION BY project_id,
kind,
object_id
ORDER BY created_at ASC
) - 1 AS version_index,
count(*) OVER (PARTITION BY project_id, kind, object_id) as version_count,
if(version_index + 1 = version_count, 1, 0) AS is_latest
FROM (
SELECT
project_id,
object_id,
created_at,
kind,
base_object_class,
refs,
digest,
if (kind = 'op', 1, 0) AS is_op,
row_number() OVER (
PARTITION BY project_id,
kind,
object_id,
digest
ORDER BY created_at ASC
) AS rn
FROM object_versions"""


def test_object_query_builder_metadata_query_basic():
builder = ObjectMetadataQueryBuilder(project_id="test_project")
builder.add_digest_condition("latest")

query = builder.make_metadata_query()
parameters = builder.parameters

expected_query = f"""{STATIC_METADATA_QUERY_PART}
WHERE project_id = {{project_id: String}}
)
WHERE rn = 1
)
WHERE is_latest = 1"""

assert query == expected_query
assert parameters == {"project_id": "test_project"}


def test_object_query_builder_metadata_query_with_limit_offset_sort():
builder = ObjectMetadataQueryBuilder(project_id="test_project")

limit = 10
offset = 5

builder.set_limit(limit)
builder.set_offset(offset)
builder.add_order("created_at", "desc")
builder.add_object_ids_condition(["object_1"])
builder.add_digest_condition("digestttttttttttttttt")
builder.add_base_object_classes_condition(["Model", "Model2"])

query = builder.make_metadata_query()
parameters = builder.parameters

expected_query = f"""{STATIC_METADATA_QUERY_PART}
WHERE project_id = {{project_id: String}} AND object_id = {{object_id: String}}
)
WHERE rn = 1
)
WHERE ((digest = {{version_digest: String}}) AND (base_object_class IN {{base_object_classes: Array(String)}}))
ORDER BY created_at DESC
LIMIT 10
OFFSET 5"""

assert query == expected_query
assert parameters == {
"project_id": "test_project",
"object_id": "object_1",
"version_digest": "digestttttttttttttttt",
"base_object_classes": ["Model", "Model2"],
}


def test_objects_query_metadata_op():
builder = ObjectMetadataQueryBuilder(project_id="test_project")
builder.add_is_op_condition(True)
builder.add_object_ids_condition(["my_op"])
builder.add_digest_condition("v3", "vvvvvversion")

query = builder.make_metadata_query()
parameters = builder.parameters

expected_query = f"""{STATIC_METADATA_QUERY_PART}
WHERE project_id = {{project_id: String}} AND object_id = {{object_id: String}}
)
WHERE rn = 1
)
WHERE ((is_op = 1) AND (version_index = {{vvvvvversion: Int64}}))"""

assert query == expected_query
assert parameters == {
"project_id": "test_project",
"object_id": "my_op",
"vvvvvversion": 3,
}


def test_make_objects_val_query_and_parameters():
project_id = "test_project"
object_ids = ["object_1"]
digests = ["digestttttttttttttttt", "digestttttttttttttttt2"]

query, parameters = make_objects_val_query_and_parameters(
project_id, object_ids, digests
)

expected_query = """
SELECT object_id, digest, any(val_dump)
FROM object_versions
WHERE project_id = {project_id: String} AND
object_id IN {object_ids: Array(String)} AND
digest IN {digests: Array(String)}
GROUP BY object_id, digest
"""

assert query == expected_query
assert parameters == {
"project_id": "test_project",
"object_ids": ["object_1"],
"digests": ["digestttttttttttttttt", "digestttttttttttttttt2"],
}
Loading
Loading