From 3de87215e9340ff39b8d75bc0983f9ae3787e139 Mon Sep 17 00:00:00 2001 From: gtarpenning Date: Mon, 16 Dec 2024 09:27:36 -0800 Subject: [PATCH] better, stil not passing --- .../trace/test_permanently_delete_project.py | 115 +++++++----------- .../clickhouse_trace_server_batched.py | 7 +- ...ternal_to_internal_trace_server_adapter.py | 8 ++ weave/trace_server/sqlite_trace_server.py | 25 ++-- 4 files changed, 72 insertions(+), 83 deletions(-) diff --git a/tests/trace/test_permanently_delete_project.py b/tests/trace/test_permanently_delete_project.py index a216eda4069..067c5cd3492 100644 --- a/tests/trace/test_permanently_delete_project.py +++ b/tests/trace/test_permanently_delete_project.py @@ -1,5 +1,9 @@ +from datetime import datetime + import pytest + import weave +from weave.trace.weave_client import WeaveClient from weave.trace_server import trace_server_interface as tsi @@ -9,8 +13,10 @@ def project_id(): @pytest.fixture -def setup_test_data(project_id): - client = weave.init(project_name=project_id) +def setup_test_data(project_id, client: WeaveClient): + # client.project = project_id + + print("Setting up test data", project_id) @weave.op() def create_test_data(pid: str): @@ -34,7 +40,7 @@ def create_test_data(pid: str): project_id=project_id, wb_user_id="test-user", costs={ - "gpt-4": tsi.CostSchema( + "gpt-4": tsi.LLMCostSchema( prompt_token_cost=0.01, completion_token_cost=0.02, prompt_token_cost_unit="USD/1K tokens", @@ -45,32 +51,24 @@ def create_test_data(pid: str): client.server.cost_create(cost_req) -def test_permanently_delete_project_deletes_all_data(project_id, setup_test_data): - client = weave.init(project_name=project_id) +def test_permanently_delete_project_deletes_all_data( + project_id, setup_test_data, client: WeaveClient +): + client.project = project_id + calls = client.server.calls_query_stream(tsi.CallsQueryReq(project_id=project_id)) + + print("calls", list(calls)) + # Verify data exists before deletion - assert ( - len( - list( - client.server.calls_query_stream( - tsi.CallsQueryReq(project_id=project_id) - ) - ) - ) - == 4 - ) + assert len(list(calls)) == 4 - assert ( - len(client.server.objs_query(tsi.ObjQueryReq(project_id=project_id)).objs) == 2 - ) + obj_query = client.server.objs_query(tsi.ObjQueryReq(project_id=project_id)) + assert len(obj_query.objs) == 2 - assert ( - len( - client.server.table_query( - tsi.TableQueryReq(project_id=project_id, digest="latest") - ).rows - ) - == 1 + table_query = client.server.table_query( + tsi.TableQueryReq(project_id=project_id, digest="latest") ) + assert len(table_query.rows) == 1 feedback_query = tsi.FeedbackQueryReq( project_id=project_id, fields=["id", "feedback_type"] @@ -86,35 +84,28 @@ def test_permanently_delete_project_deletes_all_data(project_id, setup_test_data ) # Verify all data is deleted - assert ( - len( - list( - client.server.calls_query_stream( - tsi.CallsQueryReq(project_id=project_id) - ) - ) - ) - == 0 - ) + calls1 = client.server.calls_query_stream(tsi.CallsQueryReq(project_id=project_id)) + assert len(list(calls1)) == 0 - assert ( - len(client.server.objs_query(tsi.ObjQueryReq(project_id=project_id)).objs) == 0 - ) + obj_query = client.server.objs_query(tsi.ObjQueryReq(project_id=project_id)) + assert len(obj_query.objs) == 0 - assert ( - len( - client.server.table_query( - tsi.TableQueryReq(project_id=project_id, digest="latest") - ).rows - ) - == 0 + table_query = client.server.table_query( + tsi.TableQueryReq(project_id=project_id, digest="latest") ) + assert len(table_query.rows) == 0 + feedback_query = tsi.FeedbackQueryReq( + project_id=project_id, fields=["id", "feedback_type"] + ) assert len(client.server.feedback_query(feedback_query).result) == 0 + cost_query = tsi.CostQueryReq(project_id=project_id, fields=["id", "llm_id"]) + assert len(client.server.cost_query(cost_query).results) == 0 -def test_permanently_delete_project_with_nonexistent_project(): - client = weave.init("exists") + +def test_permanently_delete_project_with_nonexistent_project(client: WeaveClient): + client.project = "exists" # Should not raise an error when deleting non-existent project nonexistent_project_id = "nonexistent" client.server.permanently_delete_project( @@ -123,18 +114,20 @@ def test_permanently_delete_project_with_nonexistent_project(): def test_permanently_delete_project_does_not_affect_other_projects( - project_id, setup_test_data + project_id, setup_test_data, client: WeaveClient ): - client = weave.init("other-project") + client.project = "other-project" # Create another project with data other_project_id = "other-project" other_call_start = tsi.StartedCallSchemaForInsert( + id="idddddddddddd", + started_at=datetime.now(), project_id=other_project_id, op_name="test_op", inputs={}, attributes={}, - started_at=None, wb_user_id="test-user", + trace_id="test-trace-id", ) client.server.call_start(tsi.CallStartReq(start=other_call_start)) @@ -150,8 +143,10 @@ def test_permanently_delete_project_does_not_affect_other_projects( assert len(other_project_calls) > 0 -def test_permanently_delete_project_idempotency(project_id, setup_test_data): - client = weave.init(project_name=project_id) +def test_permanently_delete_project_idempotency( + project_id, setup_test_data, client: WeaveClient +): + client.project = project_id # Delete project twice client.server.permanently_delete_project( tsi.PermanentlyDeleteProjectReq(project_id=project_id) @@ -171,19 +166,3 @@ def test_permanently_delete_project_idempotency(project_id, setup_test_data): ) == 0 ) - - -@pytest.mark.parametrize( - "invalid_project_id", - [ - "", # Empty string - None, # None - " ", # Whitespace - ], -) -def test_permanently_delete_project_with_invalid_project_id(invalid_project_id): - client = weave.init(project_name="test-project-123") - with pytest.raises(Exception): - client.server.permanently_delete_project( - tsi.PermanentlyDeleteProjectReq(project_id=invalid_project_id) - ) diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index a285aead364..0c1acc7fe43 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -1504,7 +1504,7 @@ def completions_create( ) def _make_purge_query(self, project_id: str, table_name: str) -> str: - return f"DELETE FROM {table_name} WHERE project_id = '{project_id}'" + return f"DELETE FROM {table_name} WHERE project_id = {{{project_id: String}}}" def permanently_delete_project( self, req: tsi.PermanentlyDeleteProjectReq @@ -1518,6 +1518,8 @@ def permanently_delete_project( 5. Delete all feedback data 6. Delete all cost data (?) """ + if not req.project_id.strip(): + raise InvalidRequest("Project ID is required") tables_to_purge = [ "call_parts", "object_versions", @@ -1530,7 +1532,8 @@ def permanently_delete_project( for table in tables_to_purge: query = self._make_purge_query(req.project_id, table) - self.ch_client.query(query) + parameters = {"project_id": req.project_id} + self.ch_client.query(query, parameters) return tsi.PermanentlyDeleteProjectRes() diff --git a/weave/trace_server/external_to_internal_trace_server_adapter.py b/weave/trace_server/external_to_internal_trace_server_adapter.py index 1df739adbcd..e7734ec1307 100644 --- a/weave/trace_server/external_to_internal_trace_server_adapter.py +++ b/weave/trace_server/external_to_internal_trace_server_adapter.py @@ -376,3 +376,11 @@ def completions_create( req.project_id = self._idc.ext_to_int_project_id(req.project_id) res = self._ref_apply(self._internal_trace_server.completions_create, req) return res + + def permanently_delete_project( + self, req: tsi.PermanentlyDeleteProjectReq + ) -> tsi.PermanentlyDeleteProjectRes: + req.project_id = self._idc.ext_to_int_project_id(req.project_id) + return self._ref_apply( + self._internal_trace_server.permanently_delete_project, req + ) diff --git a/weave/trace_server/sqlite_trace_server.py b/weave/trace_server/sqlite_trace_server.py index a7bdbac52ef..6d4186f7428 100644 --- a/weave/trace_server/sqlite_trace_server.py +++ b/weave/trace_server/sqlite_trace_server.py @@ -1110,36 +1110,35 @@ def file_content_read(self, req: tsi.FileContentReadReq) -> tsi.FileContentReadR def cost_create(self, req: tsi.CostCreateReq) -> tsi.CostCreateRes: print("COST CREATE is not implemented for local sqlite", req) - return tsi.CostCreateRes() + return tsi.CostCreateRes(ids=[]) def cost_query(self, req: tsi.CostQueryReq) -> tsi.CostQueryRes: print("COST QUERY is not implemented for local sqlite", req) - return tsi.CostQueryRes() + return tsi.CostQueryRes(result=[]) def cost_purge(self, req: tsi.CostPurgeReq) -> tsi.CostPurgeRes: print("COST PURGE is not implemented for local sqlite", req) - return tsi.CostPurgeRes() + return tsi.CostPurgeRes(ids=[]) def completions_create( self, req: tsi.CompletionsCreateReq ) -> tsi.CompletionsCreateRes: print("COMPLETIONS CREATE is not implemented for local sqlite", req) - return tsi.CompletionsCreateRes() + return tsi.CompletionsCreateRes(response={}) def permanently_delete_project( self, req: tsi.PermanentlyDeleteProjectReq ) -> tsi.PermanentlyDeleteProjectRes: conn, cursor = get_conn_cursor(self.db_path) + tables_to_purge = [ + "calls", + "objects", + "tables", + "table_rows", + "files", + "feedback", + ] with self.lock: - tables_to_purge = [ - "call_parts", - "object_versions", - "tables", - "table_rows", - "files", - "feedback", - "cost", - ] conn.execute("BEGIN TRANSACTION") for table in tables_to_purge: cursor.execute(