Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add ConditionalToolRules #2279

Merged
7 changes: 6 additions & 1 deletion letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,9 @@ def __init__(

self.first_message_verify_mono = first_message_verify_mono

# State needed for conditional tool chaining
self.last_function_response = None

# Controls if the convo memory pressure warning is triggered
# When an alert is sent in the message queue, set this to True (to avoid repeat alerts)
# When the summarizer is run, set this back to False (to reset)
Expand Down Expand Up @@ -586,7 +589,7 @@ def _get_ai_reply(
) -> ChatCompletionResponse:
"""Get response from LLM API with robust retry mechanism."""

allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names()
allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names(last_function_response=self.last_function_response)
agent_state_tool_jsons = [t.json_schema for t in self.agent_state.tools]

allowed_functions = (
Expand Down Expand Up @@ -826,6 +829,7 @@ def _handle_ai_response(
error_msg_user = f"{error_msg}\n{traceback.format_exc()}"
printd(error_msg_user)
function_response = package_function_response(False, error_msg)
self.last_function_response = function_response
# TODO: truncate error message somehow
messages.append(
Message.dict_to_message(
Expand Down Expand Up @@ -861,6 +865,7 @@ def _handle_ai_response(
) # extend conversation with function response
self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1])
self.interface.function_message(f"Success: {function_response_string}", msg_obj=messages[-1])
self.last_function_response = function_response

else:
# Standard non-function reply
Expand Down
129 changes: 78 additions & 51 deletions letta/helpers/tool_rule_solver.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Dict, List, Optional, Set
import json
from typing import List, Optional, Union

from pydantic import BaseModel, Field

from letta.schemas.enums import ToolRuleType
from letta.schemas.tool_rule import (
BaseToolRule,
ChildToolRule,
ConditionalToolRule,
InitToolRule,
TerminalToolRule,
)
Expand All @@ -22,7 +24,7 @@ class ToolRulesSolver(BaseModel):
init_tool_rules: List[InitToolRule] = Field(
default_factory=list, description="Initial tool rules to be used at the start of tool execution."
)
tool_rules: List[ChildToolRule] = Field(
tool_rules: List[Union[ChildToolRule, ConditionalToolRule]] = Field(
default_factory=list, description="Standard tool rules for controlling execution sequence and allowed transitions."
)
terminal_tool_rules: List[TerminalToolRule] = Field(
Expand All @@ -35,21 +37,25 @@ def __init__(self, tool_rules: List[BaseToolRule], **kwargs):
# Separate the provided tool rules into init, standard, and terminal categories
for rule in tool_rules:
if rule.type == ToolRuleType.run_first:
assert isinstance(rule, InitToolRule)
self.init_tool_rules.append(rule)
elif rule.type == ToolRuleType.constrain_child_tools:
assert isinstance(rule, ChildToolRule)
self.tool_rules.append(rule)
elif rule.type == ToolRuleType.conditional:
assert isinstance(rule, ConditionalToolRule)
self.validate_conditional_tool(rule)
self.tool_rules.append(rule)
elif rule.type == ToolRuleType.exit_loop:
assert isinstance(rule, TerminalToolRule)
self.terminal_tool_rules.append(rule)

# Validate the tool rules to ensure they form a DAG
if not self.validate_tool_rules():
raise ToolRuleValidationError("Tool rules contain cycles, which are not allowed in a valid configuration.")

def update_tool_usage(self, tool_name: str):
"""Update the internal state to track the last tool called."""
self.last_tool_name = tool_name

def get_allowed_tool_names(self, error_on_empty: bool = False) -> List[str]:
def get_allowed_tool_names(self, error_on_empty: bool = False, last_function_response: Optional[str] = None) -> List[str]:
"""Get a list of tool names allowed based on the last tool called."""
if self.last_tool_name is None:
# Use initial tool rules if no tool has been called yet
Expand All @@ -58,18 +64,20 @@ def get_allowed_tool_names(self, error_on_empty: bool = False) -> List[str]:
# Find a matching ToolRule for the last tool used
current_rule = next((rule for rule in self.tool_rules if rule.tool_name == self.last_tool_name), None)

# Return children which must exist on ToolRule
if current_rule:
return current_rule.children

# Default to empty if no rule matches
message = "User provided tool rules and execution state resolved to no more possible tool calls."
if error_on_empty:
raise RuntimeError(message)
else:
# warnings.warn(message)
if current_rule is None:
if error_on_empty:
raise ValueError(f"No tool rule found for {self.last_tool_name}")
return []

# If the current rule is a conditional tool rule, use the LLM response to
# determine which child tool to use
if isinstance(current_rule, ConditionalToolRule):
if not last_function_response:
raise ValueError("Conditional tool rule requires an LLM response to determine which child tool to use")
return [self.evaluate_conditional_tool(current_rule, last_function_response)]

return current_rule.children if current_rule.children else []

def is_terminal_tool(self, tool_name: str) -> bool:
"""Check if the tool is defined as a terminal tool in the terminal tool rules."""
return any(rule.tool_name == tool_name for rule in self.terminal_tool_rules)
Expand All @@ -78,38 +86,57 @@ def has_children_tools(self, tool_name):
"""Check if the tool has children tools"""
return any(rule.tool_name == tool_name for rule in self.tool_rules)

def validate_tool_rules(self) -> bool:
"""
Validate that the tool rules define a directed acyclic graph (DAG).
Returns True if valid (no cycles), otherwise False.
"""
# Build adjacency list for the tool graph
adjacency_list: Dict[str, List[str]] = {rule.tool_name: rule.children for rule in self.tool_rules}

# Track visited nodes
visited: Set[str] = set()
path_stack: Set[str] = set()

# Define DFS helper function
def dfs(tool_name: str) -> bool:
if tool_name in path_stack:
return False # Cycle detected
if tool_name in visited:
return True # Already validated

# Mark the node as visited in the current path
path_stack.add(tool_name)
for child in adjacency_list.get(tool_name, []):
if not dfs(child):
return False # Cycle detected in DFS
path_stack.remove(tool_name) # Remove from current path
visited.add(tool_name)
return True

# Run DFS from each tool in `tool_rules`
for rule in self.tool_rules:
if rule.tool_name not in visited:
if not dfs(rule.tool_name):
return False # Cycle found, invalid tool rules

return True # No cycles, valid DAG
def validate_conditional_tool(self, rule: ConditionalToolRule):
'''
Validate a conditional tool rule

Args:
rule (ConditionalToolRule): The conditional tool rule to validate

Raises:
ToolRuleValidationError: If the rule is invalid
'''
if rule.children is None or len(rule.children) == 0:
raise ToolRuleValidationError("Conditional tool rule must have at least one child tool.")
if len(rule.children) != len(rule.child_output_mapping):
raise ToolRuleValidationError("Conditional tool rule must have a child output mapping for each child tool.")
if set(rule.children) != set(rule.child_output_mapping.values()):
raise ToolRuleValidationError("Conditional tool rule must have a child output mapping for each child tool.")
return True

def evaluate_conditional_tool(self, tool: ConditionalToolRule, last_function_response: str) -> str:
'''
Parse function response to determine which child tool to use based on the mapping

Args:
tool (ConditionalToolRule): The conditional tool rule
last_function_response (str): The function response in JSON format

Returns:
str: The name of the child tool to use next
'''
json_response = json.loads(last_function_response)
function_output = json_response["message"]

# Try to match the function output with a mapping key
for key in tool.child_output_mapping:

# Convert function output to match key type for comparison
if isinstance(key, bool):
typed_output = function_output.lower() == "true"
elif isinstance(key, int):
try:
typed_output = int(function_output)
except (ValueError, TypeError):
continue
else: # string
if function_output == "True" or function_output == "False":
typed_output = function_output.lower()
else:
typed_output = function_output

if typed_output == key:
return tool.child_output_mapping[key]

# If no match found, use default
return tool.default_child
7 changes: 5 additions & 2 deletions letta/orm/custom_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from letta.schemas.enums import ToolRuleType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule


class EmbeddingConfigColumn(TypeDecorator):
Expand Down Expand Up @@ -80,7 +80,7 @@ def process_result_value(self, value, dialect) -> List[Union[ChildToolRule, Init
return value

@staticmethod
def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule]:
def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule]:
"""Deserialize a dictionary to the appropriate ToolRule subclass based on the 'type'."""
rule_type = ToolRuleType(data.get("type")) # Remove 'type' field if it exists since it is a class var
if rule_type == ToolRuleType.run_first:
Expand All @@ -90,6 +90,9 @@ def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, Term
elif rule_type == ToolRuleType.constrain_child_tools:
rule = ChildToolRule(**data)
return rule
elif rule_type == ToolRuleType.conditional:
rule = ConditionalToolRule(**data)
return rule
else:
raise ValueError(f"Unknown tool rule type: {rule_type}")

Expand Down
1 change: 1 addition & 0 deletions letta/schemas/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,6 @@ class ToolRuleType(str, Enum):
run_first = "InitToolRule"
exit_loop = "TerminalToolRule" # reasoning loop should exit
continue_loop = "continue_loop" # reasoning loop should continue
conditional = "conditional"
constrain_child_tools = "ToolRule"
require_parent_tools = "require_parent_tools"
15 changes: 13 additions & 2 deletions letta/schemas/tool_rule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Union
from typing import Dict, List, Union

from pydantic import Field

Expand All @@ -21,6 +21,17 @@ class ChildToolRule(BaseToolRule):
children: List[str] = Field(..., description="The children tools that can be invoked.")


class ConditionalToolRule(BaseToolRule):
"""
A ToolRule that conditionally maps to different child tools based on the output.
"""
type: ToolRuleType = ToolRuleType.conditional
default_child: str = Field(..., description="The default child tool to be called")
child_output_mapping: Dict[Union[bool, str, int], str] = Field(..., description="The output case to check for mapping")
mlong93 marked this conversation as resolved.
Show resolved Hide resolved
children: List[str] = Field(..., description="The child tool to call when output matches the case")
mlong93 marked this conversation as resolved.
Show resolved Hide resolved
throw_error: bool = Field(default=False, description="Whether to throw an error when output doesn't match any case")
mlong93 marked this conversation as resolved.
Show resolved Hide resolved

mlong93 marked this conversation as resolved.
Show resolved Hide resolved

class InitToolRule(BaseToolRule):
"""
Represents the initial tool rule configuration.
Expand All @@ -37,4 +48,4 @@ class TerminalToolRule(BaseToolRule):
type: ToolRuleType = ToolRuleType.exit_loop


ToolRule = Union[ChildToolRule, InitToolRule, TerminalToolRule]
ToolRule = Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule]
Loading
Loading