Skip to content

Commit

Permalink
better, stil not passing
Browse files Browse the repository at this point in the history
  • Loading branch information
gtarpenning committed Dec 16, 2024
1 parent db8b2a3 commit 3de8721
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 83 deletions.
115 changes: 47 additions & 68 deletions tests/trace/test_permanently_delete_project.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from datetime import datetime

import pytest

import weave
from weave.trace.weave_client import WeaveClient
from weave.trace_server import trace_server_interface as tsi


Expand All @@ -9,8 +13,10 @@ def project_id():


@pytest.fixture
def setup_test_data(project_id):
client = weave.init(project_name=project_id)
def setup_test_data(project_id, client: WeaveClient):
# client.project = project_id

print("Setting up test data", project_id)

@weave.op()
def create_test_data(pid: str):
Expand All @@ -34,7 +40,7 @@ def create_test_data(pid: str):
project_id=project_id,
wb_user_id="test-user",
costs={
"gpt-4": tsi.CostSchema(
"gpt-4": tsi.LLMCostSchema(
prompt_token_cost=0.01,
completion_token_cost=0.02,
prompt_token_cost_unit="USD/1K tokens",
Expand All @@ -45,32 +51,24 @@ def create_test_data(pid: str):
client.server.cost_create(cost_req)


def test_permanently_delete_project_deletes_all_data(project_id, setup_test_data):
client = weave.init(project_name=project_id)
def test_permanently_delete_project_deletes_all_data(
project_id, setup_test_data, client: WeaveClient
):
client.project = project_id
calls = client.server.calls_query_stream(tsi.CallsQueryReq(project_id=project_id))

print("calls", list(calls))

# Verify data exists before deletion
assert (
len(
list(
client.server.calls_query_stream(
tsi.CallsQueryReq(project_id=project_id)
)
)
)
== 4
)
assert len(list(calls)) == 4

assert (
len(client.server.objs_query(tsi.ObjQueryReq(project_id=project_id)).objs) == 2
)
obj_query = client.server.objs_query(tsi.ObjQueryReq(project_id=project_id))
assert len(obj_query.objs) == 2

assert (
len(
client.server.table_query(
tsi.TableQueryReq(project_id=project_id, digest="latest")
).rows
)
== 1
table_query = client.server.table_query(
tsi.TableQueryReq(project_id=project_id, digest="latest")
)
assert len(table_query.rows) == 1

feedback_query = tsi.FeedbackQueryReq(
project_id=project_id, fields=["id", "feedback_type"]
Expand All @@ -86,35 +84,28 @@ def test_permanently_delete_project_deletes_all_data(project_id, setup_test_data
)

# Verify all data is deleted
assert (
len(
list(
client.server.calls_query_stream(
tsi.CallsQueryReq(project_id=project_id)
)
)
)
== 0
)
calls1 = client.server.calls_query_stream(tsi.CallsQueryReq(project_id=project_id))
assert len(list(calls1)) == 0

assert (
len(client.server.objs_query(tsi.ObjQueryReq(project_id=project_id)).objs) == 0
)
obj_query = client.server.objs_query(tsi.ObjQueryReq(project_id=project_id))
assert len(obj_query.objs) == 0

assert (
len(
client.server.table_query(
tsi.TableQueryReq(project_id=project_id, digest="latest")
).rows
)
== 0
table_query = client.server.table_query(
tsi.TableQueryReq(project_id=project_id, digest="latest")
)
assert len(table_query.rows) == 0

feedback_query = tsi.FeedbackQueryReq(
project_id=project_id, fields=["id", "feedback_type"]
)
assert len(client.server.feedback_query(feedback_query).result) == 0

cost_query = tsi.CostQueryReq(project_id=project_id, fields=["id", "llm_id"])
assert len(client.server.cost_query(cost_query).results) == 0

def test_permanently_delete_project_with_nonexistent_project():
client = weave.init("exists")

def test_permanently_delete_project_with_nonexistent_project(client: WeaveClient):
client.project = "exists"
# Should not raise an error when deleting non-existent project
nonexistent_project_id = "nonexistent"
client.server.permanently_delete_project(
Expand All @@ -123,18 +114,20 @@ def test_permanently_delete_project_with_nonexistent_project():


def test_permanently_delete_project_does_not_affect_other_projects(
project_id, setup_test_data
project_id, setup_test_data, client: WeaveClient
):
client = weave.init("other-project")
client.project = "other-project"
# Create another project with data
other_project_id = "other-project"
other_call_start = tsi.StartedCallSchemaForInsert(
id="idddddddddddd",
started_at=datetime.now(),
project_id=other_project_id,
op_name="test_op",
inputs={},
attributes={},
started_at=None,
wb_user_id="test-user",
trace_id="test-trace-id",
)
client.server.call_start(tsi.CallStartReq(start=other_call_start))

Expand All @@ -150,8 +143,10 @@ def test_permanently_delete_project_does_not_affect_other_projects(
assert len(other_project_calls) > 0


def test_permanently_delete_project_idempotency(project_id, setup_test_data):
client = weave.init(project_name=project_id)
def test_permanently_delete_project_idempotency(
project_id, setup_test_data, client: WeaveClient
):
client.project = project_id
# Delete project twice
client.server.permanently_delete_project(
tsi.PermanentlyDeleteProjectReq(project_id=project_id)
Expand All @@ -171,19 +166,3 @@ def test_permanently_delete_project_idempotency(project_id, setup_test_data):
)
== 0
)


@pytest.mark.parametrize(
"invalid_project_id",
[
"", # Empty string
None, # None
" ", # Whitespace
],
)
def test_permanently_delete_project_with_invalid_project_id(invalid_project_id):
client = weave.init(project_name="test-project-123")
with pytest.raises(Exception):
client.server.permanently_delete_project(
tsi.PermanentlyDeleteProjectReq(project_id=invalid_project_id)
)
7 changes: 5 additions & 2 deletions weave/trace_server/clickhouse_trace_server_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -1504,7 +1504,7 @@ def completions_create(
)

def _make_purge_query(self, project_id: str, table_name: str) -> str:
return f"DELETE FROM {table_name} WHERE project_id = '{project_id}'"
return f"DELETE FROM {table_name} WHERE project_id = {{{project_id: String}}}"

def permanently_delete_project(
self, req: tsi.PermanentlyDeleteProjectReq
Expand All @@ -1518,6 +1518,8 @@ def permanently_delete_project(
5. Delete all feedback data
6. Delete all cost data (?)
"""
if not req.project_id.strip():
raise InvalidRequest("Project ID is required")
tables_to_purge = [
"call_parts",
"object_versions",
Expand All @@ -1530,7 +1532,8 @@ def permanently_delete_project(

for table in tables_to_purge:
query = self._make_purge_query(req.project_id, table)
self.ch_client.query(query)
parameters = {"project_id": req.project_id}
self.ch_client.query(query, parameters)

return tsi.PermanentlyDeleteProjectRes()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -376,3 +376,11 @@ def completions_create(
req.project_id = self._idc.ext_to_int_project_id(req.project_id)
res = self._ref_apply(self._internal_trace_server.completions_create, req)
return res

def permanently_delete_project(
self, req: tsi.PermanentlyDeleteProjectReq
) -> tsi.PermanentlyDeleteProjectRes:
req.project_id = self._idc.ext_to_int_project_id(req.project_id)
return self._ref_apply(
self._internal_trace_server.permanently_delete_project, req
)
25 changes: 12 additions & 13 deletions weave/trace_server/sqlite_trace_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,36 +1110,35 @@ def file_content_read(self, req: tsi.FileContentReadReq) -> tsi.FileContentReadR

def cost_create(self, req: tsi.CostCreateReq) -> tsi.CostCreateRes:
print("COST CREATE is not implemented for local sqlite", req)
return tsi.CostCreateRes()
return tsi.CostCreateRes(ids=[])

def cost_query(self, req: tsi.CostQueryReq) -> tsi.CostQueryRes:
print("COST QUERY is not implemented for local sqlite", req)
return tsi.CostQueryRes()
return tsi.CostQueryRes(result=[])

def cost_purge(self, req: tsi.CostPurgeReq) -> tsi.CostPurgeRes:
print("COST PURGE is not implemented for local sqlite", req)
return tsi.CostPurgeRes()
return tsi.CostPurgeRes(ids=[])

def completions_create(
self, req: tsi.CompletionsCreateReq
) -> tsi.CompletionsCreateRes:
print("COMPLETIONS CREATE is not implemented for local sqlite", req)
return tsi.CompletionsCreateRes()
return tsi.CompletionsCreateRes(response={})

def permanently_delete_project(
self, req: tsi.PermanentlyDeleteProjectReq
) -> tsi.PermanentlyDeleteProjectRes:
conn, cursor = get_conn_cursor(self.db_path)
tables_to_purge = [
"calls",
"objects",
"tables",
"table_rows",
"files",
"feedback",
]
with self.lock:
tables_to_purge = [
"call_parts",
"object_versions",
"tables",
"table_rows",
"files",
"feedback",
"cost",
]
conn.execute("BEGIN TRANSACTION")
for table in tables_to_purge:
cursor.execute(
Expand Down

0 comments on commit 3de8721

Please sign in to comment.