Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Adding init tool rule for Anthropic endpoint #2262

Merged
merged 6 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
MESSAGE_SUMMARY_WARNING_FRAC,
O1_BASE_TOOLS,
REQ_HEARTBEAT_MESSAGE,
STRUCTURED_OUTPUT_MODELS
)
from letta.errors import LLMError
from letta.helpers import ToolRulesSolver
Expand Down Expand Up @@ -276,6 +277,7 @@ def __init__(

# gpt-4, gpt-3.5-turbo, ...
self.model = self.agent_state.llm_config.model
self.check_tool_rules()

# state managers
self.block_manager = BlockManager()
Expand Down Expand Up @@ -381,6 +383,13 @@ def __init__(
# Create the agent in the DB
self.update_state()

def check_tool_rules(self):
if self.model not in STRUCTURED_OUTPUT_MODELS:
assert len(self.tool_rules_solver.init_tool_rules) <= 1, "Multiple initial tools not supported for non-structured models"
self.supports_structured_output = False
else:
self.supports_structured_output = True

def update_memory_if_change(self, new_memory: Memory) -> bool:
"""
Update internal memory object and system prompt if there have been modifications.
Expand Down Expand Up @@ -596,6 +605,11 @@ def _get_ai_reply(
self.functions if not allowed_tool_names else [func for func in self.functions if func["name"] in allowed_tool_names]
)

# For the first message, force the initial tool if one is specified
force_tool_call = None
if first_message and not self.supports_structured_output and len(self.tool_rules_solver.init_tool_rules) > 0:
force_tool_call = self.tool_rules_solver.init_tool_rules[0].tool_name

for attempt in range(1, empty_response_retry_limit + 1):
try:
response = create(
Expand All @@ -606,6 +620,7 @@ def _get_ai_reply(
functions_python=self.functions_python,
function_call=function_call,
first_message=first_message,
force_tool_call=force_tool_call,
stream=stream,
stream_interface=self.interface,
)
Expand Down Expand Up @@ -896,7 +911,10 @@ def step(
total_usage = UsageStatistics()
step_count = 0
while True:
kwargs["first_message"] = False
mlong93 marked this conversation as resolved.
Show resolved Hide resolved
if step_count > 0:
kwargs["first_message"] = False
else:
kwargs["first_message"] = True
step_response = self.inner_step(
messages=next_input_message,
**kwargs,
Expand Down Expand Up @@ -1014,6 +1032,7 @@ def inner_step(
else:
response = self._get_ai_reply(
message_sequence=input_message_sequence,
first_message=first_message,
stream=stream,
)

Expand Down
1 change: 1 addition & 0 deletions letta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2156,6 +2156,7 @@ def create_agent(
"block_ids": [b.id for b in memory.get_blocks()] + block_ids,
"tool_ids": tool_ids,
"tool_rules": tool_rules,
"include_base_tools": include_base_tools,
"system": system,
"agent_type": agent_type,
"llm_config": llm_config if llm_config else self._default_llm_config,
Expand Down
3 changes: 3 additions & 0 deletions letta/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@
DEFAULT_MESSAGE_TOOL = "send_message"
DEFAULT_MESSAGE_TOOL_KWARG = "message"

# Structured output models
STRUCTURED_OUTPUT_MODELS = {"gpt-4o-mini-2024-07-18", "gpt-4o-2024-08-06"}
mlong93 marked this conversation as resolved.
Show resolved Hide resolved

# LOGGER_LOG_LEVEL is use to convert Text to Logging level value for logging mostly for Cli input to setting level
LOGGER_LOG_LEVELS = {"CRITICAL": CRITICAL, "ERROR": ERROR, "WARN": WARN, "WARNING": WARNING, "INFO": INFO, "DEBUG": DEBUG, "NOTSET": NOTSET}

Expand Down
20 changes: 12 additions & 8 deletions letta/llm_api/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,20 @@ def convert_tools_to_anthropic_format(tools: List[Tool]) -> List[dict]:
- 1 level less of nesting
- "parameters" -> "input_schema"
"""
tools_dict_list = []
formatted_tools = []
for tool in tools:
tools_dict_list.append(
{
"name": tool.function.name,
"description": tool.function.description,
"input_schema": tool.function.parameters,
formatted_tool = {
"name" : tool.function.name,
"description" : tool.function.description,
"input_schema" : tool.function.parameters or {
"type": "object",
"properties": {},
"required": []
}
)
return tools_dict_list
}
formatted_tools.append(formatted_tool)

return formatted_tools


def merge_tool_results_into_user_messages(messages: List[dict]):
Expand Down
13 changes: 12 additions & 1 deletion letta/llm_api/llm_api_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def create(
function_call: str = "auto",
# hint
first_message: bool = False,
force_tool_call: Optional[str] = None, # Force a specific tool to be called
# use tool naming?
# if false, will use deprecated 'functions' style
use_tool_naming: bool = True,
Expand Down Expand Up @@ -252,14 +253,24 @@ def create(
if not use_tool_naming:
raise NotImplementedError("Only tool calling supported on Anthropic API requests")

tool_call = None
if force_tool_call is not None:
tool_call = {
"type": "function",
"function": {
"name": force_tool_call
}
}
assert functions is not None

return anthropic_chat_completions_request(
url=llm_config.model_endpoint,
api_key=model_settings.anthropic_api_key,
data=ChatCompletionRequest(
model=llm_config.model,
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
tools=[{"type": "function", "function": f} for f in functions] if functions else None,
# tool_choice=function_call,
tool_choice=tool_call,
# user=str(user_id),
# NOTE: max_tokens is required for Anthropic API
max_tokens=1024, # TODO make dynamic
Expand Down
9 changes: 9 additions & 0 deletions tests/configs/llm_model_configs/claude-3-sonnet-20240229.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"context_window": 200000,
"model": "claude-3-5-sonnet-20241022",
"model_endpoint_type": "anthropic",
"model_endpoint": "https://api.anthropic.com/v1",
"context_window": 200000,
"model_wrapper": null,
"put_inner_thoughts_in_kwargs": true
}
55 changes: 54 additions & 1 deletion tests/integration_test_agent_tool_graph.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import time
import uuid

import pytest

from letta import create_client
from letta.schemas.letta_message import FunctionCallMessage
from letta.schemas.llm_config import LLMConfig
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
from tests.helpers.endpoints_helper import (
assert_invoked_function_call,
Expand Down Expand Up @@ -127,3 +128,55 @@ def test_single_path_agent_tool_call_graph(mock_e2b_api_key_none):

print(f"Got successful response from client: \n\n{response}")
cleanup(client=client, agent_uuid=agent_uuid)


def test_claude_initial_tool_rule_enforced(mock_e2b_api_key_none):
"""Test that the initial tool rule is enforced for the first message."""
client = create_client()
client.set_default_llm_config(
LLMConfig(
model="claude-3-opus-20240229",
model_endpoint_type="anthropic",
model_endpoint="https://api.anthropic.com/v1",
context_window=200000, # NOTE: can be set to <= 200000
)
)
cleanup(client=client, agent_uuid=agent_uuid)

# Create tool rules that require tool_a to be called first
t1 = client.create_or_update_tool(first_secret_word)
t2 = client.create_or_update_tool(second_secret_word)
tool_rules = [
InitToolRule(tool_name="first_secret_word"),
ChildToolRule(tool_name="first_secret_word", children=["second_secret_word"]),
]
tools = [t1, t2]

# Make agent state
anthropic_config_file = "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json"
agent_state = setup_agent(client, anthropic_config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
for i in range(3):
response = client.user_message(agent_id=agent_state.id, message="What is the second secret word?")

assert_sanity_checks(response)
messages = response.messages

assert_invoked_function_call(messages, "first_secret_word")
assert_invoked_function_call(messages, "second_secret_word")

tool_names = [t.name for t in [t1, t2]]
tool_names += ["send_message"]
for m in messages:
if isinstance(m, FunctionCallMessage):
# Check that it's equal to the first one
assert m.function_call.name == tool_names[0]

# Pop out first one
tool_names = tool_names[1:]

print(f"Passed iteration {i}")

# Implement exponential backoff with initial time of 10 seconds
if i < 2:
backoff_time = 10 * (2 ** i)
time.sleep(backoff_time)
1 change: 1 addition & 0 deletions tests/integration_test_offline_memory_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def test_chat_only_agent(client, mock_e2b_api_key_none):
)
assert chat_only_agent is not None
assert set(chat_only_agent.memory.list_block_labels()) == {"chat_agent_persona", "chat_agent_human"}
assert len(chat_only_agent.tools) == 1

for message in ["hello", "my name is not chad, my name is swoodily"]:
client.send_message(agent_id=chat_only_agent.id, message=message, role="user")
Expand Down
Loading