Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add ConditionalToolRules #2279

Merged
21 changes: 20 additions & 1 deletion letta/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import inspect
import json
import time
import traceback
import warnings
Expand Down Expand Up @@ -371,6 +372,9 @@ def __init__(
self._append_to_messages(added_messages=init_messages_objs)
self._validate_message_buffer_is_utc()

# Load last function response from message history
self.last_function_response = self.load_last_function_response()

# Keep track of the total number of messages throughout all time
self.messages_total = messages_total if messages_total is not None else (len(self._messages) - 1) # (-system)
self.messages_total_init = len(self._messages) - 1
Expand All @@ -389,6 +393,19 @@ def check_tool_rules(self):
else:
self.supports_structured_output = True

def load_last_function_response(self):
"""Load the last function response from message history"""
for i in range(len(self._messages) - 1, -1, -1):
msg = self._messages[i]
if msg.role == MessageRole.tool and msg.text:
try:
response_json = json.loads(msg.text)
if response_json.get("message"):
return response_json["message"]
except (json.JSONDecodeError, KeyError):
raise ValueError(f"Invalid JSON format in message: {msg.text}")
return None

def update_memory_if_change(self, new_memory: Memory) -> bool:
"""
Update internal memory object and system prompt if there have been modifications.
Expand Down Expand Up @@ -586,7 +603,7 @@ def _get_ai_reply(
) -> ChatCompletionResponse:
"""Get response from LLM API with robust retry mechanism."""

allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names()
allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names(last_function_response=self.last_function_response)
agent_state_tool_jsons = [t.json_schema for t in self.agent_state.tools]

allowed_functions = (
Expand Down Expand Up @@ -826,6 +843,7 @@ def _handle_ai_response(
error_msg_user = f"{error_msg}\n{traceback.format_exc()}"
printd(error_msg_user)
function_response = package_function_response(False, error_msg)
self.last_function_response = function_response
# TODO: truncate error message somehow
messages.append(
Message.dict_to_message(
Expand Down Expand Up @@ -861,6 +879,7 @@ def _handle_ai_response(
) # extend conversation with function response
self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1])
self.interface.function_message(f"Success: {function_response_string}", msg_obj=messages[-1])
self.last_function_response = function_response

else:
# Standard non-function reply
Expand Down
133 changes: 82 additions & 51 deletions letta/helpers/tool_rule_solver.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Dict, List, Optional, Set
import json
from typing import List, Optional, Union

from pydantic import BaseModel, Field

from letta.schemas.enums import ToolRuleType
from letta.schemas.tool_rule import (
BaseToolRule,
ChildToolRule,
ConditionalToolRule,
InitToolRule,
TerminalToolRule,
)
Expand All @@ -22,7 +24,7 @@ class ToolRulesSolver(BaseModel):
init_tool_rules: List[InitToolRule] = Field(
default_factory=list, description="Initial tool rules to be used at the start of tool execution."
)
tool_rules: List[ChildToolRule] = Field(
tool_rules: List[Union[ChildToolRule, ConditionalToolRule]] = Field(
default_factory=list, description="Standard tool rules for controlling execution sequence and allowed transitions."
)
terminal_tool_rules: List[TerminalToolRule] = Field(
Expand All @@ -35,21 +37,25 @@ def __init__(self, tool_rules: List[BaseToolRule], **kwargs):
# Separate the provided tool rules into init, standard, and terminal categories
for rule in tool_rules:
if rule.type == ToolRuleType.run_first:
assert isinstance(rule, InitToolRule)
self.init_tool_rules.append(rule)
elif rule.type == ToolRuleType.constrain_child_tools:
assert isinstance(rule, ChildToolRule)
self.tool_rules.append(rule)
elif rule.type == ToolRuleType.conditional:
assert isinstance(rule, ConditionalToolRule)
self.validate_conditional_tool(rule)
self.tool_rules.append(rule)
elif rule.type == ToolRuleType.exit_loop:
assert isinstance(rule, TerminalToolRule)
self.terminal_tool_rules.append(rule)

# Validate the tool rules to ensure they form a DAG
if not self.validate_tool_rules():
raise ToolRuleValidationError("Tool rules contain cycles, which are not allowed in a valid configuration.")

def update_tool_usage(self, tool_name: str):
"""Update the internal state to track the last tool called."""
self.last_tool_name = tool_name

def get_allowed_tool_names(self, error_on_empty: bool = False) -> List[str]:
def get_allowed_tool_names(self, error_on_empty: bool = False, last_function_response: Optional[str] = None) -> List[str]:
"""Get a list of tool names allowed based on the last tool called."""
if self.last_tool_name is None:
# Use initial tool rules if no tool has been called yet
Expand All @@ -58,18 +64,21 @@ def get_allowed_tool_names(self, error_on_empty: bool = False) -> List[str]:
# Find a matching ToolRule for the last tool used
current_rule = next((rule for rule in self.tool_rules if rule.tool_name == self.last_tool_name), None)

# Return children which must exist on ToolRule
if current_rule:
return current_rule.children

# Default to empty if no rule matches
message = "User provided tool rules and execution state resolved to no more possible tool calls."
if error_on_empty:
raise RuntimeError(message)
else:
# warnings.warn(message)
if current_rule is None:
if error_on_empty:
raise ValueError(f"No tool rule found for {self.last_tool_name}")
return []

# If the current rule is a conditional tool rule, use the LLM response to
# determine which child tool to use
if isinstance(current_rule, ConditionalToolRule):
if not last_function_response:
raise ValueError("Conditional tool rule requires an LLM response to determine which child tool to use")
next_tool = self.evaluate_conditional_tool(current_rule, last_function_response)
return [next_tool] if next_tool else []

return current_rule.children if current_rule.children else []

def is_terminal_tool(self, tool_name: str) -> bool:
"""Check if the tool is defined as a terminal tool in the terminal tool rules."""
return any(rule.tool_name == tool_name for rule in self.terminal_tool_rules)
Expand All @@ -78,38 +87,60 @@ def has_children_tools(self, tool_name):
"""Check if the tool has children tools"""
return any(rule.tool_name == tool_name for rule in self.tool_rules)

def validate_tool_rules(self) -> bool:
"""
Validate that the tool rules define a directed acyclic graph (DAG).
Returns True if valid (no cycles), otherwise False.
"""
# Build adjacency list for the tool graph
adjacency_list: Dict[str, List[str]] = {rule.tool_name: rule.children for rule in self.tool_rules}

# Track visited nodes
visited: Set[str] = set()
path_stack: Set[str] = set()

# Define DFS helper function
def dfs(tool_name: str) -> bool:
if tool_name in path_stack:
return False # Cycle detected
if tool_name in visited:
return True # Already validated

# Mark the node as visited in the current path
path_stack.add(tool_name)
for child in adjacency_list.get(tool_name, []):
if not dfs(child):
return False # Cycle detected in DFS
path_stack.remove(tool_name) # Remove from current path
visited.add(tool_name)
return True

# Run DFS from each tool in `tool_rules`
for rule in self.tool_rules:
if rule.tool_name not in visited:
if not dfs(rule.tool_name):
return False # Cycle found, invalid tool rules

return True # No cycles, valid DAG
def validate_conditional_tool(self, rule: ConditionalToolRule):
'''
Validate a conditional tool rule

Args:
rule (ConditionalToolRule): The conditional tool rule to validate

Raises:
ToolRuleValidationError: If the rule is invalid
'''
if len(rule.child_output_mapping) == 0:
raise ToolRuleValidationError("Conditional tool rule must have at least one child tool.")
return True

def evaluate_conditional_tool(self, tool: ConditionalToolRule, last_function_response: str) -> str:
'''
Parse function response to determine which child tool to use based on the mapping

Args:
tool (ConditionalToolRule): The conditional tool rule
last_function_response (str): The function response in JSON format

Returns:
str: The name of the child tool to use next
'''
json_response = json.loads(last_function_response)
function_output = json_response["message"]

# Try to match the function output with a mapping key
for key in tool.child_output_mapping:

# Convert function output to match key type for comparison
if isinstance(key, bool):
typed_output = function_output.lower() == "true"
elif isinstance(key, int):
try:
typed_output = int(function_output)
except (ValueError, TypeError):
continue
elif isinstance(key, float):
try:
typed_output = float(function_output)
except (ValueError, TypeError):
continue
else: # string
if function_output == "True" or function_output == "False":
typed_output = function_output.lower()
elif function_output == "None":
typed_output = None
else:
typed_output = function_output

if typed_output == key:
return tool.child_output_mapping[key]

# If no match found, use default
return tool.default_child
7 changes: 5 additions & 2 deletions letta/orm/custom_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from letta.schemas.enums import ToolRuleType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule


class EmbeddingConfigColumn(TypeDecorator):
Expand Down Expand Up @@ -80,7 +80,7 @@ def process_result_value(self, value, dialect) -> List[Union[ChildToolRule, Init
return value

@staticmethod
def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule]:
def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule]:
"""Deserialize a dictionary to the appropriate ToolRule subclass based on the 'type'."""
rule_type = ToolRuleType(data.get("type")) # Remove 'type' field if it exists since it is a class var
if rule_type == ToolRuleType.run_first:
Expand All @@ -90,6 +90,9 @@ def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, Term
elif rule_type == ToolRuleType.constrain_child_tools:
rule = ChildToolRule(**data)
return rule
elif rule_type == ToolRuleType.conditional:
rule = ConditionalToolRule(**data)
return rule
else:
raise ValueError(f"Unknown tool rule type: {rule_type}")

Expand Down
1 change: 1 addition & 0 deletions letta/schemas/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,6 @@ class ToolRuleType(str, Enum):
run_first = "InitToolRule"
exit_loop = "TerminalToolRule" # reasoning loop should exit
continue_loop = "continue_loop" # reasoning loop should continue
conditional = "conditional"
constrain_child_tools = "ToolRule"
require_parent_tools = "require_parent_tools"
14 changes: 12 additions & 2 deletions letta/schemas/tool_rule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Union
from typing import Any, Dict, List, Optional, Union

from pydantic import Field

Expand All @@ -21,6 +21,16 @@ class ChildToolRule(BaseToolRule):
children: List[str] = Field(..., description="The children tools that can be invoked.")


class ConditionalToolRule(BaseToolRule):
"""
A ToolRule that conditionally maps to different child tools based on the output.
"""
type: ToolRuleType = ToolRuleType.conditional
default_child: Optional[str] = Field(None, description="The default child tool to be called. If None, any tool can be called.")
child_output_mapping: Dict[Any, str] = Field(..., description="The output case to check for mapping")
require_output_mapping: bool = Field(default=False, description="Whether to throw an error when output doesn't match any case")


class InitToolRule(BaseToolRule):
"""
Represents the initial tool rule configuration.
Expand All @@ -37,4 +47,4 @@ class TerminalToolRule(BaseToolRule):
type: ToolRuleType = ToolRuleType.exit_loop


ToolRule = Union[ChildToolRule, InitToolRule, TerminalToolRule]
ToolRule = Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule]
6 changes: 3 additions & 3 deletions tests/helpers/endpoints_helper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging
import uuid
from typing import Callable, List, Optional, Union
from typing import Callable, List, Optional, Sequence, Union

from letta.llm_api.helpers import unpack_inner_thoughts_from_kwargs
from letta.schemas.tool_rule import BaseToolRule
Expand Down Expand Up @@ -373,7 +373,7 @@ def assert_sanity_checks(response: LettaResponse):
assert len(response.messages) > 0, response


def assert_invoked_send_message_with_keyword(messages: List[LettaMessage], keyword: str, case_sensitive: bool = False) -> None:
def assert_invoked_send_message_with_keyword(messages: Sequence[LettaMessage], keyword: str, case_sensitive: bool = False) -> None:
# Find first instance of send_message
target_message = None
for message in messages:
Expand Down Expand Up @@ -406,7 +406,7 @@ def assert_invoked_send_message_with_keyword(messages: List[LettaMessage], keywo
raise InvalidToolCallError(messages=[target_message], explanation=f"Message argument did not contain keyword={keyword}")


def assert_invoked_function_call(messages: List[LettaMessage], function_name: str) -> None:
def assert_invoked_function_call(messages: Sequence[LettaMessage], function_name: str) -> None:
for message in messages:
if isinstance(message, ToolCallMessage) and message.tool_call.name == function_name:
# Found it, do nothing
Expand Down
Loading
Loading