diff --git a/apps/assistants/forms.py b/apps/assistants/forms.py index 57e35781d..4c2cde281 100644 --- a/apps/assistants/forms.py +++ b/apps/assistants/forms.py @@ -74,7 +74,7 @@ def clean(self): validate_prompt_variables( form_data=cleaned_data, prompt_key="instructions", - known_vars={"participant_data", "current_datetime"}, + known_vars=OpenAiAssistant.ALLOWED_INSTRUCTIONS_VARIABLES, ) return cleaned_data diff --git a/apps/assistants/models.py b/apps/assistants/models.py index f167ed66c..391b61bc1 100644 --- a/apps/assistants/models.py +++ b/apps/assistants/models.py @@ -31,6 +31,8 @@ class OpenAiAssistantManager(VersionsObjectManagerMixin, AuditingManager): audit_special_queryset_writes=True, ) class OpenAiAssistant(BaseTeamModel, VersionsMixin): + ALLOWED_INSTRUCTIONS_VARIABLES = {"participant_data", "current_datetime"} + assistant_id = models.CharField(max_length=255) name = models.CharField(max_length=255) instructions = models.TextField() diff --git a/apps/assistants/sync.py b/apps/assistants/sync.py index f7d2d9f51..978bc9c74 100644 --- a/apps/assistants/sync.py +++ b/apps/assistants/sync.py @@ -62,6 +62,7 @@ import openai from django.db.models import Count, Subquery +from django.forms import ValidationError from langchain_core.utils.function_calling import convert_to_openai_tool as lc_convert_to_openai_tool from openai import OpenAI from openai.types.beta import Assistant @@ -72,6 +73,7 @@ from apps.files.models import File from apps.service_providers.models import LlmProvider, LlmProviderModel, LlmProviderTypes from apps.teams.models import Team +from apps.utils.prompt import validate_prompt_variables logger = logging.getLogger("openai_sync") @@ -94,6 +96,8 @@ def _inner(*args, **kwargs): pass raise OpenAiSyncError(message) from e + except ValidationError as e: + raise OpenAiSyncError(str(e)) from e return _inner @@ -167,11 +171,20 @@ def import_openai_assistant(assistant_id: str, llm_provider: LlmProvider, team: client = llm_provider.get_llm_service().get_raw_client() openai_assistant = client.beta.assistants.retrieve(assistant_id) kwargs = _openai_assistant_to_ocs_kwargs(openai_assistant, team=team, llm_provider=llm_provider) + validate_instructions(kwargs["instructions"]) assistant = OpenAiAssistant.objects.create(**kwargs) _sync_tool_resources_from_openai(openai_assistant, assistant) return assistant +def validate_instructions(instructions: str): + validate_prompt_variables( + form_data={"instructions": instructions}, + prompt_key="instructions", + known_vars=OpenAiAssistant.ALLOWED_INSTRUCTIONS_VARIABLES, + ) + + @wrap_openai_errors def delete_openai_assistant(assistant: OpenAiAssistant): """Deletes the assistant from OpenAI and removes all associated files. diff --git a/apps/assistants/tests/test_sync.py b/apps/assistants/tests/test_sync.py index 62c7f1729..d40612933 100644 --- a/apps/assistants/tests/test_sync.py +++ b/apps/assistants/tests/test_sync.py @@ -1,4 +1,5 @@ import dataclasses +import re from io import BytesIO from unittest.mock import call, patch @@ -7,6 +8,7 @@ from apps.assistants.models import ToolResources from apps.assistants.sync import ( + OpenAiSyncError, _update_or_create_vector_store, delete_openai_assistant, get_out_of_sync_files, @@ -218,6 +220,18 @@ def test_import_openai_assistant(_, mock_file_retrieve, mock_vector_store_files, ] +@pytest.mark.django_db() +@patch("openai.resources.beta.Assistants.retrieve") +def test_import_openai_assistant_raises_for_invalid_instructions(mock_retrieve): + remote_assistant = AssistantFactory(instructions="This is a test with a {invalid_variable}") + mock_retrieve.return_value = remote_assistant + llm_provider = LlmProviderFactory() + + expected_error_msg = "{'instructions': ['Prompt contains unknown variables: invalid_variable']}" + with pytest.raises(OpenAiSyncError, match=re.escape(expected_error_msg)): + import_openai_assistant("123", llm_provider, llm_provider.team) + + @pytest.mark.django_db() @patch("openai.resources.beta.Assistants.delete") @patch("openai.resources.beta.vector_stores.VectorStores.delete")