Skip to content

Commit

Permalink
basic server impl + server tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gtarpenning committed Dec 13, 2024
1 parent 60133fd commit 32baaa0
Show file tree
Hide file tree
Showing 9 changed files with 355 additions and 37 deletions.
107 changes: 107 additions & 0 deletions tests/trace/test_obj_delete.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import weave
from weave.trace.weave_client import WeaveClient
from weave.trace_server import trace_server_interface as tsi


def _objs_query(client: WeaveClient, object_id: str) -> list[tsi.ObjSchema]:
objs = client.server.objs_query(
tsi.ObjQueryReq(
project_id=client._project_id(),
filter=tsi.ObjectVersionFilter(object_ids=[object_id]),
)
)
return objs.objs


def _obj_delete(client: WeaveClient, object_id: str, digests: list[str]) -> int:
return client.server.obj_delete(
tsi.ObjDeleteReq(
project_id=client._project_id(),
object_id=object_id,
digests=digests,
)
).num_deleted


def test_delete_object_versions(client: WeaveClient):
v0 = weave.publish({"i": 1}, name="obj_1")
v1 = weave.publish({"i": 2}, name="obj_1")
v2 = weave.publish({"i": 3}, name="obj_1")

objs = _objs_query(client, "obj_1")
assert len(objs) == 3

num_deleted = _obj_delete(client, "obj_1", [v0.digest])
assert num_deleted == 1

objs = _objs_query(client, "obj_1")
assert len(objs) == 2

# test deleting an already deleted digest
num_deleted = _obj_delete(client, "obj_1", [v0.digest])
assert num_deleted == 0

# test deleting a non-existent digest
num_deleted = _obj_delete(client, "obj_1", ["non-existent-digest"])
assert num_deleted == 0

# test deleting multiple digests
digests = [v1.digest, v2.digest]
num_deleted = _obj_delete(client, "obj_1", digests)
assert num_deleted == 2

objs = _objs_query(client, "obj_1")
assert len(objs) == 0


def test_delete_all_object_versions(client: WeaveClient):
weave.publish({"i": 1}, name="obj_1")
weave.publish({"i": 2}, name="obj_1")
weave.publish({"i": 3}, name="obj_1")

num_deleted = _obj_delete(client, "obj_1", None)
assert num_deleted == 3

objs = _objs_query(client, "obj_1")
assert len(objs) == 0

num_deleted = _obj_delete(client, "obj_1", None)
assert num_deleted == 0


def test_delete_version_correctness(client: WeaveClient):
v0 = weave.publish({"i": 1}, name="obj_1")
v1 = weave.publish({"i": 2}, name="obj_1")
v2 = weave.publish({"i": 3}, name="obj_1")

_obj_delete(client, "obj_1", [v1.digest])
objs = _objs_query(client, "obj_1")
assert len(objs) == 2
assert objs[0].digest == v0.digest
assert objs[0].val == {"i": 1}
assert objs[0].version_index == 0
assert objs[1].digest == v2.digest
assert objs[1].val == {"i": 3}
assert objs[1].version_index == 2

v3 = weave.publish({"i": 4}, name="obj_1")
assert len(objs) == 3
assert objs[0].digest == v0.digest
assert objs[0].val == {"i": 1}
assert objs[0].version_index == 0
assert objs[1].digest == v2.digest
assert objs[1].val == {"i": 3}
assert objs[1].version_index == 2
assert objs[2].digest == v3.digest
assert objs[2].val == {"i": 4}
assert objs[2].version_index == 3

_obj_delete(client, "obj_1", [v3.digest])
objs = _objs_query(client, "obj_1")
assert len(objs) == 2
assert objs[0].digest == v0.digest
assert objs[0].val == {"i": 1}
assert objs[0].version_index == 0
assert objs[1].digest == v2.digest
assert objs[1].val == {"i": 3}
assert objs[1].version_index == 2
5 changes: 5 additions & 0 deletions weave/trace_server/clickhouse_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ class ObjCHInsertable(BaseModel):
_refs = field_validator("refs")(validation.refs_list_validator)


class ObjDeleteCHInsertable(ObjCHInsertable):
deleted_at: datetime.datetime
created_at: datetime.datetime


class SelectableCHObjSchema(BaseModel):
project_id: str
object_id: str
Expand Down
91 changes: 88 additions & 3 deletions weave/trace_server/clickhouse_trace_server_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
CallStartCHInsertable,
CallUpdateCHInsertable,
ObjCHInsertable,
ObjDeleteCHInsertable,
SelectableCHCallSchema,
SelectableCHObjSchema,
)
Expand All @@ -69,6 +70,7 @@
InsertTooLarge,
InvalidRequest,
MissingLLMApiKeyError,
ObjectDeletedError,
RequestTooLarge,
)
from weave.trace_server.feedback import (
Expand Down Expand Up @@ -533,12 +535,19 @@ def op_read(self, req: tsi.OpReadReq) -> tsi.OpReadRes:
object_query_builder.add_is_op_condition(True)
object_query_builder.add_digest_condition(req.digest)
object_query_builder.add_object_ids_condition([req.name], "op_name")
object_query_builder.set_deleted_at_condition(include_deleted=True)

objs = self._select_objs_query(object_query_builder)
if len(objs) == 0:
raise NotFoundError(f"Obj {req.name}:{req.digest} not found")

return tsi.OpReadRes(op_obj=_ch_obj_to_obj_schema(objs[0]))
op = objs[0]
if op.deleted_at is not None:
raise ObjectDeletedError(
f"Op {req.name}:v{op.version_index} was deleted at {op.deleted_at}"
)

return tsi.OpReadRes(op_obj=_ch_obj_to_obj_schema(op))

def ops_query(self, req: tsi.OpQueryReq) -> tsi.OpQueryRes:
object_query_builder = ObjectQueryBuilder(req.project_id)
Expand Down Expand Up @@ -584,12 +593,19 @@ def obj_read(self, req: tsi.ObjReadReq) -> tsi.ObjReadRes:
object_query_builder = ObjectQueryBuilder(req.project_id)
object_query_builder.add_digest_condition(req.digest)
object_query_builder.add_object_ids_condition([req.object_id])
object_query_builder.set_deleted_at_condition(include_deleted=True)

objs = self._select_objs_query(object_query_builder)
if len(objs) == 0:
raise NotFoundError(f"Obj {req.object_id}:{req.digest} not found")

return tsi.ObjReadRes(obj=_ch_obj_to_obj_schema(objs[0]))
obj = objs[0]
if obj.deleted_at is not None:
raise ObjectDeletedError(
f"Obj {req.object_id}:v{obj.version_index} was deleted at {obj.deleted_at}"
)

return tsi.ObjReadRes(obj=_ch_obj_to_obj_schema(obj))

def objs_query(self, req: tsi.ObjQueryReq) -> tsi.ObjQueryRes:
object_query_builder = ObjectQueryBuilder(req.project_id)
Expand Down Expand Up @@ -623,6 +639,66 @@ def objs_query(self, req: tsi.ObjQueryReq) -> tsi.ObjQueryRes:

return tsi.ObjQueryRes(objs=[_ch_obj_to_obj_schema(obj) for obj in objs])

def obj_delete(self, req: tsi.ObjDeleteReq) -> tsi.ObjDeleteRes:
"""
Delete object versions by digest, belonging to given object_id.
All deletion in this method is "soft". Deletion occurs by inserting
a new row into the object_versions table with the deleted_at field set.
Inserted rows share identical primary keys (order by) with original rows,
and will be combined by the ReplacingMergeTree engine at database merge
time.
If no digests are provided all versions will have their data overwritten with
an empty val_dump.
"""
MAX_OBJECTS_TO_DELETE = 100
if req.digests and len(req.digests) > MAX_OBJECTS_TO_DELETE:
raise ValueError(
f"Object delete request contains {len(req.digests)} objects. Please delete fewer than {MAX_OBJECTS_TO_DELETE} objects at a time."
)

object_query_builder = ObjectQueryBuilder(req.project_id)
object_query_builder.add_object_ids_condition([req.object_id])
if req.digests:
for i, digest in enumerate(req.digests):
object_query_builder._add_version_digest_condition(
digest, f"digest_{i}"
)
metadata_only = req.digests is not None and len(req.digests) > 0
object_query_builder.set_metadata_only(metadata_only)

object_versions = self._select_objs_query(object_query_builder)

delete_insertables = []
now = datetime.datetime.now(datetime.timezone.utc)
for obj in object_versions:
original_created_at = _ensure_datetimes_have_tz_strict(obj.created_at)
delete_insertables.append(
ObjDeleteCHInsertable(
project_id=obj.project_id,
object_id=obj.object_id,
digest=obj.digest,
kind=obj.kind,
val_dump=obj.val_dump,
refs=obj.refs,
base_object_class=obj.base_object_class,
deleted_at=now,
created_at=original_created_at,
)
)

if len(delete_insertables) == 0:
raise NotFoundError(
f"Object {req.object_id} ({req.digests}) not found when deleting."
)

data = [list(obj.model_dump().values()) for obj in delete_insertables]
column_names = list(delete_insertables[0].model_fields.keys())
self._insert("object_versions", data=data, column_names=column_names)

num_deleted = len(delete_insertables)

return tsi.ObjDeleteRes(num_deleted=num_deleted)

def table_create(self, req: tsi.TableCreateReq) -> tsi.TableCreateRes:
insert_rows = []
for r in req.table.rows:
Expand Down Expand Up @@ -1563,7 +1639,7 @@ def _select_objs_query(
metadata_result = format_metadata_objects_from_query_result(query_result)

# -- Don't make second query for object values if metadata_only --
if object_query_builder.metadata_only:
if object_query_builder.metadata_only or len(metadata_result) == 0:
return metadata_result

value_query, value_parameters = make_objects_val_query_and_parameters(
Expand Down Expand Up @@ -1776,6 +1852,15 @@ def _ensure_datetimes_have_tz(
return dt


def _ensure_datetimes_have_tz_strict(
dt: datetime.datetime,
) -> datetime.datetime:
res = _ensure_datetimes_have_tz(dt)
if res is None:
raise ValueError(f"Datetime is None: {dt}")
return res


def _nullable_dict_dump_to_dict(
val: Optional[str],
) -> Optional[dict[str, Any]]:
Expand Down
12 changes: 12 additions & 0 deletions weave/trace_server/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@ class InvalidFieldError(Error):
pass


class NotFoundError(Error):
"""Raised when a general not found error occurs."""

pass


class ObjectDeletedError(Error):
"""Raised when an object has been deleted."""

pass


class MissingLLMApiKeyError(Error):
"""Raised when a LLM API key is missing for completion."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,10 @@ def objs_query(self, req: tsi.ObjQueryReq) -> tsi.ObjQueryRes:
obj.project_id = original_project_id
return res

def obj_delete(self, req: tsi.ObjDeleteReq) -> tsi.ObjDeleteRes:
req.project_id = self._idc.ext_to_int_project_id(req.project_id)
return self._ref_apply(self._internal_trace_server.obj_delete, req)

def table_create(self, req: tsi.TableCreateReq) -> tsi.TableCreateRes:
req.table.project_id = self._idc.ext_to_int_project_id(req.table.project_id)
return self._ref_apply(self._internal_trace_server.table_create, req)
Expand Down
25 changes: 19 additions & 6 deletions weave/trace_server/objects_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,11 @@ def set_offset(self, offset: int) -> None:
def set_metadata_only(self, metadata_only: bool) -> None:
self.metadata_only = metadata_only

def set_deleted_at_condition(self, include_deleted: bool) -> None:
if include_deleted:
return
self._conditions.append("deleted_at IS NULL")

def make_metadata_query(self) -> str:
return f"""
SELECT
Expand All @@ -217,7 +222,8 @@ def make_metadata_query(self) -> str:
is_op,
version_index,
version_count,
is_latest
is_latest,
deleted_at
FROM (
SELECT project_id,
object_id,
Expand All @@ -227,20 +233,27 @@ def make_metadata_query(self) -> str:
refs,
digest,
is_op,
deleted_at,
row_number() OVER (
PARTITION BY project_id,
kind,
object_id
ORDER BY created_at ASC
) - 1 AS version_index,
count(*) OVER (PARTITION BY project_id, kind, object_id) as version_count,
if(version_index + 1 = version_count, 1, 0) AS is_latest
row_number() OVER (
PARTITION BY project_id, kind, object_id
ORDER BY (deleted_at IS NULL) DESC, created_at DESC
) AS row_num,
if (row_num = 1, 1, 0) AS is_latest
FROM (
SELECT project_id,
object_id,
created_at,
MIN(created_at) AS created_at,
MIN(deleted_at) AS deleted_at,
kind,
base_object_class,
MIN(base_object_class) AS base_object_class,
MIN(refs) AS refs,
refs,
digest,
if (kind = 'op', 1, 0) AS is_op,
Expand All @@ -253,9 +266,9 @@ def make_metadata_query(self) -> str:
) AS rn
FROM object_versions
WHERE project_id = {{project_id: String}}
{self.object_id_conditions_part}
{self.object_id_conditions_part}
GROUP BY project_id, kind, object_id, digest
)
WHERE rn = 1
)
{self.conditions_part}
{self.sort_part}
Expand Down
Loading

0 comments on commit 32baaa0

Please sign in to comment.