From 8059183a5ec9fc5c5025803ab654faf55d640176 Mon Sep 17 00:00:00 2001 From: Caren Thomas Date: Wed, 18 Dec 2024 10:18:50 -0800 Subject: [PATCH] refactor error base class and surface error in streaming logic --- letta/agent.py | 26 +++++++------ letta/errors.py | 69 +++++++++++++++++++--------------- letta/server/rest_api/utils.py | 43 ++++++++++----------- 3 files changed, 74 insertions(+), 64 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index 3e144987f4..b61dce7641 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -20,7 +20,7 @@ REQ_HEARTBEAT_MESSAGE, STRUCTURED_OUTPUT_MODELS, ) -from letta.errors import LLMError, SummarizationError +from letta.errors import ContextWindowExceededError from letta.helpers import ToolRulesSolver from letta.interface import AgentInterface from letta.llm_api.helpers import is_context_overflow_error @@ -1170,11 +1170,13 @@ def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True, # If at this point there's nothing to summarize, throw an error if len(candidate_messages_to_summarize) == 0: - raise SummarizationError( - f"Not enough messages to compress for summarization.", - num_candidate_messages=len(candidate_messages_to_summarize), - num_total_messages=len(self.messages), - preserve_N=MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST, + raise ContextWindowExceededError( + "Not enough messages to compress for summarization", + details={ + "num_candidate_messages": len(candidate_messages_to_summarize), + "num_total_messages": len(self.messages), + "preserve_N": MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST, + }, ) # Walk down the message buffer (front-to-back) until we hit the target token count @@ -1208,11 +1210,13 @@ def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True, message_sequence_to_summarize = self._messages[1:cutoff] # do NOT get rid of the system message if len(message_sequence_to_summarize) <= 1: # This prevents a potential infinite loop of summarizing the same message over and over - raise SummarizationError( - f"Not enough messages to compress for summarization after determining cutoff.", - num_candidate_messages=len(message_sequence_to_summarize), - num_total_messages=len(self.messages), - preserve_N=MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST, + raise ContextWindowExceededError( + "Not enough messages to compress for summarization after determining cutoff", + details={ + "num_candidate_messages": len(message_sequence_to_summarize), + "num_total_messages": len(self.messages), + "preserve_N": MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST, + }, ) else: printd(f"Attempting to summarize {len(message_sequence_to_summarize)} messages [1:{cutoff}] of {len(self._messages)}") diff --git a/letta/errors.py b/letta/errors.py index 9b42c4f5b4..1108b53954 100644 --- a/letta/errors.py +++ b/letta/errors.py @@ -1,4 +1,5 @@ import json +from enum import Enum from typing import TYPE_CHECKING, List, Optional, Union # Avoid circular imports @@ -6,9 +7,31 @@ from letta.schemas.message import Message +class ErrorCode(Enum): + """Enum for error codes used by client.""" + + INTERNAL_SERVER_ERROR = "INTERNAL_SERVER_ERROR" + CONTEXT_WINDOW_EXCEEDED = "CONTEXT_WINDOW_EXCEEDED" + RATE_LIMIT_EXCEEDED = "RATE_LIMIT_EXCEEDED" + + class LettaError(Exception): """Base class for all Letta related errors.""" + def __init__(self, message: str, code: Optional[ErrorCode] = None, details: dict = {}): + self.message = message + self.code = code + self.details = details + super().__init__(message) + + def __str__(self) -> str: + if self.code: + return f"{self.code.value}: {self.message}" + return self.message + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(message='{self.message}', code='{self.code}', details={self.details})" + class LettaToolCreateError(LettaError): """Error raised when a tool cannot be created.""" @@ -16,10 +39,7 @@ class LettaToolCreateError(LettaError): default_error_message = "Error creating tool." def __init__(self, message=None): - if message is None: - message = self.default_error_message - self.message = message - super().__init__(self.message) + super().__init__(message=message or self.default_error_message) class LettaConfigurationError(LettaError): @@ -27,23 +47,17 @@ class LettaConfigurationError(LettaError): def __init__(self, message: str, missing_fields: Optional[List[str]] = None): self.missing_fields = missing_fields or [] - super().__init__(message) + super().__init__(message=message, details={"missing_fields": self.missing_fields}) class LettaAgentNotFoundError(LettaError): """Error raised when an agent is not found.""" - - def __init__(self, message: str): - self.message = message - super().__init__(self.message) + pass class LettaUserNotFoundError(LettaError): """Error raised when a user is not found.""" - - def __init__(self, message: str): - self.message = message - super().__init__(self.message) + pass class LLMError(LettaError): @@ -54,38 +68,33 @@ class LLMJSONParsingError(LettaError): """Exception raised for errors in the JSON parsing process.""" def __init__(self, message="Error parsing JSON generated by LLM"): - self.message = message - super().__init__(self.message) + super().__init__(message=message) class LocalLLMError(LettaError): """Generic catch-all error for local LLM problems""" def __init__(self, message="Encountered an error while running local LLM"): - self.message = message - super().__init__(self.message) + super().__init__(message=message) class LocalLLMConnectionError(LettaError): """Error for when local LLM cannot be reached with provided IP/port""" def __init__(self, message="Could not connect to local LLM"): - self.message = message - super().__init__(self.message) + super().__init__(message=message) -class SummarizationError(LettaError): - """Error raised when the summarization process fails.""" - - def __init__(self, message: str, num_candidate_messages: int, num_total_messages: int, preserve_N: int): - self.message = message - self.num_candidate_messages = num_candidate_messages - self.num_total_messages = num_total_messages - self.preserve_N = preserve_N - super().__init__(self.message) +class ContextWindowExceededError(LettaError): + """Error raised when the context window is exceeded but further summarization fails.""" - def __str__(self): - return f"{self.message} (num_candidate_messages={self.num_candidate_messages}, num_total_messages={self.num_total_messages}, preserve_N={self.preserve_N})" + def __init__(self, message: str, details: dict = {}): + error_message = f"{message} ({details})" + super().__init__( + message=error_message, + code=ErrorCode.CONTEXT_WINDOW_EXCEEDED, + details=details, + ) class LettaMessageError(LettaError): diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index da8d472cd7..60e247954a 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -8,6 +8,7 @@ from fastapi import Header from pydantic import BaseModel +from letta.errors import ContextWindowExceededError from letta.schemas.usage import LettaUsageStatistics from letta.server.rest_api.interface import StreamingServerInterface from letta.server.server import SyncServer @@ -61,34 +62,17 @@ async def sse_async_generator( if not isinstance(usage, LettaUsageStatistics): raise ValueError(f"Expected LettaUsageStatistics, got {type(usage)}") yield sse_formatter({"usage": usage.model_dump()}) - except Exception as e: - import traceback - - traceback.print_exc() - warnings.warn(f"SSE stream generator failed: {e}") - - # Log the error, since the exception handler upstack (in FastAPI) won't catch it, because this may be a 200 response - # Print the stack trace - if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""): - import sentry_sdk - sentry_sdk.capture_exception(e) + except ContextWindowExceededError as e: + log_error_to_sentry(e) + yield sse_formatter({"error": f"Stream failed: {e}", "code": str(e.code.value) if e.code else None}) + except Exception as e: + log_error_to_sentry(e) yield sse_formatter({"error": f"Stream failed (internal error occured)"}) except Exception as e: - import traceback - - traceback.print_exc() - warnings.warn(f"SSE stream generator failed: {e}") - - # Log the error, since the exception handler upstack (in FastAPI) won't catch it, because this may be a 200 response - # Print the stack trace - if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""): - import sentry_sdk - - sentry_sdk.capture_exception(e) - + log_error_to_sentry(e) yield sse_formatter({"error": "Stream failed (decoder encountered an error)"}) finally: @@ -113,3 +97,16 @@ def get_user_id(user_id: Optional[str] = Header(None, alias="user_id")) -> Optio def get_current_interface() -> StreamingServerInterface: return StreamingServerInterface + +def log_error_to_sentry(e): + import traceback + + traceback.print_exc() + warnings.warn(f"SSE stream generator failed: {e}") + + # Log the error, since the exception handler upstack (in FastAPI) won't catch it, because this may be a 200 response + # Print the stack trace + if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""): + import sentry_sdk + + sentry_sdk.capture_exception(e)