Skip to content

Commit

Permalink
fix: Allow ChildToolRule to work without support for structured out…
Browse files Browse the repository at this point in the history
…puts (#2270)

Co-authored-by: Mindy Long <[email protected]>
  • Loading branch information
mlong93 and Mindy Long authored Dec 18, 2024
1 parent c1f7b32 commit 264068d
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 6 deletions.
3 changes: 3 additions & 0 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,9 @@ def _get_ai_reply(
and len(self.tool_rules_solver.init_tool_rules) > 0
):
force_tool_call = self.tool_rules_solver.init_tool_rules[0].tool_name
# Force a tool call if exactly one tool is specified
elif step_count is not None and step_count > 0 and len(allowed_tool_names) == 1:
force_tool_call = allowed_tool_names[0]

for attempt in range(1, empty_response_retry_limit + 1):
try:
Expand Down
31 changes: 26 additions & 5 deletions letta/llm_api/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,24 @@ def convert_anthropic_response_to_chatcompletion(
),
)
]
else:
# Just inner mono
content = strip_xml_tags(string=response_json["content"][0]["text"], tag=inner_thoughts_xml_tag)
tool_calls = None
elif len(response_json["content"]) == 1:
if response_json["content"][0]["type"] == "tool_use":
# function call only
content = None
tool_calls = [
ToolCall(
id=response_json["content"][0]["id"],
type="function",
function=FunctionCall(
name=response_json["content"][0]["name"],
arguments=json.dumps(response_json["content"][0]["input"], indent=2),
),
)
]
else:
# inner mono only
content = strip_xml_tags(string=response_json["content"][0]["text"], tag=inner_thoughts_xml_tag)
tool_calls = None
else:
raise RuntimeError("Unexpected type for content in response_json.")

Expand Down Expand Up @@ -327,6 +341,14 @@ def anthropic_chat_completions_request(
if anthropic_tools is not None:
data["tools"] = anthropic_tools

# TODO: Add support for other tool_choice options like "auto", "any"
if len(anthropic_tools) == 1:
data["tool_choice"] = {
"type": "tool", # Changed from "function" to "tool"
"name": anthropic_tools[0]["name"], # Directly specify name without nested "function" object
"disable_parallel_tool_use": True # Force single tool use
}

# Move 'system' to the top level
# 'messages: Unexpected role "system". The Messages API accepts a top-level `system` parameter, not "system" as an input message role.'
assert data["messages"][0]["role"] == "system", f"Expected 'system' role in messages[0]:\n{data['messages'][0]}"
Expand Down Expand Up @@ -362,7 +384,6 @@ def anthropic_chat_completions_request(
data.pop("top_p", None)
data.pop("presence_penalty", None)
data.pop("user", None)
data.pop("tool_choice", None)

response_json = make_post_request(url, headers, data)
return convert_anthropic_response_to_chatcompletion(response_json=response_json, inner_thoughts_xml_tag=inner_thoughts_xml_tag)
3 changes: 2 additions & 1 deletion tests/helpers/endpoints_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def setup_agent(
tool_ids: Optional[List[str]] = None,
tool_rules: Optional[List[BaseToolRule]] = None,
agent_uuid: str = agent_uuid,
include_base_tools: bool = True,
) -> AgentState:
config_data = json.load(open(filename, "r"))
llm_config = LLMConfig(**config_data)
Expand All @@ -77,7 +78,7 @@ def setup_agent(

memory = ChatMemory(human=memory_human_str, persona=memory_persona_str)
agent_state = client.create_agent(
name=agent_uuid, llm_config=llm_config, embedding_config=embedding_config, memory=memory, tool_ids=tool_ids, tool_rules=tool_rules
name=agent_uuid, llm_config=llm_config, embedding_config=embedding_config, memory=memory, tool_ids=tool_ids, tool_rules=tool_rules, include_base_tools=include_base_tools,
)

return agent_state
Expand Down
48 changes: 48 additions & 0 deletions tests/integration_test_agent_tool_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,51 @@ def test_claude_initial_tool_rule_enforced(mock_e2b_api_key_none):
if i < 2:
backoff_time = 10 * (2 ** i)
time.sleep(backoff_time)

@pytest.mark.timeout(60) # Sets a 60-second timeout for the test since this could loop infinitely
def test_agent_no_structured_output_with_one_child_tool(mock_e2b_api_key_none):
client = create_client()
cleanup(client=client, agent_uuid=agent_uuid)

send_message = client.server.tool_manager.get_tool_by_name(tool_name="send_message", actor=client.user)
archival_memory_search = client.server.tool_manager.get_tool_by_name(tool_name="archival_memory_search", actor=client.user)
archival_memory_insert = client.server.tool_manager.get_tool_by_name(tool_name="archival_memory_insert", actor=client.user)

# Make tool rules
tool_rules = [
InitToolRule(tool_name="archival_memory_search"),
ChildToolRule(tool_name="archival_memory_search", children=["archival_memory_insert"]),
ChildToolRule(tool_name="archival_memory_insert", children=["send_message"]),
TerminalToolRule(tool_name="send_message"),
]
tools = [send_message, archival_memory_search, archival_memory_insert]

config_files = [
"tests/configs/llm_model_configs/claude-3-sonnet-20240229.json",
"tests/configs/llm_model_configs/openai-gpt-4o.json",
]

for config in config_files:
agent_state = setup_agent(client, config, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
response = client.user_message(agent_id=agent_state.id, message="hi. run archival memory search")

# Make checks
assert_sanity_checks(response)

# Assert the tools were called
assert_invoked_function_call(response.messages, "archival_memory_search")
assert_invoked_function_call(response.messages, "archival_memory_insert")
assert_invoked_function_call(response.messages, "send_message")

# Check ordering of tool calls
tool_names = [t.name for t in [archival_memory_search, archival_memory_insert, send_message]]
for m in response.messages:
if isinstance(m, FunctionCallMessage):
# Check that it's equal to the first one
assert m.function_call.name == tool_names[0]

# Pop out first one
tool_names = tool_names[1:]

print(f"Got successful response from client: \n\n{response}")
cleanup(client=client, agent_uuid=agent_uuid)

0 comments on commit 264068d

Please sign in to comment.