Skip to content

Commit

Permalink
Merge branch 'main' into fix-anthropic-kevin
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders committed Dec 17, 2024
2 parents 38c4250 + 75d40fc commit fde6836
Show file tree
Hide file tree
Showing 15 changed files with 328 additions and 67 deletions.
40 changes: 22 additions & 18 deletions .github/workflows/docker-image.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@ on:
workflow_dispatch:

jobs:

build:

runs-on: ubuntu-latest

steps:
- name: Login to Docker Hub
uses: docker/login-action@v3
Expand All @@ -19,19 +17,25 @@ jobs:
password: ${{ secrets.DOCKERHUB_TOKEN }}

- uses: actions/checkout@v3
- name: Build and push the Docker image (memgpt)
run: |
# Extract the version number from pyproject.toml using awk
CURRENT_VERSION=$(awk -F '"' '/version =/ { print $2 }' pyproject.toml | head -n 1)
docker build . --file Dockerfile --tag memgpt/letta:$CURRENT_VERSION --tag memgpt/letta:latest
docker push memgpt/letta:$CURRENT_VERSION
docker push memgpt/letta:latest

- name: Set up QEMU
uses: docker/setup-qemu-action@v3

- uses: actions/checkout@v3
- name: Build and push the Docker image (letta)
run: |
# Extract the version number from pyproject.toml using awk
CURRENT_VERSION=$(awk -F '"' '/version =/ { print $2 }' pyproject.toml | head -n 1)
docker build . --file Dockerfile --tag letta/letta:$CURRENT_VERSION --tag letta/letta:latest
docker push letta/letta:$CURRENT_VERSION
docker push letta/letta:latest
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

- name: Extract version number
id: extract_version
run: echo "CURRENT_VERSION=$(awk -F '\"' '/version =/ { print $2 }' pyproject.toml | head -n 1)" >> $GITHUB_ENV

- name: Build and push
uses: docker/build-push-action@v6
with:
platforms: linux/amd64,linux/arm64
push: true
tags: |
letta/letta:${{ env.CURRENT_VERSION }}
letta/letta:latest
memgpt/letta:${{ env.CURRENT_VERSION }}
memgpt/letta:latest
26 changes: 26 additions & 0 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
MESSAGE_SUMMARY_WARNING_FRAC,
O1_BASE_TOOLS,
REQ_HEARTBEAT_MESSAGE,
STRUCTURED_OUTPUT_MODELS
)
from letta.errors import LLMError
from letta.helpers import ToolRulesSolver
Expand Down Expand Up @@ -276,6 +277,7 @@ def __init__(

# gpt-4, gpt-3.5-turbo, ...
self.model = self.agent_state.llm_config.model
self.check_tool_rules()

# state managers
self.block_manager = BlockManager()
Expand Down Expand Up @@ -381,6 +383,14 @@ def __init__(
# Create the agent in the DB
self.update_state()

def check_tool_rules(self):
if self.model not in STRUCTURED_OUTPUT_MODELS:
if len(self.tool_rules_solver.init_tool_rules) > 1:
raise ValueError("Multiple initial tools are not supported for non-structured models. Please use only one initial tool rule.")
self.supports_structured_output = False
else:
self.supports_structured_output = True

def update_memory_if_change(self, new_memory: Memory) -> bool:
"""
Update internal memory object and system prompt if there have been modifications.
Expand Down Expand Up @@ -588,6 +598,7 @@ def _get_ai_reply(
empty_response_retry_limit: int = 3,
backoff_factor: float = 0.5, # delay multiplier for exponential backoff
max_delay: float = 10.0, # max delay between retries
step_count: Optional[int] = None,
) -> ChatCompletionResponse:
"""Get response from LLM API with robust retry mechanism."""

Expand All @@ -596,6 +607,16 @@ def _get_ai_reply(
self.functions if not allowed_tool_names else [func for func in self.functions if func["name"] in allowed_tool_names]
)

# For the first message, force the initial tool if one is specified
force_tool_call = None
if (
step_count is not None
and step_count == 0
and not self.supports_structured_output
and len(self.tool_rules_solver.init_tool_rules) > 0
):
force_tool_call = self.tool_rules_solver.init_tool_rules[0].tool_name

for attempt in range(1, empty_response_retry_limit + 1):
try:
response = create(
Expand All @@ -606,6 +627,7 @@ def _get_ai_reply(
functions_python=self.functions_python,
function_call=function_call,
first_message=first_message,
force_tool_call=force_tool_call,
stream=stream,
stream_interface=self.interface,
)
Expand Down Expand Up @@ -897,6 +919,7 @@ def step(
step_count = 0
while True:
kwargs["first_message"] = False
kwargs["step_count"] = step_count
step_response = self.inner_step(
messages=next_input_message,
**kwargs,
Expand Down Expand Up @@ -972,6 +995,7 @@ def inner_step(
first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS,
skip_verify: bool = False,
stream: bool = False, # TODO move to config?
step_count: Optional[int] = None,
) -> AgentStepResponse:
"""Runs a single step in the agent loop (generates at most one LLM call)"""

Expand Down Expand Up @@ -1014,7 +1038,9 @@ def inner_step(
else:
response = self._get_ai_reply(
message_sequence=input_message_sequence,
first_message=first_message,
stream=stream,
step_count=step_count,
)

# Step 3: check if LLM wanted to call a function
Expand Down
1 change: 1 addition & 0 deletions letta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2156,6 +2156,7 @@ def create_agent(
"block_ids": [b.id for b in memory.get_blocks()] + block_ids,
"tool_ids": tool_ids,
"tool_rules": tool_rules,
"include_base_tools": include_base_tools,
"system": system,
"agent_type": agent_type,
"llm_config": llm_config if llm_config else self._default_llm_config,
Expand Down
3 changes: 3 additions & 0 deletions letta/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@
DEFAULT_MESSAGE_TOOL = "send_message"
DEFAULT_MESSAGE_TOOL_KWARG = "message"

# Structured output models
STRUCTURED_OUTPUT_MODELS = {"gpt-4o", "gpt-4o-mini"}

# LOGGER_LOG_LEVEL is use to convert Text to Logging level value for logging mostly for Cli input to setting level
LOGGER_LOG_LEVELS = {"CRITICAL": CRITICAL, "ERROR": ERROR, "WARN": WARN, "WARNING": WARNING, "INFO": INFO, "DEBUG": DEBUG, "NOTSET": NOTSET}

Expand Down
12 changes: 3 additions & 9 deletions letta/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,16 +234,10 @@ def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None
)
elif endpoint_type == "ollama":

from llama_index.embeddings.ollama import OllamaEmbedding

ollama_additional_kwargs = {}
callback_manager = None

model = OllamaEmbedding(
model_name=config.embedding_model,
model = OllamaEmbeddings(
model=config.embedding_model,
base_url=config.embedding_endpoint,
ollama_additional_kwargs=ollama_additional_kwargs or {},
callback_manager=callback_manager or None,
ollama_additional_kwargs={},
)
return model

Expand Down
20 changes: 12 additions & 8 deletions letta/llm_api/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,20 @@ def convert_tools_to_anthropic_format(tools: List[Tool]) -> List[dict]:
- 1 level less of nesting
- "parameters" -> "input_schema"
"""
tools_dict_list = []
formatted_tools = []
for tool in tools:
tools_dict_list.append(
{
"name": tool.function.name,
"description": tool.function.description,
"input_schema": tool.function.parameters,
formatted_tool = {
"name" : tool.function.name,
"description" : tool.function.description,
"input_schema" : tool.function.parameters or {
"type": "object",
"properties": {},
"required": []
}
)
return tools_dict_list
}
formatted_tools.append(formatted_tool)

return formatted_tools


def merge_tool_results_into_user_messages(messages: List[dict]):
Expand Down
13 changes: 12 additions & 1 deletion letta/llm_api/llm_api_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def create(
function_call: str = "auto",
# hint
first_message: bool = False,
force_tool_call: Optional[str] = None, # Force a specific tool to be called
# use tool naming?
# if false, will use deprecated 'functions' style
use_tool_naming: bool = True,
Expand Down Expand Up @@ -252,14 +253,24 @@ def create(
if not use_tool_naming:
raise NotImplementedError("Only tool calling supported on Anthropic API requests")

tool_call = None
if force_tool_call is not None:
tool_call = {
"type": "function",
"function": {
"name": force_tool_call
}
}
assert functions is not None

return anthropic_chat_completions_request(
url=llm_config.model_endpoint,
api_key=model_settings.anthropic_api_key,
data=ChatCompletionRequest(
model=llm_config.model,
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
tools=[{"type": "function", "function": f} for f in functions] if functions else None,
# tool_choice=function_call,
tool_choice=tool_call,
# user=str(user_id),
# NOTE: max_tokens is required for Anthropic API
max_tokens=1024, # TODO make dynamic
Expand Down
8 changes: 8 additions & 0 deletions letta/orm/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,11 @@ class UniqueConstraintViolationError(ValueError):

class ForeignKeyConstraintViolationError(ValueError):
"""Custom exception for foreign key constraint violations."""


class DatabaseTimeoutError(Exception):
"""Custom exception for database timeout issues."""

def __init__(self, message="Database operation timed out", original_exception=None):
super().__init__(message)
self.original_exception = original_exception
41 changes: 24 additions & 17 deletions letta/orm/sqlalchemy_base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from datetime import datetime
from enum import Enum
from functools import wraps
from typing import TYPE_CHECKING, List, Literal, Optional

from sqlalchemy import String, desc, func, or_, select
from sqlalchemy.exc import DBAPIError, IntegrityError
from sqlalchemy.exc import DBAPIError, IntegrityError, TimeoutError
from sqlalchemy.orm import Mapped, Session, mapped_column

from letta.log import get_logger
from letta.orm.base import Base, CommonSqlalchemyMetaMixins
from letta.orm.errors import (
DatabaseTimeoutError,
ForeignKeyConstraintViolationError,
NoResultFound,
UniqueConstraintViolationError,
Expand All @@ -23,6 +25,20 @@
logger = get_logger(__name__)


def handle_db_timeout(func):
"""Decorator to handle SQLAlchemy TimeoutError and wrap it in a custom exception."""

@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except TimeoutError as e:
logger.error(f"Timeout while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}")
raise DatabaseTimeoutError(message=f"Timeout occurred in {func.__name__}.", original_exception=e)

return wrapper


class AccessType(str, Enum):
ORGANIZATION = "organization"
USER = "user"
Expand All @@ -36,22 +52,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
id: Mapped[str] = mapped_column(String, primary_key=True)

@classmethod
def get(cls, *, db_session: Session, id: str) -> Optional["SqlalchemyBase"]:
"""Get a record by ID.
Args:
db_session: SQLAlchemy session
id: Record ID to retrieve
Returns:
Optional[SqlalchemyBase]: The record if found, None otherwise
"""
try:
return db_session.query(cls).filter(cls.id == id).first()
except DBAPIError:
return None

@classmethod
@handle_db_timeout
def list(
cls,
*,
Expand Down Expand Up @@ -180,6 +181,7 @@ def list(
return list(session.execute(query).scalars())

@classmethod
@handle_db_timeout
def read(
cls,
db_session: "Session",
Expand Down Expand Up @@ -231,6 +233,7 @@ def read(
conditions_str = ", ".join(query_conditions) if query_conditions else "no specific conditions"
raise NoResultFound(f"{cls.__name__} not found with {conditions_str}")

@handle_db_timeout
def create(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
logger.debug(f"Creating {self.__class__.__name__} with ID: {self.id} with actor={actor}")

Expand All @@ -245,6 +248,7 @@ def create(self, db_session: "Session", actor: Optional["User"] = None) -> "Sqla
except (DBAPIError, IntegrityError) as e:
self._handle_dbapi_error(e)

@handle_db_timeout
def delete(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
logger.debug(f"Soft deleting {self.__class__.__name__} with ID: {self.id} with actor={actor}")

Expand All @@ -254,6 +258,7 @@ def delete(self, db_session: "Session", actor: Optional["User"] = None) -> "Sqla
self.is_deleted = True
return self.update(db_session)

@handle_db_timeout
def hard_delete(self, db_session: "Session", actor: Optional["User"] = None) -> None:
"""Permanently removes the record from the database."""
logger.debug(f"Hard deleting {self.__class__.__name__} with ID: {self.id} with actor={actor}")
Expand All @@ -269,6 +274,7 @@ def hard_delete(self, db_session: "Session", actor: Optional["User"] = None) ->
else:
logger.debug(f"{self.__class__.__name__} with ID {self.id} successfully hard deleted")

@handle_db_timeout
def update(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
logger.debug(f"Updating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
if actor:
Expand All @@ -281,6 +287,7 @@ def update(self, db_session: "Session", actor: Optional["User"] = None) -> "Sqla
return self

@classmethod
@handle_db_timeout
def size(
cls,
*,
Expand Down
Loading

0 comments on commit fde6836

Please sign in to comment.