Skip to content

Commit

Permalink
ensure file fields are cleared after delete
Browse files Browse the repository at this point in the history
  • Loading branch information
snopoke committed Dec 13, 2024
1 parent 8bcbc27 commit 32e4c79
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 8 deletions.
26 changes: 20 additions & 6 deletions apps/assistants/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,15 @@ def convert_to_openai_tool(tool):
@wrap_openai_errors
def delete_file_from_openai(client: OpenAI, file: File):
if not file.external_id or file.external_source != "openai":
return
return False

try:
client.files.delete(file.external_id)
except openai.NotFoundError:
pass
file.external_id = ""
file.external_source = ""
return True


@wrap_openai_errors
Expand Down Expand Up @@ -183,11 +184,23 @@ def delete_openai_assistant(assistant: OpenAiAssistant):
for resource in tool_resources:
if resource.tool_type == "file_search" and "vector_store_id" in resource.extra:
vector_store_id = resource.extra.pop("vector_store_id")
client.beta.vector_stores.delete(vector_store_id=vector_store_id)
try:
client.beta.vector_stores.delete(vector_store_id=vector_store_id)
except openai.NotFoundError:
pass

delete_openai_files_for_resource(client, assistant.team, resource)


def delete_openai_files_for_resource(client, team, resource: ToolResources):
files_to_delete = _get_files_to_delete(team, resource.id)
files_to_update = []
for file in files_to_delete:
if delete_file_from_openai(client, file):
files_to_update.append(file.id)

files_to_delete = _get_files_to_delete(assistant.team, resource.id)
for file in files_to_delete:
delete_file_from_openai(client, file)
if files_to_update:
File.objects.filter(id__in=files_to_update).update(external_id="", external_source="")


def _get_files_to_delete(team, tool_resource_id):
Expand All @@ -200,7 +213,8 @@ def _get_files_to_delete(team, tool_resource_id):
.values("file_id")
)

return File.objects.filter(toolresources=tool_resource_id, id__in=Subquery(files_with_single_reference)).iterator()
subquery = Subquery(files_with_single_reference)
return File.objects.filter(toolresources=tool_resource_id, id__in=subquery).iterator()


def is_tool_configured_remotely_but_missing_locally(assistant_data, local_tool_types, tool_name: str) -> bool:
Expand Down
23 changes: 22 additions & 1 deletion apps/assistants/tests/test_delete.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import uuid
from unittest.mock import Mock

import pytest

from apps.assistants.models import ToolResources
from apps.assistants.sync import _get_files_to_delete
from apps.assistants.sync import _get_files_to_delete, delete_openai_files_for_resource
from apps.utils.factories.assistants import OpenAiAssistantFactory
from apps.utils.factories.files import FileFactory

Expand All @@ -14,6 +17,10 @@ def assistant():
@pytest.fixture()
def code_resource(assistant):
files = FileFactory.create_batch(2, team=assistant.team)
for f in files:
f.external_id = str(uuid.uuid4())
f.external_source = "openai"
f.save()

tool_resource = ToolResources.objects.create(tool_type="code_interpreter", assistant=assistant)
tool_resource.files.set(files)
Expand All @@ -40,3 +47,17 @@ def test_files_not_to_delete_when_referenced_by_multiple_resources(code_resource

files_to_delete = list(_get_files_to_delete(tool_resource.assistant.team, tool_resource.id))
assert len(files_to_delete) == 0


@pytest.mark.django_db()
def test_delete_openai_files_for_resource(code_resource):
all_files = list(code_resource.files.all())
assert all(f.external_id for f in all_files)
assert all(f.external_source for f in all_files)
client = Mock()
delete_openai_files_for_resource(client, code_resource.assistant.team, code_resource)

assert client.files.delete.call_count == 2
all_files = list(code_resource.files.all())
assert not any(f.external_id for f in all_files)
assert not any(f.external_source for f in all_files)
3 changes: 2 additions & 1 deletion apps/assistants/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,5 +277,6 @@ def get_success_response(self, file):
)

client = resource.assistant.llm_provider.get_llm_service().get_raw_client()
delete_file_from_openai(client, file)
if delete_file_from_openai(client, file):
file.save()
return HttpResponse()

0 comments on commit 32e4c79

Please sign in to comment.