Skip to content

Commit

Permalink
feat: Add ConditionalToolRules (#2279)
Browse files Browse the repository at this point in the history
Co-authored-by: Mindy Long <[email protected]>
  • Loading branch information
mlong93 and Mindy Long authored Dec 19, 2024
1 parent a336092 commit 8cc6870
Show file tree
Hide file tree
Showing 10 changed files with 564 additions and 96 deletions.
21 changes: 20 additions & 1 deletion letta/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import inspect
import json
import time
import traceback
import warnings
Expand Down Expand Up @@ -371,6 +372,9 @@ def __init__(
self._append_to_messages(added_messages=init_messages_objs)
self._validate_message_buffer_is_utc()

# Load last function response from message history
self.last_function_response = self.load_last_function_response()

# Keep track of the total number of messages throughout all time
self.messages_total = messages_total if messages_total is not None else (len(self._messages) - 1) # (-system)
self.messages_total_init = len(self._messages) - 1
Expand All @@ -389,6 +393,19 @@ def check_tool_rules(self):
else:
self.supports_structured_output = True

def load_last_function_response(self):
"""Load the last function response from message history"""
for i in range(len(self._messages) - 1, -1, -1):
msg = self._messages[i]
if msg.role == MessageRole.tool and msg.text:
try:
response_json = json.loads(msg.text)
if response_json.get("message"):
return response_json["message"]
except (json.JSONDecodeError, KeyError):
raise ValueError(f"Invalid JSON format in message: {msg.text}")
return None

def update_memory_if_change(self, new_memory: Memory) -> bool:
"""
Update internal memory object and system prompt if there have been modifications.
Expand Down Expand Up @@ -586,7 +603,7 @@ def _get_ai_reply(
) -> ChatCompletionResponse:
"""Get response from LLM API with robust retry mechanism."""

allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names()
allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names(last_function_response=self.last_function_response)
agent_state_tool_jsons = [t.json_schema for t in self.agent_state.tools]

allowed_functions = (
Expand Down Expand Up @@ -826,6 +843,7 @@ def _handle_ai_response(
error_msg_user = f"{error_msg}\n{traceback.format_exc()}"
printd(error_msg_user)
function_response = package_function_response(False, error_msg)
self.last_function_response = function_response
# TODO: truncate error message somehow
messages.append(
Message.dict_to_message(
Expand Down Expand Up @@ -861,6 +879,7 @@ def _handle_ai_response(
) # extend conversation with function response
self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1])
self.interface.function_message(f"Success: {function_response_string}", msg_obj=messages[-1])
self.last_function_response = function_response

else:
# Standard non-function reply
Expand Down
133 changes: 82 additions & 51 deletions letta/helpers/tool_rule_solver.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Dict, List, Optional, Set
import json
from typing import List, Optional, Union

from pydantic import BaseModel, Field

from letta.schemas.enums import ToolRuleType
from letta.schemas.tool_rule import (
BaseToolRule,
ChildToolRule,
ConditionalToolRule,
InitToolRule,
TerminalToolRule,
)
Expand All @@ -22,7 +24,7 @@ class ToolRulesSolver(BaseModel):
init_tool_rules: List[InitToolRule] = Field(
default_factory=list, description="Initial tool rules to be used at the start of tool execution."
)
tool_rules: List[ChildToolRule] = Field(
tool_rules: List[Union[ChildToolRule, ConditionalToolRule]] = Field(
default_factory=list, description="Standard tool rules for controlling execution sequence and allowed transitions."
)
terminal_tool_rules: List[TerminalToolRule] = Field(
Expand All @@ -35,21 +37,25 @@ def __init__(self, tool_rules: List[BaseToolRule], **kwargs):
# Separate the provided tool rules into init, standard, and terminal categories
for rule in tool_rules:
if rule.type == ToolRuleType.run_first:
assert isinstance(rule, InitToolRule)
self.init_tool_rules.append(rule)
elif rule.type == ToolRuleType.constrain_child_tools:
assert isinstance(rule, ChildToolRule)
self.tool_rules.append(rule)
elif rule.type == ToolRuleType.conditional:
assert isinstance(rule, ConditionalToolRule)
self.validate_conditional_tool(rule)
self.tool_rules.append(rule)
elif rule.type == ToolRuleType.exit_loop:
assert isinstance(rule, TerminalToolRule)
self.terminal_tool_rules.append(rule)

# Validate the tool rules to ensure they form a DAG
if not self.validate_tool_rules():
raise ToolRuleValidationError("Tool rules contain cycles, which are not allowed in a valid configuration.")

def update_tool_usage(self, tool_name: str):
"""Update the internal state to track the last tool called."""
self.last_tool_name = tool_name

def get_allowed_tool_names(self, error_on_empty: bool = False) -> List[str]:
def get_allowed_tool_names(self, error_on_empty: bool = False, last_function_response: Optional[str] = None) -> List[str]:
"""Get a list of tool names allowed based on the last tool called."""
if self.last_tool_name is None:
# Use initial tool rules if no tool has been called yet
Expand All @@ -58,18 +64,21 @@ def get_allowed_tool_names(self, error_on_empty: bool = False) -> List[str]:
# Find a matching ToolRule for the last tool used
current_rule = next((rule for rule in self.tool_rules if rule.tool_name == self.last_tool_name), None)

# Return children which must exist on ToolRule
if current_rule:
return current_rule.children

# Default to empty if no rule matches
message = "User provided tool rules and execution state resolved to no more possible tool calls."
if error_on_empty:
raise RuntimeError(message)
else:
# warnings.warn(message)
if current_rule is None:
if error_on_empty:
raise ValueError(f"No tool rule found for {self.last_tool_name}")
return []

# If the current rule is a conditional tool rule, use the LLM response to
# determine which child tool to use
if isinstance(current_rule, ConditionalToolRule):
if not last_function_response:
raise ValueError("Conditional tool rule requires an LLM response to determine which child tool to use")
next_tool = self.evaluate_conditional_tool(current_rule, last_function_response)
return [next_tool] if next_tool else []

return current_rule.children if current_rule.children else []

def is_terminal_tool(self, tool_name: str) -> bool:
"""Check if the tool is defined as a terminal tool in the terminal tool rules."""
return any(rule.tool_name == tool_name for rule in self.terminal_tool_rules)
Expand All @@ -78,38 +87,60 @@ def has_children_tools(self, tool_name):
"""Check if the tool has children tools"""
return any(rule.tool_name == tool_name for rule in self.tool_rules)

def validate_tool_rules(self) -> bool:
"""
Validate that the tool rules define a directed acyclic graph (DAG).
Returns True if valid (no cycles), otherwise False.
"""
# Build adjacency list for the tool graph
adjacency_list: Dict[str, List[str]] = {rule.tool_name: rule.children for rule in self.tool_rules}

# Track visited nodes
visited: Set[str] = set()
path_stack: Set[str] = set()

# Define DFS helper function
def dfs(tool_name: str) -> bool:
if tool_name in path_stack:
return False # Cycle detected
if tool_name in visited:
return True # Already validated

# Mark the node as visited in the current path
path_stack.add(tool_name)
for child in adjacency_list.get(tool_name, []):
if not dfs(child):
return False # Cycle detected in DFS
path_stack.remove(tool_name) # Remove from current path
visited.add(tool_name)
return True

# Run DFS from each tool in `tool_rules`
for rule in self.tool_rules:
if rule.tool_name not in visited:
if not dfs(rule.tool_name):
return False # Cycle found, invalid tool rules

return True # No cycles, valid DAG
def validate_conditional_tool(self, rule: ConditionalToolRule):
'''
Validate a conditional tool rule
Args:
rule (ConditionalToolRule): The conditional tool rule to validate
Raises:
ToolRuleValidationError: If the rule is invalid
'''
if len(rule.child_output_mapping) == 0:
raise ToolRuleValidationError("Conditional tool rule must have at least one child tool.")
return True

def evaluate_conditional_tool(self, tool: ConditionalToolRule, last_function_response: str) -> str:
'''
Parse function response to determine which child tool to use based on the mapping
Args:
tool (ConditionalToolRule): The conditional tool rule
last_function_response (str): The function response in JSON format
Returns:
str: The name of the child tool to use next
'''
json_response = json.loads(last_function_response)
function_output = json_response["message"]

# Try to match the function output with a mapping key
for key in tool.child_output_mapping:

# Convert function output to match key type for comparison
if isinstance(key, bool):
typed_output = function_output.lower() == "true"
elif isinstance(key, int):
try:
typed_output = int(function_output)
except (ValueError, TypeError):
continue
elif isinstance(key, float):
try:
typed_output = float(function_output)
except (ValueError, TypeError):
continue
else: # string
if function_output == "True" or function_output == "False":
typed_output = function_output.lower()
elif function_output == "None":
typed_output = None
else:
typed_output = function_output

if typed_output == key:
return tool.child_output_mapping[key]

# If no match found, use default
return tool.default_child
7 changes: 5 additions & 2 deletions letta/orm/custom_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from letta.schemas.enums import ToolRuleType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule


class EmbeddingConfigColumn(TypeDecorator):
Expand Down Expand Up @@ -80,7 +80,7 @@ def process_result_value(self, value, dialect) -> List[Union[ChildToolRule, Init
return value

@staticmethod
def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule]:
def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule]:
"""Deserialize a dictionary to the appropriate ToolRule subclass based on the 'type'."""
rule_type = ToolRuleType(data.get("type")) # Remove 'type' field if it exists since it is a class var
if rule_type == ToolRuleType.run_first:
Expand All @@ -90,6 +90,9 @@ def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, Term
elif rule_type == ToolRuleType.constrain_child_tools:
rule = ChildToolRule(**data)
return rule
elif rule_type == ToolRuleType.conditional:
rule = ConditionalToolRule(**data)
return rule
else:
raise ValueError(f"Unknown tool rule type: {rule_type}")

Expand Down
1 change: 1 addition & 0 deletions letta/schemas/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,6 @@ class ToolRuleType(str, Enum):
run_first = "InitToolRule"
exit_loop = "TerminalToolRule" # reasoning loop should exit
continue_loop = "continue_loop" # reasoning loop should continue
conditional = "conditional"
constrain_child_tools = "ToolRule"
require_parent_tools = "require_parent_tools"
14 changes: 12 additions & 2 deletions letta/schemas/tool_rule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Union
from typing import Any, Dict, List, Optional, Union

from pydantic import Field

Expand All @@ -21,6 +21,16 @@ class ChildToolRule(BaseToolRule):
children: List[str] = Field(..., description="The children tools that can be invoked.")


class ConditionalToolRule(BaseToolRule):
"""
A ToolRule that conditionally maps to different child tools based on the output.
"""
type: ToolRuleType = ToolRuleType.conditional
default_child: Optional[str] = Field(None, description="The default child tool to be called. If None, any tool can be called.")
child_output_mapping: Dict[Any, str] = Field(..., description="The output case to check for mapping")
require_output_mapping: bool = Field(default=False, description="Whether to throw an error when output doesn't match any case")


class InitToolRule(BaseToolRule):
"""
Represents the initial tool rule configuration.
Expand All @@ -37,4 +47,4 @@ class TerminalToolRule(BaseToolRule):
type: ToolRuleType = ToolRuleType.exit_loop


ToolRule = Union[ChildToolRule, InitToolRule, TerminalToolRule]
ToolRule = Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule]
6 changes: 3 additions & 3 deletions tests/helpers/endpoints_helper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging
import uuid
from typing import Callable, List, Optional, Union
from typing import Callable, List, Optional, Sequence, Union

from letta.llm_api.helpers import unpack_inner_thoughts_from_kwargs
from letta.schemas.tool_rule import BaseToolRule
Expand Down Expand Up @@ -373,7 +373,7 @@ def assert_sanity_checks(response: LettaResponse):
assert len(response.messages) > 0, response


def assert_invoked_send_message_with_keyword(messages: List[LettaMessage], keyword: str, case_sensitive: bool = False) -> None:
def assert_invoked_send_message_with_keyword(messages: Sequence[LettaMessage], keyword: str, case_sensitive: bool = False) -> None:
# Find first instance of send_message
target_message = None
for message in messages:
Expand Down Expand Up @@ -406,7 +406,7 @@ def assert_invoked_send_message_with_keyword(messages: List[LettaMessage], keywo
raise InvalidToolCallError(messages=[target_message], explanation=f"Message argument did not contain keyword={keyword}")


def assert_invoked_function_call(messages: List[LettaMessage], function_name: str) -> None:
def assert_invoked_function_call(messages: Sequence[LettaMessage], function_name: str) -> None:
for message in messages:
if isinstance(message, ToolCallMessage) and message.tool_call.name == function_name:
# Found it, do nothing
Expand Down
Loading

0 comments on commit 8cc6870

Please sign in to comment.