diff --git a/letta/agent.py b/letta/agent.py index 485f2112b9..b61dce7641 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -20,7 +20,7 @@ REQ_HEARTBEAT_MESSAGE, STRUCTURED_OUTPUT_MODELS, ) -from letta.errors import LLMError +from letta.errors import ContextWindowExceededError from letta.helpers import ToolRulesSolver from letta.interface import AgentInterface from letta.llm_api.helpers import is_context_overflow_error @@ -1094,6 +1094,7 @@ def inner_step( # If we got a context alert, try trimming the messages length, then try again if is_context_overflow_error(e): + printd(f"context window exceeded with limit {self.agent_state.llm_config.context_window}, running summarizer to trim messages") # A separate API call to run a summarizer self.summarize_messages_inplace() @@ -1169,8 +1170,13 @@ def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True, # If at this point there's nothing to summarize, throw an error if len(candidate_messages_to_summarize) == 0: - raise LLMError( - f"Summarize error: tried to run summarize, but couldn't find enough messages to compress [len={len(self.messages)}, preserve_N={MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST}]" + raise ContextWindowExceededError( + "Not enough messages to compress for summarization", + details={ + "num_candidate_messages": len(candidate_messages_to_summarize), + "num_total_messages": len(self.messages), + "preserve_N": MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST, + }, ) # Walk down the message buffer (front-to-back) until we hit the target token count @@ -1204,8 +1210,13 @@ def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True, message_sequence_to_summarize = self._messages[1:cutoff] # do NOT get rid of the system message if len(message_sequence_to_summarize) <= 1: # This prevents a potential infinite loop of summarizing the same message over and over - raise LLMError( - f"Summarize error: tried to run summarize, but couldn't find enough messages to compress [len={len(message_sequence_to_summarize)} <= 1]" + raise ContextWindowExceededError( + "Not enough messages to compress for summarization after determining cutoff", + details={ + "num_candidate_messages": len(message_sequence_to_summarize), + "num_total_messages": len(self.messages), + "preserve_N": MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST, + }, ) else: printd(f"Attempting to summarize {len(message_sequence_to_summarize)} messages [1:{cutoff}] of {len(self._messages)}") @@ -1218,6 +1229,7 @@ def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True, self.agent_state.llm_config.context_window = ( LLM_MAX_TOKENS[self.model] if (self.model is not None and self.model in LLM_MAX_TOKENS) else LLM_MAX_TOKENS["DEFAULT"] ) + summary = summarize_messages(agent_state=self.agent_state, message_sequence_to_summarize=message_sequence_to_summarize) printd(f"Got summary: {summary}") diff --git a/letta/errors.py b/letta/errors.py index c478ef42fa..0dc7cc9ec0 100644 --- a/letta/errors.py +++ b/letta/errors.py @@ -1,4 +1,5 @@ import json +from enum import Enum from typing import TYPE_CHECKING, List, Optional, Union # Avoid circular imports @@ -6,9 +7,31 @@ from letta.schemas.message import Message +class ErrorCode(Enum): + """Enum for error codes used by client.""" + + INTERNAL_SERVER_ERROR = "INTERNAL_SERVER_ERROR" + CONTEXT_WINDOW_EXCEEDED = "CONTEXT_WINDOW_EXCEEDED" + RATE_LIMIT_EXCEEDED = "RATE_LIMIT_EXCEEDED" + + class LettaError(Exception): """Base class for all Letta related errors.""" + def __init__(self, message: str, code: Optional[ErrorCode] = None, details: dict = {}): + self.message = message + self.code = code + self.details = details + super().__init__(message) + + def __str__(self) -> str: + if self.code: + return f"{self.code.value}: {self.message}" + return self.message + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(message='{self.message}', code='{self.code}', details={self.details})" + class LettaToolCreateError(LettaError): """Error raised when a tool cannot be created.""" @@ -16,10 +39,7 @@ class LettaToolCreateError(LettaError): default_error_message = "Error creating tool." def __init__(self, message=None): - if message is None: - message = self.default_error_message - self.message = message - super().__init__(self.message) + super().__init__(message=message or self.default_error_message) class LettaConfigurationError(LettaError): @@ -27,23 +47,17 @@ class LettaConfigurationError(LettaError): def __init__(self, message: str, missing_fields: Optional[List[str]] = None): self.missing_fields = missing_fields or [] - super().__init__(message) + super().__init__(message=message, details={"missing_fields": self.missing_fields}) class LettaAgentNotFoundError(LettaError): """Error raised when an agent is not found.""" - - def __init__(self, message: str): - self.message = message - super().__init__(self.message) + pass class LettaUserNotFoundError(LettaError): """Error raised when a user is not found.""" - - def __init__(self, message: str): - self.message = message - super().__init__(self.message) + pass class LLMError(LettaError): @@ -54,24 +68,45 @@ class LLMJSONParsingError(LettaError): """Exception raised for errors in the JSON parsing process.""" def __init__(self, message="Error parsing JSON generated by LLM"): - self.message = message - super().__init__(self.message) + super().__init__(message=message) class LocalLLMError(LettaError): """Generic catch-all error for local LLM problems""" def __init__(self, message="Encountered an error while running local LLM"): - self.message = message - super().__init__(self.message) + super().__init__(message=message) class LocalLLMConnectionError(LettaError): """Error for when local LLM cannot be reached with provided IP/port""" def __init__(self, message="Could not connect to local LLM"): - self.message = message - super().__init__(self.message) + super().__init__(message=message) + + +class ContextWindowExceededError(LettaError): + """Error raised when the context window is exceeded but further summarization fails.""" + + def __init__(self, message: str, details: dict = {}): + error_message = f"{message} ({details})" + super().__init__( + message=error_message, + code=ErrorCode.CONTEXT_WINDOW_EXCEEDED, + details=details, + ) + + +class RateLimitExceededError(LettaError): + """Error raised when the llm rate limiter throttles api requests.""" + + def __init__(self, message: str, max_retries: int): + error_message = f"{message} ({max_retries})" + super().__init__( + message=error_message, + code=ErrorCode.RATE_LIMIT_EXCEEDED, + details={"max_retries": max_retries}, + ) class LettaMessageError(LettaError): diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index dadd128aa9..578779d72b 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -5,7 +5,7 @@ import requests from letta.constants import CLI_WARNING_PREFIX -from letta.errors import LettaConfigurationError +from letta.errors import LettaConfigurationError, RateLimitExceededError from letta.llm_api.anthropic import anthropic_chat_completions_request from letta.llm_api.azure_openai import azure_openai_chat_completions_request from letta.llm_api.google_ai import ( @@ -80,7 +80,7 @@ def wrapper(*args, **kwargs): # Check if max retries has been reached if num_retries > max_retries: - raise Exception(f"Maximum number of retries ({max_retries}) exceeded.") + raise RateLimitExceededError("Maximum number of retries exceeded", max_retries=max_retries) # Increment the delay delay *= exponential_base * (1 + jitter * random.random()) diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index da8d472cd7..64d46a5d3f 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -8,6 +8,7 @@ from fastapi import Header from pydantic import BaseModel +from letta.errors import ContextWindowExceededError, RateLimitExceededError from letta.schemas.usage import LettaUsageStatistics from letta.server.rest_api.interface import StreamingServerInterface from letta.server.server import SyncServer @@ -61,34 +62,21 @@ async def sse_async_generator( if not isinstance(usage, LettaUsageStatistics): raise ValueError(f"Expected LettaUsageStatistics, got {type(usage)}") yield sse_formatter({"usage": usage.model_dump()}) - except Exception as e: - import traceback - - traceback.print_exc() - warnings.warn(f"SSE stream generator failed: {e}") - # Log the error, since the exception handler upstack (in FastAPI) won't catch it, because this may be a 200 response - # Print the stack trace - if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""): - import sentry_sdk + except ContextWindowExceededError as e: + log_error_to_sentry(e) + yield sse_formatter({"error": f"Stream failed: {e}", "code": str(e.code.value) if e.code else None}) - sentry_sdk.capture_exception(e) + except RateLimitExceededError as e: + log_error_to_sentry(e) + yield sse_formatter({"error": f"Stream failed: {e}", "code": str(e.code.value) if e.code else None}) + except Exception as e: + log_error_to_sentry(e) yield sse_formatter({"error": f"Stream failed (internal error occured)"}) except Exception as e: - import traceback - - traceback.print_exc() - warnings.warn(f"SSE stream generator failed: {e}") - - # Log the error, since the exception handler upstack (in FastAPI) won't catch it, because this may be a 200 response - # Print the stack trace - if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""): - import sentry_sdk - - sentry_sdk.capture_exception(e) - + log_error_to_sentry(e) yield sse_formatter({"error": "Stream failed (decoder encountered an error)"}) finally: @@ -113,3 +101,16 @@ def get_user_id(user_id: Optional[str] = Header(None, alias="user_id")) -> Optio def get_current_interface() -> StreamingServerInterface: return StreamingServerInterface + +def log_error_to_sentry(e): + import traceback + + traceback.print_exc() + warnings.warn(f"SSE stream generator failed: {e}") + + # Log the error, since the exception handler upstack (in FastAPI) won't catch it, because this may be a 200 response + # Print the stack trace + if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""): + import sentry_sdk + + sentry_sdk.capture_exception(e)