From 929daf5fefcb4152496ea49443ac072cb568bf38 Mon Sep 17 00:00:00 2001 From: Griffin Tarpenning Date: Tue, 17 Dec 2024 15:28:20 -0800 Subject: [PATCH] chore(weave): refactor objects query into a simple query builder (#3223) --- tests/trace/test_objects_query_builder.py | 291 +++++++++++++++++ .../clickhouse_trace_server_batched.py | 274 ++++------------ weave/trace_server/objects_query_builder.py | 305 ++++++++++++++++++ 3 files changed, 656 insertions(+), 214 deletions(-) create mode 100644 tests/trace/test_objects_query_builder.py create mode 100644 weave/trace_server/objects_query_builder.py diff --git a/tests/trace/test_objects_query_builder.py b/tests/trace/test_objects_query_builder.py new file mode 100644 index 00000000000..77662655b19 --- /dev/null +++ b/tests/trace/test_objects_query_builder.py @@ -0,0 +1,291 @@ +import pytest + +from weave.trace_server import trace_server_interface as tsi +from weave.trace_server.objects_query_builder import ( + ObjectMetadataQueryBuilder, + _make_conditions_part, + _make_limit_part, + _make_object_id_conditions_part, + _make_offset_part, + _make_sort_part, + make_objects_val_query_and_parameters, +) + + +def test_make_limit_part(): + assert _make_limit_part(None) == "" + assert _make_limit_part(10) == "LIMIT 10" + assert _make_limit_part(0) == "LIMIT 0" + + +def test_make_offset_part(): + assert _make_offset_part(None) == "" + assert _make_offset_part(5) == "OFFSET 5" + assert _make_offset_part(0) == "OFFSET 0" + + +def test_make_sort_part(): + assert _make_sort_part(None) == "" + assert _make_sort_part([]) == "" + + sort_by = [tsi.SortBy(field="created_at", direction="asc")] + assert _make_sort_part(sort_by) == "ORDER BY created_at ASC" + + sort_by = [ + tsi.SortBy(field="created_at", direction="desc"), + tsi.SortBy(field="object_id", direction="asc"), + ] + assert _make_sort_part(sort_by) == "ORDER BY created_at DESC, object_id ASC" + + # Invalid sort fields should be ignored + sort_by = [tsi.SortBy(field="invalid_field", direction="asc")] + assert _make_sort_part(sort_by) == "" + + +def test_make_conditions_part(): + assert _make_conditions_part(None) == "" + assert _make_conditions_part([]) == "" + assert _make_conditions_part(["condition1"]) == "WHERE condition1" + assert ( + _make_conditions_part(["condition1", "condition2"]) + == "WHERE ((condition1) AND (condition2))" + ) + + +def test_make_object_id_conditions_part(): + assert _make_object_id_conditions_part(None) == "" + assert _make_object_id_conditions_part([]) == "" + assert _make_object_id_conditions_part(["id = 1"]) == " AND id = 1" + assert ( + _make_object_id_conditions_part(["id = 1", "id = 2"]) + == " AND ((id = 1) AND (id = 2))" + ) + + +def test_object_query_builder_basic(): + builder = ObjectMetadataQueryBuilder(project_id="test_project") + assert "project_id = {project_id: String}" in builder.make_metadata_query() + assert builder.parameters["project_id"] == "test_project" + + +def test_object_query_builder_add_digest_condition(): + builder = ObjectMetadataQueryBuilder(project_id="test_project") + + # Test latest digest + builder.add_digests_conditions("latest") + assert "is_latest = 1" in builder.conditions_part + + # Test specific digest + builder = ObjectMetadataQueryBuilder(project_id="test_project") + builder.add_digests_conditions("abc123") + assert "digest = {version_digest_0: String}" in builder.conditions_part + assert builder.parameters["version_digest_0"] == "abc123" + + +def test_object_query_builder_add_object_ids_condition(): + builder = ObjectMetadataQueryBuilder(project_id="test_project") + + # Test single object ID + builder.add_object_ids_condition(["obj1"]) + assert "object_id = {object_id: String}" in builder.object_id_conditions_part + assert builder.parameters["object_id"] == "obj1" + + # Test multiple object IDs + builder = ObjectMetadataQueryBuilder(project_id="test_project") + builder.add_object_ids_condition(["obj1", "obj2"]) + assert ( + "object_id IN {object_ids: Array(String)}" in builder.object_id_conditions_part + ) + assert builder.parameters["object_ids"] == ["obj1", "obj2"] + + +def test_object_query_builder_add_is_op_condition(): + builder = ObjectMetadataQueryBuilder(project_id="test_project") + builder.add_is_op_condition(True) + assert "is_op = 1" in builder.conditions_part + + +def test_object_query_builder_limit_offset(): + builder = ObjectMetadataQueryBuilder(project_id="test_project") + assert builder.limit_part == "" + assert builder.offset_part == "" + + builder.set_limit(10) + builder.set_offset(5) + assert builder.limit_part == "LIMIT 10" + assert builder.offset_part == "OFFSET 5" + + # Test invalid values + with pytest.raises(ValueError): + builder.set_limit(-1) + with pytest.raises(ValueError): + builder.set_offset(-1) + with pytest.raises(ValueError): + builder.set_limit(5) # Limit already set + with pytest.raises(ValueError): + builder.set_offset(10) # Offset already set + + +def test_object_query_builder_sort(): + builder = ObjectMetadataQueryBuilder(project_id="test_project") + builder.add_order("created_at", "DESC") + assert builder.sort_part == "ORDER BY created_at DESC" + + with pytest.raises(ValueError): + builder.add_order("created_at", "INVALID") + + +STATIC_METADATA_QUERY_PART = """ +SELECT + project_id, + object_id, + created_at, + refs, + kind, + base_object_class, + digest, + version_index, + is_latest, + version_count, + is_op +FROM ( + SELECT + project_id, + object_id, + created_at, + kind, + base_object_class, + refs, + digest, + is_op, + row_number() OVER ( + PARTITION BY project_id, + kind, + object_id + ORDER BY created_at ASC + ) - 1 AS version_index, + count(*) OVER (PARTITION BY project_id, kind, object_id) as version_count, + if(version_index + 1 = version_count, 1, 0) AS is_latest + FROM ( + SELECT + project_id, + object_id, + created_at, + kind, + base_object_class, + refs, + digest, + if (kind = 'op', 1, 0) AS is_op, + row_number() OVER ( + PARTITION BY project_id, + kind, + object_id, + digest + ORDER BY created_at ASC + ) AS rn + FROM object_versions""" + + +def test_object_query_builder_metadata_query_basic(): + builder = ObjectMetadataQueryBuilder(project_id="test_project") + builder.add_digests_conditions("latest") + + query = builder.make_metadata_query() + parameters = builder.parameters + + expected_query = f"""{STATIC_METADATA_QUERY_PART} + WHERE project_id = {{project_id: String}} + ) + WHERE rn = 1 +) +WHERE is_latest = 1""" + + assert query == expected_query + assert parameters == {"project_id": "test_project"} + + +def test_object_query_builder_metadata_query_with_limit_offset_sort(): + builder = ObjectMetadataQueryBuilder(project_id="test_project") + + limit = 10 + offset = 5 + + builder.set_limit(limit) + builder.set_offset(offset) + builder.add_order("created_at", "desc") + builder.add_object_ids_condition(["object_1"]) + builder.add_digests_conditions("digestttttttttttttttt", "another-one", "v2") + builder.add_base_object_classes_condition(["Model", "Model2"]) + + query = builder.make_metadata_query() + parameters = builder.parameters + + expected_query = f"""{STATIC_METADATA_QUERY_PART} + WHERE project_id = {{project_id: String}} AND object_id = {{object_id: String}} + ) + WHERE rn = 1 +) +WHERE ((((digest = {{version_digest_0: String}}) OR (digest = {{version_digest_1: String}}) OR (version_index = {{version_index_2: Int64}}))) AND (base_object_class IN {{base_object_classes: Array(String)}})) +ORDER BY created_at DESC +LIMIT 10 +OFFSET 5""" + + assert query == expected_query + assert parameters == { + "project_id": "test_project", + "object_id": "object_1", + "version_digest_0": "digestttttttttttttttt", + "version_digest_1": "another-one", + "version_index_2": 2, + "base_object_classes": ["Model", "Model2"], + } + + +def test_objects_query_metadata_op(): + builder = ObjectMetadataQueryBuilder(project_id="test_project") + builder.add_is_op_condition(True) + builder.add_object_ids_condition(["my_op"]) + builder.add_digests_conditions("v3") + + query = builder.make_metadata_query() + parameters = builder.parameters + + expected_query = f"""{STATIC_METADATA_QUERY_PART} + WHERE project_id = {{project_id: String}} AND object_id = {{object_id: String}} + ) + WHERE rn = 1 +) +WHERE ((is_op = 1) AND (version_index = {{version_index_0: Int64}}))""" + + assert query == expected_query + assert parameters == { + "project_id": "test_project", + "object_id": "my_op", + "version_index_0": 3, + } + + +def test_make_objects_val_query_and_parameters(): + project_id = "test_project" + object_ids = ["object_1"] + digests = ["digestttttttttttttttt", "digestttttttttttttttt2"] + + query, parameters = make_objects_val_query_and_parameters( + project_id, object_ids, digests + ) + + expected_query = """ + SELECT object_id, digest, any(val_dump) + FROM object_versions + WHERE project_id = {project_id: String} AND + object_id IN {object_ids: Array(String)} AND + digest IN {digests: Array(String)} + GROUP BY object_id, digest + """ + + assert query == expected_query + assert parameters == { + "project_id": "test_project", + "object_ids": ["object_1"], + "digests": ["digestttttttttttttttt", "digestttttttttttttttt2"], + } diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index 31e974a3004..973d51080e2 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -81,6 +81,11 @@ read_model_to_provider_info_map, ) from weave.trace_server.object_class_util import process_incoming_object_val +from weave.trace_server.objects_query_builder import ( + ObjectMetadataQueryBuilder, + format_metadata_objects_from_query_result, + make_objects_val_query_and_parameters, +) from weave.trace_server.orm import ParamBuilder, Row from weave.trace_server.secret_fetcher_context import _secret_fetcher_context from weave.trace_server.table_query_builder import ( @@ -97,7 +102,6 @@ from weave.trace_server.trace_server_common import ( DynamicBatchProcessor, LRUCache, - digest_is_version_like, empty_str_to_none, get_nested_key, hydrate_calls_with_feedback, @@ -525,40 +529,29 @@ def op_create(self, req: tsi.OpCreateReq) -> tsi.OpCreateRes: raise NotImplementedError() def op_read(self, req: tsi.OpReadReq) -> tsi.OpReadRes: - conds = [ - "is_op = 1", - "digest = {digest: String}", - ] - object_id_conditions = ["object_id = {object_id: String}"] - parameters = {"name": req.name, "digest": req.digest} - objs = self._select_objs_query( - req.project_id, - conditions=conds, - object_id_conditions=object_id_conditions, - parameters=parameters, - ) + object_query_builder = ObjectMetadataQueryBuilder(req.project_id) + object_query_builder.add_is_op_condition(True) + object_query_builder.add_digests_conditions(req.digest) + object_query_builder.add_object_ids_condition([req.name], "op_name") + + objs = self._select_objs_query(object_query_builder) if len(objs) == 0: raise NotFoundError(f"Obj {req.name}:{req.digest} not found") return tsi.OpReadRes(op_obj=_ch_obj_to_obj_schema(objs[0])) def ops_query(self, req: tsi.OpQueryReq) -> tsi.OpQueryRes: - parameters = {} - conds: list[str] = ["is_op = 1"] - object_id_conditions: list[str] = [] + object_query_builder = ObjectMetadataQueryBuilder(req.project_id) + object_query_builder.add_is_op_condition(True) if req.filter: if req.filter.op_names: - object_id_conditions.append("object_id IN {op_names: Array(String)}") - parameters["op_names"] = req.filter.op_names + object_query_builder.add_object_ids_condition( + req.filter.op_names, "op_names" + ) if req.filter.latest_only: - conds.append("is_latest = 1") + object_query_builder.add_is_latest_condition() - ch_objs = self._select_objs_query( - req.project_id, - conditions=conds, - object_id_conditions=object_id_conditions, - parameters=parameters, - ) + ch_objs = self._select_objs_query(object_query_builder) objs = [_ch_obj_to_obj_schema(call) for call in ch_objs] return tsi.OpQueryRes(op_objs=objs) @@ -588,61 +581,44 @@ def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes: return tsi.ObjCreateRes(digest=digest) def obj_read(self, req: tsi.ObjReadReq) -> tsi.ObjReadRes: - conds: list[str] = [] - object_id_conditions = ["object_id = {object_id: String}"] - parameters: dict[str, Union[str, int]] = {"object_id": req.object_id} - if req.digest == "latest": - conds.append("is_latest = 1") - else: - (is_version, version_index) = digest_is_version_like(req.digest) - if is_version: - conds.append("version_index = {version_index: UInt64}") - parameters["version_index"] = version_index - else: - conds.append("digest = {version_digest: String}") - parameters["version_digest"] = req.digest - objs = self._select_objs_query( - req.project_id, - conditions=conds, - object_id_conditions=object_id_conditions, - parameters=parameters, - ) + object_query_builder = ObjectMetadataQueryBuilder(req.project_id) + object_query_builder.add_digests_conditions(req.digest) + object_query_builder.add_object_ids_condition([req.object_id]) + + objs = self._select_objs_query(object_query_builder) if len(objs) == 0: raise NotFoundError(f"Obj {req.object_id}:{req.digest} not found") return tsi.ObjReadRes(obj=_ch_obj_to_obj_schema(objs[0])) def objs_query(self, req: tsi.ObjQueryReq) -> tsi.ObjQueryRes: - conds: list[str] = [] - object_id_conditions: list[str] = [] - parameters = {} + object_query_builder = ObjectMetadataQueryBuilder(req.project_id) if req.filter: if req.filter.is_op is not None: if req.filter.is_op: - conds.append("is_op = 1") + object_query_builder.add_is_op_condition(True) else: - conds.append("is_op = 0") + object_query_builder.add_is_op_condition(False) if req.filter.object_ids: - object_id_conditions.append("object_id IN {object_ids: Array(String)}") - parameters["object_ids"] = req.filter.object_ids + object_query_builder.add_object_ids_condition( + req.filter.object_ids, "object_ids" + ) if req.filter.latest_only: - conds.append("is_latest = 1") + object_query_builder.add_is_latest_condition() if req.filter.base_object_classes: - conds.append( - "base_object_class IN {base_object_classes: Array(String)}" + object_query_builder.add_base_object_classes_condition( + req.filter.base_object_classes ) - parameters["base_object_classes"] = req.filter.base_object_classes + if req.limit is not None: + object_query_builder.set_limit(req.limit) + if req.offset is not None: + object_query_builder.set_offset(req.offset) + if req.sort_by: + for sort in req.sort_by: + object_query_builder.add_order(sort.field, sort.direction) + metadata_only = req.metadata_only or False - objs = self._select_objs_query( - req.project_id, - conditions=conds, - object_id_conditions=object_id_conditions, - parameters=parameters, - metadata_only=req.metadata_only, - limit=req.limit, - offset=req.offset, - sort_by=req.sort_by, - ) + objs = self._select_objs_query(object_query_builder, metadata_only) return tsi.ObjQueryRes(objs=[_ch_obj_to_obj_schema(obj) for obj in objs]) @@ -950,7 +926,7 @@ def get_object_refs_root_val( ) -> Any: conds: list[str] = [] object_id_conds: list[str] = [] - parameters = {} + parameters: dict[str, Union[str, int]] = {} for ref_index, ref in enumerate(refs): if ref.version == "latest": @@ -979,12 +955,13 @@ def get_object_refs_root_val( if len(conds) > 0: conditions = [combine_conditions(conds, "OR")] object_id_conditions = [combine_conditions(object_id_conds, "OR")] - objs = self._select_objs_query( + object_query_builder = ObjectMetadataQueryBuilder( project_id=project_id_scope, conditions=conditions, object_id_conditions=object_id_conditions, parameters=parameters, ) + objs = self._select_objs_query(object_query_builder) for obj in objs: root_val_cache[make_obj_cache_key(obj)] = json.loads(obj.val_dump) @@ -1562,14 +1539,8 @@ def _insert_call_batch(self, batch: list) -> None: def _select_objs_query( self, - project_id: str, - conditions: Optional[list[str]] = None, - object_id_conditions: Optional[list[str]] = None, - parameters: Optional[dict[str, Any]] = None, - metadata_only: Optional[bool] = False, - limit: Optional[int] = None, - offset: Optional[int] = None, - sort_by: Optional[list[tsi.SortBy]] = None, + object_query_builder: ObjectMetadataQueryBuilder, + metadata_only: bool = False, ) -> list[SelectableCHObjSchema]: """ Main query for fetching objects. @@ -1587,146 +1558,21 @@ def _select_objs_query( if metadata_only is True, then we return early and dont grab the value. Otherwise, make a second query to grab the val_dump from the db """ - if not conditions: - conditions = ["1 = 1"] - if not object_id_conditions: - object_id_conditions = ["1 = 1"] - - conditions_part = combine_conditions(conditions, "AND") - object_id_conditions_part = combine_conditions(object_id_conditions, "AND") - - limit_part = "" - offset_part = "" - if limit is not None: - limit_part = f"LIMIT {int(limit)}" - if offset is not None: - offset_part = f" OFFSET {int(offset)}" - - sort_part = "" - if sort_by: - valid_sort_fields = {"object_id", "created_at"} - sort_clauses = [] - for sort in sort_by: - if sort.field in valid_sort_fields and sort.direction in { - "asc", - "desc", - }: - sort_clauses.append(f"{sort.field} {sort.direction.upper()}") - if sort_clauses: - sort_part = f"ORDER BY {', '.join(sort_clauses)}" - - if parameters is None: - parameters = {} - - select_without_val_dump_query = f""" - SELECT - project_id, - object_id, - created_at, - kind, - base_object_class, - refs, - digest, - is_op, - version_index, - version_count, - is_latest - FROM ( - SELECT project_id, - object_id, - created_at, - kind, - base_object_class, - refs, - digest, - is_op, - row_number() OVER ( - PARTITION BY project_id, - kind, - object_id - ORDER BY created_at ASC - ) - 1 AS version_index, - count(*) OVER (PARTITION BY project_id, kind, object_id) as version_count, - if(version_index + 1 = version_count, 1, 0) AS is_latest - FROM ( - SELECT project_id, - object_id, - created_at, - kind, - base_object_class, - refs, - digest, - if (kind = 'op', 1, 0) AS is_op, - row_number() OVER ( - PARTITION BY project_id, - kind, - object_id, - digest - ORDER BY created_at ASC - ) AS rn - FROM object_versions - WHERE project_id = {{project_id: String}} AND - {object_id_conditions_part} - ) - WHERE rn = 1 - ) - WHERE {conditions_part} - {sort_part} - {limit_part} - {offset_part} - """ - query_result = self._query_stream( - select_without_val_dump_query, - {"project_id": project_id, **parameters}, - ) - result: list[SelectableCHObjSchema] = [] - for row in query_result: - result.append( - SelectableCHObjSchema.model_validate( - dict( - zip( - [ - "project_id", - "object_id", - "created_at", - "kind", - "base_object_class", - "refs", - "digest", - "is_op", - "version_index", - "version_count", - "is_latest", - "val_dump", - ], - # Add an empty val_dump to the end of the row - list(row) + ["{}"], - ) - ) - ) - ) + obj_metadata_query = object_query_builder.make_metadata_query() + parameters = object_query_builder.parameters or {} + query_result = self._query_stream(obj_metadata_query, parameters) + metadata_result = format_metadata_objects_from_query_result(query_result) # -- Don't make second query for object values if metadata_only -- - if metadata_only: - return result + if metadata_only or len(metadata_result) == 0: + return metadata_result - # now get the val_dump for each object - object_ids = list({row.object_id for row in result}) - digests = list({row.digest for row in result}) - query = """ - SELECT object_id, digest, any(val_dump) - FROM object_versions - WHERE project_id = {project_id: String} AND - object_id IN {object_ids: Array(String)} AND - digest IN {digests: Array(String)} - GROUP BY object_id, digest - """ - parameters = { - "project_id": project_id, - "object_ids": object_ids, - "digests": digests, - } - query_result = self._query_stream(query, parameters) + value_query, value_parameters = make_objects_val_query_and_parameters( + project_id=object_query_builder.project_id, + object_ids=list({row.object_id for row in metadata_result}), + digests=list({row.digest for row in metadata_result}), + ) + query_result = self._query_stream(value_query, value_parameters) # Map (object_id, digest) to val_dump object_values: dict[tuple[str, str], Any] = {} for row in query_result: @@ -1734,9 +1580,9 @@ def _select_objs_query( object_values[(object_id, digest)] = val_dump # update the val_dump for each object - for obj in result: + for obj in metadata_result: obj.val_dump = object_values.get((obj.object_id, obj.digest), "{}") - return result + return metadata_result def _run_migrations(self) -> None: logger.info("Running migrations") diff --git a/weave/trace_server/objects_query_builder.py b/weave/trace_server/objects_query_builder.py new file mode 100644 index 00000000000..a5675c4b5c1 --- /dev/null +++ b/weave/trace_server/objects_query_builder.py @@ -0,0 +1,305 @@ +from collections.abc import Iterator +from typing import Any, Optional + +from weave.trace_server import trace_server_interface as tsi +from weave.trace_server.clickhouse_schema import SelectableCHObjSchema +from weave.trace_server.orm import combine_conditions +from weave.trace_server.trace_server_common import digest_is_version_like + +VALID_OBJECT_SORT_FIELDS = {"created_at", "object_id"} +VALID_SORT_DIRECTIONS = {"asc", "desc"} +OBJECT_METADATA_COLUMNS = [ + "project_id", + "object_id", + "created_at", + "refs", + "kind", + "base_object_class", + "digest", + "version_index", + "is_latest", + # columns not used in SelectableCHObjSchema(): + "version_count", + "is_op", +] + + +def _make_optional_part(query_keyword: str, part: Optional[str]) -> str: + if part is None or part == "": + return "" + return f"{query_keyword} {part}" + + +def _make_limit_part(limit: Optional[int]) -> str: + if limit is None: + return "" + return _make_optional_part("LIMIT", str(limit)) + + +def _make_offset_part(offset: Optional[int]) -> str: + if offset is None: + return "" + return _make_optional_part("OFFSET", str(offset)) + + +def _make_sort_part(sort_by: Optional[list[tsi.SortBy]]) -> str: + if not sort_by: + return "" + + sort_clauses = [] + for sort in sort_by: + if ( + sort.field in VALID_OBJECT_SORT_FIELDS + and sort.direction in VALID_SORT_DIRECTIONS + ): + sort_clause = f"{sort.field} {sort.direction.upper()}" + sort_clauses.append(sort_clause) + return _make_optional_part("ORDER BY", ", ".join(sort_clauses)) + + +def _make_conditions_part(conditions: Optional[list[str]]) -> str: + if not conditions: + return "" + conditions_str = combine_conditions(conditions, "AND") + return _make_optional_part("WHERE", conditions_str) + + +def _make_object_id_conditions_part( + object_id_conditions: Optional[list[str]], add_where_clause: bool = False +) -> str: + """ + Formats object_id_conditions into a query string. In this file is it only + used after the WHERE project_id... clause, but passing add_where_clause=True + adds a WHERE clause to the query string. + """ + if not object_id_conditions: + return "" + conditions_str = combine_conditions(object_id_conditions, "AND") + conditions_str_with_and = " " + _make_optional_part("AND", conditions_str) + if add_where_clause: + return "WHERE " + conditions_str_with_and + return conditions_str_with_and + + +def format_metadata_objects_from_query_result( + query_result: Iterator[tuple[Any, ...]], +) -> list[SelectableCHObjSchema]: + result = [] + for row in query_result: + # Add an empty val_dump to the end of the row + row_with_val_dump = row + ("{}",) + columns_with_val_dump = OBJECT_METADATA_COLUMNS + ["val_dump"] + row_dict = dict(zip(columns_with_val_dump, row_with_val_dump)) + row_model = SelectableCHObjSchema.model_validate(row_dict) + result.append(row_model) + return result + + +class ObjectMetadataQueryBuilder: + def __init__( + self, + project_id: str, + conditions: Optional[list[str]] = None, + object_id_conditions: Optional[list[str]] = None, + parameters: Optional[dict[str, Any]] = None, + ): + self.project_id = project_id + self.parameters: dict[str, Any] = parameters or {} + if not self.parameters.get(project_id): + self.parameters.update({"project_id": project_id}) + + self._conditions: list[str] = conditions or [] + self._object_id_conditions: list[str] = object_id_conditions or [] + self._limit: Optional[int] = None + self._offset: Optional[int] = None + self._sort_by: list[tsi.SortBy] = [] + + @property + def conditions_part(self) -> str: + return _make_conditions_part(self._conditions) + + @property + def object_id_conditions_part(self) -> str: + return _make_object_id_conditions_part(self._object_id_conditions) + + @property + def sort_part(self) -> str: + return _make_sort_part(self._sort_by) + + @property + def limit_part(self) -> str: + return _make_limit_part(self._limit) + + @property + def offset_part(self) -> str: + return _make_offset_part(self._offset) + + def _make_digest_condition( + self, digest: str, param_key: Optional[str] = None, index: Optional[int] = None + ) -> str: + """ + If digest is "latest", return the condition for the latest version. + Otherwise, return the condition for the version with the given digest. + If digest is a version like "v123", return the condition for the version + with the given version index. + If digest is a hash like "sha256" return the hash + Use index to make the param_key unique if there are multiple digests. + """ + if digest == "latest": + return "is_latest = 1" + + (is_version, version_index) = digest_is_version_like(digest) + if is_version: + param_key = param_key or "version_index" + return self._make_version_index_condition(version_index, param_key, index) + else: + param_key = param_key or "version_digest" + return self._make_version_digest_condition(digest, param_key, index) + + def _make_version_digest_condition( + self, digest: str, param_key: str, index: Optional[int] = None + ) -> str: + if index is not None: + param_key = f"{param_key}_{index}" + self.parameters.update({param_key: digest}) + return f"digest = {{{param_key}: String}}" + + def _make_version_index_condition( + self, version_index: int, param_key: str, index: Optional[int] = None + ) -> str: + if index is not None: + param_key = f"{param_key}_{index}" + self.parameters.update({param_key: version_index}) + return f"version_index = {{{param_key}: Int64}}" + + def add_digests_conditions(self, *digests: str) -> None: + digest_conditions = [] + for i, digest in enumerate(digests): + condition = self._make_digest_condition(digest, None, i) + digest_conditions.append(condition) + + digests_condition = combine_conditions(digest_conditions, "OR") + self._conditions.append(digests_condition) + + def add_object_ids_condition( + self, object_ids: list[str], param_key: Optional[str] = None + ) -> None: + if len(object_ids) == 1: + param_key = param_key or "object_id" + self._object_id_conditions.append(f"object_id = {{{param_key}: String}}") + self.parameters.update({param_key: object_ids[0]}) + else: + param_key = param_key or "object_ids" + self._object_id_conditions.append( + f"object_id IN {{{param_key}: Array(String)}}" + ) + self.parameters.update({param_key: object_ids}) + + def add_is_latest_condition(self) -> None: + self._conditions.append("is_latest = 1") + + def add_is_op_condition(self, is_op: bool) -> None: + if is_op: + self._conditions.append("is_op = 1") + else: + self._conditions.append("is_op = 0") + + def add_base_object_classes_condition(self, base_object_classes: list[str]) -> None: + self._conditions.append( + "base_object_class IN {base_object_classes: Array(String)}" + ) + self.parameters.update({"base_object_classes": base_object_classes}) + + def add_order(self, field: str, direction: str) -> None: + direction = direction.lower() + if direction not in ("asc", "desc"): + raise ValueError(f"Direction {direction} is not allowed") + self._sort_by.append(tsi.SortBy(field=field, direction=direction)) + + def set_limit(self, limit: int) -> None: + if limit < 0: + raise ValueError("Limit must be a positive integer") + if self._limit is not None: + raise ValueError("Limit can only be set once") + self._limit = limit + + def set_offset(self, offset: int) -> None: + if offset < 0: + raise ValueError("Offset must be a positive integer") + if self._offset is not None: + raise ValueError("Offset can only be set once") + self._offset = offset + + def make_metadata_query(self) -> str: + columns = ",\n ".join(OBJECT_METADATA_COLUMNS) + query = f""" +SELECT + {columns} +FROM ( + SELECT + project_id, + object_id, + created_at, + kind, + base_object_class, + refs, + digest, + is_op, + row_number() OVER ( + PARTITION BY project_id, + kind, + object_id + ORDER BY created_at ASC + ) - 1 AS version_index, + count(*) OVER (PARTITION BY project_id, kind, object_id) as version_count, + if(version_index + 1 = version_count, 1, 0) AS is_latest + FROM ( + SELECT + project_id, + object_id, + created_at, + kind, + base_object_class, + refs, + digest, + if (kind = 'op', 1, 0) AS is_op, + row_number() OVER ( + PARTITION BY project_id, + kind, + object_id, + digest + ORDER BY created_at ASC + ) AS rn + FROM object_versions + WHERE project_id = {{project_id: String}}{self.object_id_conditions_part} + ) + WHERE rn = 1 +)""" + if self.conditions_part: + query += f"\n{self.conditions_part}" + if self.sort_part: + query += f"\n{self.sort_part}" + if self.limit_part: + query += f"\n{self.limit_part}" + if self.offset_part: + query += f"\n{self.offset_part}" + return query + + +def make_objects_val_query_and_parameters( + project_id: str, object_ids: list[str], digests: list[str] +) -> tuple[str, dict[str, Any]]: + query = """ + SELECT object_id, digest, any(val_dump) + FROM object_versions + WHERE project_id = {project_id: String} AND + object_id IN {object_ids: Array(String)} AND + digest IN {digests: Array(String)} + GROUP BY object_id, digest + """ + parameters = { + "project_id": project_id, + "object_ids": object_ids, + "digests": digests, + } + return query, parameters