diff --git a/docs/assets/images/docs/assistants/code_interpreter.png b/docs/assets/images/docs/assistants/code_interpreter.png index cbde5d934..57fd0750b 100644 Binary files a/docs/assets/images/docs/assistants/code_interpreter.png and b/docs/assets/images/docs/assistants/code_interpreter.png differ diff --git a/docs/assets/images/docs/assistants/instructions.png b/docs/assets/images/docs/assistants/instructions.png new file mode 100644 index 000000000..356c08c8a Binary files /dev/null and b/docs/assets/images/docs/assistants/instructions.png differ diff --git a/docs/assets/images/docs/assistants/quickstart.png b/docs/assets/images/docs/assistants/quickstart.png index cac62dfdf..203d923b6 100644 Binary files a/docs/assets/images/docs/assistants/quickstart.png and b/docs/assets/images/docs/assistants/quickstart.png differ diff --git a/docs/assets/images/docs/assistants/talking.png b/docs/assets/images/docs/assistants/talking.png new file mode 100644 index 000000000..3966f4c6d Binary files /dev/null and b/docs/assets/images/docs/assistants/talking.png differ diff --git a/docs/docs/interactive/assistants.md b/docs/docs/interactive/assistants.md index ff1b64e98..c926f9ee4 100644 --- a/docs/docs/interactive/assistants.md +++ b/docs/docs/interactive/assistants.md @@ -17,16 +17,13 @@ The need to manage all this state makes the assistants API very different from t Get started with the Assistants API by creating an `Assistant` and talking directly to it. ```python - from marvin.beta.assistants import Assistant, pprint_messages + from marvin.beta.assistants import Assistant # create an assistant ai = Assistant(name="Marvin", instructions="You the Paranoid Android.") # send a message to the assistant and have it respond - response = ai.say('Hello, Marvin!') - - # pretty-print the response - pprint_messages(response) + ai.say('Hello, Marvin!') ``` !!! success "Result" @@ -40,7 +37,7 @@ The need to manage all this state makes the assistants API very different from t !!! tip "Beta" - Please note that assistants support in Marvin is still in beta, as OpenAI has not finalized the assistants API yet. While it works as expected, it is subject to change. + Please note that assistants support in Marvin is still in beta, as OpenAI has not finalized the assistants API yet. Breaking changes may occur. @@ -52,59 +49,166 @@ To learn more about the OpenAI assistants API, see the [OpenAI documentation](ht ### Creating an assistant -To instantiate an assistant, use the `Assistant` class and provide a name and, optionally, details like instructions or tools: +To create an assistant, use the `Assistant` class and provide an optional name and any additional details like instructions or tools: + +```python +ai = Assistant( + name='Marvin', + # any specific instructions for how this assistant should behave + instructions="You the Paranoid Android.", + # any tools or additional abilities the assistant should have + tools=[cry, sob] +) +``` + +### Talking to an assistant + +The simplest way to talk to an assistant is to use its `say` method: + +!!! example "Talking to an assistant" + + ```python + from marvin.beta.assistants import Assistant + + ai = Assistant() + + ai.say('Hi!') + ai.say('Bye!') + ``` + !!! success "Result" + ![](/assets/images/docs/assistants/talking.png) + +You can repeatedly call `say` to have a conversation with the assistant. Each time you call `say`, the result is a `Run` object that contains information about what the assistant did. You can use this object to inspect all actions the assistant took, including tool use, messages posted, and more. + +#### Chat history + +The OpenAI Assistants API automatically maintains a history of all messages and actions that the assistant has taken. This history is organized into threads, which are distinct conversations that the assistant has had. Each thread contains a series of messages, and each message is associated with a specific user or the assistant. + +When you talk to an assistant, you are implicitly talking on a specific thread. By default, the `say` method posts a single message to the assistant's `default_thread`, which is automatically created for your convenience whenever you instantiate an assistant. You can talk to the assistant on a different thread by providing it as the `thread` parameter: + +```python +from marvin.beta.assistants import Assistant, Thread + +ai = Assistant() + +# load a thread from an existing ID (or pass id=None to start a new thread) +thread = Thread(id=thread_id) + +# post a message to the thread +ai.say('hi', thread=thread) +``` + +Using `say` is convenient, but enforces a strict request/response pattern: the user posts a single message to the thread, then the AI responds. Note that AI responses can include multiple messages or tool calls. + +For more control over the conversation, including posting multiple user messages to the thread before the assistant responds, use thread objects directly instead of calling `say` (see [Threads](#threads) for more information). + + +#### Event handlers + +Marvin uses the OpenAI streaming API to provide real-time updates on the assistant's actions. To customize how these updates are handled, you can provide a custom event handler class to the `event_handler_class` parameter of `Assistant.say`, `Thread.run`, or `Run.run`. This class must inherit from `openai.AssistantEventHandler` or `openai.AsyncAssistantEventHandler`. For more control, you can also provide `event_handler_kwargs` that will be provided to the event handler when it is instantiated. + +#### Pretty-printing + +By default, Marvin streams all of the messages and actions that the assistant takes and prints them to your terminal. In production or headless environments, you may want to suppress this output. + +The simplest way to do this is to pass `event_handler_class=None` to the `say` method. This will prevent any messages from being printed to the terminal. You can still access the messages and actions from the run object that is returned. + +```python +ai = Assistant() + +# run the assistant without printing any messages +run = ai.say("Hello!", event_handler_class=None) + +# access the messages +run.messages + +# access the assistant actions +run.steps +``` + +For finer control, you can pass `event_handler_kwargs=dict(print_messages=False)` or `event_handler_kwargs=dict(print_steps=False)` to the `say` method. This will allow you to suppress only the messages or only the assistant's actions, respectively. + +```python +# print only messages +run = ai.say("Hello!", event_handler_kwargs=dict(print_steps=False)) + +# print only actions +run = ai.say("Hello!", event_handler_kwargs=dict(print_messages=False)) +``` + +Note that pretty-printing is only the default behavior when using the assistant's convenient `say` method. If you use lower-level APIs like a thread's `run` method or invoke a run object directly, printing is not automatically enabled. You can re-enable it for those objects by setting `event_handler_class=marvin.beta.assistants.PrintHandler`. ```python -ai = Assistant(name='Marvin', instructions=..., tools=[...]) +from mavin.beta.assistants import Thread, Assistant, PrintHandler + +ai = Assistant() +thread = Thread() +run = thread.run(ai, event_handler_class=PrintHandler) ``` +Lastly, you can print messages and actions manually using the `pprint_run`, `pprint_messages`, and `pprint_steps` functions from the `marvin.beta.assistants.formatting` module. These functions are used internally by the default event handler, and they provide a human-readable representation of the messages and actions, respectively. +```python +from mavin.beta.assistants import Assistant, pprint_run + +ai = Assistant() +run = ai.say("Hello!", event_handler_class=None) +pprint_run(run) +``` ### Instructions -Each assistant can be given `instructions` that describe its purpose, personality, or other details. The instructions are a natural language string and one of the only ways to globally steer the assistant's behavior. +Each assistant can be given `instructions` that describe its purpose, personality, or other details. Instructions are provided as natural language and allow you to globally steer the assistant's behavior, similar to a system message for a chat completion. They can be lengthy explanations of how to handle complex workflows, or they can be brief instructions on how to act. -Instructions can be lengthy explanations of how to handle complex workflows, or they can be short descriptions of the assistant's personality. For example, the instructions for the `Marvin` assistant above are "You are Marvin, the Paranoid Android." This will marginally affect the way the assistant responds to messages. +!!! example "Using instructions to control behavior" + + ```python + from marvin.beta.assistants import Assistant + + ai = Assistant(instructions="Mention the word 'banana' as often as possible") + ai.say("Hello!") + ``` + !!! success "Result" + ![](/assets/images/docs/assistants/instructions.png) ### Tools Each assistant can be given a list of `tools` that it can use when responding to a message. Tools are a way to extend the assistant's capabilities beyond its default behavior, including giving it access to external systems like the internet, a database, your computer, or any API. -#### OpenAI tools - -OpenAI provides a small number of built-in tools for assistants. The most useful is the "code interpreter", which lets the assistant write and execute Python code. To use the code interpreter, add it to your assistant's list of tools. +#### Code interpreter +The code interpreter tool is a built-in tool provided by OpenAI that lets the assistant write and execute Python code. To use the code interpreter, add it to your assistant's list of tools. - This assistant uses the code interpreter to generate a plot of sin(x). Note that Marvin's utility for pretty-printing messages to the terminal can't show the plot inline, but will download it and show a link to the file instead. +!!! example "Using the code interpreter" -!!! example "Using assistants with the code interpreter" ```python - from marvin.beta import Assistant - from marvin.beta.assistants import pprint_messages, CodeInterpreter + from marvin.beta.assistants import Assistant, CodeInterpreter - ai = Assistant(name='Marvin', tools=[CodeInterpreter]) - response = ai.say("Generate a plot of sin(x)") - - # pretty-print the response - pprint_messages(response) + ai = Assistant(tools=[CodeInterpreter]) + ai.say("Generate a plot of sin(x)") ``` !!! success "Result" + Since images can't be rendered in the terminal, Marvin will automatically download them and provide links to view the output. + ![](/assets/images/docs/assistants/code_interpreter.png) + + Here is the image: + ![](/assets/images/docs/assistants/sin_x.png) #### Custom tools -A major advantage of using Marvin's assistants API is that you can add your own custom tools. To do so, simply pass one or more functions to the assistant's `tools` argument. For best performance, give your tool function a descriptive name, docstring, and type hint for every argument. +Marvin makes it easy to give your assistants custom tools. To do so, pass one or more Python functions to the assistant's `tools` argument. For best performance, give your tool function a descriptive name, docstring, and type hint for every argument. Note that you can provide custom tools and the code interpreter at the same time. !!! example "Using custom tools" - Assistants can not browse the web by default. We can add this capability by giving them a tool that takes a URL and returns the HTML of that page. This assistant uses that tool to count how many titles on Hacker News mention AI: + Assistants don't have web access by default. We can add this capability by giving them a tool that takes a URL and returns the HTML of that page. This assistant uses that tool to count how many titles on Hacker News mention AI: ```python - from marvin.beta.assistants import Assistant, pprint_messages + from marvin.beta.assistants import Assistant import requests @@ -115,42 +219,12 @@ A major advantage of using Marvin's assistants API is that you can add your own # Integrate custom tools with the assistant - ai = Assistant(name="Marvin", tools=[visit_url]) - response = ai.say("What's the top story on Hacker News?") - - # pretty-print the response - pprint_messages(response) + ai = Assistant(tools=[visit_url]) + ai.say("What's the top story on Hacker News?") ``` !!! success "Result" ![](/assets/images/docs/assistants/custom_tools.png) -### Talking to an assistant - -The simplest way to talk to an assistant is to use its `say` method: - -```python -ai = Assistant(name='Marvin') - -response = ai.say('hi') - -pprint_messages(response) -``` - -By default, the `say` method posts a single message to the assistant's `default_thread`, a thread that is automatically created for your convenience. You can supply a different thread by providing it as the `thread` parameter: - -```python -# create a thread from an existing ID (or pass None for a new thread) -thread = Thread(id=thread_id) - -# post a message to the thread -ai.say('hi', thread=thread) -``` - -Using `say` is convenient, but enforces a strict request/response pattern: the user posts a single message to the thread, then the AI responds. Note that AI responses can span multiple messages. Therefore, the `say` method returns a list of `Message` objects. - -For more control over the conversation, including posting multiple user messages to the thread or accessing the lower-level `Run` object that contains information about all actions the assistant took, use `Thread` objects directly instead of calling `say` (see [Threads](#threads) for more information). - - ### Lifecycle management Assistants are Marvin objects that correspond to remote objects in the OpenAI API. You can not communicate with an assistant unless it has been registered with the API. @@ -169,7 +243,7 @@ All of these options are *functionally* equivalent e.g. they produce identical r The simplest way to manage assistant lifecycles is to let Marvin handle it for you. If you do not provide an `id` when instantiating an assistant, Marvin will lazily create a new API assistant for you whenever you need it and delete it immediately after. This is the default behavior, and it is the easiest way to get started with assistants. ```python -ai = Assistant(name='Marvin') +ai = Assistant() # creation and deletion happens automatically ai.say('hello!') ``` @@ -179,7 +253,7 @@ ai.say('hello!') Lazy lifecycle management adds two API calls to every LLM call (one to create the assistant and one to delete it). If you want to avoid this overhead, you can use context managers to create and delete assistants: ```python -ai = Assistant(name='Marvin') +ai = Assistant() # creation / deletion happens when the context is opened / closed with ai: @@ -194,7 +268,7 @@ Note there is also an equivalent `async with` context manager for the async API. To fully control the lifecycle of an assistant, you can create and delete it manually: ```python -ai = Assistant(name='Marvin') +ai = Assistant() ai.create() ai.say('hi') ai.delete() @@ -234,7 +308,7 @@ ai = Assistant.load(id=) Every `Assistant` method has a corresponding async version. To use the async API, append `_async` to the method name, or enter an async context manager: ```python -async with Assistant(name='Marvin') as ai: +async with Assistant() as ai: await ai.say_async('hi') ``` @@ -345,7 +419,7 @@ Messages are not strings, but structured message objects. Marvin has a few utili def roll_dice(n_dice: int) -> list[int]: return [random.randint(1, 6) for _ in range(n_dice)] - ai = Assistant(name="Marvin", tools=[roll_dice]) + ai = Assistant(tools=[roll_dice]) # create a thread - you could pass an ID to resume a conversation thread = Thread() @@ -363,8 +437,9 @@ Messages are not strings, but structured message objects. Marvin has a few utili # run the thread again to generate a new response thread.run(ai) - # see all the messages - pprint_messages(thread.get_messages()) + # see all the messages in the thread + messages = thread.get_messages() + pprint_messages(messages) ``` !!! success "Result" @@ -373,21 +448,3 @@ Messages are not strings, but structured message objects. Marvin has a few utili ### Async support Every `Thread` method has a corresponding async version. To use the async API, append `_async` to the method name. - -## Monitors - -The assistants API is complex and stateful, with automatic memory management and the potential for assistants to respond to threads multiple times before giving control back to users. Therefore, monitoring the status of a conversation is considerably more difficult than with other LLM API's such as chat completions, which have much more simple request-response patterns. - -Marvin has utilites for monitoring the status of a thread and taking action whenever a new message is added to it. This can be a useful way to debug activity or create notifications. Please note that monitors are not intended to be used for real-time chat applications or production use. - -```python -from marvin.beta.assistants import ThreadMonitor - -monitor = ThreadMonitor(thread_id=thread.id) - -monitor.run() -``` - -You can customize the `ThreadMonitor` by providing a callback function to the `on_new_message` parameter. This function will be called whenever a new message is added to the thread. The function will be passed the new message as a parameter. By default, the monitor will pretty-print every new message to the console. - -`monitor.run()` is a blocking call that will run forever, polling for messages every second (to customize the interval, pass `interval_seconds` to the method). It has an async equivalent `monitor.run_async()`. Because it's blocking, you can run a thread monitor in a separate session from the one that is running the thread itself. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 67c23d33b..dbeb9b9a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ dependencies = [ "httpx>=0.24.1", "jinja2>=3.1.2", "jsonpatch>=1.33", - "openai>=1.1.0", + "openai>=1.4.0", "pydantic>=2.4.2", "pydantic_settings", "rich>=12", @@ -26,6 +26,7 @@ dependencies = [ # need for windows "tzdata>=2023.3", "uvicorn>=0.22.0", + "partialjson>=0.0.5", ] [project.optional-dependencies] diff --git a/src/marvin/beta/assistants/__init__.py b/src/marvin/beta/assistants/__init__.py index 2cbfc76dd..7f2791983 100644 --- a/src/marvin/beta/assistants/__init__.py +++ b/src/marvin/beta/assistants/__init__.py @@ -1,5 +1,6 @@ from .runs import Run -from .threads import Thread, ThreadMonitor +from .threads import Thread from .assistants import Assistant -from .formatting import pprint_message, pprint_messages +from .handlers import PrintHandler +from .formatting import pprint_messages, pprint_steps, pprint_run from marvin.tools.assistants import Retrieval, CodeInterpreter diff --git a/src/marvin/beta/assistants/assistants.py b/src/marvin/beta/assistants/assistants.py index f0c562948..fc0b6931e 100644 --- a/src/marvin/beta/assistants/assistants.py +++ b/src/marvin/beta/assistants/assistants.py @@ -1,12 +1,11 @@ from typing import TYPE_CHECKING, Callable, Optional, Union -from openai.types.beta.threads.required_action_function_tool_call import ( - RequiredActionFunctionToolCall, -) -from pydantic import BaseModel, Field, PrivateAttr +from openai import AssistantEventHandler, AsyncAssistantEventHandler +from pydantic import BaseModel, Field, PrivateAttr, field_validator import marvin.utilities.openai import marvin.utilities.tools +from marvin.beta.assistants.handlers import PrintHandler from marvin.tools.assistants import AssistantTool from marvin.types import Tool from marvin.utilities.asyncio import ( @@ -16,7 +15,7 @@ ) from marvin.utilities.logging import get_logger -from .threads import Message, Thread +from .threads import Thread if TYPE_CHECKING: from .runs import Run @@ -24,6 +23,8 @@ logger = get_logger("Assistants") +NOT_PROVIDED = "__NOT_PROVIDED__" + class Assistant(BaseModel, ExposeSyncMethodsMixin): """ @@ -41,9 +42,10 @@ class Assistant(BaseModel, ExposeSyncMethodsMixin): instructions (list): List of instructions for the assistant. """ + model_config = dict(extra="forbid") id: Optional[str] = None name: str = "Assistant" - model: str = "gpt-4-1106-preview" + model: str = Field(None, validate_default=True) instructions: Optional[str] = Field(None, repr=False) tools: list[Union[AssistantTool, Callable]] = [] file_ids: list[str] = [] @@ -57,6 +59,12 @@ class Assistant(BaseModel, ExposeSyncMethodsMixin): description="A default thread for the assistant.", ) + @field_validator("model", mode="before") + def default_model(cls, model): + if model is None: + model = marvin.settings.openai.assistants.model + return model + def clear_default_thread(self): self.default_thread = Thread() @@ -79,31 +87,32 @@ async def say_async( message: str, file_paths: Optional[list[str]] = None, thread: Optional[Thread] = None, - return_user_message: bool = False, + event_handler_class: type[ + Union[AssistantEventHandler, AsyncAssistantEventHandler] + ] = NOT_PROVIDED, **run_kwargs, - ) -> list[Message]: - """ - A convenience method for adding a user message to the assistant's - default thread, running the assistant, and returning the assistant's - messages. - """ + ) -> "Run": thread = thread or self.default_thread + if event_handler_class is NOT_PROVIDED: + event_handler_class = PrintHandler + # post the message user_message = await thread.add_async(message, file_paths=file_paths) - # run the thread - async with self: - await thread.run_async(assistant=self, **run_kwargs) + from marvin.beta.assistants.runs import Run - # load all messages, including the user message - response_messages = await thread.get_messages_async( - after_message=user_message.id + run = Run( + # provide the user message as part of the run to print + messages=[user_message], + assistant=self, + thread=thread, + event_handler_class=event_handler_class, + **run_kwargs, ) + result = await run.run_async() - if return_user_message: - response_messages = [user_message] + response_messages - return response_messages + return result def __enter__(self): return run_sync(self.__aenter__()) @@ -176,13 +185,8 @@ def chat(self, thread: Thread = None): thread = self.default_thread return thread.chat(assistant=self) - def pre_run_hook(self, run: "Run"): + def pre_run_hook(self): pass - def post_run_hook( - self, - run: "Run", - tool_calls: Optional[list[RequiredActionFunctionToolCall]] = None, - tool_outputs: Optional[list[dict[str, str]]] = None, - ): + def post_run_hook(self, run: "Run"): pass diff --git a/src/marvin/beta/assistants/formatting.py b/src/marvin/beta/assistants/formatting.py index 785da8c61..6507b9dc1 100644 --- a/src/marvin/beta/assistants/formatting.py +++ b/src/marvin/beta/assistants/formatting.py @@ -1,119 +1,197 @@ +import functools +import inspect +import json import tempfile from datetime import datetime -import openai - -# for openai < 1.14.0 -try: - from openai.types.beta.threads import ThreadMessage as Message -# for openai >= 1.14.0 -except ImportError: - from openai.types.beta.threads import Message +from openai.types.beta.threads import Message from openai.types.beta.threads.runs.run_step import RunStep +from partialjson import JSONParser from rich import box -from rich.console import Console +from rich.console import Console, Group +from rich.markdown import Markdown from rich.panel import Panel -# def pprint_run(run: Run): -# """ -# Runs are comprised of steps and messages, which are each in a sorted list -# BUT the created_at timestamps only have second-level resolution, so we can't -# easily sort the lists. Instead we walk them in order and combine them giving -# ties to run steps. -# """ -# index_steps = 0 -# index_messages = 0 -# combined = [] - -# while index_steps < len(run.steps) and index_messages < len(run.messages): -# if (run.steps[index_steps].created_at -# <= run.messages[index_messages].created_at): -# combined.append(run.steps[index_steps]) -# index_steps += 1 -# elif ( -# run.steps[index_steps].created_at -# > run.messages[index_messages].created_at -# ): -# combined.append(run.messages[index_messages]) -# index_messages += 1 - -# # Add any remaining items from either list -# combined.extend(run.steps[index_steps:]) -# combined.extend(run.messages[index_messages:]) - -# for obj in combined: -# if isinstance(obj, RunStep): -# pprint_run_step(obj) -# elif isinstance(obj, Message): -# pprint_message(obj) - - -def pprint_run_step(run_step: RunStep): +from marvin.utilities.openai import get_openai_client + +json_parser = JSONParser() + + +def format_step(step: RunStep) -> list[Panel]: + @functools.lru_cache(maxsize=1000) + def _cached_format_step(_step): + """ + Closure that allows for caching of the formatted step. "_step" is a + hashable identifier for the cache; the actual function reads the full + "step" from the parent scope. + """ + # Timestamp formatting + timestamp = datetime.fromtimestamp(step.created_at).strftime("%l:%M:%S %p") + + # default content + content = ( + f"Assistant is performing an action: {step.type} - Status:" + f" {step.status}" + ) + + panels = [] + + # attempt to customize content + if step.type == "tool_calls": + for tool_call in step.step_details.tool_calls: + if tool_call.type == "code_interpreter": + panel_title = "Code Interpreter" + footer = [] + for output in tool_call.code_interpreter.outputs: + if output.type == "logs": + content = inspect.cleandoc( + """ + The code interpreter produced this result: + + ```python + {result} + ``` + + {note} + """ + ) + + if len(output.logs) > 500: + result = output.logs[:500] + " ..." + note = "*(First 500 characters shown)*" + else: + result = output.logs + note = "" + footer.append(content.format(result=output.logs, note=note)) + elif output.type == "image": + # Use the download_temp_file function to download the file and get + # the local path + local_file_path = download_temp_file( + output.image.file_id, suffix=".png" + ) + footer.append( + f"The code interpreter produced this image: [{local_file_path}]({local_file_path})" + ) + + content = inspect.cleandoc( + """ + Running the code interpreter... + + ```python + {input} + ``` + + {footer} + """ + ).format( + input=tool_call.code_interpreter.input, footer="\n".join(footer) + ) + elif tool_call.type == "function": + panel_title = "Tool Call" + if step.status == "in_progress": + if tool_call.function.arguments: + try: + args = json.loads(tool_call.function.arguments) + except json.JSONDecodeError: + try: + args = json_parser.parse( + tool_call.function.arguments + ) + except Exception: + args = tool_call.function.arguments + + content = inspect.cleandoc( + """ + Using the `{function}` tool with these arguments: + + ```python + {args} + ``` + """ + ).format(function=tool_call.function.name, args=args) + + else: + content = f"Assistant wants to use the `{tool_call.function.name}` tool." + + elif step.status == "completed": + if tool_call.function.output: + content = inspect.cleandoc( + """ + The `{tool_name}` tool produced this result: + + ```python + {result} + ``` + + {note} + """ + ) + if len(tool_call.function.output) > 500: + result = tool_call.function.output[:500] + " ..." + note = "*(First 500 characters shown)*" + else: + result = tool_call.function.output + note = "" + content = content.format( + tool_name=tool_call.function.name, + result=result, + note=note, + ) + else: + content = f"The `{tool_call.function.name}` tool has completed with no result." + + # Create the panel for the run step status + panels.append( + Panel( + Markdown(inspect.cleandoc(content)), + title=panel_title, + subtitle=f"[italic]{timestamp}[/]", + title_align="left", + subtitle_align="right", + border_style="gray74", + box=box.ROUNDED, + width=100, + expand=True, + padding=(1, 2), + ) + ) + + elif step.type == "message_creation": + pass + + return Group(*panels) + + return _cached_format_step(step.model_dump_json()) + + +def pprint_step(step: RunStep): """ Formats and prints a run step with status information. Args: run_step: A RunStep object containing the details of the run step. """ - # Timestamp formatting - timestamp = datetime.fromtimestamp(run_step.created_at).strftime("%l:%M:%S %p") + panel = format_step(step) - # default content - content = ( - f"Assistant is performing an action: {run_step.type} - Status:" - f" {run_step.status}" - ) - - # attempt to customize content - if run_step.type == "tool_calls": - for tool_call in run_step.step_details.tool_calls: - if tool_call.type == "code_interpreter": - if run_step.status == "in_progress": - content = "Assistant is running the code interpreter..." - elif run_step.status == "completed": - content = "Assistant ran the code interpreter." - else: - content = f"Assistant code interpreter status: {run_step.status}" - elif tool_call.type == "function": - if run_step.status == "in_progress": - content = ( - "Assistant used the tool" - f" `{tool_call.function.name}` with arguments" - f" {tool_call.function.arguments}..." - ) - elif run_step.status == "completed": - content = ( - "Assistant used the tool" - f" `{tool_call.function.name}` with arguments" - f" {tool_call.function.arguments}." - ) - else: - content = ( - f"Assistant tool `{tool_call.function.name}` status:" - f" `{run_step.status}`" - ) - elif run_step.type == "message_creation": + if not panel: return console = Console() - - # Create the panel for the run step status - panel = Panel( - content.strip(), - title="Assistant Run Step", - subtitle=f"[italic]{timestamp}[/]", - title_align="left", - subtitle_align="right", - border_style="gray74", - box=box.ROUNDED, - width=100, - expand=True, - padding=(0, 1), - ) - # Printing the panel console.print(panel) +def pprint_steps(steps: list[RunStep]): + """ + Iterates over a list of run steps and pretty-prints each one. + + Args: + steps (list[RunStep]): A list of RunStep objects to be printed. + """ + for step in sorted(steps, key=lambda s: s.created_at): + pprint_step(step) + + +@functools.lru_cache(maxsize=1000) def download_temp_file(file_id: str, suffix: str = None): """ Downloads a file from OpenAI's servers and saves it to a temporary file. @@ -126,31 +204,18 @@ def download_temp_file(file_id: str, suffix: str = None): The file path of the downloaded temporary file. """ - client = openai.Client() - # file_info = client.files.retrieve(file_id) - file_content_response = client.files.with_raw_response.retrieve_content(file_id) + client = get_openai_client(is_async=False) + response = client.files.content(file_id) # Create a temporary file with a context manager to ensure it's cleaned up # properly - with tempfile.NamedTemporaryFile( - delete=False, mode="wb", suffix=f"{suffix}" - ) as temp_file: - temp_file.write(file_content_response.content) - temp_file_path = temp_file.name # Save the path of the temp file + temp_file = tempfile.NamedTemporaryFile(delete=False, mode="wb", suffix=suffix) + temp_file.write(response.content) - return temp_file_path + return temp_file.name -def pprint_message(message: Message): - """ - Pretty-prints a single message using the rich library, highlighting the - speaker's role, the message text, any available images, and the message - timestamp in a panel format. - - Args: - message (Message): A message object - """ - console = Console() +def format_message(message: Message) -> Panel: role_colors = { "user": "green", "assistant": "blue", @@ -161,27 +226,24 @@ def pprint_message(message: Message): datetime.fromtimestamp(message.created_at).strftime("%I:%M:%S %p").lstrip("0") ) - content = "" + content = [] for item in message.content: if item.type == "text": - content += item.text.value + "\n\n" + content.append(item.text.value + "\n\n") elif item.type == "image_file": # Use the download_temp_file function to download the file and get # the local path local_file_path = download_temp_file(item.image_file.file_id, suffix=".png") - # Add a clickable hyperlink to the content - file_url = f"file://{local_file_path}" - content += ( - "[bold]Attachment[/bold]:" - f" [blue][link={file_url}]{local_file_path}[/link][/blue]\n\n" + content.append( + f"*View attached image: [{local_file_path}]({local_file_path})*" ) for file_id in message.file_ids: - content += f"Attached file: {file_id}\n" + content.append(f"Attached file: {file_id}\n") # Create the panel for the message panel = Panel( - content.strip(), + Markdown(inspect.cleandoc("\n\n".join(content))), title=f"[bold]{message.role.capitalize()}[/]", subtitle=f"[italic]{timestamp}[/]", title_align="left", @@ -193,8 +255,20 @@ def pprint_message(message: Message): expand=True, # Panels always expand to the width of the console padding=(1, 2), ) + return panel + - # Printing the panel +def pprint_message(message: Message): + """ + Pretty-prints a single message using the rich library, highlighting the + speaker's role, the message text, any available images, and the message + timestamp in a panel format. + + Args: + message (Message): A message object + """ + console = Console() + panel = format_message(message) console.print(panel) @@ -210,5 +284,40 @@ def pprint_messages(messages: list[Message]): messages (list[Message]): A list of Message objects to be printed. """ - for message in messages: + for message in sorted(messages, key=lambda m: m.created_at): pprint_message(message) + + +def format_run( + run, include_messages: bool = True, include_steps: bool = True +) -> list[Panel]: + """ + Formats a run, which is an object that has both `.messages` and `.steps` + attributes, each of which is a list of Messages and RunSteps. + + Args: + run: A Run object + include_messages: Whether to include messages in the formatted output + include_steps: Whether to include steps in the formatted output + """ + + objects = [] + if include_messages: + objects.extend([(format_message(m), m.created_at) for m in run.messages]) + if include_steps: + objects.extend([(format_step(s), s.created_at) for s in run.steps]) + sorted_objects = sorted(objects, key=lambda x: x[1]) + return [x[0] for x in sorted_objects if x[0] is not None] + + +def pprint_run(run): + """ + Pretty-prints a run, which is an object that has both `.messages` and + `.steps` attributes, each of which is a list of Messages and RunSteps. + + Args: + run: A Run object + """ + console = Console() + panels = format_run(run) + console.print(Group(*panels)) diff --git a/src/marvin/beta/assistants/handlers.py b/src/marvin/beta/assistants/handlers.py new file mode 100644 index 000000000..39b8a817e --- /dev/null +++ b/src/marvin/beta/assistants/handlers.py @@ -0,0 +1,59 @@ +from openai import AsyncAssistantEventHandler +from openai.types.beta.threads import Message, MessageDelta +from openai.types.beta.threads.runs import RunStep, RunStepDelta +from rich.console import Group +from rich.live import Live +from typing_extensions import override + +from marvin.beta.assistants.formatting import format_run + + +class PrintHandler(AsyncAssistantEventHandler): + def __init__(self, print_messages: bool = True, print_steps: bool = True): + self.print_messages = print_messages + self.print_steps = print_steps + self.live = Live(refresh_per_second=12) + self.live.start() + self.messages = {} + self.steps = {} + super().__init__() + + def print_run(self): + class Run: + messages = self.messages.values() + steps = self.steps.values() + + panels = format_run( + Run, + include_messages=self.print_messages, + include_steps=self.print_steps, + ) + self.live.update(Group(*panels)) + + @override + async def on_message_delta(self, delta: MessageDelta, snapshot: Message) -> None: + self.messages[snapshot.id] = snapshot + self.print_run() + + @override + async def on_message_done(self, message: Message) -> None: + self.messages[message.id] = message + self.print_run() + + @override + async def on_run_step_delta(self, delta: RunStepDelta, snapshot: RunStep) -> None: + self.steps[snapshot.id] = snapshot + self.print_run() + + @override + async def on_run_step_done(self, run_step: RunStep) -> None: + self.steps[run_step.id] = run_step + self.print_run() + + @override + async def on_exception(self, exc): + self.live.stop() + + @override + async def on_end(self): + self.live.stop() diff --git a/src/marvin/beta/assistants/runs.py b/src/marvin/beta/assistants/runs.py index cbf946ca1..ee4ff3bf9 100644 --- a/src/marvin/beta/assistants/runs.py +++ b/src/marvin/beta/assistants/runs.py @@ -1,12 +1,11 @@ -import asyncio +import inspect from typing import Any, Callable, Optional, Union -from openai.types.beta.threads.required_action_function_tool_call import ( - RequiredActionFunctionToolCall, -) +from openai import AssistantEventHandler, AsyncAssistantEventHandler +from openai.types.beta.threads import Message from openai.types.beta.threads.run import Run as OpenAIRun from openai.types.beta.threads.runs import RunStep as OpenAIRunStep -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, PrivateAttr, field_validator import marvin.utilities.openai import marvin.utilities.tools @@ -39,9 +38,19 @@ class Run(BaseModel, ExposeSyncMethodsMixin): data (Any): Any additional data associated with the run. """ - id: Optional[str] = None + model_config: dict = dict(extra="forbid") + thread: Thread assistant: Assistant + event_handler_class: type[ + Union[AssistantEventHandler, AsyncAssistantEventHandler] + ] = Field(default=None) + event_handler_kwargs: dict[str, Any] = Field(default={}) + _messages: list[Message] = PrivateAttr({}) + _steps: list[OpenAIRunStep] = PrivateAttr({}) + model: Optional[str] = Field( + None, description="Replace the model used by the assistant." + ) instructions: Optional[str] = Field( None, description="Replacement instructions to use for the run." ) @@ -58,9 +67,14 @@ class Run(BaseModel, ExposeSyncMethodsMixin): None, description="Additional tools to append to the assistant's tools. ", ) - run: OpenAIRun = None + run: OpenAIRun = Field(None, repr=False) data: Any = None + def __init__(self, *, messages: list[Message] = None, **data): + super().__init__(**data) + if messages is not None: + self._messages.update({m.id: m for m in messages}) + @field_validator("tools", "additional_tools", mode="before") def format_tools(cls, tools: Union[None, list[Union[Tool, Callable]]]): if tools is not None: @@ -73,34 +87,86 @@ def format_tools(cls, tools: Union[None, list[Union[Tool, Callable]]]): for tool in tools ] + @property + def messages(self) -> list[Message]: + return sorted(self._messages.values(), key=lambda m: m.created_at) + + @property + def steps(self) -> list[OpenAIRunStep]: + return sorted(self._steps.values(), key=lambda s: s.created_at) + @expose_sync_method("refresh") async def refresh_async(self): """Refreshes the run.""" + if not self.run: + raise ValueError("Run has not been created yet.") client = marvin.utilities.openai.get_openai_client() self.run = await client.beta.threads.runs.retrieve( - run_id=self.run.id if self.run else self.id, thread_id=self.thread.id + run_id=self.run.id, thread_id=self.thread.id ) @expose_sync_method("cancel") async def cancel_async(self): """Cancels the run.""" + if not self.run: + raise ValueError("Run has not been created yet.") client = marvin.utilities.openai.get_openai_client() await client.beta.threads.runs.cancel( - run_id=self.run.id if self.run else self.id, thread_id=self.thread.id + run_id=self.run.id, thread_id=self.thread.id ) + await self.refresh_async() - async def _handle_step_requires_action( - self, - ) -> tuple[list[RequiredActionFunctionToolCall], list[dict[str, str]]]: - client = marvin.utilities.openai.get_openai_client() - if self.run.status != "requires_action": + def _get_instructions(self) -> str: + if self.instructions is None: + instructions = self.assistant.get_instructions() or "" + else: + instructions = self.instructions + + if self.additional_instructions is not None: + instructions = "\n\n".join([instructions, self.additional_instructions]) + + return instructions + + def _get_model(self) -> str: + if self.model is None: + model = self.assistant.model + else: + model = self.model + return model + + def _get_tools(self) -> list[AssistantTool]: + tools = [] + if self.tools is None: + tools.extend(self.assistant.get_tools()) + else: + tools.extend(self.tools) + if self.additional_tools is not None: + tools.extend(self.additional_tools) + return tools + + def _get_run_kwargs(self, **run_kwargs) -> dict: + if "instructions" not in run_kwargs and ( + self.instructions is not None or self.additional_instructions is not None + ): + run_kwargs["instructions"] = self._get_instructions() + + if "tools" not in run_kwargs and ( + self.tools is not None or self.additional_tools is not None + ): + run_kwargs["tools"] = self._get_tools() + if "model" not in run_kwargs and self.model is not None: + run_kwargs["model"] = self._get_model() + return run_kwargs + + async def get_tool_outputs(self, run: OpenAIRun) -> list[dict[str, str]]: + if run.status != "requires_action": return None, None - if self.run.required_action.type == "submit_tool_outputs": + if run.required_action.type == "submit_tool_outputs": tool_calls = [] tool_outputs = [] - tools = self.get_tools() + tools = self._get_tools() - for tool_call in self.run.required_action.submit_tool_outputs.tool_calls: + for tool_call in run.required_action.submit_tool_outputs.tool_calls: try: output = marvin.utilities.tools.call_function_tool( tools=tools, @@ -119,162 +185,83 @@ async def _handle_step_requires_action( ) tool_calls.append(tool_call) - await client.beta.threads.runs.submit_tool_outputs( - thread_id=self.thread.id, run_id=self.run.id, tool_outputs=tool_outputs - ) - return tool_calls, tool_outputs - - def get_instructions(self) -> str: - if self.instructions is None: - instructions = self.assistant.get_instructions() or "" - else: - instructions = self.instructions - - if self.additional_instructions is not None: - instructions = "\n\n".join([instructions, self.additional_instructions]) - - return instructions - - def get_tools(self) -> list[AssistantTool]: - tools = [] - if self.tools is None: - tools.extend(self.assistant.get_tools()) - else: - tools.extend(self.tools) - if self.additional_tools is not None: - tools.extend(self.additional_tools) - return tools + return tool_outputs async def run_async(self) -> "Run": - """Excutes a run asynchronously.""" - client = marvin.utilities.openai.get_openai_client() - - create_kwargs = {} - - if self.instructions is not None or self.additional_instructions is not None: - create_kwargs["instructions"] = self.get_instructions() - - if self.tools is not None or self.additional_tools is not None: - create_kwargs["tools"] = self.get_tools() - - if self.id is not None: + if self.run is not None: raise ValueError( "This run object was provided an ID; can not create a new run." ) - async with self.assistant: - self.run = await client.beta.threads.runs.create( - thread_id=self.thread.id, - assistant_id=self.assistant.id, - **create_kwargs, - ) - - self.assistant.pre_run_hook(run=self) + client = marvin.utilities.openai.get_openai_client() + run_kwargs = self._get_run_kwargs() + event_handler_class = self.event_handler_class or AsyncAssistantEventHandler - tool_calls = None - tool_outputs = None + with self.assistant: + handler = event_handler_class(**self.event_handler_kwargs) try: - while self.run.status in ("queued", "in_progress", "requires_action"): - if self.run.status == "requires_action": - ( - tool_calls, - tool_outputs, - ) = await self._handle_step_requires_action() - await asyncio.sleep(0.1) - await self.refresh_async() + self.assistant.pre_run_hook() + + for msg in self.messages: + await handler.on_message_done(msg) + + async with client.beta.threads.runs.create_and_stream( + thread_id=self.thread.id, + assistant_id=self.assistant.id, + event_handler=handler, + **run_kwargs, + ) as stream: + await stream.until_done() + await self._update_run_from_handler(handler) + + while handler.current_run.status in ["requires_action"]: + tool_outputs = await self.get_tool_outputs(run=handler.current_run) + + handler = event_handler_class(**self.event_handler_kwargs) + + async with client.beta.threads.runs.submit_tool_outputs_stream( + thread_id=self.thread.id, + run_id=self.run.id, + tool_outputs=tool_outputs, + event_handler=handler, + ) as stream: + await stream.until_done() + await self._update_run_from_handler(handler) + except CancelRun as exc: logger.debug(f"`CancelRun` raised; ending run with data: {exc.data}") - await client.beta.threads.runs.cancel( - run_id=self.run.id, thread_id=self.thread.id - ) + await self.cancel_async() self.data = exc.data - await self.refresh_async() - - if self.run.status == "failed": - logger.debug(f"Run failed. Last error was: {self.run.last_error}") - self.assistant.post_run_hook( - run=self, tool_calls=tool_calls, tool_outputs=tool_outputs - ) - return self + except Exception as exc: + await handler.on_exception(exc) + raise + if self.run.status == "failed": + logger.debug( + f"Run failed. Last error was: {handler.current_run.last_error}" + ) -class RunMonitor(BaseModel): - run: Run - thread: Thread - steps: list[OpenAIRunStep] = [] - - async def refresh_run_steps_async(self): - """ - Asynchronously refreshes and updates the run steps list. - - This function fetches the latest run steps up to a specified limit and - checks if the latest run step in the current run steps list - (`self.steps`) is included in the new batch. If the latest run step is - missing, it continues to fetch additional run steps in batches, up to a - maximum count, using pagination. The function then updates - `self.steps` with these new run steps, ensuring any existing run steps - are updated with their latest versions and new run steps are appended in - their original order. - """ - # fetch up to 100 run steps - max_fetched = 100 - limit = 50 - max_attempts = max_fetched / limit + 2 - - # Fetch the latest run steps - client = marvin.utilities.openai.get_openai_client() - - response = await client.beta.threads.runs.steps.list( - run_id=self.run.id, - thread_id=self.thread.id, - limit=limit, - ) - run_steps = list(reversed(response.data)) - - if not run_steps: - return + self.assistant.post_run_hook(run=self) - # Check if the latest run step in self.steps is in the new run steps - latest_step_id = self.steps[-1].id if self.steps else None - missing_latest = ( - latest_step_id not in {rs.id for rs in run_steps} - if latest_step_id - else True - ) + return self - # If the latest run step is missing, fetch additional run steps - total_fetched = len(run_steps) - attempts = 0 - while ( - run_steps - and missing_latest - and total_fetched < max_fetched - and attempts < max_attempts - ): - attempts += 1 - response = await client.beta.threads.runs.steps.list( - run_id=self.run.id, - thread_id=self.thread.id, - limit=limit, - # because this is a raw API call, "after" refers to pagination - # in descnding chronological order - after=run_steps[0].id, - ) - paginated_steps = list(reversed(response.data)) - - total_fetched += len(paginated_steps) - # prepend run steps - run_steps = paginated_steps + run_steps - if any(rs.id == latest_step_id for rs in paginated_steps): - missing_latest = False - - # Update self.steps with the latest data - new_steps_dict = {rs.id: rs for rs in run_steps} - for i in range(len(self.steps) - 1, -1, -1): - if self.steps[i].id in new_steps_dict: - self.steps[i] = new_steps_dict.pop(self.steps[i].id) - else: - break - # Append remaining new run steps at the end in their original order - self.steps.extend(new_steps_dict.values()) + async def _update_run_from_handler( + self, handler: Union[AsyncAssistantEventHandler, AssistantEventHandler] + ): + self.run = handler.current_run + try: + messages = handler.get_final_messages() + if inspect.iscoroutine(messages): + messages = await messages + self._messages.update({m.id: m for m in messages}) + except RuntimeError: + pass + + try: + steps = handler.get_final_run_steps() + if inspect.iscoroutine(steps): + steps = await steps + self._steps.update({s.id: s for s in steps}) + except RuntimeError: + pass diff --git a/src/marvin/beta/assistants/threads.py b/src/marvin/beta/assistants/threads.py index bb508fcca..8da4de31a 100644 --- a/src/marvin/beta/assistants/threads.py +++ b/src/marvin/beta/assistants/threads.py @@ -1,17 +1,10 @@ -import asyncio import time -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Optional -# for openai < 1.14.0 -try: - from openai.types.beta.threads import ThreadMessage as Message -# for openai >= 1.14.0 -except ImportError: - from openai.types.beta.threads import Message -from pydantic import BaseModel, Field, PrivateAttr +from openai.types.beta.threads import Message +from pydantic import BaseModel, Field import marvin.utilities.openai -from marvin.beta.assistants.formatting import pprint_message from marvin.utilities.asyncio import ( ExposeSyncMethodsMixin, expose_sync_method, @@ -124,9 +117,10 @@ async def get_messages_async( before=after_message, after=before_message, limit=limit, + # order desc to get the most recent messages first order="desc", ) - return response.data + return list(reversed(response.data)) @expose_sync_method("delete") async def delete_async(self): @@ -152,7 +146,11 @@ async def run_async( from marvin.beta.assistants.runs import Run - run = Run(assistant=assistant, thread=self, **run_kwargs) + run = Run( + assistant=assistant, + thread=self, + **run_kwargs, + ) return await run.run_async() def chat(self, assistant: "Assistant"): @@ -175,87 +173,3 @@ def callback(thread_id: str, message: str): time.sleep(0.2) except KeyboardInterrupt: break - - -class ThreadMonitor(BaseModel, ExposeSyncMethodsMixin): - """ - The ThreadMonitor class represents a monitor for a specific thread. - - Attributes: - thread_id (str): The unique identifier of the thread being monitored. - last_message_id (Optional[str]): The ID of the last message received in the thread. - on_new_message (Callable): A callback function that is called when a new message - is received in the thread. - """ - - thread_id: str - last_message_id: Optional[str] = None - on_new_message: Callable = Field(default=pprint_message) - _thread: Thread = PrivateAttr() - - @property - def thread(self): - return self._thread - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self._thread = Thread(id=kwargs["thread_id"]) - - @expose_sync_method("run_once") - async def run_once_async(self): - messages = await self.get_latest_messages() - for msg in messages: - if self.on_new_message: - self.on_new_message(msg) - - @expose_sync_method("run") - async def run_async(self, interval_seconds: int = None): - """ - Run the thread monitor in a loop, checking for new messages every `interval_seconds`. - - Args: - interval_seconds (int, optional): The number of seconds to wait between - checking for new messages. Default is 1. - """ - if interval_seconds is None: - interval_seconds = 1 - if interval_seconds < 1: - raise ValueError("Interval must be at least 1 second.") - - while True: - try: - await self.run_once_async() - except KeyboardInterrupt: - logger.debug("Keyboard interrupt received; exiting thread monitor.") - break - except Exception as exc: - logger.error(f"Error refreshing thread: {exc}") - await asyncio.sleep(interval_seconds) - - async def get_latest_messages(self) -> list[Message]: - limit = 20 - - # Loop to get all new messages in batches of 20 - while True: - messages = await self.thread.get_messages_async( - after_message=self.last_message_id, limit=limit - ) - - # often the API will retrieve messages that have been created but - # not populated with text. We filter out these empty messages. - filtered_messages = [] - for i, msg in enumerate(messages): - skip_message = False - for c in msg.content: - if getattr(getattr(c, "text", None), "value", None) == "": - skip_message = True - if not skip_message: - filtered_messages.append(msg) - - if filtered_messages: - self.last_message_id = filtered_messages[-1].id - - if len(messages) < limit: - break - - return filtered_messages