From 4978687ec055ccd0c3369b4de43ddcfbbd4b95de Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Sat, 8 Apr 2023 11:47:18 -0400 Subject: [PATCH] Improve streaming response when using plugin --- src/marvin/bot/base.py | 6 +++- src/marvin/cli/tui.py | 66 ++++++++++++++++++++++-------------------- 2 files changed, 40 insertions(+), 32 deletions(-) diff --git a/src/marvin/bot/base.py b/src/marvin/bot/base.py index 698f8f143..f378085de 100644 --- a/src/marvin/bot/base.py +++ b/src/marvin/bot/base.py @@ -390,7 +390,11 @@ async def say( else: raise ValueError(f"Unknown on_error value: {on_error}") else: - raise RuntimeError("Failed to validate response after 3 attempts") + response = ( + "Error: could not validate response after" + f" {MAX_VALIDATION_ATTEMPTS} attempts." + ) + parsed_response = response if validated: parsed_response = self.response_format.parse_response(response) diff --git a/src/marvin/cli/tui.py b/src/marvin/cli/tui.py index f97f6b70f..f18c2387c 100644 --- a/src/marvin/cli/tui.py +++ b/src/marvin/cli/tui.py @@ -1,7 +1,8 @@ import asyncio -import json import logging +import re import warnings +from functools import partial from typing import Optional import dotenv @@ -36,6 +37,8 @@ handlers=[TextualHandler()], ) +USING_PLUGIN_REGEX = re.compile(r'{\s*"action":\s*"run-plugin",\s*"name":\s*"(.*?)"') + @marvin.ai_fn(llm_model_name="gpt-3.5-turbo", llm_model_temperature=1) async def name_thread(history: str, personality: str, current_name: str = None) -> str: @@ -164,7 +167,11 @@ class ResponseHover(Message): class ResponseBody(Markdown): - pass + text: str = "" + + def update(self, markdown: str): + self.text = markdown + super().update(markdown) def on_enter(self): self.post_message(ResponseHover()) @@ -172,6 +179,7 @@ def on_enter(self): class Response(Container): body = None + stream_finished: bool = False def __init__(self, message: marvin.models.threads.Message, **kwargs) -> None: classes = kwargs.setdefault("classes", "") @@ -290,7 +298,6 @@ def clear_responses(self) -> None: for response in responses: response.remove() self.bot_name = getattr(self.app.bot, "name") - print(self.bot_name) empty = self.query_one("Conversation #empty-thread-container") empty.remove_class("hidden") @@ -606,34 +613,21 @@ async def on_button_pressed(self, event: Button.Pressed) -> None: elif event.button.id == "quit": self.app.exit() - async def update_last_bot_response(self, token_buffer: list[str]): + async def stream_bot_response(self, token_buffer: list[str], response: BotResponse): streaming_response = "".join(token_buffer) - responses = self.query("Response") - if responses: - response = responses.last() - if not isinstance(response, BotResponse): - conversation = self.query_one("Conversation", Conversation) - await conversation.add_response( - BotResponse( - marvin.models.threads.Message( - role="bot", - name=self.app.bot.name, - bot_id=self.app.bot.id, - content=streaming_response, - ) - ) - ) - else: - # the bot is going to use a plugin - if match := marvin.bot.base.PLUGIN_REGEX.search(streaming_response): - try: - plugin_name = json.loads(match.group(1))["name"] - response.body.update(f'Using plugin "{plugin_name}"...') - except Exception: - response.body.update("Using plugin...") - else: - response.message.content = streaming_response - response.body.update(streaming_response) + + if not self.app.is_mounted(response): + conversation = self.query_one("Conversation", Conversation) + await conversation.add_response(response) + + # the bot is going to use a plugin + if match := USING_PLUGIN_REGEX.search(streaming_response): + plugin_name = match.group(1) + if not response.body.text == f'Using plugin "{plugin_name}"...': + response.body.update(f'Using plugin "{plugin_name}"...') + else: + response.message.content = streaming_response + response.body.update(streaming_response) # scroll to bottom messages = self.query_one("Conversation #messages", VerticalScroll) @@ -644,9 +638,19 @@ async def get_bot_response(self, event: Input.Submitted) -> str: bot = self.app.bot self.app.bot_responding = True try: + bot_response = BotResponse( + marvin.models.threads.Message( + role="bot", + name=self.app.bot.name, + bot_id=self.app.bot.id, + content="", + ) + ) response = await bot.say( event.value, - on_token_callback=self.update_last_bot_response, + on_token_callback=partial( + self.stream_bot_response, response=bot_response + ), ) self.query_one("Conversation", Conversation)