From f460527376f183a6dcb768a90d25f432c2022969 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Mon, 16 Dec 2024 16:58:09 -0800 Subject: [PATCH] fixtesttttttttt --- .../trace/test_permanently_delete_project.py | 57 ++++++++----------- .../clickhouse_trace_server_batched.py | 11 ++-- 2 files changed, 28 insertions(+), 40 deletions(-) diff --git a/tests/trace/test_permanently_delete_project.py b/tests/trace/test_permanently_delete_project.py index 067c5cd3492..9624c8dfe3a 100644 --- a/tests/trace/test_permanently_delete_project.py +++ b/tests/trace/test_permanently_delete_project.py @@ -5,6 +5,7 @@ import weave from weave.trace.weave_client import WeaveClient from weave.trace_server import trace_server_interface as tsi +from weave.trace_server.ids import generate_id @pytest.fixture @@ -14,7 +15,7 @@ def project_id(): @pytest.fixture def setup_test_data(project_id, client: WeaveClient): - # client.project = project_id + client.project = project_id print("Setting up test data", project_id) @@ -28,36 +29,21 @@ def create_test_data(pid: str): create_test_data(project_id) obj_dataset = weave.Dataset( - name=f"{project_id}/test-obj", rows=[{"id": "test-obj"}] + name=f"{project_id}-test-obj", rows=[{"id": "test-obj"}] ) weave.publish(obj_dataset) call1 = create_test_data.calls()[0] call1.feedback.add_reaction("👍") - # Create test cost - cost_req = tsi.CostCreateReq( - project_id=project_id, - wb_user_id="test-user", - costs={ - "gpt-4": tsi.LLMCostSchema( - prompt_token_cost=0.01, - completion_token_cost=0.02, - prompt_token_cost_unit="USD/1K tokens", - completion_token_cost_unit="USD/1K tokens", - ) - }, - ) - client.server.cost_create(cost_req) - def test_permanently_delete_project_deletes_all_data( - project_id, setup_test_data, client: WeaveClient + setup_test_data, client: WeaveClient ): - client.project = project_id - calls = client.server.calls_query_stream(tsi.CallsQueryReq(project_id=project_id)) - - print("calls", list(calls)) + project_id = client._project_id() + calls = client.server.calls_query_stream( + tsi.CallsQueryReq(project_id=project_id, limit=1000, offset=0) + ) # Verify data exists before deletion assert len(list(calls)) == 4 @@ -65,8 +51,15 @@ def test_permanently_delete_project_deletes_all_data( obj_query = client.server.objs_query(tsi.ObjQueryReq(project_id=project_id)) assert len(obj_query.objs) == 2 + table_ref = "" + for obj in obj_query.objs: + if obj.base_object_class == "Dataset": + uri = obj.val["rows"] + table_ref = uri.split("/")[-1] + break + table_query = client.server.table_query( - tsi.TableQueryReq(project_id=project_id, digest="latest") + tsi.TableQueryReq(project_id=project_id, digest=table_ref) ) assert len(table_query.rows) == 1 @@ -75,9 +68,6 @@ def test_permanently_delete_project_deletes_all_data( ) assert len(client.server.feedback_query(feedback_query).result) == 1 - cost_query = tsi.CostQueryReq(project_id=project_id, fields=["id", "llm_id"]) - assert len(client.server.cost_query(cost_query).results) == 1 - # Execute permanent deletion client.server.permanently_delete_project( tsi.PermanentlyDeleteProjectReq(project_id=project_id) @@ -100,9 +90,6 @@ def test_permanently_delete_project_deletes_all_data( ) assert len(client.server.feedback_query(feedback_query).result) == 0 - cost_query = tsi.CostQueryReq(project_id=project_id, fields=["id", "llm_id"]) - assert len(client.server.cost_query(cost_query).results) == 0 - def test_permanently_delete_project_with_nonexistent_project(client: WeaveClient): client.project = "exists" @@ -116,18 +103,20 @@ def test_permanently_delete_project_with_nonexistent_project(client: WeaveClient def test_permanently_delete_project_does_not_affect_other_projects( project_id, setup_test_data, client: WeaveClient ): - client.project = "other-project" + project_id = client._project_id() # Create another project with data - other_project_id = "other-project" + other_project_id = "shawn/other-project" + call_id = generate_id() + trace_id = generate_id() other_call_start = tsi.StartedCallSchemaForInsert( - id="idddddddddddd", + id=call_id, started_at=datetime.now(), project_id=other_project_id, op_name="test_op", inputs={}, attributes={}, wb_user_id="test-user", - trace_id="test-trace-id", + trace_id=trace_id, ) client.server.call_start(tsi.CallStartReq(start=other_call_start)) @@ -146,7 +135,7 @@ def test_permanently_delete_project_does_not_affect_other_projects( def test_permanently_delete_project_idempotency( project_id, setup_test_data, client: WeaveClient ): - client.project = project_id + project_id = client._project_id() # Delete project twice client.server.permanently_delete_project( tsi.PermanentlyDeleteProjectReq(project_id=project_id) diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index 0c1acc7fe43..cb0ae457c1c 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -1503,8 +1503,8 @@ def completions_create( response=res.response, weave_call_id=start_call.id ) - def _make_purge_query(self, project_id: str, table_name: str) -> str: - return f"DELETE FROM {table_name} WHERE project_id = {{{project_id: String}}}" + def _make_purge_query(self, table_name: str) -> str: + return f"DELETE FROM {table_name} WHERE project_id = {{project_id: String}}" def permanently_delete_project( self, req: tsi.PermanentlyDeleteProjectReq @@ -1516,22 +1516,21 @@ def permanently_delete_project( 3. Delete all table/table_row data 4. Delete all file data 5. Delete all feedback data - 6. Delete all cost data (?) """ if not req.project_id.strip(): raise InvalidRequest("Project ID is required") + tables_to_purge = [ "call_parts", + "calls_merged", "object_versions", "tables", "table_rows", "files", "feedback", - "cost", ] - for table in tables_to_purge: - query = self._make_purge_query(req.project_id, table) + query = self._make_purge_query(table) parameters = {"project_id": req.project_id} self.ch_client.query(query, parameters)