From 264068dbb965528163e2966afc07168b79d89b66 Mon Sep 17 00:00:00 2001 From: mlong93 <35275280+mlong93@users.noreply.github.com> Date: Tue, 17 Dec 2024 18:04:26 -0800 Subject: [PATCH] fix: Allow `ChildToolRule` to work without support for structured outputs (#2270) Co-authored-by: Mindy Long --- letta/agent.py | 3 ++ letta/llm_api/anthropic.py | 31 +++++++++++--- tests/helpers/endpoints_helper.py | 3 +- tests/integration_test_agent_tool_graph.py | 48 ++++++++++++++++++++++ 4 files changed, 79 insertions(+), 6 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index 532c4e13a7..485f2112b9 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -604,6 +604,9 @@ def _get_ai_reply( and len(self.tool_rules_solver.init_tool_rules) > 0 ): force_tool_call = self.tool_rules_solver.init_tool_rules[0].tool_name + # Force a tool call if exactly one tool is specified + elif step_count is not None and step_count > 0 and len(allowed_tool_names) == 1: + force_tool_call = allowed_tool_names[0] for attempt in range(1, empty_response_retry_limit + 1): try: diff --git a/letta/llm_api/anthropic.py b/letta/llm_api/anthropic.py index 912ac4567f..4cca920a5c 100644 --- a/letta/llm_api/anthropic.py +++ b/letta/llm_api/anthropic.py @@ -262,10 +262,24 @@ def convert_anthropic_response_to_chatcompletion( ), ) ] - else: - # Just inner mono - content = strip_xml_tags(string=response_json["content"][0]["text"], tag=inner_thoughts_xml_tag) - tool_calls = None + elif len(response_json["content"]) == 1: + if response_json["content"][0]["type"] == "tool_use": + # function call only + content = None + tool_calls = [ + ToolCall( + id=response_json["content"][0]["id"], + type="function", + function=FunctionCall( + name=response_json["content"][0]["name"], + arguments=json.dumps(response_json["content"][0]["input"], indent=2), + ), + ) + ] + else: + # inner mono only + content = strip_xml_tags(string=response_json["content"][0]["text"], tag=inner_thoughts_xml_tag) + tool_calls = None else: raise RuntimeError("Unexpected type for content in response_json.") @@ -327,6 +341,14 @@ def anthropic_chat_completions_request( if anthropic_tools is not None: data["tools"] = anthropic_tools + # TODO: Add support for other tool_choice options like "auto", "any" + if len(anthropic_tools) == 1: + data["tool_choice"] = { + "type": "tool", # Changed from "function" to "tool" + "name": anthropic_tools[0]["name"], # Directly specify name without nested "function" object + "disable_parallel_tool_use": True # Force single tool use + } + # Move 'system' to the top level # 'messages: Unexpected role "system". The Messages API accepts a top-level `system` parameter, not "system" as an input message role.' assert data["messages"][0]["role"] == "system", f"Expected 'system' role in messages[0]:\n{data['messages'][0]}" @@ -362,7 +384,6 @@ def anthropic_chat_completions_request( data.pop("top_p", None) data.pop("presence_penalty", None) data.pop("user", None) - data.pop("tool_choice", None) response_json = make_post_request(url, headers, data) return convert_anthropic_response_to_chatcompletion(response_json=response_json, inner_thoughts_xml_tag=inner_thoughts_xml_tag) diff --git a/tests/helpers/endpoints_helper.py b/tests/helpers/endpoints_helper.py index ddaa1d960d..8f1aa99c74 100644 --- a/tests/helpers/endpoints_helper.py +++ b/tests/helpers/endpoints_helper.py @@ -64,6 +64,7 @@ def setup_agent( tool_ids: Optional[List[str]] = None, tool_rules: Optional[List[BaseToolRule]] = None, agent_uuid: str = agent_uuid, + include_base_tools: bool = True, ) -> AgentState: config_data = json.load(open(filename, "r")) llm_config = LLMConfig(**config_data) @@ -77,7 +78,7 @@ def setup_agent( memory = ChatMemory(human=memory_human_str, persona=memory_persona_str) agent_state = client.create_agent( - name=agent_uuid, llm_config=llm_config, embedding_config=embedding_config, memory=memory, tool_ids=tool_ids, tool_rules=tool_rules + name=agent_uuid, llm_config=llm_config, embedding_config=embedding_config, memory=memory, tool_ids=tool_ids, tool_rules=tool_rules, include_base_tools=include_base_tools, ) return agent_state diff --git a/tests/integration_test_agent_tool_graph.py b/tests/integration_test_agent_tool_graph.py index 19c7dbd6cb..336777215d 100644 --- a/tests/integration_test_agent_tool_graph.py +++ b/tests/integration_test_agent_tool_graph.py @@ -234,3 +234,51 @@ def test_claude_initial_tool_rule_enforced(mock_e2b_api_key_none): if i < 2: backoff_time = 10 * (2 ** i) time.sleep(backoff_time) + +@pytest.mark.timeout(60) # Sets a 60-second timeout for the test since this could loop infinitely +def test_agent_no_structured_output_with_one_child_tool(mock_e2b_api_key_none): + client = create_client() + cleanup(client=client, agent_uuid=agent_uuid) + + send_message = client.server.tool_manager.get_tool_by_name(tool_name="send_message", actor=client.user) + archival_memory_search = client.server.tool_manager.get_tool_by_name(tool_name="archival_memory_search", actor=client.user) + archival_memory_insert = client.server.tool_manager.get_tool_by_name(tool_name="archival_memory_insert", actor=client.user) + + # Make tool rules + tool_rules = [ + InitToolRule(tool_name="archival_memory_search"), + ChildToolRule(tool_name="archival_memory_search", children=["archival_memory_insert"]), + ChildToolRule(tool_name="archival_memory_insert", children=["send_message"]), + TerminalToolRule(tool_name="send_message"), + ] + tools = [send_message, archival_memory_search, archival_memory_insert] + + config_files = [ + "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json", + "tests/configs/llm_model_configs/openai-gpt-4o.json", + ] + + for config in config_files: + agent_state = setup_agent(client, config, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) + response = client.user_message(agent_id=agent_state.id, message="hi. run archival memory search") + + # Make checks + assert_sanity_checks(response) + + # Assert the tools were called + assert_invoked_function_call(response.messages, "archival_memory_search") + assert_invoked_function_call(response.messages, "archival_memory_insert") + assert_invoked_function_call(response.messages, "send_message") + + # Check ordering of tool calls + tool_names = [t.name for t in [archival_memory_search, archival_memory_insert, send_message]] + for m in response.messages: + if isinstance(m, FunctionCallMessage): + # Check that it's equal to the first one + assert m.function_call.name == tool_names[0] + + # Pop out first one + tool_names = tool_names[1:] + + print(f"Got successful response from client: \n\n{response}") + cleanup(client=client, agent_uuid=agent_uuid)