Skip to content

Commit

Permalink
refactor error base class and surface error in streaming logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Caren Thomas committed Dec 18, 2024
1 parent 8d57b7a commit 8059183
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 64 deletions.
26 changes: 15 additions & 11 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
REQ_HEARTBEAT_MESSAGE,
STRUCTURED_OUTPUT_MODELS,
)
from letta.errors import LLMError, SummarizationError
from letta.errors import ContextWindowExceededError
from letta.helpers import ToolRulesSolver
from letta.interface import AgentInterface
from letta.llm_api.helpers import is_context_overflow_error
Expand Down Expand Up @@ -1170,11 +1170,13 @@ def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True,

# If at this point there's nothing to summarize, throw an error
if len(candidate_messages_to_summarize) == 0:
raise SummarizationError(
f"Not enough messages to compress for summarization.",
num_candidate_messages=len(candidate_messages_to_summarize),
num_total_messages=len(self.messages),
preserve_N=MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
raise ContextWindowExceededError(
"Not enough messages to compress for summarization",
details={
"num_candidate_messages": len(candidate_messages_to_summarize),
"num_total_messages": len(self.messages),
"preserve_N": MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
},
)

# Walk down the message buffer (front-to-back) until we hit the target token count
Expand Down Expand Up @@ -1208,11 +1210,13 @@ def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True,
message_sequence_to_summarize = self._messages[1:cutoff] # do NOT get rid of the system message
if len(message_sequence_to_summarize) <= 1:
# This prevents a potential infinite loop of summarizing the same message over and over
raise SummarizationError(
f"Not enough messages to compress for summarization after determining cutoff.",
num_candidate_messages=len(message_sequence_to_summarize),
num_total_messages=len(self.messages),
preserve_N=MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
raise ContextWindowExceededError(
"Not enough messages to compress for summarization after determining cutoff",
details={
"num_candidate_messages": len(message_sequence_to_summarize),
"num_total_messages": len(self.messages),
"preserve_N": MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
},
)
else:
printd(f"Attempting to summarize {len(message_sequence_to_summarize)} messages [1:{cutoff}] of {len(self._messages)}")
Expand Down
69 changes: 39 additions & 30 deletions letta/errors.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,63 @@
import json
from enum import Enum
from typing import TYPE_CHECKING, List, Optional, Union

# Avoid circular imports
if TYPE_CHECKING:
from letta.schemas.message import Message


class ErrorCode(Enum):
"""Enum for error codes used by client."""

INTERNAL_SERVER_ERROR = "INTERNAL_SERVER_ERROR"
CONTEXT_WINDOW_EXCEEDED = "CONTEXT_WINDOW_EXCEEDED"
RATE_LIMIT_EXCEEDED = "RATE_LIMIT_EXCEEDED"


class LettaError(Exception):
"""Base class for all Letta related errors."""

def __init__(self, message: str, code: Optional[ErrorCode] = None, details: dict = {}):
self.message = message
self.code = code
self.details = details
super().__init__(message)

def __str__(self) -> str:
if self.code:
return f"{self.code.value}: {self.message}"
return self.message

def __repr__(self) -> str:
return f"{self.__class__.__name__}(message='{self.message}', code='{self.code}', details={self.details})"


class LettaToolCreateError(LettaError):
"""Error raised when a tool cannot be created."""

default_error_message = "Error creating tool."

def __init__(self, message=None):
if message is None:
message = self.default_error_message
self.message = message
super().__init__(self.message)
super().__init__(message=message or self.default_error_message)


class LettaConfigurationError(LettaError):
"""Error raised when there are configuration-related issues."""

def __init__(self, message: str, missing_fields: Optional[List[str]] = None):
self.missing_fields = missing_fields or []
super().__init__(message)
super().__init__(message=message, details={"missing_fields": self.missing_fields})


class LettaAgentNotFoundError(LettaError):
"""Error raised when an agent is not found."""

def __init__(self, message: str):
self.message = message
super().__init__(self.message)
pass


class LettaUserNotFoundError(LettaError):
"""Error raised when a user is not found."""

def __init__(self, message: str):
self.message = message
super().__init__(self.message)
pass


class LLMError(LettaError):
Expand All @@ -54,38 +68,33 @@ class LLMJSONParsingError(LettaError):
"""Exception raised for errors in the JSON parsing process."""

def __init__(self, message="Error parsing JSON generated by LLM"):
self.message = message
super().__init__(self.message)
super().__init__(message=message)


class LocalLLMError(LettaError):
"""Generic catch-all error for local LLM problems"""

def __init__(self, message="Encountered an error while running local LLM"):
self.message = message
super().__init__(self.message)
super().__init__(message=message)


class LocalLLMConnectionError(LettaError):
"""Error for when local LLM cannot be reached with provided IP/port"""

def __init__(self, message="Could not connect to local LLM"):
self.message = message
super().__init__(self.message)
super().__init__(message=message)


class SummarizationError(LettaError):
"""Error raised when the summarization process fails."""

def __init__(self, message: str, num_candidate_messages: int, num_total_messages: int, preserve_N: int):
self.message = message
self.num_candidate_messages = num_candidate_messages
self.num_total_messages = num_total_messages
self.preserve_N = preserve_N
super().__init__(self.message)
class ContextWindowExceededError(LettaError):
"""Error raised when the context window is exceeded but further summarization fails."""

def __str__(self):
return f"{self.message} (num_candidate_messages={self.num_candidate_messages}, num_total_messages={self.num_total_messages}, preserve_N={self.preserve_N})"
def __init__(self, message: str, details: dict = {}):
error_message = f"{message} ({details})"
super().__init__(
message=error_message,
code=ErrorCode.CONTEXT_WINDOW_EXCEEDED,
details=details,
)


class LettaMessageError(LettaError):
Expand Down
43 changes: 20 additions & 23 deletions letta/server/rest_api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fastapi import Header
from pydantic import BaseModel

from letta.errors import ContextWindowExceededError
from letta.schemas.usage import LettaUsageStatistics
from letta.server.rest_api.interface import StreamingServerInterface
from letta.server.server import SyncServer
Expand Down Expand Up @@ -61,34 +62,17 @@ async def sse_async_generator(
if not isinstance(usage, LettaUsageStatistics):
raise ValueError(f"Expected LettaUsageStatistics, got {type(usage)}")
yield sse_formatter({"usage": usage.model_dump()})
except Exception as e:
import traceback

traceback.print_exc()
warnings.warn(f"SSE stream generator failed: {e}")

# Log the error, since the exception handler upstack (in FastAPI) won't catch it, because this may be a 200 response
# Print the stack trace
if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""):
import sentry_sdk

sentry_sdk.capture_exception(e)
except ContextWindowExceededError as e:
log_error_to_sentry(e)
yield sse_formatter({"error": f"Stream failed: {e}", "code": str(e.code.value) if e.code else None})

except Exception as e:
log_error_to_sentry(e)
yield sse_formatter({"error": f"Stream failed (internal error occured)"})

except Exception as e:
import traceback

traceback.print_exc()
warnings.warn(f"SSE stream generator failed: {e}")

# Log the error, since the exception handler upstack (in FastAPI) won't catch it, because this may be a 200 response
# Print the stack trace
if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""):
import sentry_sdk

sentry_sdk.capture_exception(e)

log_error_to_sentry(e)
yield sse_formatter({"error": "Stream failed (decoder encountered an error)"})

finally:
Expand All @@ -113,3 +97,16 @@ def get_user_id(user_id: Optional[str] = Header(None, alias="user_id")) -> Optio

def get_current_interface() -> StreamingServerInterface:
return StreamingServerInterface

def log_error_to_sentry(e):
import traceback

traceback.print_exc()
warnings.warn(f"SSE stream generator failed: {e}")

# Log the error, since the exception handler upstack (in FastAPI) won't catch it, because this may be a 200 response
# Print the stack trace
if (os.getenv("SENTRY_DSN") is not None) and (os.getenv("SENTRY_DSN") != ""):
import sentry_sdk

sentry_sdk.capture_exception(e)

0 comments on commit 8059183

Please sign in to comment.