Skip to content

Commit

Permalink
Merge pull request #1000 from dimagi/cs/validate_imported_instructions
Browse files Browse the repository at this point in the history
Check imported assistant instructions for illegal variables
  • Loading branch information
SmittieC authored Dec 18, 2024
2 parents 9e601ac + 92dc5d4 commit f4a633e
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 1 deletion.
2 changes: 1 addition & 1 deletion apps/assistants/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def clean(self):
validate_prompt_variables(
form_data=cleaned_data,
prompt_key="instructions",
known_vars={"participant_data", "current_datetime"},
known_vars=OpenAiAssistant.ALLOWED_INSTRUCTIONS_VARIABLES,
)
return cleaned_data

Expand Down
2 changes: 2 additions & 0 deletions apps/assistants/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class OpenAiAssistantManager(VersionsObjectManagerMixin, AuditingManager):
audit_special_queryset_writes=True,
)
class OpenAiAssistant(BaseTeamModel, VersionsMixin):
ALLOWED_INSTRUCTIONS_VARIABLES = {"participant_data", "current_datetime"}

assistant_id = models.CharField(max_length=255)
name = models.CharField(max_length=255)
instructions = models.TextField()
Expand Down
13 changes: 13 additions & 0 deletions apps/assistants/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@

import openai
from django.db.models import Count, Subquery
from django.forms import ValidationError
from langchain_core.utils.function_calling import convert_to_openai_tool as lc_convert_to_openai_tool
from openai import OpenAI
from openai.types.beta import Assistant
Expand All @@ -72,6 +73,7 @@
from apps.files.models import File
from apps.service_providers.models import LlmProvider, LlmProviderModel, LlmProviderTypes
from apps.teams.models import Team
from apps.utils.prompt import validate_prompt_variables

logger = logging.getLogger("openai_sync")

Expand All @@ -94,6 +96,8 @@ def _inner(*args, **kwargs):
pass

raise OpenAiSyncError(message) from e
except ValidationError as e:
raise OpenAiSyncError(str(e)) from e

return _inner

Expand Down Expand Up @@ -167,11 +171,20 @@ def import_openai_assistant(assistant_id: str, llm_provider: LlmProvider, team:
client = llm_provider.get_llm_service().get_raw_client()
openai_assistant = client.beta.assistants.retrieve(assistant_id)
kwargs = _openai_assistant_to_ocs_kwargs(openai_assistant, team=team, llm_provider=llm_provider)
validate_instructions(kwargs["instructions"])
assistant = OpenAiAssistant.objects.create(**kwargs)
_sync_tool_resources_from_openai(openai_assistant, assistant)
return assistant


def validate_instructions(instructions: str):
validate_prompt_variables(
form_data={"instructions": instructions},
prompt_key="instructions",
known_vars=OpenAiAssistant.ALLOWED_INSTRUCTIONS_VARIABLES,
)


@wrap_openai_errors
def delete_openai_assistant(assistant: OpenAiAssistant):
"""Deletes the assistant from OpenAI and removes all associated files.
Expand Down
14 changes: 14 additions & 0 deletions apps/assistants/tests/test_sync.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
import re
from io import BytesIO
from unittest.mock import call, patch

Expand All @@ -7,6 +8,7 @@

from apps.assistants.models import ToolResources
from apps.assistants.sync import (
OpenAiSyncError,
_update_or_create_vector_store,
delete_openai_assistant,
get_out_of_sync_files,
Expand Down Expand Up @@ -218,6 +220,18 @@ def test_import_openai_assistant(_, mock_file_retrieve, mock_vector_store_files,
]


@pytest.mark.django_db()
@patch("openai.resources.beta.Assistants.retrieve")
def test_import_openai_assistant_raises_for_invalid_instructions(mock_retrieve):
remote_assistant = AssistantFactory(instructions="This is a test with a {invalid_variable}")
mock_retrieve.return_value = remote_assistant
llm_provider = LlmProviderFactory()

expected_error_msg = "{'instructions': ['Prompt contains unknown variables: invalid_variable']}"
with pytest.raises(OpenAiSyncError, match=re.escape(expected_error_msg)):
import_openai_assistant("123", llm_provider, llm_provider.team)


@pytest.mark.django_db()
@patch("openai.resources.beta.Assistants.delete")
@patch("openai.resources.beta.vector_stores.VectorStores.delete")
Expand Down

0 comments on commit f4a633e

Please sign in to comment.