From 16a6c3fac8becf8155876118465d7521676475d8 Mon Sep 17 00:00:00 2001 From: Andrew Robertson Date: Tue, 23 Jul 2024 15:46:03 +0100 Subject: [PATCH] allow passing of response_format to Assistant constructor, default to 'auto' (the open ai default) --- src/marvin/beta/assistants/assistants.py | 6 ++++-- src/marvin/types.py | 12 ++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/marvin/beta/assistants/assistants.py b/src/marvin/beta/assistants/assistants.py index 367b3687f..343aebb33 100644 --- a/src/marvin/beta/assistants/assistants.py +++ b/src/marvin/beta/assistants/assistants.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union from openai import AsyncAssistantEventHandler from prompt_toolkit import PromptSession @@ -13,7 +13,7 @@ import marvin.utilities.tools from marvin.beta.assistants.handlers import PrintHandler from marvin.tools.assistants import AssistantTool -from marvin.types import Tool +from marvin.types import AssistantResponseFormat, Tool from marvin.utilities.asyncio import ( ExposeSyncMethodsMixin, expose_sync_method, @@ -64,6 +64,7 @@ class Assistant(BaseModel, ExposeSyncMethodsMixin): tools: list[Union[AssistantTool, Callable]] = [] tool_resources: dict[str, Any] = {} metadata: dict[str, str] = {} + response_format: Optional[Union[Literal["auto"], AssistantResponseFormat]] = "auto" # context level tracks nested assistant contexts _context_level: int = PrivateAttr(0) @@ -173,6 +174,7 @@ async def create_async(self, _auto_delete: bool = False): "metadata", "tool_resources", "metadata", + "response_format", } ), tools=[tool.model_dump() for tool in self.get_tools()], diff --git a/src/marvin/types.py b/src/marvin/types.py index 98d0616bf..f663ee81a 100644 --- a/src/marvin/types.py +++ b/src/marvin/types.py @@ -104,6 +104,18 @@ class FunctionCall(MarvinType): name: str +class AssistantResponseFormat(MarvinType): + type: Literal["json_object", "text"] + + +class JsonObjectAssistantResponseFormat(AssistantResponseFormat): + type: Literal["json_object"] = "json_object" + + +class TextAssistantResponseFormat(AssistantResponseFormat): + type: Literal["text"] = "text" + + class ImageUrl(MarvinType): url: str = Field( description="URL of the image to be sent or a base64 encoded image."