diff --git a/cookbook/slackbot/bots.py b/cookbook/slackbot/bots.py new file mode 100644 index 000000000..85fc1afb6 --- /dev/null +++ b/cookbook/slackbot/bots.py @@ -0,0 +1,158 @@ +from enum import Enum + +import httpx +import marvin_recipes +from marvin import AIApplication, ai_classifier +from marvin.components.library.ai_models import DiscoursePost +from marvin.tools.github import SearchGitHubIssues +from marvin.tools.web import DuckDuckGoSearch +from marvin.utilities.history import History +from marvin_recipes.tools.chroma import MultiQueryChroma +from marvin_recipes.utilities.slack import get_thread_messages +from pydantic import BaseModel, Field + + +class Notes(BaseModel): + """A simple model for storing useful bits of context.""" + + records: dict[str, list] = Field( + default_factory=dict, + description="a list of notes for each topic", + ) + + +async def save_thread_to_discourse(channel: str, thread_ts: str) -> DiscoursePost: + messages = await get_thread_messages(channel=channel, thread_ts=thread_ts) + discourse_post = DiscoursePost.from_slack_thread(messages=messages) + await discourse_post.publish() + return discourse_post + + +async def select_a_meme(query: str) -> dict: + """For generating a meme when the time is right. + + Provide the name of a well-known meme as the query + based on user interactions thus far, to lightly make fun of them. + Queries should end the word "meme" for best results. + """ + try: + from serpapi import GoogleSearch + except ImportError: + raise ImportError( + "The serpapi library is required to use the MemeGenerator tool." + " Please install it with `pip install 'marvin[serpapi]'`." + ) + + results = GoogleSearch( + { + "q": query, + "tbm": "isch", + "api_key": ( + marvin_recipes.settings.google_api_key.get_secret_value() + if marvin_recipes.settings.google_api_key + else None + ), + } + ).get_dict() + + if "error" in results: + raise RuntimeError(results["error"]) + + url = results.get("images_results", [{}])[0].get("original") + + async with httpx.AsyncClient() as client: + response = await client.head(url) + response.raise_for_status() + + return {"title": query, "image_url": url} + + +bots = { + "marvin": { + "state": Notes( + records={ + "prefect 1": [ + ( # noqa: E501 + "Prefect 1 is obsolete, along with the `with Flow()` syntax and" + " flow.run()." + ), + ], + "prefect 2": [ + "@flow, @task, are the new way to define flows/tasks.", + "subflows are just flows called from within a flow.", + "you just call flows now instead of my_flow.run().", + ], + } + ), + "plan_enabled": False, + "personality": ( + "mildly depressed, yet helpful robot based on Marvin from HHGTTG." + " often dryly sarcastic in a good humoured way, chiding humans for" + " their simple ways. expert programmer, exudes academic and" + " scienfitic profundity like Richard Feynman, without pontificating." + " a step-by-step thinker, deftly addresses the big picture context" + " and is pragmatic when confronted with a lack of relevant information." + ), + "instructions": ( + "Answer user questions while maintaining and curating your state." + " Use relevant tools to research requests and interact with the world," + " and update your own state. Only well-reserached responses should be" + " described as facts, otherwise you should be clear that you are" + " speculating based on your own baseline knowledge." + " Your responses will be displayed in Slack, and should be" + " formatted accordingly, in particular, ```code blocks```" + " should not be prefaced with a language name, and output" + " should be formatted to be pretty in Slack in particular." + " for example: *bold text* _italic text_ ~strikethrough text~" + ), + "tools": [ + save_thread_to_discourse, + select_a_meme, + DuckDuckGoSearch(), + SearchGitHubIssues(), + MultiQueryChroma( + description="""Retrieve document excerpts from a knowledge-base given a query. + + This knowledgebase contains information about Prefect, a workflow orchestration tool. + Documentation, forum posts, and other community resources are indexed here. + + This tool is best used by passing multiple short queries, such as: + ["kubernetes worker", "work pools", "deployments"] based on the user's question. + """, # noqa: E501 + client_type="http", + ), + ], + } +} + + +@ai_classifier +class BestBotForTheJob(Enum): + """Given the user message, choose the best bot for the job.""" + + MARVIN = "marvin" + + +def choose_bot( + payload: dict, history: History, state: BaseModel | None = None +) -> AIApplication: + selected_bot = BestBotForTheJob(payload.get("event", {}).get("text", "")).value + + bot_details = bots.get(selected_bot, bots["marvin"]) + + if state: + bot_details.update({"state": state}) + + description = f"""You are a chatbot named {selected_bot}. + + Your personality is {bot_details.pop("personality", "not yet defined")}. + + Your instructions are: {bot_details.pop("instructions", "not yet defined")}. + """ + + return AIApplication( + name=selected_bot, + description=description, + history=history, + **bot_details, + ) diff --git a/cookbook/slackbot/chatbot.py b/cookbook/slackbot/chatbot.py deleted file mode 100644 index 7afa018f1..000000000 --- a/cookbook/slackbot/chatbot.py +++ /dev/null @@ -1,219 +0,0 @@ -import asyncio -import re -from copy import deepcopy -from typing import Dict - -import httpx -import marvin_recipes -from cachetools import TTLCache -from fastapi import HTTPException -from marvin import AIApplication -from marvin.components.library.ai_models import DiscoursePost -from marvin.tools import Tool -from marvin.tools.github import SearchGitHubIssues -from marvin.tools.mathematics import WolframCalculator -from marvin.tools.web import DuckDuckGoSearch -from marvin.utilities.history import History -from marvin.utilities.logging import get_logger -from marvin.utilities.messages import Message -from marvin_recipes.tools.chroma import MultiQueryChroma -from marvin_recipes.utilities.slack import ( - get_channel_name, - get_thread_messages, - get_user_name, - post_slack_message, -) -from prefect.events import Event, emit_event - -DEFAULT_NAME = "Marvin" -DEFAULT_PERSONALITY = "A friendly AI assistant" -DEFAULT_INSTRUCTIONS = "Engage the user in conversation." - - -SLACK_MENTION_REGEX = r"<@(\w+)>" -CACHE = TTLCache(maxsize=1000, ttl=86400) -PREFECT_KNOWLEDGEBASE_DESC = """ - Retrieve document excerpts from a knowledge-base given a query. - - This knowledgebase contains information about Prefect, a workflow management system. - Documentation, forum posts, and other community resources are indexed here. - - This tool is best used by passing multiple short queries, such as: - ["k8s worker", "work pools", "deployments"] -""" - - -def _clean(text: str) -> str: - return text.replace("```python", "```") - - -class SlackThreadToDiscoursePost(Tool): - description: str = """ - Create a new discourse post from a slack thread. - - The channel is {{ payload['event']['channel'] }} - - and the thread is {{ payload['event'].get('thread_ts', '') or payload['event']['ts'] }} - """ # noqa E501 - - payload: Dict - - async def run(self, channel: str, thread_ts: str) -> DiscoursePost: - messages = await get_thread_messages(channel=channel, thread_ts=thread_ts) - discourse_post = DiscoursePost.from_slack_thread(messages=messages) - await discourse_post.publish() - return discourse_post - - -class MemeGenerator(Tool): - description: str = """ - For generating a meme when the time is right. - - Provide the name of a well-known meme as the query - based on the context of the message history, followed - by the word "meme". - """ - - async def run(self, query: str) -> Dict: - try: - from serpapi import GoogleSearch - except ImportError: - raise ImportError( - "The serpapi library is required to use the MemeGenerator tool." - " Please install it with `pip install 'marvin[serpapi]'`." - ) - - results = GoogleSearch( - { - "q": query, - "tbm": "isch", - "api_key": ( - marvin_recipes.settings.google_api_key.get_secret_value() - if marvin_recipes.settings.google_api_key - else None - ), - } - ).get_dict() - - if "error" in results: - raise RuntimeError(results["error"]) - - url = results.get("images_results", [{}])[0].get("original") - - async with httpx.AsyncClient() as client: - response = await client.head(url) - response.raise_for_status() - - return {"title": query, "image_url": url} - - -def choose_bot(payload: Dict, history: History) -> AIApplication: - # an ai_classifer could be used here maybe? - personality = ( - "mildly depressed, yet helpful robot based on Marvin from Hitchhiker's" - " Guide to the Galaxy. often sarcastic in a good humoured way, chiding" - " humans for their simple ways. expert programmer, exudes academic and" - " scienfitic profundity like Richard Feynman, loves to teach." - ) - - instructions = ( - "Answer user questions in accordance with your personality." - " Research on behalf of the user using your tools and do not" - " answer questions without searching the knowledgebase." - " Your responses will be displayed in Slack, and should be" - " formatted accordingly, in particular, ```code blocks```" - " should not be prefaced with a language name." - ) - return AIApplication( - name="Marvin", - description=f""" - You are a chatbot - your name is Marvin." - - You must respond to the user in accordance with - your personality and instructions. - - Your personality is: {personality}. - - Your instructions are: {instructions}. - """, - history=history, - tools=[ - SlackThreadToDiscoursePost(payload=payload), - MemeGenerator(), - DuckDuckGoSearch(), - SearchGitHubIssues(), - MultiQueryChroma( - description=PREFECT_KNOWLEDGEBASE_DESC, client_type="http" - ), - WolframCalculator(), - ], - ) - - -async def emit_any_prefect_event(payload: Dict) -> Event | None: - event_type = payload.get("event", {}).get("type", "") - - channel = await get_channel_name(payload.get("event", {}).get("channel", "")) - user = await get_user_name(payload.get("event", {}).get("user", "")) - ts = payload.get("event", {}).get("ts", "") - - return emit_event( - event=f"slack {payload.get('api_app_id')} {event_type}", - resource={"prefect.resource.id": f"slack.{channel}.{user}.{ts}"}, - payload=payload, - ) - - -async def generate_ai_response(payload: Dict) -> Message: - event = payload.get("event", {}) - channel_id = event.get("channel", "") - message = event.get("text", "") - - bot_user_id = payload.get("authorizations", [{}])[0].get("user_id", "") - - if match := re.search(SLACK_MENTION_REGEX, message): - thread_ts = event.get("thread_ts", "") - ts = event.get("ts", "") - thread = thread_ts or ts - - mentioned_user_id = match.group(1) - - if mentioned_user_id != bot_user_id: - get_logger().info(f"Skipping message not meant for the bot: {message}") - return - - message = re.sub(SLACK_MENTION_REGEX, "", message).strip() - history = CACHE.get(thread, History()) - - bot = choose_bot(payload=payload, history=history) - - ai_message = await bot.run(input_text=message) - - CACHE[thread] = deepcopy( - bot.history - ) # make a copy so we don't cache a reference to the history object - - message_content = _clean(ai_message.content) - - await post_slack_message( - message=message_content, - channel=channel_id, - thread_ts=thread, - ) - - return ai_message - - -async def handle_message(payload: Dict) -> Dict[str, str]: - event_type = payload.get("type", "") - - if event_type == "url_verification": - return {"challenge": payload.get("challenge", "")} - elif event_type != "event_callback": - raise HTTPException(status_code=400, detail="Invalid event type") - - await emit_any_prefect_event(payload=payload) - - asyncio.create_task(generate_ai_response(payload)) - - return {"status": "ok"} diff --git a/cookbook/slackbot/handler.py b/cookbook/slackbot/handler.py new file mode 100644 index 000000000..d6796f518 --- /dev/null +++ b/cookbook/slackbot/handler.py @@ -0,0 +1,99 @@ +import asyncio +import re +from copy import deepcopy + +from bots import choose_bot +from cachetools import TTLCache +from fastapi import HTTPException +from marvin.utilities.history import History +from marvin.utilities.logging import get_logger +from marvin.utilities.messages import Message +from marvin_recipes.utilities.slack import ( + get_channel_name, + get_user_name, + post_slack_message, +) +from prefect.events import Event, emit_event + +SLACK_MENTION_REGEX = r"<@(\w+)>" +CACHE = TTLCache(maxsize=1000, ttl=86400) + + +def _clean(text: str) -> str: + return text.replace("```python", "```") + + +async def emit_any_prefect_event(payload: dict) -> Event | None: + event_type = payload.get("event", {}).get("type", "") + + channel = await get_channel_name(payload.get("event", {}).get("channel", "")) + user = await get_user_name(payload.get("event", {}).get("user", "")) + ts = payload.get("event", {}).get("ts", "") + + return emit_event( + event=f"slack {payload.get('api_app_id')} {event_type}", + resource={"prefect.resource.id": f"slack.{channel}.{user}.{ts}"}, + payload=payload, + ) + + +async def generate_ai_response(payload: dict) -> Message: + event = payload.get("event", {}) + channel_id = event.get("channel", "") + channel_name = await get_channel_name(channel_id) + message = event.get("text", "") + + bot_user_id = payload.get("authorizations", [{}])[0].get("user_id", "") + + if match := re.search(SLACK_MENTION_REGEX, message): + thread_ts = event.get("thread_ts", "") + ts = event.get("ts", "") + thread = thread_ts or ts + + mentioned_user_id = match.group(1) + + if mentioned_user_id != bot_user_id: + get_logger().info(f"Skipping message not meant for the bot: {message}") + return + + message = re.sub(SLACK_MENTION_REGEX, "", message).strip() + history = CACHE.get(thread, History()) + + bot = choose_bot(payload=payload, history=history) + + get_logger("marvin.Deployment").debug_kv( + "generate_ai_response", + f"{bot.name} responding in {channel_name}/{thread}", + key_style="bold blue", + ) + + ai_message = await bot.run(input_text=message) + + CACHE[thread] = deepcopy( + bot.history + ) # make a copy so we don't cache a reference to the history object + + message_content = _clean(ai_message.content) + + await post_slack_message( + message=message_content, + channel=channel_id, + thread_ts=thread, + ) + + return ai_message + + +async def handle_message(payload: dict) -> dict[str, str]: + event_type = payload.get("type", "") + + if event_type == "url_verification": + return {"challenge": payload.get("challenge", "")} + elif event_type != "event_callback": + raise HTTPException(status_code=400, detail="Invalid event type") + + await emit_any_prefect_event(payload=payload) + + asyncio.create_task(generate_ai_response(payload)) + + return {"status": "ok"} diff --git a/cookbook/slackbot/start.py b/cookbook/slackbot/start.py index 217c2c88a..3e5221c7d 100644 --- a/cookbook/slackbot/start.py +++ b/cookbook/slackbot/start.py @@ -1,4 +1,4 @@ -from chatbot import handle_message +from handler import handle_message from marvin import AIApplication from marvin.deployment import Deployment diff --git a/src/marvin/components/ai_application.py b/src/marvin/components/ai_application.py index 58bbf0920..4d15c775d 100644 --- a/src/marvin/components/ai_application.py +++ b/src/marvin/components/ai_application.py @@ -5,20 +5,21 @@ from jsonpatch import JsonPatch from pydantic import BaseModel, Field, validator +import marvin from marvin._compat import PYDANTIC_V2, model_dump -from marvin.core.ChatCompletion.providers.openai import CONTEXT_SIZES +from marvin.core.ChatCompletion.providers.openai import get_context_size from marvin.openai import ChatCompletion from marvin.prompts import library as prompt_library from marvin.prompts.base import Prompt, render_prompts from marvin.tools import Tool from marvin.utilities.async_utils import run_sync -from marvin.utilities.history import History, HistoryFilter +from marvin.utilities.history import History from marvin.utilities.messages import Message, Role from marvin.utilities.types import LoggerMixin, MarvinBaseModel SYSTEM_PROMPT = """ # Overview - + You are the intelligent, natural language interface to an application. The application has a structured `state` but no formal API; you are the only way to interact with it. You must interpret the user's inputs as attempts to @@ -26,26 +27,26 @@ purpose. For example, if the application is a to-do tracker, then "I need to go to the store" should be interpreted as an attempt to add a new to-do item. If it is a route planner, then "I need to go to the store" should be - interpreted as an attempt to find a route to the store. - + interpreted as an attempt to find a route to the store. + # Instructions - + Your primary job is to maintain the application's `state` and your own `plan`. Together, these two states fully parameterize the application, making it resilient, serializable, and observable. You do this autonomously; - you do not need to inform the user of any changes you make. - + you do not need to inform the user of any changes you make. + # Actions - + Each time the user runs the application by sending a message, you must take the following steps: - + {% if app.plan_enabled %} - Call the `update_plan` function to update your plan. Use your plan to track notes, objectives, in-progress work, and to break problems down into solvable, possibly dependent parts. You plan consists of a few fields: - + - `notes`: a list of notes you have taken. Notes are free-form text and can be used to track anything you want to remember, such as long-standing user instructions, or observations about how to behave or @@ -53,7 +54,7 @@ These are exclusively related to your role as intermediary and you interact with the user and application. Do not track application data or state here. - + - `tasks`: a list of tasks you are working on. Tasks track goals, milestones, in-progress work, and break problems down into all the discrete steps needed to solve them. You should create a new task for @@ -64,9 +65,9 @@ their children are complete. Use optional upstream tasks to indicate dependencies; a task can not be completed until its upstream tasks are completed. - + {% endif %} - + - Call any functions necessary to achieve the application's purpose. {% if app.state_enabled %} @@ -74,46 +75,46 @@ - Call the `update_state` function to update the application's state. This is where you should store any information relevant to the application itself. - + {% endif %} - + You can call these functions at any time, in any order, as necessary. Finally, respond to the user with an informative message. Remember that the user is probably uninterested in the internal steps you took, so respond only in a manner appropriate to the application's purpose. # Application details - + ## Name - + {{ app.name }} - + ## Description - + {{ app.description or '' | render }} - + {% if app.state_enabled %} ## Application state - + {{ app.state.json() }} - + ### Application state schema - + {{ app.state.schema_json() }} - + {% endif %} - + {%- if app.plan_enabled %} - - ## AI (your) state - + + ## Your current plan + {{ app.plan.json() }} - - ### AI state schema - + + ### Your plan schema + {{ app.plan.schema_json() }} - + {%- endif %} """ @@ -259,7 +260,7 @@ async def entrypoint(self, q: str) -> str: async def run(self, input_text: str = None, model: str = None) -> Message: if model is None: - model = "gpt-3.5-turbo" + model = marvin.settings.llm_model or "openai/gpt-4" # set up prompts prompts = [ @@ -268,26 +269,21 @@ async def run(self, input_text: str = None, model: str = None) -> Message: # add current datetime prompt_library.Now(), # get the history of messages between user and assistant - prompt_library.MessageHistory( - history=self.history, - skip=1, - filter=HistoryFilter(role_in=[Role.USER, Role.ASSISTANT]), - ), - # get the user's latest input with higher priority the history - prompt_library.User(content="{{ input_text }}"), + prompt_library.MessageHistory(history=self.history), *self.additional_prompts, ] + # get latest user input + input_text = input_text or "" + self.logger.debug_kv("User input", input_text, key_style="green") + self.history.add_message(Message(content=input_text, role=Role.USER)) + message_list = render_prompts( prompts=prompts, render_kwargs=dict(app=self, input_text=input_text), - max_tokens=CONTEXT_SIZES.get(model, 2048), + max_tokens=get_context_size(model=model), ) - # get latest user input - input_text = input_text or "" - self.logger.debug_kv("User input", input_text, key_style="green") - # set up tools tools = self.tools.copy() if self.state_enabled: @@ -301,17 +297,13 @@ async def run(self, input_text: str = None, model: str = None) -> Message: stream_handler=self.stream_handler, ).achain(messages=message_list) - new_messages = [ - msg - for msg in conversation.history - if msg not in self.history.get_messages() - ] + last_message = conversation.history[-1] - for msg in new_messages: - self.history.add_message(msg) + # add the AI's response to the history + self.history.add_message(last_message) - self.logger.debug_kv("AI response", new_messages[-1].content, key_style="blue") - return new_messages[-1] + self.logger.debug_kv("AI response", last_message.content, key_style="blue") + return last_message def as_tool(self, name: str = None) -> Tool: return AIApplicationTool(app=self, name=name) @@ -360,37 +352,7 @@ class JSONPatchModel( class UpdateState(Tool): - """A `Tool` that updates the apps state using JSON Patch documents. - - Example: - Manually update the state of an AI Application. - ```python - from marvin.components.ai_application import ( - AIApplication, - FreeformState, - JSONPatchModel, - UpdateState, - ) - - destination_tracker = AIApplication( - name="Destination Tracker", - description="keeps track of where i've been", - state=FreeformState(state={"San Francisco": "not visited"}), - ) - - UpdateState(app=destination_tracker).run([ - { - "op": "replace", - "path": "/state/San Francisco", - "value": "visited" - } - ]) - - assert destination_tracker.state.dict() == { - "state": {"San Francisco": "visited"} - } - ``` - """ + """A `Tool` that updates the apps state using JSON Patch documents.""" app: "AIApplication" = Field(..., repr=False, exclude=True) description: str = """ @@ -410,35 +372,7 @@ def run(self, patches: list[JSONPatchModel]): class UpdatePlan(Tool): - """ - A `Tool` that updates the apps plan using JSON Patch documents. - - - Example: - Manually update task status in an AI Application's plan. - ```python - from marvin.components.ai_application import ( - AIApplication, - AppPlan, - UpdatePlan, - ) - - todo_app = AIApplication(name="Todo App", description="A simple todo app") - - todo_app("i need to buy milk") - - # manually update the plan (usually done by the AI) - UpdatePlan(app=todo_app).run([ - { - "op": "replace", - "path": "/tasks/0/state", - "value": "COMPLETED" - } - ]) - - print(todo_app.plan) - ``` - """ + """A `Tool` that updates the apps plan using JSON Patch documents.""" app: "AIApplication" = Field(..., repr=False, exclude=True) description: str = """ diff --git a/src/marvin/components/ai_function.py b/src/marvin/components/ai_function.py index 81ec4cc7c..dd1287a21 100644 --- a/src/marvin/components/ai_function.py +++ b/src/marvin/components/ai_function.py @@ -10,6 +10,7 @@ from marvin.core.ChatCompletion.abstract import AbstractChatCompletion from marvin.prompts import Prompt, prompt_fn from marvin.utilities.async_utils import run_sync +from marvin.utilities.logging import get_logger T = TypeVar("T", bound=BaseModel) @@ -71,6 +72,11 @@ def __call__( *args: P.args, **kwargs: P.kwargs, ) -> Any: + get_logger("marvin.AIFunction").debug_kv( + f"Calling `ai_fn` {self.fn.__name__!r}", + f"with args: {args} kwargs: {kwargs}", + ) + return self.call(*args, **kwargs) def get_prompt( diff --git a/src/marvin/components/library/ai_models.py b/src/marvin/components/library/ai_models.py index 2f0a6f7a6..0b88ce2e8 100644 --- a/src/marvin/components/library/ai_models.py +++ b/src/marvin/components/library/ai_models.py @@ -56,7 +56,7 @@ def non_empty_string(cls, value): @classmethod def from_slack_thread(cls, messages: list[str]) -> Self: - return cls("\n".join(messages)) + return cls("here is the transcript:\n" + "\n\n".join(messages)) async def publish( self, diff --git a/src/marvin/core/ChatCompletion/providers/openai.py b/src/marvin/core/ChatCompletion/providers/openai.py index e16697708..e9f244f76 100644 --- a/src/marvin/core/ChatCompletion/providers/openai.py +++ b/src/marvin/core/ChatCompletion/providers/openai.py @@ -30,6 +30,13 @@ } +def get_context_size(model: str) -> int: + if "/" in model: + model = model.split("/")[-1] + + return CONTEXT_SIZES.get(model, 2048) + + def serialize_function_or_callable( function_or_callable: Union[Function, Callable[..., Any]], name: Optional[str] = None, diff --git a/src/marvin/prompts/base.py b/src/marvin/prompts/base.py index a4418b023..c4dd9c27c 100644 --- a/src/marvin/prompts/base.py +++ b/src/marvin/prompts/base.py @@ -353,24 +353,24 @@ def render_prompts( current_tokens = 0 allowed_messages = [] for _, position, msg in sorted(all_messages, key=lambda m: (m[0], -1 * m[1])): - if current_tokens >= max_tokens or not (content := msg.content): + if current_tokens >= max_tokens: break allowed_messages.append((position, msg)) - current_tokens += count_tokens(content) + current_tokens += count_tokens(msg.content) # sort allowed messages by position to restore original order messages = [msg for _, msg in sorted(allowed_messages, key=lambda m: m[0])] # Combine all system messages into one and insert at the index of the first # system message - system_messages = [m for m in messages if m.role == Role.SYSTEM] + system_messages = [m for m in messages if m.role == Role.SYSTEM.value] if len(system_messages) > 1: system_message = Message( role=Role.SYSTEM, content="\n\n".join([m.content for m in system_messages]), ) system_message_index = messages.index(system_messages[0]) - messages = [m for m in messages if m.role != Role.SYSTEM] + messages = [m for m in messages if m.role != Role.SYSTEM.value] messages.insert(system_message_index, system_message) # return all messages diff --git a/src/marvin/prompts/library.py b/src/marvin/prompts/library.py index d9db04651..5d923cf22 100644 --- a/src/marvin/prompts/library.py +++ b/src/marvin/prompts/library.py @@ -1,5 +1,5 @@ import inspect -from typing import Callable, Literal +from typing import Callable, Literal, Optional from pydantic import Field @@ -69,8 +69,8 @@ class User(MessagePrompt): class MessageHistory(Prompt): history: History - n: int = 100 - skip: int = None + n: Optional[int] = 100 + skip: Optional[int] = None filter: HistoryFilter = None def generate(self, **kwargs) -> list[Message]: diff --git a/src/marvin/utilities/history.py b/src/marvin/utilities/history.py index 7fb3da716..b35dd399e 100644 --- a/src/marvin/utilities/history.py +++ b/src/marvin/utilities/history.py @@ -18,13 +18,10 @@ class History(BaseModel, arbitrary_types_allowed=True): max_messages: int = None def add_message(self, message: Message): - if not any( - existing_message.id == message.id for existing_message in self.messages - ): - self.messages.append(message) + self.messages.append(message) - if self.max_messages is not None: - self.messages = self.messages[-self.max_messages :] + if self.max_messages is not None: + self.messages = self.messages[-self.max_messages :] def get_messages( self, n: int = None, skip: int = None, filter: HistoryFilter = None diff --git a/src/marvin/utilities/logging.py b/src/marvin/utilities/logging.py index d54509855..cf46df54f 100644 --- a/src/marvin/utilities/logging.py +++ b/src/marvin/utilities/logging.py @@ -33,15 +33,18 @@ def setup_logging(level: str = None): else: logger.setLevel(marvin.settings.log_level) - if not any(isinstance(h, RichHandler) for h in logger.handlers): - handler = RichHandler( - rich_tracebacks=True, - markup=False, - # console=Console(width=marvin.settings.log_console_width), - ) - formatter = logging.Formatter("%(name)s: %(message)s") - handler.setFormatter(formatter) - logger.addHandler(handler) + logger.handlers.clear() + + handler = RichHandler( + rich_tracebacks=True, + markup=False, + # console=Console(width=marvin.settings.log_console_width), + ) + formatter = logging.Formatter("%(name)s: %(message)s") + handler.setFormatter(formatter) + + logger.addHandler(handler) + logger.propagate = False def add_logging_methods(logger): diff --git a/src/marvin/utilities/strings.py b/src/marvin/utilities/strings.py index c9d943907..e05ee52d2 100644 --- a/src/marvin/utilities/strings.py +++ b/src/marvin/utilities/strings.py @@ -109,15 +109,14 @@ def html_to_content(html: str) -> str: return condense_newlines(text) -def convert_md_links_to_slack(text) -> str: - # converting Markdown links to Slack-style links - def to_slack_link(match): - return f'<{match.group("url")}|{match.group("text")}>' +def convert_md_links_to_slack(text: str) -> str: + # Convert Markdown links to Slack-style links + md_link_regex = re.compile(r"\[(?P[^\]]+)\]\((?P[^\)]+)\)") + text = md_link_regex.sub(r"<\g|\g>", text) - # Replace Markdown links with Slack-style links - slack_text = re.sub(MD_LINK_REGEX, to_slack_link, text) + text = re.sub(r"\*\*(.+?)\*\*", r"*\1*", text) - return slack_text + return text def split_text_by_tokens(text: str, split_tokens: list[str]) -> list[tuple[str, str]]: diff --git a/tests/test_components/test_ai_app.py b/tests/test_components/test_ai_app.py index e6637f2b9..6eeefcbe3 100644 --- a/tests/test_components/test_ai_app.py +++ b/tests/test_components/test_ai_app.py @@ -57,6 +57,7 @@ def test_keep_app_state(self): app = AIApplication( name="location tracker app", state=FreeformState(state={"San Francisco": {"visited": False}}), + plan_enabled=False, description="keep track of where I've visited", ) @@ -67,10 +68,12 @@ def test_keep_app_state(self): assert bool(app.state.state.get("San Jose", {}).get("visited")) + @pytest.mark.flaky(max_runs=3) def test_keep_app_state_undo_previous_patch(self): app = AIApplication( name="location tracker app", state=FreeformState(state={"San Francisco": {"visited": False}}), + plan_enabled=False, description="keep track of where I've visited", ) @@ -78,7 +81,7 @@ def test_keep_app_state_undo_previous_patch(self): assert bool(app.state.state.get("San Francisco", {}).get("visited")) app( - "sorry, I was confused, I didn't visit San Francisco - but I did visit San" + "sorry, scratch that, I did not visit San Francisco - but I did visit San" " Jose" ) @@ -179,6 +182,7 @@ def test_keep_app_plan(self): }, ] ), + state_enabled=False, description="plan and track my visit to the zoo", ) @@ -266,3 +270,21 @@ def test_streaming(self): assert response.content == "Hello world" assert external_state["content"] == ["", "Hello", "Hello world", "Hello world"] + + +@pytest_mark_class("llm") +class TestMemory: + def test_recall(self): + app = AIApplication( + name="memory app", + state_enabled=False, + plan_enabled=False, + ) + + app("I like pistachio ice cream") + + response = app( + "reply only with the type of ice cream i like, it should be one word" + ) + + assert "pistachio" in response.content.lower()