From feb77d00d2353908408222ce05c7a11b93de1442 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Wed, 13 Mar 2024 19:15:24 -0400 Subject: [PATCH 1/4] Support openai 1.14+ --- src/marvin/beta/assistants/assistants.py | 4 ++-- src/marvin/beta/assistants/formatting.py | 17 ++++++++++------ src/marvin/beta/assistants/threads.py | 25 ++++++++++++++---------- src/marvin/beta/chat_ui/chat_ui.py | 4 ++-- 4 files changed, 30 insertions(+), 20 deletions(-) diff --git a/src/marvin/beta/assistants/assistants.py b/src/marvin/beta/assistants/assistants.py index 7cb0b1744..9debc7d8f 100644 --- a/src/marvin/beta/assistants/assistants.py +++ b/src/marvin/beta/assistants/assistants.py @@ -16,7 +16,7 @@ ) from marvin.utilities.logging import get_logger -from .threads import Thread, ThreadMessage +from .threads import Thread, Message if TYPE_CHECKING: from .runs import Run @@ -81,7 +81,7 @@ async def say_async( thread: Optional[Thread] = None, return_user_message: bool = False, **run_kwargs, - ) -> list[ThreadMessage]: + ) -> list[Message]: """ A convenience method for adding a user message to the assistant's default thread, running the assistant, and returning the assistant's diff --git a/src/marvin/beta/assistants/formatting.py b/src/marvin/beta/assistants/formatting.py index 99655835e..1e7235485 100644 --- a/src/marvin/beta/assistants/formatting.py +++ b/src/marvin/beta/assistants/formatting.py @@ -2,7 +2,12 @@ from datetime import datetime import openai -from openai.types.beta.threads import ThreadMessage +# for openai < 1.14.0 +try: + from openai.types.beta.threads import ThreadMessage as Message +# for openai >= 1.14.0 +except ImportError: + from openai.types.beta.threads import Message from openai.types.beta.threads.runs.run_step import RunStep from rich import box from rich.console import Console @@ -38,7 +43,7 @@ # for obj in combined: # if isinstance(obj, RunStep): # pprint_run_step(obj) -# elif isinstance(obj, ThreadMessage): +# elif isinstance(obj, Message): # pprint_message(obj) @@ -135,14 +140,14 @@ def download_temp_file(file_id: str, suffix: str = None): return temp_file_path -def pprint_message(message: ThreadMessage): +def pprint_message(message: Message): """ Pretty-prints a single message using the rich library, highlighting the speaker's role, the message text, any available images, and the message timestamp in a panel format. Args: - message (ThreadMessage): A message object + message (Message): A message object """ console = Console() role_colors = { @@ -192,7 +197,7 @@ def pprint_message(message: ThreadMessage): console.print(panel) -def pprint_messages(messages: list[ThreadMessage]): +def pprint_messages(messages: list[Message]): """ Iterates over a list of messages and pretty-prints each one. @@ -201,7 +206,7 @@ def pprint_messages(messages: list[ThreadMessage]): timestamp in a panel format. Args: - messages (list[ThreadMessage]): A list of ThreadMessage objects to be + messages (list[Message]): A list of Message objects to be printed. """ for message in messages: diff --git a/src/marvin/beta/assistants/threads.py b/src/marvin/beta/assistants/threads.py index 42704edf7..61b79bbeb 100644 --- a/src/marvin/beta/assistants/threads.py +++ b/src/marvin/beta/assistants/threads.py @@ -2,7 +2,12 @@ import time from typing import TYPE_CHECKING, Callable, Optional, Union -from openai.types.beta.threads import ThreadMessage +# for openai < 1.14.0 +try: + from openai.types.beta.threads import ThreadMessage as Message +# for openai >= 1.14.0 +except ImportError: + from openai.types.beta.threads import Message from pydantic import BaseModel, Field, PrivateAttr import marvin.utilities.openai @@ -34,7 +39,7 @@ class Thread(BaseModel, ExposeSyncMethodsMixin): id: Optional[str] = None metadata: dict = {} - messages: list[ThreadMessage] = Field([], repr=False) + messages: list[Message] = Field([], repr=False) def __enter__(self): return run_sync(self.__aenter__) @@ -67,7 +72,7 @@ async def create_async(self, messages: list[str] = None): @expose_sync_method("add") async def add_async( self, message: str, file_paths: Optional[list[str]] = None, role: str = "user" - ) -> ThreadMessage: + ) -> Message: """ Add a user message to the thread. """ @@ -87,7 +92,7 @@ async def add_async( response = await client.beta.threads.messages.create( thread_id=self.id, role=role, content=message, file_ids=file_ids ) - return ThreadMessage.model_validate(response.model_dump()) + return Message.model_validate(response.model_dump()) @expose_sync_method("get_messages") async def get_messages_async( @@ -96,7 +101,7 @@ async def get_messages_async( before_message: Optional[str] = None, after_message: Optional[str] = None, json_compatible: bool = False, - ) -> list[Union[ThreadMessage, dict]]: + ) -> list[Union[Message, dict]]: """ Asynchronously retrieves messages from the thread. @@ -107,12 +112,12 @@ async def get_messages_async( after_message (str, optional): The ID of the message to start the list from, retrieving messages sent after this one. json_compatible (bool, optional): If True, returns messages as dictionaries. - If False, returns messages as ThreadMessage + If False, returns messages as Message objects. Default is False. Returns: - list[Union[ThreadMessage, dict]]: A list of messages from the thread, either - as dictionaries or ThreadMessage objects, + list[Union[Message, dict]]: A list of messages from the thread, either + as dictionaries or Message objects, depending on the value of json_compatible. """ @@ -130,7 +135,7 @@ async def get_messages_async( order="desc", ) - T = dict if json_compatible else ThreadMessage + T = dict if json_compatible else Message return parse_as(list[T], reversed(response.model_dump()["data"])) @@ -238,7 +243,7 @@ async def run_async(self, interval_seconds: int = None): logger.error(f"Error refreshing thread: {exc}") await asyncio.sleep(interval_seconds) - async def get_latest_messages(self) -> list[ThreadMessage]: + async def get_latest_messages(self) -> list[Message]: limit = 20 # Loop to get all new messages in batches of 20 diff --git a/src/marvin/beta/chat_ui/chat_ui.py b/src/marvin/beta/chat_ui/chat_ui.py index e0b6b8adb..d12405d71 100644 --- a/src/marvin/beta/chat_ui/chat_ui.py +++ b/src/marvin/beta/chat_ui/chat_ui.py @@ -12,7 +12,7 @@ from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles -from marvin.beta.assistants.threads import Thread, ThreadMessage +from marvin.beta.assistants.threads import Thread, Message def find_free_port(): @@ -47,7 +47,7 @@ async def post_message( message_queue.put(dict(thread_id=thread_id, message=content)) @app.get("/api/messages/") - async def get_messages(thread_id: str) -> list[ThreadMessage]: + async def get_messages(thread_id: str) -> list[Message]: thread = Thread(id=thread_id) return await thread.get_messages_async(limit=100) From 1d1deacaee501f531469d018d944f2e8f93e725a Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Wed, 13 Mar 2024 19:30:12 -0400 Subject: [PATCH 2/4] Further updates --- src/marvin/beta/assistants/threads.py | 2 +- src/marvin/types.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/marvin/beta/assistants/threads.py b/src/marvin/beta/assistants/threads.py index 61b79bbeb..7ff09eaff 100644 --- a/src/marvin/beta/assistants/threads.py +++ b/src/marvin/beta/assistants/threads.py @@ -92,7 +92,7 @@ async def add_async( response = await client.beta.threads.messages.create( thread_id=self.id, role=role, content=message, file_ids=file_ids ) - return Message.model_validate(response.model_dump()) + return response @expose_sync_method("get_messages") async def get_messages_async( diff --git a/src/marvin/types.py b/src/marvin/types.py index f1e7277bb..800d9e6db 100644 --- a/src/marvin/types.py +++ b/src/marvin/types.py @@ -89,14 +89,14 @@ class ImageUrl(MarvinType): detail: str = "auto" -class MessageImageURLContent(MarvinType): +class ImageFileContentBlock(MarvinType): """Schema for messages containing images""" type: Literal["image_url"] = "image_url" image_url: ImageUrl -class MessageTextContent(MarvinType): +class TextContentBlock(MarvinType): """Schema for messages containing text""" type: Literal["text"] = "text" @@ -106,7 +106,7 @@ class MessageTextContent(MarvinType): class BaseMessage(MarvinType): """Base schema for messages""" - content: Union[str, list[Union[MessageImageURLContent, MessageTextContent]]] + content: Union[str, list[Union[ImageFileContentBlock, TextContentBlock]]] role: str @@ -305,15 +305,15 @@ def from_path(cls, path: Union[str, Path]) -> "Image": def from_url(cls, url: str) -> "Image": return cls(url=url) - def to_message_content(self) -> MessageImageURLContent: + def to_message_content(self) -> ImageFileContentBlock: if self.url: - return MessageImageURLContent( + return ImageFileContentBlock( image_url=dict(url=self.url, detail=self.detail) ) elif self.data: b64_image = base64.b64encode(self.data).decode("utf-8") path = f"data:image/{self.format};base64,{b64_image}" - return MessageImageURLContent(image_url=dict(url=path, detail=self.detail)) + return ImageFileContentBlock(image_url=dict(url=path, detail=self.detail)) else: raise ValueError("Image source is not specified") From 9615abf12f9e25583df2922564dc87a839075ddc Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Wed, 13 Mar 2024 19:33:48 -0400 Subject: [PATCH 3/4] Allow disable beta import --- src/marvin/__init__.py | 7 +++++-- src/marvin/settings.py | 6 ++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/marvin/__init__.py b/src/marvin/__init__.py index d2b0e941a..b71b47251 100644 --- a/src/marvin/__init__.py +++ b/src/marvin/__init__.py @@ -16,7 +16,9 @@ ) from .ai.images import paint, image from .ai.audio import speak_async, speak, speech, transcribe, transcribe_async -from . import beta + +if settings.auto_import_beta_modules: + from . import beta try: from ._version import version as __version__ @@ -48,9 +50,10 @@ "transcribe", "transcribe_async", # --- beta --- - "beta", ] +if settings.auto_import_beta_modules: + __all__.append("beta") # compatibility with Marvin v1 ai_fn = fn diff --git a/src/marvin/settings.py b/src/marvin/settings.py index e5982e3f2..6ea80eb54 100644 --- a/src/marvin/settings.py +++ b/src/marvin/settings.py @@ -253,6 +253,12 @@ class Settings(MarvinSettings): # ai settings ai: AISettings = Field(default_factory=AISettings) + # beta settings + auto_import_beta_modules: bool = Field( + True, + description="If True, the marvin.beta module will be automatically imported when marvin is imported.", + ) + # log settings log_level: str = Field( default="INFO", From e5dcd41aa7f5c291a75f6509ec172b697184d7e1 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Wed, 13 Mar 2024 19:35:52 -0400 Subject: [PATCH 4/4] Linting --- src/marvin/beta/assistants/assistants.py | 2 +- src/marvin/beta/assistants/formatting.py | 1 + src/marvin/beta/assistants/threads.py | 2 +- src/marvin/beta/chat_ui/chat_ui.py | 2 +- 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/marvin/beta/assistants/assistants.py b/src/marvin/beta/assistants/assistants.py index 9debc7d8f..f0c562948 100644 --- a/src/marvin/beta/assistants/assistants.py +++ b/src/marvin/beta/assistants/assistants.py @@ -16,7 +16,7 @@ ) from marvin.utilities.logging import get_logger -from .threads import Thread, Message +from .threads import Message, Thread if TYPE_CHECKING: from .runs import Run diff --git a/src/marvin/beta/assistants/formatting.py b/src/marvin/beta/assistants/formatting.py index 1e7235485..785da8c61 100644 --- a/src/marvin/beta/assistants/formatting.py +++ b/src/marvin/beta/assistants/formatting.py @@ -2,6 +2,7 @@ from datetime import datetime import openai + # for openai < 1.14.0 try: from openai.types.beta.threads import ThreadMessage as Message diff --git a/src/marvin/beta/assistants/threads.py b/src/marvin/beta/assistants/threads.py index 7ff09eaff..14d071cd2 100644 --- a/src/marvin/beta/assistants/threads.py +++ b/src/marvin/beta/assistants/threads.py @@ -7,7 +7,7 @@ from openai.types.beta.threads import ThreadMessage as Message # for openai >= 1.14.0 except ImportError: - from openai.types.beta.threads import Message + from openai.types.beta.threads import Message from pydantic import BaseModel, Field, PrivateAttr import marvin.utilities.openai diff --git a/src/marvin/beta/chat_ui/chat_ui.py b/src/marvin/beta/chat_ui/chat_ui.py index d12405d71..a05ee9e3f 100644 --- a/src/marvin/beta/chat_ui/chat_ui.py +++ b/src/marvin/beta/chat_ui/chat_ui.py @@ -12,7 +12,7 @@ from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles -from marvin.beta.assistants.threads import Thread, Message +from marvin.beta.assistants.threads import Message, Thread def find_free_port():