From b959fcd5ac046de186977f7c99cb9cffec6e0cd2 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Mon, 16 Dec 2024 14:16:52 -0800 Subject: [PATCH 1/5] chore(weave): soft deletion for objects server --- tests/trace/test_obj_delete.py | 174 ++++++++ tests/trace/test_objects_query_builder.py | 296 ++++++++++++++ weave/trace_server/clickhouse_schema.py | 6 + .../clickhouse_trace_server_batched.py | 375 +++++++----------- weave/trace_server/errors.py | 12 + ...ternal_to_internal_trace_server_adapter.py | 4 + weave/trace_server/objects_query_builder.py | 292 ++++++++++++++ weave/trace_server/sqlite_trace_server.py | 153 +++++-- weave/trace_server/trace_server_interface.py | 14 + .../remote_http_trace_server.py | 5 + 10 files changed, 1082 insertions(+), 249 deletions(-) create mode 100644 tests/trace/test_obj_delete.py create mode 100644 tests/trace/test_objects_query_builder.py create mode 100644 weave/trace_server/objects_query_builder.py diff --git a/tests/trace/test_obj_delete.py b/tests/trace/test_obj_delete.py new file mode 100644 index 00000000000..34f3ebb369b --- /dev/null +++ b/tests/trace/test_obj_delete.py @@ -0,0 +1,174 @@ +import pytest + +import weave +from weave.trace.weave_client import WeaveClient +from weave.trace_server import trace_server_interface as tsi + + +def _objs_query(client: WeaveClient, object_id: str) -> list[tsi.ObjSchema]: + objs = client.server.objs_query( + tsi.ObjQueryReq( + project_id=client._project_id(), + filter=tsi.ObjectVersionFilter(object_ids=[object_id]), + sort_by=[tsi.SortBy(field="created_at", direction="asc")], + ) + ) + return objs.objs + + +def _obj_delete(client: WeaveClient, object_id: str, digests: list[str]) -> int: + return client.server.obj_delete( + tsi.ObjDeleteReq( + project_id=client._project_id(), + object_id=object_id, + digests=digests, + ) + ).num_deleted + + +def test_delete_object_versions(client: WeaveClient): + v0 = weave.publish({"i": 1}, name="obj_1") + v1 = weave.publish({"i": 2}, name="obj_1") + v2 = weave.publish({"i": 3}, name="obj_1") + + objs = _objs_query(client, "obj_1") + assert len(objs) == 3 + + num_deleted = _obj_delete(client, "obj_1", [v0.digest]) + assert num_deleted == 1 + + objs = _objs_query(client, "obj_1") + assert len(objs) == 2 + + # test deleting an already deleted digest + with pytest.raises(weave.trace_server.errors.NotFoundError): + _obj_delete(client, "obj_1", [v0.digest]) + + # test deleting multiple digests + digests = [v1.digest, v2.digest] + num_deleted = _obj_delete(client, "obj_1", digests) + assert num_deleted == 2 + + objs = _objs_query(client, "obj_1") + assert len(objs) == 0 + + +def test_delete_all_object_versions(client: WeaveClient): + weave.publish({"i": 1}, name="obj_1") + weave.publish({"i": 2}, name="obj_1") + weave.publish({"i": 3}, name="obj_1") + + num_deleted = _obj_delete(client, "obj_1", None) + assert num_deleted == 3 + + objs = _objs_query(client, "obj_1") + assert len(objs) == 0 + + with pytest.raises(weave.trace_server.errors.NotFoundError): + _obj_delete(client, "obj_1", None) + + +def test_delete_version_correctness(client: WeaveClient): + v0 = weave.publish({"i": 1}, name="obj_1") + v1 = weave.publish({"i": 2}, name="obj_1") + v2 = weave.publish({"i": 3}, name="obj_1") + + _obj_delete(client, "obj_1", [v1.digest]) + objs = _objs_query(client, "obj_1") + assert len(objs) == 2 + assert objs[0].digest == v0.digest + assert objs[0].val == {"i": 1} + assert objs[0].version_index == 0 + assert objs[1].digest == v2.digest + assert objs[1].val == {"i": 3} + assert objs[1].version_index == 2 + + v3 = weave.publish({"i": 4}, name="obj_1") + objs = _objs_query(client, "obj_1") + assert len(objs) == 3 + assert objs[0].digest == v0.digest + assert objs[0].val == {"i": 1} + assert objs[0].version_index == 0 + assert objs[1].digest == v2.digest + assert objs[1].val == {"i": 3} + assert objs[1].version_index == 2 + assert objs[2].digest == v3.digest + assert objs[2].val == {"i": 4} + assert objs[2].version_index == 3 + + _obj_delete(client, "obj_1", [v3.digest]) + objs = _objs_query(client, "obj_1") + assert len(objs) == 2 + assert objs[0].digest == v0.digest + assert objs[0].val == {"i": 1} + assert objs[0].version_index == 0 + assert objs[1].digest == v2.digest + assert objs[1].val == {"i": 3} + assert objs[1].version_index == 2 + + +def test_delete_object_max_limit(client: WeaveClient): + # Create more than MAX_OBJECTS_TO_DELETE objects + max_objs = 100 + digests = [] + for i in range(max_objs + 1): + digests.append(f"test_{i}") + + with pytest.raises( + ValueError, match=f"Please delete {max_objs} or fewer objects at a time" + ): + _obj_delete(client, "obj_1", digests) + + +def test_delete_nonexistent_object_id(client: WeaveClient): + with pytest.raises(weave.trace_server.errors.NotFoundError): + _obj_delete(client, "nonexistent_obj", None) + + +def test_delete_mixed_valid_invalid_digests(client: WeaveClient): + v0 = weave.publish({"i": 1}, name="obj_1") + v1 = weave.publish({"i": 2}, name="obj_1") + + invalid_digests = [v0.digest, "invalid-digest", v1.digest] + with pytest.raises(weave.trace_server.errors.NotFoundError): + _obj_delete(client, "obj_1", invalid_digests) + + +def test_delete_duplicate_digests(client: WeaveClient): + v0 = weave.publish({"i": 1}, name="obj_1") + + num_deleted = _obj_delete(client, "obj_1", [v0.digest, v0.digest]) + assert num_deleted == 1 + + +def test_delete_with_digest_aliases(client: WeaveClient): + v0 = weave.publish({"i": 1}, name="obj_1") + weave.publish({"i": 2}, name="obj_1") + + num_deleted = _obj_delete(client, "obj_1", ["latest"]) + assert num_deleted == 1 + + objs = _objs_query(client, "obj_1") + assert len(objs) == 1 + assert objs[0].digest == v0.digest + assert objs[0].val == {"i": 1} + + num_deleted = _obj_delete(client, "obj_1", ["v0"]) + assert num_deleted == 1 + + objs = _objs_query(client, "obj_1") + assert len(objs) == 0 + + +def test_delete_and_recreate_object(client: WeaveClient): + # Create and delete initial object + v0 = weave.publish({"i": 1}, name="obj_1") + _obj_delete(client, "obj_1", [v0.digest]) + + # Create new object with same ID + v1 = weave.publish({"i": 2}, name="obj_1") + + objs = _objs_query(client, "obj_1") + assert len(objs) == 1 + assert objs[0].digest == v1.digest + assert objs[0].val == {"i": 2} diff --git a/tests/trace/test_objects_query_builder.py b/tests/trace/test_objects_query_builder.py new file mode 100644 index 00000000000..70e4e40e536 --- /dev/null +++ b/tests/trace/test_objects_query_builder.py @@ -0,0 +1,296 @@ +import pytest + +from weave.trace_server import trace_server_interface as tsi +from weave.trace_server.objects_query_builder import ( + ObjectMetadataQueryBuilder, + _make_conditions_part, + _make_limit_part, + _make_object_id_conditions_part, + _make_offset_part, + _make_sort_part, + make_objects_val_query_and_parameters, +) + + +def test_make_limit_part(): + assert _make_limit_part(None) == "" + assert _make_limit_part(10) == "LIMIT 10" + assert _make_limit_part(0) == "LIMIT 0" + + +def test_make_offset_part(): + assert _make_offset_part(None) == "" + assert _make_offset_part(5) == "OFFSET 5" + assert _make_offset_part(0) == "OFFSET 0" + + +def test_make_sort_part(): + assert _make_sort_part(None) == "" + assert _make_sort_part([]) == "" + + sort_by = [tsi.SortBy(field="created_at", direction="asc")] + assert _make_sort_part(sort_by) == "ORDER BY created_at ASC" + + sort_by = [ + tsi.SortBy(field="created_at", direction="desc"), + tsi.SortBy(field="object_id", direction="asc"), + ] + assert _make_sort_part(sort_by) == "ORDER BY created_at DESC, object_id ASC" + + # Invalid sort fields should be ignored + sort_by = [tsi.SortBy(field="invalid_field", direction="asc")] + assert _make_sort_part(sort_by) == "" + + +def test_make_conditions_part(): + assert _make_conditions_part(None) == "" + assert _make_conditions_part([]) == "" + assert _make_conditions_part(["condition1"]) == "WHERE condition1" + assert ( + _make_conditions_part(["condition1", "condition2"]) + == "WHERE ((condition1) AND (condition2))" + ) + + +def test_make_object_id_conditions_part(): + assert _make_object_id_conditions_part(None) == "" + assert _make_object_id_conditions_part([]) == "" + assert _make_object_id_conditions_part(["id = 1"]) == " AND id = 1" + assert ( + _make_object_id_conditions_part(["id = 1", "id = 2"]) + == " AND ((id = 1) AND (id = 2))" + ) + + +def test_object_query_builder_basic(): + builder = ObjectMetadataQueryBuilder(project_id="test_project") + assert "project_id = {project_id: String}" in builder.make_metadata_query() + assert builder.parameters["project_id"] == "test_project" + + +def test_object_query_builder_add_digest_condition(): + builder = ObjectMetadataQueryBuilder(project_id="test_project") + + # Test latest digest + builder.add_digest_condition("latest") + assert "is_latest = 1" in builder.conditions_part + + # Test specific digest + builder = ObjectMetadataQueryBuilder(project_id="test_project") + builder.add_digest_condition("abc123") + assert "digest = {version_digest: String}" in builder.conditions_part + assert builder.parameters["version_digest"] == "abc123" + + +def test_object_query_builder_add_object_ids_condition(): + builder = ObjectMetadataQueryBuilder(project_id="test_project") + + # Test single object ID + builder.add_object_ids_condition(["obj1"]) + assert "object_id = {object_id: String}" in builder.object_id_conditions_part + assert builder.parameters["object_id"] == "obj1" + + # Test multiple object IDs + builder = ObjectMetadataQueryBuilder(project_id="test_project") + builder.add_object_ids_condition(["obj1", "obj2"]) + assert ( + "object_id IN {object_ids: Array(String)}" in builder.object_id_conditions_part + ) + assert builder.parameters["object_ids"] == ["obj1", "obj2"] + + +def test_object_query_builder_add_is_op_condition(): + builder = ObjectMetadataQueryBuilder(project_id="test_project") + builder.add_is_op_condition(True) + assert "is_op = 1" in builder.conditions_part + + +def test_object_query_builder_limit_offset(): + builder = ObjectMetadataQueryBuilder(project_id="test_project") + assert builder.limit_part == "" + assert builder.offset_part == "" + + builder.set_limit(10) + builder.set_offset(5) + assert builder.limit_part == "LIMIT 10" + assert builder.offset_part == "OFFSET 5" + + # Test invalid values + with pytest.raises(ValueError): + builder.set_limit(-1) + with pytest.raises(ValueError): + builder.set_offset(-1) + with pytest.raises(ValueError): + builder.set_limit(5) # Limit already set + with pytest.raises(ValueError): + builder.set_offset(10) # Offset already set + + +def test_object_query_builder_sort(): + builder = ObjectMetadataQueryBuilder(project_id="test_project") + builder.add_order("created_at", "DESC") + assert builder.sort_part == "ORDER BY created_at DESC" + + with pytest.raises(ValueError): + builder.add_order("created_at", "INVALID") + + +STATIC_METADATA_QUERY_PART = """ +SELECT + project_id, + object_id, + created_at, + refs, + kind, + base_object_class, + digest, + version_index, + is_latest, + deleted_at, + version_count, + is_op +FROM ( + SELECT + project_id, + object_id, + created_at, + deleted_at, + kind, + base_object_class, + refs, + digest, + is_op, + row_number() OVER ( + PARTITION BY project_id, + kind, + object_id + ORDER BY created_at ASC + ) - 1 AS version_index, + count(*) OVER (PARTITION BY project_id, kind, object_id) as version_count, + row_number() OVER ( + PARTITION BY project_id, kind, object_id + ORDER BY (deleted_at IS NULL) DESC, created_at DESC + ) AS row_num, + if (row_num = 1, 1, 0) AS is_latest + FROM ( + SELECT + project_id, + object_id, + created_at, + deleted_at, + kind, + base_object_class, + refs, + digest, + if (kind = 'op', 1, 0) AS is_op, + row_number() OVER ( + PARTITION BY project_id, + kind, + object_id, + digest + ORDER BY created_at ASC + ) AS rn + FROM object_versions""" + + +def test_object_query_builder_metadata_query_basic(): + builder = ObjectMetadataQueryBuilder(project_id="test_project") + builder.add_digest_condition("latest") + + query = builder.make_metadata_query() + parameters = builder.parameters + + expected_query = f"""{STATIC_METADATA_QUERY_PART} + WHERE project_id = {{project_id: String}} + ) + WHERE rn = 1 +) +WHERE is_latest = 1""" + + assert query == expected_query + assert parameters == {"project_id": "test_project"} + + +def test_object_query_builder_metadata_query_with_limit_offset_sort(): + builder = ObjectMetadataQueryBuilder(project_id="test_project") + + limit = 10 + offset = 5 + + builder.set_limit(limit) + builder.set_offset(offset) + builder.add_order("created_at", "desc") + builder.add_object_ids_condition(["object_1"]) + builder.add_digest_condition("digestttttttttttttttt") + builder.add_base_object_classes_condition(["Model", "Model2"]) + + query = builder.make_metadata_query() + parameters = builder.parameters + + expected_query = f"""{STATIC_METADATA_QUERY_PART} + WHERE project_id = {{project_id: String}} AND object_id = {{object_id: String}} + ) + WHERE rn = 1 +) +WHERE ((digest = {{version_digest: String}}) AND (base_object_class IN {{base_object_classes: Array(String)}})) +ORDER BY created_at DESC +LIMIT 10 +OFFSET 5""" + + assert query == expected_query + assert parameters == { + "project_id": "test_project", + "object_id": "object_1", + "version_digest": "digestttttttttttttttt", + "base_object_classes": ["Model", "Model2"], + } + + +def test_objects_query_metadata_op(): + builder = ObjectMetadataQueryBuilder(project_id="test_project") + builder.add_is_op_condition(True) + builder.add_object_ids_condition(["my_op"]) + builder.add_digest_condition("v3", "vvvvvversion") + + query = builder.make_metadata_query() + parameters = builder.parameters + + expected_query = f"""{STATIC_METADATA_QUERY_PART} + WHERE project_id = {{project_id: String}} AND object_id = {{object_id: String}} + ) + WHERE rn = 1 +) +WHERE ((is_op = 1) AND (version_index = {{vvvvvversion: Int64}}))""" + + assert query == expected_query + assert parameters == { + "project_id": "test_project", + "object_id": "my_op", + "vvvvvversion": 3, + } + + +def test_make_objects_val_query_and_parameters(): + project_id = "test_project" + object_ids = ["object_1"] + digests = ["digestttttttttttttttt", "digestttttttttttttttt2"] + + query, parameters = make_objects_val_query_and_parameters( + project_id, object_ids, digests + ) + + expected_query = """ + SELECT object_id, digest, any(val_dump) + FROM object_versions + WHERE project_id = {project_id: String} AND + object_id IN {object_ids: Array(String)} AND + digest IN {digests: Array(String)} + GROUP BY object_id, digest + """ + + assert query == expected_query + assert parameters == { + "project_id": "test_project", + "object_ids": ["object_1"], + "digests": ["digestttttttttttttttt", "digestttttttttttttttt2"], + } diff --git a/weave/trace_server/clickhouse_schema.py b/weave/trace_server/clickhouse_schema.py index c02b5f7b332..ac6507059fe 100644 --- a/weave/trace_server/clickhouse_schema.py +++ b/weave/trace_server/clickhouse_schema.py @@ -136,6 +136,11 @@ class ObjCHInsertable(BaseModel): _refs = field_validator("refs")(validation.refs_list_validator) +class ObjDeleteCHInsertable(ObjCHInsertable): + deleted_at: datetime.datetime + created_at: datetime.datetime + + class SelectableCHObjSchema(BaseModel): project_id: str object_id: str @@ -147,3 +152,4 @@ class SelectableCHObjSchema(BaseModel): digest: str version_index: int is_latest: int + deleted_at: Optional[datetime.datetime] = None diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index 31e974a3004..f27b5b84e1e 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -59,6 +59,7 @@ CallStartCHInsertable, CallUpdateCHInsertable, ObjCHInsertable, + ObjDeleteCHInsertable, SelectableCHCallSchema, SelectableCHObjSchema, ) @@ -68,6 +69,8 @@ InsertTooLarge, InvalidRequest, MissingLLMApiKeyError, + NotFoundError, + ObjectDeletedError, RequestTooLarge, ) from weave.trace_server.feedback import ( @@ -81,6 +84,11 @@ read_model_to_provider_info_map, ) from weave.trace_server.object_class_util import process_incoming_object_val +from weave.trace_server.objects_query_builder import ( + ObjectMetadataQueryBuilder, + format_metadata_objects_from_query_result, + make_objects_val_query_and_parameters, +) from weave.trace_server.orm import ParamBuilder, Row from weave.trace_server.secret_fetcher_context import _secret_fetcher_context from weave.trace_server.table_query_builder import ( @@ -97,7 +105,6 @@ from weave.trace_server.trace_server_common import ( DynamicBatchProcessor, LRUCache, - digest_is_version_like, empty_str_to_none, get_nested_key, hydrate_calls_with_feedback, @@ -125,10 +132,6 @@ MAX_CALLS_STREAM_BATCH_SIZE = 500 -class NotFoundError(Exception): - pass - - CallCHInsertable = Union[ CallStartCHInsertable, CallEndCHInsertable, @@ -525,40 +528,36 @@ def op_create(self, req: tsi.OpCreateReq) -> tsi.OpCreateRes: raise NotImplementedError() def op_read(self, req: tsi.OpReadReq) -> tsi.OpReadRes: - conds = [ - "is_op = 1", - "digest = {digest: String}", - ] - object_id_conditions = ["object_id = {object_id: String}"] - parameters = {"name": req.name, "digest": req.digest} - objs = self._select_objs_query( - req.project_id, - conditions=conds, - object_id_conditions=object_id_conditions, - parameters=parameters, - ) + object_query_builder = ObjectMetadataQueryBuilder(req.project_id) + object_query_builder.add_is_op_condition(True) + object_query_builder.add_digests_conditions([req.digest]) + object_query_builder.add_object_ids_condition([req.name], "op_name") + object_query_builder.set_include_deleted(include_deleted=True) + + objs = self._select_objs_query(object_query_builder) if len(objs) == 0: raise NotFoundError(f"Obj {req.name}:{req.digest} not found") - return tsi.OpReadRes(op_obj=_ch_obj_to_obj_schema(objs[0])) + op = objs[0] + if op.deleted_at is not None: + raise ObjectDeletedError( + f"Op {req.name}:v{op.version_index} was deleted at {op.deleted_at}" + ) + + return tsi.OpReadRes(op_obj=_ch_obj_to_obj_schema(op)) def ops_query(self, req: tsi.OpQueryReq) -> tsi.OpQueryRes: - parameters = {} - conds: list[str] = ["is_op = 1"] - object_id_conditions: list[str] = [] + object_query_builder = ObjectMetadataQueryBuilder(req.project_id) + object_query_builder.add_is_op_condition(True) if req.filter: if req.filter.op_names: - object_id_conditions.append("object_id IN {op_names: Array(String)}") - parameters["op_names"] = req.filter.op_names + object_query_builder.add_object_ids_condition( + req.filter.op_names, "op_names" + ) if req.filter.latest_only: - conds.append("is_latest = 1") + object_query_builder.add_is_latest_condition() - ch_objs = self._select_objs_query( - req.project_id, - conditions=conds, - object_id_conditions=object_id_conditions, - parameters=parameters, - ) + ch_objs = self._select_objs_query(object_query_builder) objs = [_ch_obj_to_obj_schema(call) for call in ch_objs] return tsi.OpQueryRes(op_objs=objs) @@ -588,64 +587,119 @@ def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes: return tsi.ObjCreateRes(digest=digest) def obj_read(self, req: tsi.ObjReadReq) -> tsi.ObjReadRes: - conds: list[str] = [] - object_id_conditions = ["object_id = {object_id: String}"] - parameters: dict[str, Union[str, int]] = {"object_id": req.object_id} - if req.digest == "latest": - conds.append("is_latest = 1") - else: - (is_version, version_index) = digest_is_version_like(req.digest) - if is_version: - conds.append("version_index = {version_index: UInt64}") - parameters["version_index"] = version_index - else: - conds.append("digest = {version_digest: String}") - parameters["version_digest"] = req.digest - objs = self._select_objs_query( - req.project_id, - conditions=conds, - object_id_conditions=object_id_conditions, - parameters=parameters, - ) + object_query_builder = ObjectMetadataQueryBuilder(req.project_id) + object_query_builder.add_digests_conditions([req.digest]) + object_query_builder.add_object_ids_condition([req.object_id]) + object_query_builder.set_include_deleted(include_deleted=True) + + objs = self._select_objs_query(object_query_builder) if len(objs) == 0: raise NotFoundError(f"Obj {req.object_id}:{req.digest} not found") - return tsi.ObjReadRes(obj=_ch_obj_to_obj_schema(objs[0])) + obj = objs[0] + if obj.deleted_at is not None: + raise ObjectDeletedError( + f"Obj {req.object_id}:v{obj.version_index} was deleted at {obj.deleted_at}" + ) + + return tsi.ObjReadRes(obj=_ch_obj_to_obj_schema(obj)) def objs_query(self, req: tsi.ObjQueryReq) -> tsi.ObjQueryRes: - conds: list[str] = [] - object_id_conditions: list[str] = [] - parameters = {} + object_query_builder = ObjectMetadataQueryBuilder(req.project_id) if req.filter: if req.filter.is_op is not None: if req.filter.is_op: - conds.append("is_op = 1") + object_query_builder.add_is_op_condition(True) else: - conds.append("is_op = 0") + object_query_builder.add_is_op_condition(False) if req.filter.object_ids: - object_id_conditions.append("object_id IN {object_ids: Array(String)}") - parameters["object_ids"] = req.filter.object_ids + object_query_builder.add_object_ids_condition( + req.filter.object_ids, "object_ids" + ) if req.filter.latest_only: - conds.append("is_latest = 1") + object_query_builder.add_is_latest_condition() if req.filter.base_object_classes: - conds.append( - "base_object_class IN {base_object_classes: Array(String)}" + object_query_builder.add_base_object_classes_condition( + req.filter.base_object_classes ) - parameters["base_object_classes"] = req.filter.base_object_classes - - objs = self._select_objs_query( - req.project_id, - conditions=conds, - object_id_conditions=object_id_conditions, - parameters=parameters, - metadata_only=req.metadata_only, - limit=req.limit, - offset=req.offset, - sort_by=req.sort_by, - ) + if req.limit is not None: + object_query_builder.set_limit(req.limit) + if req.offset is not None: + object_query_builder.set_offset(req.offset) + if req.sort_by: + for sort in req.sort_by: + object_query_builder.add_order(sort.field, sort.direction) + metadata_only = req.metadata_only or False + object_query_builder.set_include_deleted(include_deleted=False) + objs = self._select_objs_query(object_query_builder, metadata_only) return tsi.ObjQueryRes(objs=[_ch_obj_to_obj_schema(obj) for obj in objs]) + def obj_delete(self, req: tsi.ObjDeleteReq) -> tsi.ObjDeleteRes: + """ + Delete object versions by digest, belonging to given object_id. + All deletion in this method is "soft". Deletion occurs by inserting + a new row into the object_versions table with the deleted_at field set. + Inserted rows share identical primary keys (order by) with original rows, + and will be combined by the ReplacingMergeTree engine at database merge + time. + If no digests are provided all versions will have their data overwritten with + an empty val_dump. + """ + MAX_OBJECTS_TO_DELETE = 100 + if req.digests and len(req.digests) > MAX_OBJECTS_TO_DELETE: + raise ValueError( + f"Object delete request contains {len(req.digests)} objects. Please delete {MAX_OBJECTS_TO_DELETE} or fewer objects at a time." + ) + + object_query_builder = ObjectMetadataQueryBuilder(req.project_id) + object_query_builder.add_object_ids_condition([req.object_id]) + metadata_only = True + if req.digests: + object_query_builder.add_digests_conditions(req.digests) + metadata_only = False + + object_versions = self._select_objs_query(object_query_builder, metadata_only) + + delete_insertables = [] + now = datetime.datetime.now(datetime.timezone.utc) + for obj in object_versions: + original_created_at = _ensure_datetimes_have_tz_strict(obj.created_at) + delete_insertables.append( + ObjDeleteCHInsertable( + project_id=obj.project_id, + object_id=obj.object_id, + digest=obj.digest, + kind=obj.kind, + val_dump=obj.val_dump, + refs=obj.refs, + base_object_class=obj.base_object_class, + deleted_at=now, + created_at=original_created_at, + ) + ) + + if len(delete_insertables) == 0: + raise NotFoundError( + f"Object {req.object_id} ({req.digests}) not found when deleting." + ) + + if req.digests: + given_digests = set(req.digests) + found_digests = {obj.digest for obj in delete_insertables} + if len(given_digests) != len(found_digests): + raise NotFoundError( + f"Delete request contains {len(req.digests)} digests, but found {len(found_digests)} objects to delete. Diff digests: {given_digests - found_digests}" + ) + + data = [list(obj.model_dump().values()) for obj in delete_insertables] + column_names = list(delete_insertables[0].model_fields.keys()) + + self._insert("object_versions", data=data, column_names=column_names) + num_deleted = len(delete_insertables) + + return tsi.ObjDeleteRes(num_deleted=num_deleted) + def table_create(self, req: tsi.TableCreateReq) -> tsi.TableCreateRes: insert_rows = [] for r in req.table.rows: @@ -950,7 +1004,7 @@ def get_object_refs_root_val( ) -> Any: conds: list[str] = [] object_id_conds: list[str] = [] - parameters = {} + parameters: dict[str, Union[str, int]] = {} for ref_index, ref in enumerate(refs): if ref.version == "latest": @@ -979,12 +1033,13 @@ def get_object_refs_root_val( if len(conds) > 0: conditions = [combine_conditions(conds, "OR")] object_id_conditions = [combine_conditions(object_id_conds, "OR")] - objs = self._select_objs_query( + object_query_builder = ObjectMetadataQueryBuilder( project_id=project_id_scope, conditions=conditions, object_id_conditions=object_id_conditions, parameters=parameters, ) + objs = self._select_objs_query(object_query_builder) for obj in objs: root_val_cache[make_obj_cache_key(obj)] = json.loads(obj.val_dump) @@ -1562,14 +1617,8 @@ def _insert_call_batch(self, batch: list) -> None: def _select_objs_query( self, - project_id: str, - conditions: Optional[list[str]] = None, - object_id_conditions: Optional[list[str]] = None, - parameters: Optional[dict[str, Any]] = None, - metadata_only: Optional[bool] = False, - limit: Optional[int] = None, - offset: Optional[int] = None, - sort_by: Optional[list[tsi.SortBy]] = None, + object_query_builder: ObjectMetadataQueryBuilder, + metadata_only: bool = False, ) -> list[SelectableCHObjSchema]: """ Main query for fetching objects. @@ -1587,146 +1636,21 @@ def _select_objs_query( if metadata_only is True, then we return early and dont grab the value. Otherwise, make a second query to grab the val_dump from the db """ - if not conditions: - conditions = ["1 = 1"] - if not object_id_conditions: - object_id_conditions = ["1 = 1"] - - conditions_part = combine_conditions(conditions, "AND") - object_id_conditions_part = combine_conditions(object_id_conditions, "AND") - - limit_part = "" - offset_part = "" - if limit is not None: - limit_part = f"LIMIT {int(limit)}" - if offset is not None: - offset_part = f" OFFSET {int(offset)}" - - sort_part = "" - if sort_by: - valid_sort_fields = {"object_id", "created_at"} - sort_clauses = [] - for sort in sort_by: - if sort.field in valid_sort_fields and sort.direction in { - "asc", - "desc", - }: - sort_clauses.append(f"{sort.field} {sort.direction.upper()}") - if sort_clauses: - sort_part = f"ORDER BY {', '.join(sort_clauses)}" - - if parameters is None: - parameters = {} - - select_without_val_dump_query = f""" - SELECT - project_id, - object_id, - created_at, - kind, - base_object_class, - refs, - digest, - is_op, - version_index, - version_count, - is_latest - FROM ( - SELECT project_id, - object_id, - created_at, - kind, - base_object_class, - refs, - digest, - is_op, - row_number() OVER ( - PARTITION BY project_id, - kind, - object_id - ORDER BY created_at ASC - ) - 1 AS version_index, - count(*) OVER (PARTITION BY project_id, kind, object_id) as version_count, - if(version_index + 1 = version_count, 1, 0) AS is_latest - FROM ( - SELECT project_id, - object_id, - created_at, - kind, - base_object_class, - refs, - digest, - if (kind = 'op', 1, 0) AS is_op, - row_number() OVER ( - PARTITION BY project_id, - kind, - object_id, - digest - ORDER BY created_at ASC - ) AS rn - FROM object_versions - WHERE project_id = {{project_id: String}} AND - {object_id_conditions_part} - ) - WHERE rn = 1 - ) - WHERE {conditions_part} - {sort_part} - {limit_part} - {offset_part} - """ - query_result = self._query_stream( - select_without_val_dump_query, - {"project_id": project_id, **parameters}, - ) - result: list[SelectableCHObjSchema] = [] - for row in query_result: - result.append( - SelectableCHObjSchema.model_validate( - dict( - zip( - [ - "project_id", - "object_id", - "created_at", - "kind", - "base_object_class", - "refs", - "digest", - "is_op", - "version_index", - "version_count", - "is_latest", - "val_dump", - ], - # Add an empty val_dump to the end of the row - list(row) + ["{}"], - ) - ) - ) - ) + obj_metadata_query = object_query_builder.make_metadata_query() + parameters = object_query_builder.parameters or {} + query_result = self._query_stream(obj_metadata_query, parameters) + metadata_result = format_metadata_objects_from_query_result(query_result) # -- Don't make second query for object values if metadata_only -- - if metadata_only: - return result + if metadata_only or len(metadata_result) == 0: + return metadata_result - # now get the val_dump for each object - object_ids = list({row.object_id for row in result}) - digests = list({row.digest for row in result}) - query = """ - SELECT object_id, digest, any(val_dump) - FROM object_versions - WHERE project_id = {project_id: String} AND - object_id IN {object_ids: Array(String)} AND - digest IN {digests: Array(String)} - GROUP BY object_id, digest - """ - parameters = { - "project_id": project_id, - "object_ids": object_ids, - "digests": digests, - } - query_result = self._query_stream(query, parameters) + value_query, value_parameters = make_objects_val_query_and_parameters( + project_id=object_query_builder.project_id, + object_ids=list({row.object_id for row in metadata_result}), + digests=list({row.digest for row in metadata_result}), + ) + query_result = self._query_stream(value_query, value_parameters) # Map (object_id, digest) to val_dump object_values: dict[tuple[str, str], Any] = {} for row in query_result: @@ -1734,9 +1658,9 @@ def _select_objs_query( object_values[(object_id, digest)] = val_dump # update the val_dump for each object - for obj in result: + for obj in metadata_result: obj.val_dump = object_values.get((obj.object_id, obj.digest), "{}") - return result + return metadata_result def _run_migrations(self) -> None: logger.info("Running migrations") @@ -1931,6 +1855,15 @@ def _ensure_datetimes_have_tz( return dt +def _ensure_datetimes_have_tz_strict( + dt: datetime.datetime, +) -> datetime.datetime: + res = _ensure_datetimes_have_tz(dt) + if res is None: + raise ValueError(f"Datetime is None: {dt}") + return res + + def _nullable_dict_dump_to_dict( val: Optional[str], ) -> Optional[dict[str, Any]]: diff --git a/weave/trace_server/errors.py b/weave/trace_server/errors.py index b2014fc1a48..06ffebe1fea 100644 --- a/weave/trace_server/errors.py +++ b/weave/trace_server/errors.py @@ -34,3 +34,15 @@ class MissingLLMApiKeyError(Error): def __init__(self, message: str, api_key_name: str): self.api_key_name = api_key_name super().__init__(message) + + +class NotFoundError(Error): + """Raised when a general not found error occurs.""" + + pass + + +class ObjectDeletedError(Error): + """Raised when an object has been deleted.""" + + pass diff --git a/weave/trace_server/external_to_internal_trace_server_adapter.py b/weave/trace_server/external_to_internal_trace_server_adapter.py index 1df739adbcd..7718295d9d3 100644 --- a/weave/trace_server/external_to_internal_trace_server_adapter.py +++ b/weave/trace_server/external_to_internal_trace_server_adapter.py @@ -257,6 +257,10 @@ def objs_query(self, req: tsi.ObjQueryReq) -> tsi.ObjQueryRes: obj.project_id = original_project_id return res + def obj_delete(self, req: tsi.ObjDeleteReq) -> tsi.ObjDeleteRes: + req.project_id = self._idc.ext_to_int_project_id(req.project_id) + return self._ref_apply(self._internal_trace_server.obj_delete, req) + def table_create(self, req: tsi.TableCreateReq) -> tsi.TableCreateRes: req.table.project_id = self._idc.ext_to_int_project_id(req.table.project_id) return self._ref_apply(self._internal_trace_server.table_create, req) diff --git a/weave/trace_server/objects_query_builder.py b/weave/trace_server/objects_query_builder.py new file mode 100644 index 00000000000..d72ba5d5125 --- /dev/null +++ b/weave/trace_server/objects_query_builder.py @@ -0,0 +1,292 @@ +from collections.abc import Iterator +from typing import Any, Optional + +from weave.trace_server import trace_server_interface as tsi +from weave.trace_server.clickhouse_schema import SelectableCHObjSchema +from weave.trace_server.orm import combine_conditions +from weave.trace_server.trace_server_common import digest_is_version_like + +VALID_OBJECT_SORT_FIELDS = {"created_at", "object_id"} +VALID_SORT_DIRECTIONS = {"asc", "desc"} +OBJECT_METADATA_COLUMNS = [ + "project_id", + "object_id", + "created_at", + "refs", + "kind", + "base_object_class", + "digest", + "version_index", + "is_latest", + "deleted_at", + # columns not used in SelectableCHObjSchema: + "version_count", + "is_op", +] + + +def _make_optional_part(query_keyword: str, part: Optional[str]) -> str: + if part is None or part == "": + return "" + return f"{query_keyword} {part}" + + +def _make_limit_part(limit: Optional[int]) -> str: + if limit is None: + return "" + return _make_optional_part("LIMIT", str(limit)) + + +def _make_offset_part(offset: Optional[int]) -> str: + if offset is None: + return "" + return _make_optional_part("OFFSET", str(offset)) + + +def _make_sort_part(sort_by: Optional[list[tsi.SortBy]]) -> str: + if not sort_by: + return "ORDER BY created_at ASC" + + sort_clauses = [] + for sort in sort_by: + if ( + sort.field in VALID_OBJECT_SORT_FIELDS + and sort.direction in VALID_SORT_DIRECTIONS + ): + sort_clause = f"{sort.field} {sort.direction.upper()}" + sort_clauses.append(sort_clause) + return _make_optional_part("ORDER BY", ", ".join(sort_clauses)) + + +def _make_conditions_part(conditions: Optional[list[str]]) -> str: + if not conditions: + return "" + conditions_str = combine_conditions(conditions, "AND") + return _make_optional_part("WHERE", conditions_str) + + +def _make_object_id_conditions_part(object_id_conditions: Optional[list[str]]) -> str: + if not object_id_conditions: + return "" + conditions_str = combine_conditions(object_id_conditions, "AND") + return " " + _make_optional_part("AND", conditions_str) + + +def format_metadata_objects_from_query_result( + query_result: Iterator[tuple[Any, ...]], +) -> list[SelectableCHObjSchema]: + result = [] + for row in query_result: + # Add an empty val_dump to the end of the row + row_with_val_dump = row + ("{}",) + columns_with_val_dump = OBJECT_METADATA_COLUMNS + ["val_dump"] + row_dict = dict(zip(columns_with_val_dump, row_with_val_dump)) + row_model = SelectableCHObjSchema.model_validate(row_dict) + result.append(row_model) + return result + + +class ObjectMetadataQueryBuilder: + def __init__( + self, + project_id: str, + conditions: Optional[list[str]] = None, + object_id_conditions: Optional[list[str]] = None, + parameters: Optional[dict[str, Any]] = None, + include_deleted: bool = False, + ): + self.project_id = project_id + self.parameters: dict[str, Any] = parameters or {} + if not self.parameters.get(project_id): + self.parameters.update({"project_id": project_id}) + self._conditions: list[str] = conditions or [] + self._object_id_conditions: list[str] = object_id_conditions or [] + self._limit: Optional[int] = None + self._offset: Optional[int] = None + self._sort_by: list[tsi.SortBy] = [] + self._include_deleted: bool = include_deleted + + @property + def conditions_part(self) -> str: + _conditions = list(self._conditions) + if not self._include_deleted: + _conditions.append("deleted_at IS NULL") + return _make_conditions_part(_conditions) + + @property + def object_id_conditions_part(self) -> str: + return _make_object_id_conditions_part(self._object_id_conditions) + + @property + def sort_part(self) -> str: + return _make_sort_part(self._sort_by) + + @property + def limit_part(self) -> str: + return _make_limit_part(self._limit) + + @property + def offset_part(self) -> str: + return _make_offset_part(self._offset) + + def _make_digest_condition( + self, digest: str, param_key: Optional[str] = None + ) -> str: + if digest == "latest": + return "is_latest = 1" + + param_key = param_key or "version_digest" + (is_version, version_index) = digest_is_version_like(digest) + if is_version: + self.parameters.update({param_key: version_index}) + return self._make_version_index_condition(version_index, param_key) + else: + self.parameters.update({param_key: digest}) + return self._make_version_digest_condition(digest, param_key) + + def _make_version_digest_condition(self, digest: str, param_key: str) -> str: + return f"digest = {{{param_key}: String}}" + + def _make_version_index_condition(self, version_index: int, param_key: str) -> str: + return f"version_index = {{{param_key}: Int64}}" + + def add_digests_conditions(self, digests: list[str]) -> None: + digest_conditions = [] + for i, digest in enumerate(digests): + condition = self._make_digest_condition(digest, f"version_digest_{i}") + digest_conditions.append(condition) + + digests_condition = combine_conditions(digest_conditions, "OR") + self._conditions.append(digests_condition) + + def add_object_ids_condition( + self, object_ids: list[str], param_key: Optional[str] = None + ) -> None: + if len(object_ids) == 1: + param_key = param_key or "object_id" + self._object_id_conditions.append(f"object_id = {{{param_key}: String}}") + self.parameters.update({param_key: object_ids[0]}) + else: + param_key = param_key or "object_ids" + self._object_id_conditions.append( + f"object_id IN {{{param_key}: Array(String)}}" + ) + self.parameters.update({param_key: object_ids}) + + def add_is_latest_condition(self) -> None: + self._conditions.append("is_latest = 1") + + def add_is_op_condition(self, is_op: bool) -> None: + if is_op: + self._conditions.append("is_op = 1") + else: + self._conditions.append("is_op = 0") + + def add_base_object_classes_condition(self, base_object_classes: list[str]) -> None: + self._conditions.append( + "base_object_class IN {base_object_classes: Array(String)}" + ) + self.parameters.update({"base_object_classes": base_object_classes}) + + def add_order(self, field: str, direction: str) -> None: + direction = direction.lower() + if direction not in ("asc", "desc"): + raise ValueError(f"Direction {direction} is not allowed") + self._sort_by.append(tsi.SortBy(field=field, direction=direction)) + + def set_limit(self, limit: int) -> None: + if limit < 0: + raise ValueError("Limit must be a positive integer") + if self._limit is not None: + raise ValueError("Limit can only be set once") + self._limit = limit + + def set_offset(self, offset: int) -> None: + if offset < 0: + raise ValueError("Offset must be a positive integer") + if self._offset is not None: + raise ValueError("Offset can only be set once") + self._offset = offset + + def set_include_deleted(self, include_deleted: bool) -> None: + self._include_deleted = include_deleted + + def make_metadata_query(self) -> str: + columns = ",\n ".join(OBJECT_METADATA_COLUMNS) + query = f""" +SELECT + {columns} +FROM ( + SELECT + project_id, + object_id, + created_at, + deleted_at, + kind, + base_object_class, + refs, + digest, + is_op, + row_number() OVER ( + PARTITION BY project_id, + kind, + object_id + ORDER BY created_at ASC + ) - 1 AS version_index, + count(*) OVER (PARTITION BY project_id, kind, object_id) as version_count, + row_number() OVER ( + PARTITION BY project_id, kind, object_id + ORDER BY (deleted_at IS NULL) DESC, created_at DESC + ) AS row_num, + if (row_num = 1, 1, 0) AS is_latest + FROM ( + SELECT + project_id, + object_id, + created_at, + deleted_at, + kind, + base_object_class, + refs, + digest, + if (kind = 'op', 1, 0) AS is_op, + row_number() OVER ( + PARTITION BY project_id, + kind, + object_id, + digest + ORDER BY created_at ASC + ) AS rn + FROM object_versions + WHERE project_id = {{project_id: String}}{self.object_id_conditions_part} + ) + WHERE rn = 1 +)""" + if self.conditions_part: + query += f"\n{self.conditions_part}" + if self.sort_part: + query += f"\n{self.sort_part}" + if self.limit_part: + query += f"\n{self.limit_part}" + if self.offset_part: + query += f"\n{self.offset_part}" + return query + + +def make_objects_val_query_and_parameters( + project_id: str, object_ids: list[str], digests: list[str] +) -> tuple[str, dict[str, Any]]: + query = """ + SELECT object_id, digest, any(val_dump) + FROM object_versions + WHERE project_id = {project_id: String} AND + object_id IN {object_ids: Array(String)} AND + digest IN {digests: Array(String)} + GROUP BY object_id, digest + """ + parameters = { + "project_id": project_id, + "object_ids": object_ids, + "digests": digests, + } + return query, parameters diff --git a/weave/trace_server/sqlite_trace_server.py b/weave/trace_server/sqlite_trace_server.py index 41d0e581390..8b4330afcae 100644 --- a/weave/trace_server/sqlite_trace_server.py +++ b/weave/trace_server/sqlite_trace_server.py @@ -15,7 +15,11 @@ from weave.trace_server import refs_internal as ri from weave.trace_server import trace_server_interface as tsi from weave.trace_server.emoji_util import detone_emojis -from weave.trace_server.errors import InvalidRequest +from weave.trace_server.errors import ( + InvalidRequest, + NotFoundError, + ObjectDeletedError, +) from weave.trace_server.feedback import ( TABLE_FEEDBACK, validate_feedback_create_req, @@ -47,10 +51,6 @@ MAX_FLUSH_AGE = 15 -class NotFoundError(Exception): - pass - - _conn_cursor: contextvars.ContextVar[ Optional[tuple[sqlite3.Connection, sqlite3.Cursor]] ] = contextvars.ContextVar("conn_cursor", default=None) @@ -617,25 +617,18 @@ def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes: processed_val = processed_result["val"] json_val = json.dumps(processed_val) digest = str_digest(json_val) + project_id, object_id = req.obj.project_id, req.obj.object_id # Validate - object_id_validator(req.obj.object_id) + object_id_validator(object_id) - # TODO: version index isn't right here, what if we delete stuff? with self.lock: - cursor.execute("BEGIN TRANSACTION") - # Mark all existing objects with such id as not latest - cursor.execute( - """UPDATE objects SET is_latest = 0 WHERE project_id = ? AND object_id = ?""", - (req.obj.project_id, req.obj.object_id), - ) - # first get version count - cursor.execute( - """SELECT COUNT(*) FROM objects WHERE project_id = ? AND object_id = ?""", - (req.obj.project_id, req.obj.object_id), - ) - version_index = cursor.fetchone()[0] + if self._obj_exists(cursor, project_id, object_id, digest): + return tsi.ObjCreateRes(digest=digest) + cursor.execute("BEGIN TRANSACTION") + self._mark_existing_objects_as_not_latest(cursor, project_id, object_id) + version_index = self._get_obj_version_index(cursor, project_id, object_id) cursor.execute( """INSERT OR IGNORE INTO objects ( project_id, @@ -647,11 +640,22 @@ def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes: val_dump, digest, version_index, - is_latest - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + is_latest, + deleted_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(project_id, kind, object_id, digest) DO UPDATE SET + created_at = excluded.created_at, + kind = excluded.kind, + base_object_class = excluded.base_object_class, + refs = excluded.refs, + val_dump = excluded.val_dump, + version_index = excluded.version_index, + is_latest = excluded.is_latest, + deleted_at = excluded.deleted_at + """, ( - req.obj.project_id, - req.obj.object_id, + project_id, + object_id, datetime.datetime.now().isoformat(), get_kind(processed_val), processed_result["base_object_class"], @@ -660,11 +664,48 @@ def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes: digest, version_index, 1, + None, ), ) conn.commit() return tsi.ObjCreateRes(digest=digest) + def _obj_exists( + self, cursor: sqlite3.Cursor, project_id: str, object_id: str, digest: str + ) -> bool: + cursor.execute( + "SELECT COUNT(*) FROM objects WHERE project_id = ? AND object_id = ? AND digest = ? AND deleted_at IS NULL", + (project_id, object_id, digest), + ) + return_row = cursor.fetchone() + if return_row is None: + return False + return return_row[0] > 0 + + def _mark_existing_objects_as_not_latest( + self, cursor: sqlite3.Cursor, project_id: str, object_id: str + ) -> None: + """Mark all existing objects with such id as not latest. + We are creating a new object with the same id, all existing ones are no longer latest. + """ + cursor.execute( + "UPDATE objects SET is_latest = 0 WHERE project_id = ? AND object_id = ?", + (project_id, object_id), + ) + + def _get_obj_version_index( + self, cursor: sqlite3.Cursor, project_id: str, object_id: str + ) -> int: + """Get the version index for a new object with such id.""" + cursor.execute( + "SELECT COUNT(*) FROM objects WHERE project_id = ? AND object_id = ?", + (project_id, object_id), + ) + return_row = cursor.fetchone() + if return_row is None: + return 0 + return return_row[0] + def obj_read(self, req: tsi.ObjReadReq) -> tsi.ObjReadRes: conds = [f"object_id = '{req.object_id}'"] if req.digest == "latest": @@ -678,10 +719,14 @@ def obj_read(self, req: tsi.ObjReadReq) -> tsi.ObjReadRes: objs = self._select_objs_query( req.project_id, conditions=conds, + include_deleted=True, ) if len(objs) == 0: raise NotFoundError(f"Obj {req.object_id}:{req.digest} not found") - + if objs[0].deleted_at is not None: + raise ObjectDeletedError( + f"Obj {req.object_id}:v{objs[0].version_index} was deleted at {objs[0].deleted_at}" + ) return tsi.ObjReadRes(obj=objs[0]) def objs_query(self, req: tsi.ObjQueryReq) -> tsi.ObjQueryRes: @@ -716,6 +761,53 @@ def objs_query(self, req: tsi.ObjQueryReq) -> tsi.ObjQueryRes: return tsi.ObjQueryRes(objs=objs) + def obj_delete(self, req: tsi.ObjDeleteReq) -> tsi.ObjDeleteRes: + # First, select the objects that match the query + select_query = """ + SELECT digest FROM objects + WHERE project_id = ? AND + object_id = ? AND + deleted_at IS NULL + """ + parameters = [req.project_id, req.object_id] + if req.digests: + num_digests = len(req.digests) + select_query += "AND digest IN ({})".format(", ".join("?" * num_digests)) + + parameters.extend(req.digests) + + conn, cursor = get_conn_cursor(self.db_path) + cursor.execute(select_query, parameters) + matching_objects = cursor.fetchall() + + if len(matching_objects) == 0: + raise NotFoundError( + f"Object {req.object_id} ({req.digests}) not found when deleting." + ) + found_digests = {obj[0] for obj in matching_objects} + if req.digests: + given_digests = set(req.digests) + if len(given_digests) != len(found_digests): + raise NotFoundError( + f"Delete request contains {len(req.digests)} digests, but found {len(found_digests)} objects to delete. Diff digests: {given_digests - found_digests}" + ) + + # Create a delete query that will set the deleted_at field to now + delete_query = """ + UPDATE objects SET deleted_at = CURRENT_TIMESTAMP + WHERE project_id = ? AND + object_id = ? AND + digest IN ({}) + """.format(", ".join("?" * len(found_digests))) + delete_parameters = [req.project_id, req.object_id] + list(found_digests) + + with self.lock: + cursor.execute("BEGIN TRANSACTION") + cursor.execute(delete_query, delete_parameters) + conn.commit() + + return tsi.ObjDeleteRes(num_deleted=len(matching_objects)) + def table_create(self, req: tsi.TableCreateReq) -> tsi.TableCreateRes: conn, cursor = get_conn_cursor(self.db_path) insert_rows = [] @@ -1151,11 +1243,15 @@ def _select_objs_query( parameters: Optional[dict[str, Any]] = None, metadata_only: Optional[bool] = False, limit: Optional[int] = None, + include_deleted: bool = False, offset: Optional[int] = None, sort_by: Optional[list[tsi.SortBy]] = None, ) -> list[tsi.ObjSchema]: conn, cursor = get_conn_cursor(self.db_path) - pred = " AND ".join(conditions or ["1 = 1"]) + conditions = conditions or ["1 = 1"] + if not include_deleted: + conditions.append("deleted_at IS NULL") + pred = " AND ".join(conditions) val_dump_part = "'{}' as val_dump" if metadata_only else "val_dump" query = f""" SELECT @@ -1167,10 +1263,10 @@ def _select_objs_query( {val_dump_part}, digest, version_index, - is_latest + is_latest, + deleted_at FROM objects - WHERE deleted_at IS NULL AND - project_id = ? AND {pred} + WHERE project_id = ? AND {pred} """ if sort_by: @@ -1212,6 +1308,7 @@ def _select_objs_query( digest=row[6], version_index=row[7], is_latest=row[8], + deleted_at=row[9], ) ) return result diff --git a/weave/trace_server/trace_server_interface.py b/weave/trace_server/trace_server_interface.py index c419a9e23e5..699e1e128c6 100644 --- a/weave/trace_server/trace_server_interface.py +++ b/weave/trace_server/trace_server_interface.py @@ -476,6 +476,19 @@ class ObjQueryReq(BaseModel): ) +class ObjDeleteReq(BaseModel): + project_id: str + object_id: str + digests: Optional[list[str]] = Field( + default=None, + description="List of digests to delete. If not provided, all digests for the object will be deleted.", + ) + + +class ObjDeleteRes(BaseModel): + num_deleted: int + + class ObjQueryRes(BaseModel): objs: list[ObjSchema] @@ -906,6 +919,7 @@ def cost_purge(self, req: CostPurgeReq) -> CostPurgeRes: ... def obj_create(self, req: ObjCreateReq) -> ObjCreateRes: ... def obj_read(self, req: ObjReadReq) -> ObjReadRes: ... def objs_query(self, req: ObjQueryReq) -> ObjQueryRes: ... + def obj_delete(self, req: ObjDeleteReq) -> ObjDeleteRes: ... def table_create(self, req: TableCreateReq) -> TableCreateRes: ... def table_update(self, req: TableUpdateReq) -> TableUpdateRes: ... def table_query(self, req: TableQueryReq) -> TableQueryRes: ... diff --git a/weave/trace_server_bindings/remote_http_trace_server.py b/weave/trace_server_bindings/remote_http_trace_server.py index b749af60d24..966bf92a815 100644 --- a/weave/trace_server_bindings/remote_http_trace_server.py +++ b/weave/trace_server_bindings/remote_http_trace_server.py @@ -352,6 +352,11 @@ def objs_query( "/objs/query", req, tsi.ObjQueryReq, tsi.ObjQueryRes ) + def obj_delete(self, req: tsi.ObjDeleteReq) -> tsi.ObjDeleteRes: + return self._generic_request( + "/obj/delete", req, tsi.ObjDeleteReq, tsi.ObjDeleteRes + ) + def table_create( self, req: Union[tsi.TableCreateReq, dict[str, Any]] ) -> tsi.TableCreateRes: From 9b562560cc9356df3541096357c3ab765f981317 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Mon, 16 Dec 2024 14:31:41 -0800 Subject: [PATCH 2/5] fix tests --- tests/trace/test_objects_query_builder.py | 16 +++++------ weave/trace_server/sqlite_trace_server.py | 34 +++++++++++++++-------- 2 files changed, 31 insertions(+), 19 deletions(-) diff --git a/tests/trace/test_objects_query_builder.py b/tests/trace/test_objects_query_builder.py index 70e4e40e536..b074fa00e08 100644 --- a/tests/trace/test_objects_query_builder.py +++ b/tests/trace/test_objects_query_builder.py @@ -68,16 +68,16 @@ def test_object_query_builder_basic(): assert builder.parameters["project_id"] == "test_project" -def test_object_query_builder_add_digest_condition(): +def test_object_query_builder_add_digests_condition(): builder = ObjectMetadataQueryBuilder(project_id="test_project") # Test latest digest - builder.add_digest_condition("latest") + builder.add_digests_condition(["latest"]) assert "is_latest = 1" in builder.conditions_part # Test specific digest builder = ObjectMetadataQueryBuilder(project_id="test_project") - builder.add_digest_condition("abc123") + builder.add_digests_condition(["abc123"]) assert "digest = {version_digest: String}" in builder.conditions_part assert builder.parameters["version_digest"] == "abc123" @@ -195,7 +195,7 @@ def test_object_query_builder_sort(): def test_object_query_builder_metadata_query_basic(): builder = ObjectMetadataQueryBuilder(project_id="test_project") - builder.add_digest_condition("latest") + builder.add_digests_condition(["latest"]) query = builder.make_metadata_query() parameters = builder.parameters @@ -221,7 +221,7 @@ def test_object_query_builder_metadata_query_with_limit_offset_sort(): builder.set_offset(offset) builder.add_order("created_at", "desc") builder.add_object_ids_condition(["object_1"]) - builder.add_digest_condition("digestttttttttttttttt") + builder.add_digests_condition(["digestttttttttttttttt"]) builder.add_base_object_classes_condition(["Model", "Model2"]) query = builder.make_metadata_query() @@ -250,7 +250,7 @@ def test_objects_query_metadata_op(): builder = ObjectMetadataQueryBuilder(project_id="test_project") builder.add_is_op_condition(True) builder.add_object_ids_condition(["my_op"]) - builder.add_digest_condition("v3", "vvvvvversion") + builder.add_digests_condition(["v3"]) query = builder.make_metadata_query() parameters = builder.parameters @@ -260,13 +260,13 @@ def test_objects_query_metadata_op(): ) WHERE rn = 1 ) -WHERE ((is_op = 1) AND (version_index = {{vvvvvversion: Int64}}))""" +WHERE ((is_op = 1) AND (version_index = {{version_index_0: Int64}}))""" assert query == expected_query assert parameters == { "project_id": "test_project", "object_id": "my_op", - "vvvvvversion": 3, + "version_index_0": 3, } diff --git a/weave/trace_server/sqlite_trace_server.py b/weave/trace_server/sqlite_trace_server.py index 8b4330afcae..be1ea26bb2f 100644 --- a/weave/trace_server/sqlite_trace_server.py +++ b/weave/trace_server/sqlite_trace_server.py @@ -706,16 +706,21 @@ def _get_obj_version_index( return 0 return return_row[0] - def obj_read(self, req: tsi.ObjReadReq) -> tsi.ObjReadRes: - conds = [f"object_id = '{req.object_id}'"] - if req.digest == "latest": - conds.append("is_latest = 1") + @staticmethod + def _make_digest_condition(digest: str) -> str: + if digest == "latest": + return "is_latest = 1" else: - (is_version, version_index) = digest_is_version_like(req.digest) + (is_version, version_index) = digest_is_version_like(digest) if is_version: - conds.append(f"version_index = '{version_index}'") + return f"version_index = '{version_index}'" else: - conds.append(f"digest = '{req.digest}'") + return f"digest = '{digest}'" + + def obj_read(self, req: tsi.ObjReadReq) -> tsi.ObjReadRes: + conds = [f"object_id = '{req.object_id}'"] + digest_condition = self._make_digest_condition(req.digest) + conds.append(digest_condition) objs = self._select_objs_query( req.project_id, conditions=conds, @@ -762,6 +767,12 @@ def objs_query(self, req: tsi.ObjQueryReq) -> tsi.ObjQueryRes: return tsi.ObjQueryRes(objs=objs) def obj_delete(self, req: tsi.ObjDeleteReq) -> tsi.ObjDeleteRes: + MAX_OBJECTS_TO_DELETE = 100 + if req.digests and len(req.digests) > MAX_OBJECTS_TO_DELETE: + raise ValueError( + f"Object delete request contains {len(req.digests)} objects. Please delete {MAX_OBJECTS_TO_DELETE} or fewer objects at a time." + ) + # First, select the objects that match the query select_query = """ SELECT digest FROM objects @@ -771,10 +782,11 @@ def obj_delete(self, req: tsi.ObjDeleteReq) -> tsi.ObjDeleteRes: """ parameters = [req.project_id, req.object_id] if req.digests: - num_digests = len(req.digests) - select_query += "AND digest IN ({})".format(", ".join("?" * num_digests)) - - parameters.extend(req.digests) + digest_conditions = [ + self._make_digest_condition(digest) for digest in req.digests + ] + digest_conditions_str = " AND ".join(digest_conditions) + select_query += f"AND ({digest_conditions_str})" conn, cursor = get_conn_cursor(self.db_path) cursor.execute(select_query, parameters) From 93ed6c1f9ee5045cab5ef9fae8a75c9aeae793e8 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Mon, 16 Dec 2024 14:41:37 -0800 Subject: [PATCH 3/5] fix fix test --- tests/trace/test_objects_query_builder.py | 28 +++++++++++---------- weave/trace_server/objects_query_builder.py | 6 +++-- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/tests/trace/test_objects_query_builder.py b/tests/trace/test_objects_query_builder.py index b074fa00e08..28dbe9aac21 100644 --- a/tests/trace/test_objects_query_builder.py +++ b/tests/trace/test_objects_query_builder.py @@ -68,18 +68,18 @@ def test_object_query_builder_basic(): assert builder.parameters["project_id"] == "test_project" -def test_object_query_builder_add_digests_condition(): +def test_object_query_builder_add_digests_conditions(): builder = ObjectMetadataQueryBuilder(project_id="test_project") # Test latest digest - builder.add_digests_condition(["latest"]) + builder.add_digests_conditions(["latest"]) assert "is_latest = 1" in builder.conditions_part # Test specific digest builder = ObjectMetadataQueryBuilder(project_id="test_project") - builder.add_digests_condition(["abc123"]) - assert "digest = {version_digest: String}" in builder.conditions_part - assert builder.parameters["version_digest"] == "abc123" + builder.add_digests_conditions(["abc123"]) + assert "digest = {version_0: String}" in builder.conditions_part + assert builder.parameters["version_0"] == "abc123" def test_object_query_builder_add_object_ids_condition(): @@ -195,7 +195,7 @@ def test_object_query_builder_sort(): def test_object_query_builder_metadata_query_basic(): builder = ObjectMetadataQueryBuilder(project_id="test_project") - builder.add_digests_condition(["latest"]) + builder.add_digests_conditions(["latest"]) query = builder.make_metadata_query() parameters = builder.parameters @@ -205,7 +205,8 @@ def test_object_query_builder_metadata_query_basic(): ) WHERE rn = 1 ) -WHERE is_latest = 1""" +WHERE ((is_latest = 1) AND (deleted_at IS NULL)) +ORDER BY created_at ASC""" assert query == expected_query assert parameters == {"project_id": "test_project"} @@ -221,7 +222,7 @@ def test_object_query_builder_metadata_query_with_limit_offset_sort(): builder.set_offset(offset) builder.add_order("created_at", "desc") builder.add_object_ids_condition(["object_1"]) - builder.add_digests_condition(["digestttttttttttttttt"]) + builder.add_digests_conditions(["digestttttttttttttttt"]) builder.add_base_object_classes_condition(["Model", "Model2"]) query = builder.make_metadata_query() @@ -232,7 +233,7 @@ def test_object_query_builder_metadata_query_with_limit_offset_sort(): ) WHERE rn = 1 ) -WHERE ((digest = {{version_digest: String}}) AND (base_object_class IN {{base_object_classes: Array(String)}})) +WHERE ((digest = {{version_0: String}}) AND (base_object_class IN {{base_object_classes: Array(String)}}) AND (deleted_at IS NULL)) ORDER BY created_at DESC LIMIT 10 OFFSET 5""" @@ -241,7 +242,7 @@ def test_object_query_builder_metadata_query_with_limit_offset_sort(): assert parameters == { "project_id": "test_project", "object_id": "object_1", - "version_digest": "digestttttttttttttttt", + "version_0": "digestttttttttttttttt", "base_object_classes": ["Model", "Model2"], } @@ -250,7 +251,7 @@ def test_objects_query_metadata_op(): builder = ObjectMetadataQueryBuilder(project_id="test_project") builder.add_is_op_condition(True) builder.add_object_ids_condition(["my_op"]) - builder.add_digests_condition(["v3"]) + builder.add_digests_conditions(["v3"]) query = builder.make_metadata_query() parameters = builder.parameters @@ -260,13 +261,14 @@ def test_objects_query_metadata_op(): ) WHERE rn = 1 ) -WHERE ((is_op = 1) AND (version_index = {{version_index_0: Int64}}))""" +WHERE ((is_op = 1) AND (version_index = {{version_0: Int64}}) AND (deleted_at IS NULL)) +ORDER BY created_at ASC""" assert query == expected_query assert parameters == { "project_id": "test_project", "object_id": "my_op", - "version_index_0": 3, + "version_0": 3, } diff --git a/weave/trace_server/objects_query_builder.py b/weave/trace_server/objects_query_builder.py index d72ba5d5125..8c56c562d29 100644 --- a/weave/trace_server/objects_query_builder.py +++ b/weave/trace_server/objects_query_builder.py @@ -45,7 +45,7 @@ def _make_offset_part(offset: Optional[int]) -> str: def _make_sort_part(sort_by: Optional[list[tsi.SortBy]]) -> str: if not sort_by: - return "ORDER BY created_at ASC" + return "" sort_clauses = [] for sort in sort_by: @@ -119,6 +119,8 @@ def object_id_conditions_part(self) -> str: @property def sort_part(self) -> str: + if not self._sort_by: + return "ORDER BY created_at ASC" return _make_sort_part(self._sort_by) @property @@ -153,7 +155,7 @@ def _make_version_index_condition(self, version_index: int, param_key: str) -> s def add_digests_conditions(self, digests: list[str]) -> None: digest_conditions = [] for i, digest in enumerate(digests): - condition = self._make_digest_condition(digest, f"version_digest_{i}") + condition = self._make_digest_condition(digest, f"version_{i}") digest_conditions.append(condition) digests_condition = combine_conditions(digest_conditions, "OR") From dce16e2fd0a985742f3f524e84f0c8a50e8841ac Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Mon, 16 Dec 2024 15:37:29 -0800 Subject: [PATCH 4/5] fixtest --- weave/trace_server/sqlite_trace_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/weave/trace_server/sqlite_trace_server.py b/weave/trace_server/sqlite_trace_server.py index be1ea26bb2f..ed0a5accb0c 100644 --- a/weave/trace_server/sqlite_trace_server.py +++ b/weave/trace_server/sqlite_trace_server.py @@ -785,7 +785,7 @@ def obj_delete(self, req: tsi.ObjDeleteReq) -> tsi.ObjDeleteRes: digest_conditions = [ self._make_digest_condition(digest) for digest in req.digests ] - digest_conditions_str = " AND ".join(digest_conditions) + digest_conditions_str = " OR ".join(digest_conditions) select_query += f"AND ({digest_conditions_str})" conn, cursor = get_conn_cursor(self.db_path) From dc2f278c6dd8fc312b0e2be2705ef207b1d71835 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Tue, 17 Dec 2024 08:34:56 -0800 Subject: [PATCH 5/5] cosmetic query change --- weave/trace_server/sqlite_trace_server.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/weave/trace_server/sqlite_trace_server.py b/weave/trace_server/sqlite_trace_server.py index ed0a5accb0c..e65316febc5 100644 --- a/weave/trace_server/sqlite_trace_server.py +++ b/weave/trace_server/sqlite_trace_server.py @@ -1260,9 +1260,11 @@ def _select_objs_query( sort_by: Optional[list[tsi.SortBy]] = None, ) -> list[tsi.ObjSchema]: conn, cursor = get_conn_cursor(self.db_path) - conditions = conditions or ["1 = 1"] + conditions = conditions or [] if not include_deleted: conditions.append("deleted_at IS NULL") + if not conditions: + conditions.append("1 = 1") pred = " AND ".join(conditions) val_dump_part = "'{}' as val_dump" if metadata_only else "val_dump" query = f"""