diff --git a/tests/trace/test_obj_delete.py b/tests/trace/test_obj_delete.py new file mode 100644 index 00000000000..96704b8cbd9 --- /dev/null +++ b/tests/trace/test_obj_delete.py @@ -0,0 +1,107 @@ +import weave +from weave.trace.weave_client import WeaveClient +from weave.trace_server import trace_server_interface as tsi + + +def _objs_query(client: WeaveClient, object_id: str) -> list[tsi.ObjSchema]: + objs = client.server.objs_query( + tsi.ObjQueryReq( + project_id=client._project_id(), + filter=tsi.ObjectVersionFilter(object_ids=[object_id]), + ) + ) + return objs.objs + + +def _obj_delete(client: WeaveClient, object_id: str, digests: list[str]) -> int: + return client.server.obj_delete( + tsi.ObjDeleteReq( + project_id=client._project_id(), + object_id=object_id, + digests=digests, + ) + ).num_deleted + + +def test_delete_object_versions(client: WeaveClient): + v0 = weave.publish({"i": 1}, name="obj_1") + v1 = weave.publish({"i": 2}, name="obj_1") + v2 = weave.publish({"i": 3}, name="obj_1") + + objs = _objs_query(client, "obj_1") + assert len(objs) == 3 + + num_deleted = _obj_delete(client, "obj_1", [v0.digest]) + assert num_deleted == 1 + + objs = _objs_query(client, "obj_1") + assert len(objs) == 2 + + # test deleting an already deleted digest + num_deleted = _obj_delete(client, "obj_1", [v0.digest]) + assert num_deleted == 0 + + # test deleting a non-existent digest + num_deleted = _obj_delete(client, "obj_1", ["non-existent-digest"]) + assert num_deleted == 0 + + # test deleting multiple digests + digests = [v1.digest, v2.digest] + num_deleted = _obj_delete(client, "obj_1", digests) + assert num_deleted == 2 + + objs = _objs_query(client, "obj_1") + assert len(objs) == 0 + + +def test_delete_all_object_versions(client: WeaveClient): + weave.publish({"i": 1}, name="obj_1") + weave.publish({"i": 2}, name="obj_1") + weave.publish({"i": 3}, name="obj_1") + + num_deleted = _obj_delete(client, "obj_1", None) + assert num_deleted == 3 + + objs = _objs_query(client, "obj_1") + assert len(objs) == 0 + + num_deleted = _obj_delete(client, "obj_1", None) + assert num_deleted == 0 + + +def test_delete_version_correctness(client: WeaveClient): + v0 = weave.publish({"i": 1}, name="obj_1") + v1 = weave.publish({"i": 2}, name="obj_1") + v2 = weave.publish({"i": 3}, name="obj_1") + + _obj_delete(client, "obj_1", [v1.digest]) + objs = _objs_query(client, "obj_1") + assert len(objs) == 2 + assert objs[0].digest == v0.digest + assert objs[0].val == {"i": 1} + assert objs[0].version_index == 0 + assert objs[1].digest == v2.digest + assert objs[1].val == {"i": 3} + assert objs[1].version_index == 2 + + v3 = weave.publish({"i": 4}, name="obj_1") + assert len(objs) == 3 + assert objs[0].digest == v0.digest + assert objs[0].val == {"i": 1} + assert objs[0].version_index == 0 + assert objs[1].digest == v2.digest + assert objs[1].val == {"i": 3} + assert objs[1].version_index == 2 + assert objs[2].digest == v3.digest + assert objs[2].val == {"i": 4} + assert objs[2].version_index == 3 + + _obj_delete(client, "obj_1", [v3.digest]) + objs = _objs_query(client, "obj_1") + assert len(objs) == 2 + assert objs[0].digest == v0.digest + assert objs[0].val == {"i": 1} + assert objs[0].version_index == 0 + assert objs[1].digest == v2.digest + assert objs[1].val == {"i": 3} + assert objs[1].version_index == 2 diff --git a/weave/trace_server/clickhouse_schema.py b/weave/trace_server/clickhouse_schema.py index c02b5f7b332..d82f2570b69 100644 --- a/weave/trace_server/clickhouse_schema.py +++ b/weave/trace_server/clickhouse_schema.py @@ -136,6 +136,11 @@ class ObjCHInsertable(BaseModel): _refs = field_validator("refs")(validation.refs_list_validator) +class ObjDeleteCHInsertable(ObjCHInsertable): + deleted_at: datetime.datetime + created_at: datetime.datetime + + class SelectableCHObjSchema(BaseModel): project_id: str object_id: str diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index ce31704caaf..4edc09f67c7 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -60,6 +60,7 @@ CallStartCHInsertable, CallUpdateCHInsertable, ObjCHInsertable, + ObjDeleteCHInsertable, SelectableCHCallSchema, SelectableCHObjSchema, ) @@ -69,6 +70,7 @@ InsertTooLarge, InvalidRequest, MissingLLMApiKeyError, + ObjectDeletedError, RequestTooLarge, ) from weave.trace_server.feedback import ( @@ -533,12 +535,19 @@ def op_read(self, req: tsi.OpReadReq) -> tsi.OpReadRes: object_query_builder.add_is_op_condition(True) object_query_builder.add_digest_condition(req.digest) object_query_builder.add_object_ids_condition([req.name], "op_name") + object_query_builder.set_deleted_at_condition(include_deleted=True) objs = self._select_objs_query(object_query_builder) if len(objs) == 0: raise NotFoundError(f"Obj {req.name}:{req.digest} not found") - return tsi.OpReadRes(op_obj=_ch_obj_to_obj_schema(objs[0])) + op = objs[0] + if op.deleted_at is not None: + raise ObjectDeletedError( + f"Op {req.name}:v{op.version_index} was deleted at {op.deleted_at}" + ) + + return tsi.OpReadRes(op_obj=_ch_obj_to_obj_schema(op)) def ops_query(self, req: tsi.OpQueryReq) -> tsi.OpQueryRes: object_query_builder = ObjectQueryBuilder(req.project_id) @@ -584,12 +593,19 @@ def obj_read(self, req: tsi.ObjReadReq) -> tsi.ObjReadRes: object_query_builder = ObjectQueryBuilder(req.project_id) object_query_builder.add_digest_condition(req.digest) object_query_builder.add_object_ids_condition([req.object_id]) + object_query_builder.set_deleted_at_condition(include_deleted=True) objs = self._select_objs_query(object_query_builder) if len(objs) == 0: raise NotFoundError(f"Obj {req.object_id}:{req.digest} not found") - return tsi.ObjReadRes(obj=_ch_obj_to_obj_schema(objs[0])) + obj = objs[0] + if obj.deleted_at is not None: + raise ObjectDeletedError( + f"Obj {req.object_id}:v{obj.version_index} was deleted at {obj.deleted_at}" + ) + + return tsi.ObjReadRes(obj=_ch_obj_to_obj_schema(obj)) def objs_query(self, req: tsi.ObjQueryReq) -> tsi.ObjQueryRes: object_query_builder = ObjectQueryBuilder(req.project_id) @@ -623,6 +639,66 @@ def objs_query(self, req: tsi.ObjQueryReq) -> tsi.ObjQueryRes: return tsi.ObjQueryRes(objs=[_ch_obj_to_obj_schema(obj) for obj in objs]) + def obj_delete(self, req: tsi.ObjDeleteReq) -> tsi.ObjDeleteRes: + """ + Delete object versions by digest, belonging to given object_id. + All deletion in this method is "soft". Deletion occurs by inserting + a new row into the object_versions table with the deleted_at field set. + Inserted rows share identical primary keys (order by) with original rows, + and will be combined by the ReplacingMergeTree engine at database merge + time. + If no digests are provided all versions will have their data overwritten with + an empty val_dump. + """ + MAX_OBJECTS_TO_DELETE = 100 + if req.digests and len(req.digests) > MAX_OBJECTS_TO_DELETE: + raise ValueError( + f"Object delete request contains {len(req.digests)} objects. Please delete fewer than {MAX_OBJECTS_TO_DELETE} objects at a time." + ) + + object_query_builder = ObjectQueryBuilder(req.project_id) + object_query_builder.add_object_ids_condition([req.object_id]) + if req.digests: + for i, digest in enumerate(req.digests): + object_query_builder._add_version_digest_condition( + digest, f"digest_{i}" + ) + metadata_only = req.digests is not None and len(req.digests) > 0 + object_query_builder.set_metadata_only(metadata_only) + + object_versions = self._select_objs_query(object_query_builder) + + delete_insertables = [] + now = datetime.datetime.now(datetime.timezone.utc) + for obj in object_versions: + original_created_at = _ensure_datetimes_have_tz_strict(obj.created_at) + delete_insertables.append( + ObjDeleteCHInsertable( + project_id=obj.project_id, + object_id=obj.object_id, + digest=obj.digest, + kind=obj.kind, + val_dump=obj.val_dump, + refs=obj.refs, + base_object_class=obj.base_object_class, + deleted_at=now, + created_at=original_created_at, + ) + ) + + if len(delete_insertables) == 0: + raise NotFoundError( + f"Object {req.object_id} ({req.digests}) not found when deleting." + ) + + data = [list(obj.model_dump().values()) for obj in delete_insertables] + column_names = list(delete_insertables[0].model_fields.keys()) + self._insert("object_versions", data=data, column_names=column_names) + + num_deleted = len(delete_insertables) + + return tsi.ObjDeleteRes(num_deleted=num_deleted) + def table_create(self, req: tsi.TableCreateReq) -> tsi.TableCreateRes: insert_rows = [] for r in req.table.rows: @@ -1563,7 +1639,7 @@ def _select_objs_query( metadata_result = format_metadata_objects_from_query_result(query_result) # -- Don't make second query for object values if metadata_only -- - if object_query_builder.metadata_only: + if object_query_builder.metadata_only or len(metadata_result) == 0: return metadata_result value_query, value_parameters = make_objects_val_query_and_parameters( @@ -1776,6 +1852,15 @@ def _ensure_datetimes_have_tz( return dt +def _ensure_datetimes_have_tz_strict( + dt: datetime.datetime, +) -> datetime.datetime: + res = _ensure_datetimes_have_tz(dt) + if res is None: + raise ValueError(f"Datetime is None: {dt}") + return res + + def _nullable_dict_dump_to_dict( val: Optional[str], ) -> Optional[dict[str, Any]]: diff --git a/weave/trace_server/errors.py b/weave/trace_server/errors.py index b2014fc1a48..fd9c34f34a5 100644 --- a/weave/trace_server/errors.py +++ b/weave/trace_server/errors.py @@ -28,6 +28,18 @@ class InvalidFieldError(Error): pass +class NotFoundError(Error): + """Raised when a general not found error occurs.""" + + pass + + +class ObjectDeletedError(Error): + """Raised when an object has been deleted.""" + + pass + + class MissingLLMApiKeyError(Error): """Raised when a LLM API key is missing for completion.""" diff --git a/weave/trace_server/external_to_internal_trace_server_adapter.py b/weave/trace_server/external_to_internal_trace_server_adapter.py index 1df739adbcd..7718295d9d3 100644 --- a/weave/trace_server/external_to_internal_trace_server_adapter.py +++ b/weave/trace_server/external_to_internal_trace_server_adapter.py @@ -257,6 +257,10 @@ def objs_query(self, req: tsi.ObjQueryReq) -> tsi.ObjQueryRes: obj.project_id = original_project_id return res + def obj_delete(self, req: tsi.ObjDeleteReq) -> tsi.ObjDeleteRes: + req.project_id = self._idc.ext_to_int_project_id(req.project_id) + return self._ref_apply(self._internal_trace_server.obj_delete, req) + def table_create(self, req: tsi.TableCreateReq) -> tsi.TableCreateRes: req.table.project_id = self._idc.ext_to_int_project_id(req.table.project_id) return self._ref_apply(self._internal_trace_server.table_create, req) diff --git a/weave/trace_server/objects_query_builder.py b/weave/trace_server/objects_query_builder.py index df6ed522cc7..6b354625750 100644 --- a/weave/trace_server/objects_query_builder.py +++ b/weave/trace_server/objects_query_builder.py @@ -204,6 +204,11 @@ def set_offset(self, offset: int) -> None: def set_metadata_only(self, metadata_only: bool) -> None: self.metadata_only = metadata_only + def set_deleted_at_condition(self, include_deleted: bool) -> None: + if include_deleted: + return + self._conditions.append("deleted_at IS NULL") + def make_metadata_query(self) -> str: return f""" SELECT @@ -217,7 +222,8 @@ def make_metadata_query(self) -> str: is_op, version_index, version_count, - is_latest + is_latest, + deleted_at FROM ( SELECT project_id, object_id, @@ -227,6 +233,7 @@ def make_metadata_query(self) -> str: refs, digest, is_op, + deleted_at, row_number() OVER ( PARTITION BY project_id, kind, @@ -234,13 +241,19 @@ def make_metadata_query(self) -> str: ORDER BY created_at ASC ) - 1 AS version_index, count(*) OVER (PARTITION BY project_id, kind, object_id) as version_count, - if(version_index + 1 = version_count, 1, 0) AS is_latest + row_number() OVER ( + PARTITION BY project_id, kind, object_id + ORDER BY (deleted_at IS NULL) DESC, created_at DESC + ) AS row_num, + if (row_num = 1, 1, 0) AS is_latest FROM ( SELECT project_id, object_id, - created_at, + MIN(created_at) AS created_at, + MIN(deleted_at) AS deleted_at, kind, - base_object_class, + MIN(base_object_class) AS base_object_class, + MIN(refs) AS refs, refs, digest, if (kind = 'op', 1, 0) AS is_op, @@ -253,9 +266,9 @@ def make_metadata_query(self) -> str: ) AS rn FROM object_versions WHERE project_id = {{project_id: String}} - {self.object_id_conditions_part} + {self.object_id_conditions_part} + GROUP BY project_id, kind, object_id, digest ) - WHERE rn = 1 ) {self.conditions_part} {self.sort_part} diff --git a/weave/trace_server/sqlite_trace_server.py b/weave/trace_server/sqlite_trace_server.py index e14cef8041e..f09ead96a49 100644 --- a/weave/trace_server/sqlite_trace_server.py +++ b/weave/trace_server/sqlite_trace_server.py @@ -16,7 +16,11 @@ from weave.trace_server import trace_server_interface as tsi from weave.trace_server.base_object_class_util import process_incoming_object_val from weave.trace_server.emoji_util import detone_emojis -from weave.trace_server.errors import InvalidRequest +from weave.trace_server.errors import ( + InvalidRequest, + NotFoundError, + ObjectDeletedError, +) from weave.trace_server.feedback import ( TABLE_FEEDBACK, validate_feedback_create_req, @@ -47,10 +51,6 @@ MAX_FLUSH_AGE = 15 -class NotFoundError(Exception): - pass - - _conn_cursor: contextvars.ContextVar[ Optional[tuple[sqlite3.Connection, sqlite3.Cursor]] ] = contextvars.ContextVar("conn_cursor", default=None) @@ -616,25 +616,18 @@ def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes: ) json_val = json.dumps(val) digest = str_digest(json_val) + project_id, object_id = req.obj.project_id, req.obj.object_id # Validate - object_id_validator(req.obj.object_id) + object_id_validator(object_id) - # TODO: version index isn't right here, what if we delete stuff? with self.lock: - cursor.execute("BEGIN TRANSACTION") - # Mark all existing objects with such id as not latest - cursor.execute( - """UPDATE objects SET is_latest = 0 WHERE project_id = ? AND object_id = ?""", - (req.obj.project_id, req.obj.object_id), - ) - # first get version count - cursor.execute( - """SELECT COUNT(*) FROM objects WHERE project_id = ? AND object_id = ?""", - (req.obj.project_id, req.obj.object_id), - ) - version_index = cursor.fetchone()[0] + if self._obj_exists(cursor, project_id, object_id, digest): + return tsi.ObjCreateRes(digest=digest) + cursor.execute("BEGIN TRANSACTION") + self._mark_existing_objects_as_not_latest(cursor, project_id, object_id) + version_index = self._get_obj_version_index(cursor, project_id, object_id) cursor.execute( """INSERT OR IGNORE INTO objects ( project_id, @@ -646,11 +639,22 @@ def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes: val_dump, digest, version_index, - is_latest - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + is_latest, + deleted_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(project_id, kind, object_id, digest) DO UPDATE SET + created_at = excluded.created_at, + kind = excluded.kind, + base_object_class = excluded.base_object_class, + refs = excluded.refs, + val_dump = excluded.val_dump, + version_index = excluded.version_index, + is_latest = excluded.is_latest, + deleted_at = excluded.deleted_at + """, ( - req.obj.project_id, - req.obj.object_id, + project_id, + object_id, datetime.datetime.now().isoformat(), get_kind(val), base_object_class, @@ -659,11 +663,48 @@ def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes: digest, version_index, 1, + None, ), ) conn.commit() return tsi.ObjCreateRes(digest=digest) + def _obj_exists( + self, cursor: sqlite3.Cursor, project_id: str, object_id: str, digest: str + ) -> bool: + cursor.execute( + "SELECT COUNT(*) FROM objects WHERE project_id = ? AND object_id = ? AND digest = ? AND deleted_at IS NULL", + (project_id, object_id, digest), + ) + return_row = cursor.fetchone() + if return_row is None: + return False + return return_row[0] > 0 + + def _mark_existing_objects_as_not_latest( + self, cursor: sqlite3.Cursor, project_id: str, object_id: str + ) -> None: + """Mark all existing objects with such id as not latest. + We are creating a new object with the same id, all existing ones are no longer latest. + """ + cursor.execute( + "UPDATE objects SET is_latest = 0 WHERE project_id = ? AND object_id = ?", + (project_id, object_id), + ) + + def _get_obj_version_index( + self, cursor: sqlite3.Cursor, project_id: str, object_id: str + ) -> int: + """Get the version index for a new object with such id.""" + cursor.execute( + "SELECT COUNT(*) FROM objects WHERE project_id = ? AND object_id = ?", + (project_id, object_id), + ) + return_row = cursor.fetchone() + if return_row is None: + return 0 + return return_row[0] + def obj_read(self, req: tsi.ObjReadReq) -> tsi.ObjReadRes: conds = [f"object_id = '{req.object_id}'"] if req.digest == "latest": @@ -677,10 +718,14 @@ def obj_read(self, req: tsi.ObjReadReq) -> tsi.ObjReadRes: objs = self._select_objs_query( req.project_id, conditions=conds, + include_deleted=True, ) if len(objs) == 0: raise NotFoundError(f"Obj {req.object_id}:{req.digest} not found") - + if objs[0].deleted_at is not None: + raise ObjectDeletedError( + f"Obj {req.object_id}:v{objs[0].version_index} was deleted at {objs[0].deleted_at}" + ) return tsi.ObjReadRes(obj=objs[0]) def objs_query(self, req: tsi.ObjQueryReq) -> tsi.ObjQueryRes: @@ -715,6 +760,29 @@ def objs_query(self, req: tsi.ObjQueryReq) -> tsi.ObjQueryRes: return tsi.ObjQueryRes(objs=objs) + def obj_delete(self, req: tsi.ObjDeleteReq) -> tsi.ObjDeleteRes: + delete_query = """ + UPDATE objects SET deleted_at = CURRENT_TIMESTAMP + WHERE project_id = ? AND + object_id = ? + """ + parameters = [req.project_id, req.object_id] + + if req.digests: + num_digests = len(req.digests) + delete_query += "AND digest IN ({})".format(", ".join("?" * num_digests)) + parameters.extend(req.digests) + + conn, cursor = get_conn_cursor(self.db_path) + with self.lock: + cursor.execute("BEGIN TRANSACTION") + cursor.execute(delete_query, parameters) + # get the number of objects deleted + cursor.execute("SELECT changes()") + num_deleted = cursor.fetchone()[0] + conn.commit() + return tsi.ObjDeleteRes(num_deleted=num_deleted) + def table_create(self, req: tsi.TableCreateReq) -> tsi.TableCreateRes: conn, cursor = get_conn_cursor(self.db_path) insert_rows = [] @@ -1150,11 +1218,15 @@ def _select_objs_query( parameters: Optional[dict[str, Any]] = None, metadata_only: Optional[bool] = False, limit: Optional[int] = None, + include_deleted: bool = False, offset: Optional[int] = None, sort_by: Optional[list[tsi.SortBy]] = None, ) -> list[tsi.ObjSchema]: conn, cursor = get_conn_cursor(self.db_path) - pred = " AND ".join(conditions or ["1 = 1"]) + conditions = conditions or ["1 = 1"] + if not include_deleted: + conditions.append("deleted_at IS NULL") + pred = " AND ".join(conditions) val_dump_part = "'{}' as val_dump" if metadata_only else "val_dump" query = f""" SELECT @@ -1166,10 +1238,10 @@ def _select_objs_query( {val_dump_part}, digest, version_index, - is_latest + is_latest, + deleted_at FROM objects - WHERE deleted_at IS NULL AND - project_id = ? AND {pred} + WHERE project_id = ? AND {pred} """ if sort_by: @@ -1211,6 +1283,7 @@ def _select_objs_query( digest=row[6], version_index=row[7], is_latest=row[8], + deleted_at=row[9], ) ) return result diff --git a/weave/trace_server/trace_server_interface.py b/weave/trace_server/trace_server_interface.py index 77ef6198ecf..61508d3255b 100644 --- a/weave/trace_server/trace_server_interface.py +++ b/weave/trace_server/trace_server_interface.py @@ -467,6 +467,19 @@ class ObjQueryReq(BaseModel): ) +class ObjDeleteReq(BaseModel): + project_id: str + object_id: str + digests: Optional[list[str]] = Field( + default=None, + description="List of digests to delete. If not provided, all digests for the object will be deleted.", + ) + + +class ObjDeleteRes(BaseModel): + num_deleted: int + + class ObjQueryRes(BaseModel): objs: list[ObjSchema] @@ -897,6 +910,7 @@ def cost_purge(self, req: CostPurgeReq) -> CostPurgeRes: ... def obj_create(self, req: ObjCreateReq) -> ObjCreateRes: ... def obj_read(self, req: ObjReadReq) -> ObjReadRes: ... def objs_query(self, req: ObjQueryReq) -> ObjQueryRes: ... + def obj_delete(self, req: ObjDeleteReq) -> ObjDeleteRes: ... def table_create(self, req: TableCreateReq) -> TableCreateRes: ... def table_update(self, req: TableUpdateReq) -> TableUpdateRes: ... def table_query(self, req: TableQueryReq) -> TableQueryRes: ... diff --git a/weave/trace_server_bindings/remote_http_trace_server.py b/weave/trace_server_bindings/remote_http_trace_server.py index b749af60d24..966bf92a815 100644 --- a/weave/trace_server_bindings/remote_http_trace_server.py +++ b/weave/trace_server_bindings/remote_http_trace_server.py @@ -352,6 +352,11 @@ def objs_query( "/objs/query", req, tsi.ObjQueryReq, tsi.ObjQueryRes ) + def obj_delete(self, req: tsi.ObjDeleteReq) -> tsi.ObjDeleteRes: + return self._generic_request( + "/obj/delete", req, tsi.ObjDeleteReq, tsi.ObjDeleteRes + ) + def table_create( self, req: Union[tsi.TableCreateReq, dict[str, Any]] ) -> tsi.TableCreateRes: