Skip to content

Commit

Permalink
add tests and rename -- review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
gtarpenning committed Dec 13, 2024
1 parent 5acdf01 commit 1857df2
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 85 deletions.
177 changes: 166 additions & 11 deletions tests/trace/test_objects_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

from weave.trace_server import trace_server_interface as tsi
from weave.trace_server.objects_query_builder import (
ObjectQueryBuilder,
ObjectMetadataQueryBuilder,
_make_conditions_part,
_make_limit_part,
_make_object_id_conditions_part,
_make_offset_part,
_make_sort_part,
make_objects_val_query_and_parameters,
)


Expand Down Expand Up @@ -62,35 +63,35 @@ def test_make_object_id_conditions_part():


def test_object_query_builder_basic():
builder = ObjectQueryBuilder(project_id="test_project")
builder = ObjectMetadataQueryBuilder(project_id="test_project")
assert "project_id = {project_id: String}" in builder.make_metadata_query()
assert builder.parameters["project_id"] == "test_project"


def test_object_query_builder_add_digest_condition():
builder = ObjectQueryBuilder(project_id="test_project")
builder = ObjectMetadataQueryBuilder(project_id="test_project")

# Test latest digest
builder.add_digest_condition("latest")
assert "is_latest = 1" in builder.conditions_part

# Test specific digest
builder = ObjectQueryBuilder(project_id="test_project")
builder = ObjectMetadataQueryBuilder(project_id="test_project")
builder.add_digest_condition("abc123")
assert "digest = {version_digest: String}" in builder.conditions_part
assert builder.parameters["version_digest"] == "abc123"


def test_object_query_builder_add_object_ids_condition():
builder = ObjectQueryBuilder(project_id="test_project")
builder = ObjectMetadataQueryBuilder(project_id="test_project")

# Test single object ID
builder.add_object_ids_condition(["obj1"])
assert "object_id = {object_ids: String}" in builder.object_id_conditions_part
assert builder.parameters["object_ids"] == "obj1"
assert "object_id = {object_id: String}" in builder.object_id_conditions_part
assert builder.parameters["object_id"] == "obj1"

# Test multiple object IDs
builder = ObjectQueryBuilder(project_id="test_project")
builder = ObjectMetadataQueryBuilder(project_id="test_project")
builder.add_object_ids_condition(["obj1", "obj2"])
assert (
"object_id IN {object_ids: Array(String)}" in builder.object_id_conditions_part
Expand All @@ -99,13 +100,13 @@ def test_object_query_builder_add_object_ids_condition():


def test_object_query_builder_add_is_op_condition():
builder = ObjectQueryBuilder(project_id="test_project")
builder = ObjectMetadataQueryBuilder(project_id="test_project")
builder.add_is_op_condition(True)
assert "is_op = 1" in builder.conditions_part


def test_object_query_builder_limit_offset():
builder = ObjectQueryBuilder(project_id="test_project")
builder = ObjectMetadataQueryBuilder(project_id="test_project")
assert builder.limit_part == ""
assert builder.offset_part == ""

Expand All @@ -126,9 +127,163 @@ def test_object_query_builder_limit_offset():


def test_object_query_builder_sort():
builder = ObjectQueryBuilder(project_id="test_project")
builder = ObjectMetadataQueryBuilder(project_id="test_project")
builder.add_order("created_at", "DESC")
assert builder.sort_part == "ORDER BY created_at DESC"

with pytest.raises(ValueError):
builder.add_order("created_at", "INVALID")


STATIC_METADATA_QUERY_PART = """
SELECT
project_id,
object_id,
created_at,
refs,
kind,
base_object_class,
digest,
version_index,
is_latest,
version_count,
is_op
FROM (
SELECT
project_id,
object_id,
created_at,
kind,
base_object_class,
refs,
digest,
is_op,
row_number() OVER (
PARTITION BY project_id,
kind,
object_id
ORDER BY created_at ASC
) - 1 AS version_index,
count(*) OVER (PARTITION BY project_id, kind, object_id) as version_count,
if(version_index + 1 = version_count, 1, 0) AS is_latest
FROM (
SELECT
project_id,
object_id,
created_at,
kind,
base_object_class,
refs,
digest,
if (kind = 'op', 1, 0) AS is_op,
row_number() OVER (
PARTITION BY project_id,
kind,
object_id,
digest
ORDER BY created_at ASC
) AS rn
FROM object_versions"""


def test_object_query_builder_metadata_query_basic():
builder = ObjectMetadataQueryBuilder(project_id="test_project")
builder.add_digest_condition("latest")

query = builder.make_metadata_query()
parameters = builder.parameters

expected_query = f"""{STATIC_METADATA_QUERY_PART}
WHERE project_id = {{project_id: String}}
)
WHERE rn = 1
)
WHERE is_latest = 1"""

assert query == expected_query
assert parameters == {"project_id": "test_project"}


def test_object_query_builder_metadata_query_with_limit_offset_sort():
builder = ObjectMetadataQueryBuilder(project_id="test_project")

limit = 10
offset = 5

builder.set_limit(limit)
builder.set_offset(offset)
builder.add_order("created_at", "desc")
builder.add_object_ids_condition(["object_1"])
builder.add_digest_condition("digestttttttttttttttt")
builder.add_base_object_classes_condition(["Model", "Model2"])

query = builder.make_metadata_query()
parameters = builder.parameters

expected_query = f"""{STATIC_METADATA_QUERY_PART}
WHERE project_id = {{project_id: String}} AND object_id = {{object_id: String}}
)
WHERE rn = 1
)
WHERE ((digest = {{version_digest: String}}) AND (base_object_class IN {{base_object_classes: Array(String)}}))
ORDER BY created_at DESC
LIMIT 10
OFFSET 5"""

assert query == expected_query
assert parameters == {
"project_id": "test_project",
"object_id": "object_1",
"version_digest": "digestttttttttttttttt",
"base_object_classes": ["Model", "Model2"],
}


def test_objects_query_metadata_op():
builder = ObjectMetadataQueryBuilder(project_id="test_project")
builder.add_is_op_condition(True)
builder.add_object_ids_condition(["my_op"])
builder.add_digest_condition("v3", "vvvvvversion")

query = builder.make_metadata_query()
parameters = builder.parameters

expected_query = f"""{STATIC_METADATA_QUERY_PART}
WHERE project_id = {{project_id: String}} AND object_id = {{object_id: String}}
)
WHERE rn = 1
)
WHERE ((is_op = 1) AND (version_index = {{vvvvvversion: Int64}}))"""

assert query == expected_query
assert parameters == {
"project_id": "test_project",
"object_id": "my_op",
"vvvvvversion": 3,
}


def test_make_objects_val_query_and_parameters():
project_id = "test_project"
object_ids = ["object_1"]
digests = ["digestttttttttttttttt", "digestttttttttttttttt2"]

query, parameters = make_objects_val_query_and_parameters(
project_id, object_ids, digests
)

expected_query = """
SELECT object_id, digest, any(val_dump)
FROM object_versions
WHERE project_id = {project_id: String} AND
object_id IN {object_ids: Array(String)} AND
digest IN {digests: Array(String)}
GROUP BY object_id, digest
"""

assert query == expected_query
assert parameters == {
"project_id": "test_project",
"object_ids": ["object_1"],
"digests": ["digestttttttttttttttt", "digestttttttttttttttt2"],
}
23 changes: 12 additions & 11 deletions weave/trace_server/clickhouse_trace_server_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
read_model_to_provider_info_map,
)
from weave.trace_server.objects_query_builder import (
ObjectQueryBuilder,
ObjectMetadataQueryBuilder,
format_metadata_objects_from_query_result,
make_objects_val_query_and_parameters,
)
Expand Down Expand Up @@ -529,7 +529,7 @@ def op_create(self, req: tsi.OpCreateReq) -> tsi.OpCreateRes:
raise NotImplementedError()

def op_read(self, req: tsi.OpReadReq) -> tsi.OpReadRes:
object_query_builder = ObjectQueryBuilder(req.project_id)
object_query_builder = ObjectMetadataQueryBuilder(req.project_id)
object_query_builder.add_is_op_condition(True)
object_query_builder.add_digest_condition(req.digest)
object_query_builder.add_object_ids_condition([req.name], "op_name")
Expand All @@ -541,7 +541,7 @@ def op_read(self, req: tsi.OpReadReq) -> tsi.OpReadRes:
return tsi.OpReadRes(op_obj=_ch_obj_to_obj_schema(objs[0]))

def ops_query(self, req: tsi.OpQueryReq) -> tsi.OpQueryRes:
object_query_builder = ObjectQueryBuilder(req.project_id)
object_query_builder = ObjectMetadataQueryBuilder(req.project_id)
object_query_builder.add_is_op_condition(True)
if req.filter:
if req.filter.op_names:
Expand Down Expand Up @@ -581,7 +581,7 @@ def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes:
return tsi.ObjCreateRes(digest=digest)

def obj_read(self, req: tsi.ObjReadReq) -> tsi.ObjReadRes:
object_query_builder = ObjectQueryBuilder(req.project_id)
object_query_builder = ObjectMetadataQueryBuilder(req.project_id)
object_query_builder.add_digest_condition(req.digest)
object_query_builder.add_object_ids_condition([req.object_id])

Expand All @@ -592,7 +592,7 @@ def obj_read(self, req: tsi.ObjReadReq) -> tsi.ObjReadRes:
return tsi.ObjReadRes(obj=_ch_obj_to_obj_schema(objs[0]))

def objs_query(self, req: tsi.ObjQueryReq) -> tsi.ObjQueryRes:
object_query_builder = ObjectQueryBuilder(req.project_id)
object_query_builder = ObjectMetadataQueryBuilder(req.project_id)
if req.filter:
if req.filter.is_op is not None:
if req.filter.is_op:
Expand All @@ -609,17 +609,16 @@ def objs_query(self, req: tsi.ObjQueryReq) -> tsi.ObjQueryRes:
object_query_builder.add_base_object_classes_condition(
req.filter.base_object_classes
)
if req.metadata_only:
object_query_builder.set_metadata_only(True)
if req.limit is not None:
object_query_builder.set_limit(req.limit)
if req.offset is not None:
object_query_builder.set_offset(req.offset)
if req.sort_by:
for sort in req.sort_by:
object_query_builder.add_order(sort.field, sort.direction)
metadata_only = req.metadata_only or False

objs = self._select_objs_query(object_query_builder)
objs = self._select_objs_query(object_query_builder, metadata_only)

return tsi.ObjQueryRes(objs=[_ch_obj_to_obj_schema(obj) for obj in objs])

Expand Down Expand Up @@ -956,7 +955,7 @@ def get_object_refs_root_val(
if len(conds) > 0:
conditions = [combine_conditions(conds, "OR")]
object_id_conditions = [combine_conditions(object_id_conds, "OR")]
object_query_builder = ObjectQueryBuilder(
object_query_builder = ObjectMetadataQueryBuilder(
project_id=project_id_scope,
conditions=conditions,
object_id_conditions=object_id_conditions,
Expand Down Expand Up @@ -1539,7 +1538,9 @@ def _insert_call_batch(self, batch: list) -> None:
)

def _select_objs_query(
self, object_query_builder: ObjectQueryBuilder
self,
object_query_builder: ObjectMetadataQueryBuilder,
metadata_only: bool = False,
) -> list[SelectableCHObjSchema]:
"""
Main query for fetching objects.
Expand All @@ -1563,7 +1564,7 @@ def _select_objs_query(
metadata_result = format_metadata_objects_from_query_result(query_result)

# -- Don't make second query for object values if metadata_only --
if object_query_builder.metadata_only:
if metadata_only or len(metadata_result) == 0:
return metadata_result

value_query, value_parameters = make_objects_val_query_and_parameters(
Expand Down
Loading

0 comments on commit 1857df2

Please sign in to comment.