Skip to content

Commit

Permalink
merge main
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders committed Dec 18, 2024
2 parents 4fcd939 + c1f7b32 commit f2caeb4
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 138 deletions.
77 changes: 32 additions & 45 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,6 @@ def __init__(

self.user = user

# link tools
self.link_tools(agent_state.tools)

# initialize a tool rules solver
if agent_state.tool_rules:
# if there are tool rules, print out a warning
Expand Down Expand Up @@ -425,11 +422,21 @@ def update_memory_if_change(self, new_memory: Memory) -> bool:
return True
return False

def execute_tool_and_persist_state(self, function_name, function_to_call, function_args):
def execute_tool_and_persist_state(self, function_name: str, function_args: dict, target_letta_tool: Tool):
"""
Execute tool modifications and persist the state of the agent.
Note: only some agent state modifications will be persisted, such as data in the AgentState ORM and block data
"""
# TODO: Get rid of this. This whole piece is pretty shady, that we exec the function to just get the type hints for args.
env = {}
env.update(globals())
exec(target_letta_tool.source_code, env)
callable_func = env[target_letta_tool.json_schema["name"]]
spec = inspect.getfullargspec(callable_func).annotations
for name, arg in function_args.items():
if isinstance(function_args[name], dict):
function_args[name] = spec[name](**function_args[name])

# TODO: add agent manager here
orig_memory_str = self.agent_state.memory.compile()

Expand All @@ -442,11 +449,11 @@ def execute_tool_and_persist_state(self, function_name, function_to_call, functi
if function_name in BASE_TOOLS or function_name in O1_BASE_TOOLS:
# base tools are allowed to access the `Agent` object and run on the database
function_args["self"] = self # need to attach self to arg since it's dynamically linked
function_response = function_to_call(**function_args)
function_response = callable_func(**function_args)
else:
# execute tool in a sandbox
# TODO: allow agent_state to specify which sandbox to execute tools in
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.agent_state.created_by_id).run(
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.user).run(
agent_state=self.agent_state.__deepcopy__()
)
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
Expand All @@ -471,27 +478,6 @@ def messages(self) -> List[dict]:
def messages(self, value):
raise Exception("Modifying message list directly not allowed")

def link_tools(self, tools: List[Tool]):
"""Bind a tool object (schema + python function) to the agent object"""

# Store the functions schemas (this is passed as an argument to ChatCompletion)
self.functions = []
self.functions_python = {}
env = {}
env.update(globals())
for tool in tools:
try:
# WARNING: name may not be consistent?
# if tool.module: # execute the whole module
# exec(tool.module, env)
# else:
exec(tool.source_code, env)
self.functions_python[tool.json_schema["name"]] = env[tool.json_schema["name"]]
self.functions.append(tool.json_schema)
except Exception:
warnings.warn(f"WARNING: tool {tool.name} failed to link")
assert all([callable(f) for k, f in self.functions_python.items()]), self.functions_python

def _load_messages_from_recall(self, message_ids: List[str]) -> List[Message]:
"""Load a list of messages from recall storage"""

Expand Down Expand Up @@ -600,8 +586,12 @@ def _get_ai_reply(
"""Get response from LLM API with robust retry mechanism."""

allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names()
agent_state_tool_jsons = [t.json_schema for t in self.agent_state.tools]

allowed_functions = (
self.functions if not allowed_tool_names else [func for func in self.functions if func["name"] in allowed_tool_names]
agent_state_tool_jsons
if not allowed_tool_names
else [func for func in agent_state_tool_jsons if func["name"] in allowed_tool_names]
)

# For the first message, force the initial tool if one is specified
Expand All @@ -621,7 +611,7 @@ def _get_ai_reply(
messages=message_sequence,
user_id=self.agent_state.created_by_id,
functions=allowed_functions,
functions_python=self.functions_python,
# functions_python=self.functions_python, do we need this?
function_call=function_call,
first_message=first_message,
force_tool_call=force_tool_call,
Expand Down Expand Up @@ -730,10 +720,13 @@ def _handle_ai_response(
function_name = function_call.name
printd(f"Request to call function {function_name} with tool_call_id: {tool_call_id}")

# Failure case 1: function name is wrong
try:
function_to_call = self.functions_python[function_name]
except KeyError:
# Failure case 1: function name is wrong (not in agent_state.tools)
target_letta_tool = None
for t in self.agent_state.tools:
if t.name == function_name:
target_letta_tool = t

if not target_letta_tool:
error_msg = f"No function named {function_name}"
function_response = package_function_response(False, error_msg)
messages.append(
Expand Down Expand Up @@ -801,14 +794,8 @@ def _handle_ai_response(
# this is because the function/tool role message is only created once the function/tool has executed/returned
self.interface.function_message(f"Running {function_name}({function_args})", msg_obj=messages[-1])
try:
spec = inspect.getfullargspec(function_to_call).annotations

for name, arg in function_args.items():
if isinstance(function_args[name], dict):
function_args[name] = spec[name](**function_args[name])

# handle tool execution (sandbox) and state updates
function_response = self.execute_tool_and_persist_state(function_name, function_to_call, function_args)
function_response = self.execute_tool_and_persist_state(function_name, function_args, target_letta_tool)

# handle trunction
if function_name in ["conversation_search", "conversation_search_date", "archival_memory_search"]:
Expand All @@ -820,8 +807,7 @@ def _handle_ai_response(
truncate = True

# get the function response limit
tool_obj = [tool for tool in self.agent_state.tools if tool.name == function_name][0]
return_char_limit = tool_obj.return_char_limit
return_char_limit = target_letta_tool.return_char_limit
function_response_string = validate_function_response(
function_response, return_char_limit=return_char_limit, truncate=truncate
)
Expand Down Expand Up @@ -1565,9 +1551,10 @@ def get_context_window(self) -> ContextWindowOverview:
num_tokens_external_memory_summary = count_tokens(external_memory_summary)

# tokens taken up by function definitions
if self.functions:
available_functions_definitions = [ChatCompletionRequestTool(type="function", function=f) for f in self.functions]
num_tokens_available_functions_definitions = num_tokens_from_functions(functions=self.functions, model=self.model)
agent_state_tool_jsons = [t.json_schema for t in self.agent_state.tools]
if agent_state_tool_jsons:
available_functions_definitions = [ChatCompletionRequestTool(type="function", function=f) for f in agent_state_tool_jsons]
num_tokens_available_functions_definitions = num_tokens_from_functions(functions=agent_state_tool_jsons, model=self.model)
else:
available_functions_definitions = []
num_tokens_available_functions_definitions = 0
Expand Down
2 changes: 1 addition & 1 deletion letta/server/rest_api/routers/v1/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def run_tool_from_source(
tool_source_type=request.source_type,
tool_args=request.args,
tool_name=request.name,
user_id=actor.id,
actor=actor,
)
except LettaToolCreateError as e:
# HTTP 400 == Bad Request
Expand Down
66 changes: 31 additions & 35 deletions letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,39 +825,35 @@ def create_agent(
in_memory_agent_state = self.agent_manager.get_agent_by_id(agent_state.id, actor=actor)
return in_memory_agent_state

# TODO: This is not good!
# TODO: Ideally, this should ALL be handled by the ORM
# TODO: The main blocker here IS the _message updates
# def update_agent(
# self,
# agent_id: str,
# request: UpdateAgent,
# actor: User,
# ) -> AgentState:
# """Update the agents core memory block, return the new state"""
# # Update agent state in the db first
# agent_state = self.agent_manager.update_agent(agent_id=agent_id, agent_update=request, actor=actor)

# # Get the agent object (loaded in memory)
# letta_agent = self.load_agent(agent_id=agent_id, actor=actor)

# # TODO: Everything below needs to get removed, no updating anything in memory
# # update the system prompt
# if request.system:
# letta_agent.update_system_prompt(request.system)

# # update in-context messages
# if request.message_ids:
# # This means the user is trying to change what messages are in the message buffer
# # Internally this requires (1) pulling from recall,
# # then (2) setting the attributes ._messages and .state.message_ids
# letta_agent.set_message_buffer(message_ids=request.message_ids)

# # tools
# if request.tool_ids:
# letta_agent.link_tools(letta_agent.agent_state.tools)

# letta_agent.update_state()
# TODO: This is not good!
# TODO: Ideally, this should ALL be handled by the ORM
# TODO: The main blocker here IS the _message updates
# def update_agent(
# self,
# agent_id: str,
# request: UpdateAgent,
# actor: User,
# ) -> AgentState:
# """Update the agents core memory block, return the new state"""
# # Update agent state in the db first
# agent_state = self.agent_manager.update_agent(agent_id=agent_id, agent_update=request, actor=actor)

# # Get the agent object (loaded in memory)
# letta_agent = self.load_agent(agent_id=agent_id, actor=actor)

# # TODO: Everything below needs to get removed, no updating anything in memory
# # update the system prompt
# if request.system:
# letta_agent.update_system_prompt(request.system)

# # update in-context messages
# if request.message_ids:
# # This means the user is trying to change what messages are in the message buffer
# # Internally this requires (1) pulling from recall,
# # then (2) setting the attributes ._messages and .state.message_ids
# letta_agent.set_message_buffer(message_ids=request.message_ids)

letta_agent.update_state()

# return agent_state

Expand Down Expand Up @@ -1327,7 +1323,7 @@ def get_agent_context_window(

def run_tool_from_source(
self,
user_id: str,
actor: User,
tool_args: str,
tool_source: str,
tool_source_type: Optional[str] = None,
Expand Down Expand Up @@ -1355,7 +1351,7 @@ def run_tool_from_source(

# Next, attempt to run the tool with the sandbox
try:
sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args_dict, user_id, tool_object=tool).run(agent_state=agent_state)
sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args_dict, actor, tool_object=tool).run(agent_state=agent_state)
return FunctionReturn(
id="null",
function_call_id="null",
Expand Down
27 changes: 11 additions & 16 deletions letta/services/tool_execution_sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from letta.schemas.agent import AgentState
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult, SandboxType
from letta.schemas.tool import Tool
from letta.schemas.user import User
from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager
from letta.settings import tool_settings
from letta.utils import get_friendly_error_msg

Expand All @@ -38,14 +38,10 @@ class ToolExecutionSandbox:
# We make this a long random string to avoid collisions with any variables in the user's code
LOCAL_SANDBOX_RESULT_VAR_NAME = "result_ZQqiequkcFwRwwGQMqkt"

def __init__(self, tool_name: str, args: dict, user_id: str, force_recreate=False, tool_object: Optional[Tool] = None):
def __init__(self, tool_name: str, args: dict, user: User, force_recreate=False, tool_object: Optional[Tool] = None):
self.tool_name = tool_name
self.args = args

# Get the user
# This user corresponds to the agent_state's user_id field
# agent_state is the state of the agent that invoked this run
self.user = UserManager().get_user_by_id(user_id=user_id)
self.user = user

# If a tool object is provided, we use it directly, otherwise pull via name
if tool_object is not None:
Expand Down Expand Up @@ -184,7 +180,9 @@ def run_local_dir_sandbox_venv(self, sbx_config: SandboxConfig, env: Dict[str, s
except subprocess.CalledProcessError as e:
logger.error(f"Executing tool {self.tool_name} has process error: {e}")
func_return = get_friendly_error_msg(
function_name=self.tool_name, exception_name=type(e).__name__, exception_message=str(e),
function_name=self.tool_name,
exception_name=type(e).__name__,
exception_message=str(e),
)
return SandboxRunResult(
func_return=func_return,
Expand All @@ -202,9 +200,7 @@ def run_local_dir_sandbox_venv(self, sbx_config: SandboxConfig, env: Dict[str, s
logger.error(f"Executing tool {self.tool_name} has an unexpected error: {e}")
raise e

def run_local_dir_sandbox_runpy(
self, sbx_config: SandboxConfig, env_vars: Dict[str, str], temp_file_path: str
) -> SandboxRunResult:
def run_local_dir_sandbox_runpy(self, sbx_config: SandboxConfig, env_vars: Dict[str, str], temp_file_path: str) -> SandboxRunResult:
status = "success"
agent_state, stderr = None, None

Expand All @@ -225,9 +221,7 @@ def run_local_dir_sandbox_runpy(
func_return, agent_state = self.parse_best_effort(func_result)

except Exception as e:
func_return = get_friendly_error_msg(
function_name=self.tool_name, exception_name=type(e).__name__, exception_message=str(e)
)
func_return = get_friendly_error_msg(function_name=self.tool_name, exception_name=type(e).__name__, exception_message=str(e))
traceback.print_exc(file=sys.stderr)
status = "error"

Expand All @@ -248,7 +242,7 @@ def run_local_dir_sandbox_runpy(

def parse_out_function_results_markers(self, text: str):
if self.LOCAL_SANDBOX_RESULT_START_MARKER not in text:
return '', text
return "", text
marker_len = len(self.LOCAL_SANDBOX_RESULT_START_MARKER)
start_index = text.index(self.LOCAL_SANDBOX_RESULT_START_MARKER) + marker_len
end_index = text.index(self.LOCAL_SANDBOX_RESULT_END_MARKER)
Expand Down Expand Up @@ -293,6 +287,7 @@ def run_e2b_sandbox(self, agent_state: AgentState) -> SandboxRunResult:
env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100)
code = self.generate_execution_script(agent_state=agent_state)
execution = sbx.run_code(code, envs=env_vars)

if execution.results:
func_return, agent_state = self.parse_best_effort(execution.results[0].text)
elif execution.error:
Expand All @@ -303,7 +298,7 @@ def run_e2b_sandbox(self, agent_state: AgentState) -> SandboxRunResult:
execution.logs.stderr.append(execution.error.traceback)
else:
raise ValueError(f"Tool {self.tool_name} returned execution with None")

return SandboxRunResult(
func_return=func_return,
agent_state=agent_state,
Expand Down
3 changes: 1 addition & 2 deletions tests/helpers/endpoints_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,7 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str) -> ChatComplet
llm_config=agent_state.llm_config,
user_id=str(uuid.UUID(int=1)), # dummy user_id
messages=agent._messages,
functions=agent.functions,
functions_python=agent.functions_python,
functions=[t.json_schema for t in agent.agent_state.tools],
)

# Basic check
Expand Down
Loading

0 comments on commit f2caeb4

Please sign in to comment.