Skip to content

Commit

Permalink
fixtesttttttttt
Browse files Browse the repository at this point in the history
  • Loading branch information
gtarpenning committed Dec 17, 2024
1 parent 3de8721 commit f460527
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 40 deletions.
57 changes: 23 additions & 34 deletions tests/trace/test_permanently_delete_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import weave
from weave.trace.weave_client import WeaveClient
from weave.trace_server import trace_server_interface as tsi
from weave.trace_server.ids import generate_id


@pytest.fixture
Expand All @@ -14,7 +15,7 @@ def project_id():

@pytest.fixture
def setup_test_data(project_id, client: WeaveClient):
# client.project = project_id
client.project = project_id

print("Setting up test data", project_id)

Expand All @@ -28,45 +29,37 @@ def create_test_data(pid: str):
create_test_data(project_id)

obj_dataset = weave.Dataset(
name=f"{project_id}/test-obj", rows=[{"id": "test-obj"}]
name=f"{project_id}-test-obj", rows=[{"id": "test-obj"}]
)
weave.publish(obj_dataset)

call1 = create_test_data.calls()[0]
call1.feedback.add_reaction("👍")

# Create test cost
cost_req = tsi.CostCreateReq(
project_id=project_id,
wb_user_id="test-user",
costs={
"gpt-4": tsi.LLMCostSchema(
prompt_token_cost=0.01,
completion_token_cost=0.02,
prompt_token_cost_unit="USD/1K tokens",
completion_token_cost_unit="USD/1K tokens",
)
},
)
client.server.cost_create(cost_req)


def test_permanently_delete_project_deletes_all_data(
project_id, setup_test_data, client: WeaveClient
setup_test_data, client: WeaveClient
):
client.project = project_id
calls = client.server.calls_query_stream(tsi.CallsQueryReq(project_id=project_id))

print("calls", list(calls))
project_id = client._project_id()
calls = client.server.calls_query_stream(
tsi.CallsQueryReq(project_id=project_id, limit=1000, offset=0)
)

# Verify data exists before deletion
assert len(list(calls)) == 4

obj_query = client.server.objs_query(tsi.ObjQueryReq(project_id=project_id))
assert len(obj_query.objs) == 2

table_ref = ""
for obj in obj_query.objs:
if obj.base_object_class == "Dataset":
uri = obj.val["rows"]
table_ref = uri.split("/")[-1]
break

table_query = client.server.table_query(
tsi.TableQueryReq(project_id=project_id, digest="latest")
tsi.TableQueryReq(project_id=project_id, digest=table_ref)
)
assert len(table_query.rows) == 1

Expand All @@ -75,9 +68,6 @@ def test_permanently_delete_project_deletes_all_data(
)
assert len(client.server.feedback_query(feedback_query).result) == 1

cost_query = tsi.CostQueryReq(project_id=project_id, fields=["id", "llm_id"])
assert len(client.server.cost_query(cost_query).results) == 1

# Execute permanent deletion
client.server.permanently_delete_project(
tsi.PermanentlyDeleteProjectReq(project_id=project_id)
Expand All @@ -100,9 +90,6 @@ def test_permanently_delete_project_deletes_all_data(
)
assert len(client.server.feedback_query(feedback_query).result) == 0

cost_query = tsi.CostQueryReq(project_id=project_id, fields=["id", "llm_id"])
assert len(client.server.cost_query(cost_query).results) == 0


def test_permanently_delete_project_with_nonexistent_project(client: WeaveClient):
client.project = "exists"
Expand All @@ -116,18 +103,20 @@ def test_permanently_delete_project_with_nonexistent_project(client: WeaveClient
def test_permanently_delete_project_does_not_affect_other_projects(
project_id, setup_test_data, client: WeaveClient
):
client.project = "other-project"
project_id = client._project_id()
# Create another project with data
other_project_id = "other-project"
other_project_id = "shawn/other-project"
call_id = generate_id()
trace_id = generate_id()
other_call_start = tsi.StartedCallSchemaForInsert(
id="idddddddddddd",
id=call_id,
started_at=datetime.now(),
project_id=other_project_id,
op_name="test_op",
inputs={},
attributes={},
wb_user_id="test-user",
trace_id="test-trace-id",
trace_id=trace_id,
)
client.server.call_start(tsi.CallStartReq(start=other_call_start))

Expand All @@ -146,7 +135,7 @@ def test_permanently_delete_project_does_not_affect_other_projects(
def test_permanently_delete_project_idempotency(
project_id, setup_test_data, client: WeaveClient
):
client.project = project_id
project_id = client._project_id()
# Delete project twice
client.server.permanently_delete_project(
tsi.PermanentlyDeleteProjectReq(project_id=project_id)
Expand Down
11 changes: 5 additions & 6 deletions weave/trace_server/clickhouse_trace_server_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -1503,8 +1503,8 @@ def completions_create(
response=res.response, weave_call_id=start_call.id
)

def _make_purge_query(self, project_id: str, table_name: str) -> str:
return f"DELETE FROM {table_name} WHERE project_id = {{{project_id: String}}}"
def _make_purge_query(self, table_name: str) -> str:
return f"DELETE FROM {table_name} WHERE project_id = {{project_id: String}}"

def permanently_delete_project(
self, req: tsi.PermanentlyDeleteProjectReq
Expand All @@ -1516,22 +1516,21 @@ def permanently_delete_project(
3. Delete all table/table_row data
4. Delete all file data
5. Delete all feedback data
6. Delete all cost data (?)
"""
if not req.project_id.strip():
raise InvalidRequest("Project ID is required")

tables_to_purge = [
"call_parts",
"calls_merged",
"object_versions",
"tables",
"table_rows",
"files",
"feedback",
"cost",
]

for table in tables_to_purge:
query = self._make_purge_query(req.project_id, table)
query = self._make_purge_query(table)
parameters = {"project_id": req.project_id}
self.ch_client.query(query, parameters)

Expand Down

0 comments on commit f460527

Please sign in to comment.