-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Enable basic sandboxed tool run functionality #1938
Changes from 2 commits
b41a659
2233cce
ea8f140
1936a53
a57d2d5
5a2fb5c
8b66e50
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
from abc import ABC, abstractmethod | ||
from typing import List, Literal, Optional, Tuple, Union | ||
|
||
from e2b_code_interpreter import Sandbox | ||
from tqdm import tqdm | ||
|
||
from letta.agent_store.storage import StorageConnector | ||
|
@@ -56,8 +57,10 @@ | |
from letta.utils import ( | ||
count_tokens, | ||
get_local_time, | ||
get_source_code_for_execution, | ||
get_tool_call_id, | ||
get_utc_time, | ||
is_foreign_tool, | ||
is_utc_datetime, | ||
json_dumps, | ||
json_loads, | ||
|
@@ -651,6 +654,29 @@ def _handle_ai_response( | |
|
||
function_args["self"] = self # need to attach self to arg since it's dynamically linked | ||
|
||
matching_tools = [tool for tool in self.tools if tool.name == function_name] | ||
tool = matching_tools[0] if matching_tools else None | ||
# Execute tool in sandbox | ||
if tool and is_foreign_tool(tool): | ||
sbx = Sandbox() | ||
|
||
# TODO: install dependencies | ||
# sbx.commands.run(f"pip3 install {package}") | ||
|
||
code = get_source_code_for_execution(function_name, function_args, tool) | ||
|
||
execution = sbx.run_code(code) | ||
if execution.error is not None: | ||
raise execution.error | ||
elif len(execution.results) == 0: | ||
function_response = "" | ||
else: | ||
function_response = execution.results[0].text | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there any issue with typing here - e.g. if the function returns a list? does this assume the function is always returning a string? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. e2b will return lists as a string and store it in the text field! i.e.
It seems like the primitive types will show up here, but there are more complex types that have their own field. Should be straightforward to extend as needed in the future. |
||
|
||
sbx.kill() | ||
else: | ||
function_response = function_to_call(**function_args) | ||
|
||
function_response = function_to_call(**function_args) | ||
if function_name in ["conversation_search", "conversation_search_date", "archival_memory_search"]: | ||
# with certain functions we rely on the paging mechanism to handle overflow | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
|
||
import demjson3 as demjson | ||
import pytz | ||
from letta.schemas.tool import Tool | ||
import tiktoken | ||
|
||
import letta | ||
|
@@ -1071,3 +1072,30 @@ def safe_serializer(obj): | |
|
||
def json_loads(data): | ||
return json.loads(data, strict=False) | ||
|
||
|
||
def is_foreign_tool(tool: Tool): | ||
return "foreign" in tool.tags | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just adding this function as a placeholder for now - is there a good way to programmatically determine whether a tool is unvetted right now? In the long term, we proposed having a column in the tools table that is a signature created by letta server so that it can be generated for imported tools from trusted sources |
||
|
||
|
||
def get_source_code_for_execution(function_name: str, function_args: dict, tool: Tool) -> str: | ||
code = "" | ||
# 1. Set params | ||
for param in function_args: | ||
if param != "self": | ||
code += param + ' = "' + function_args[param] + '"\n' | ||
|
||
# 2. Add function source code | ||
code += tool.source_code + "\n" | ||
|
||
# 3. Add function call | ||
code += function_name + "(" | ||
|
||
# 4. Populate params for function call | ||
for param in function_args: | ||
if param != "self": | ||
code += param + "," | ||
code += ")" | ||
|
||
# 5. Admire result | ||
return code |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -162,3 +162,68 @@ def core_memory_clear(self: Agent): | |
|
||
def test_custom_import_tool(client): | ||
pass | ||
|
||
|
||
def test_run_basic_tool_in_sandbox(client: Union[LocalClient, RESTClient]): | ||
"""Test creation of a simple tool with no input params""" | ||
|
||
def print_hello_world(): | ||
""" | ||
Returns: | ||
str: A static string "Hello world". | ||
|
||
""" | ||
print("hello world") | ||
return "hello world" | ||
|
||
tools = client.list_tools() | ||
print(f"Original tools {[t.name for t in tools]}") | ||
|
||
tool = client.create_tool(print_hello_world, name="print_hello_world", tags=["extras", "foreign"]) | ||
|
||
tools = client.list_tools() | ||
assert tool in tools, f"Expected {tool.name} in {[t.name for t in tools]}" | ||
print(f"Updated tools {[t.name for t in tools]}") | ||
|
||
# check tool id | ||
tool = client.get_tool(tool.id) | ||
assert tool is not None, "Expected tool to be created" | ||
assert tool.id == tool.id, f"Expected {tool.id} to be {tool.id}" | ||
|
||
# create agent with tool | ||
agent_state = client.create_agent(tools=[tool.name]) | ||
response = client.user_message(agent_id=agent_state.id, message="hi please use the tool called print_hello_world") | ||
|
||
|
||
def test_run_tool_with_str_params_in_sandbox(client: Union[LocalClient, RESTClient]): | ||
"""Test creation of a simple tool that relies on a provided string param""" | ||
|
||
def print_message(message: str): | ||
""" | ||
Args: | ||
message (str): The message to print. | ||
|
||
Returns: | ||
str: A static string "Hello world". | ||
|
||
""" | ||
print(message) | ||
return message | ||
|
||
tools = client.list_tools() | ||
print(f"Original tools {[t.name for t in tools]}") | ||
|
||
tool = client.create_tool(print_message, name="print_message", tags=["extras", "foreign"]) | ||
|
||
tools = client.list_tools() | ||
assert tool in tools, f"Expected {tool.name} in {[t.name for t in tools]}" | ||
print(f"Updated tools {[t.name for t in tools]}") | ||
|
||
# check tool id | ||
tool = client.get_tool(tool.id) | ||
assert tool is not None, "Expected tool to be created" | ||
assert tool.id == tool.id, f"Expected {tool.id} to be {tool.id}" | ||
|
||
# create agent with tool | ||
agent_state = client.create_agent(tools=[tool.name]) | ||
response = client.user_message(agent_id=agent_state.id, message="hi please use the tool called print_message") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right now, on an individual test run I can manually verify that the function is running in the sandboxed environment using the e2b debug logs:
Not ideal because if the function doesn't run on the sandbox for whatever reason the test suite will still consider this as passed. One option is leaving the sandbox running with a timeout so that I can still interact with it during the test after user_message returns, but I'd prefer if we consistently kill the server after execution to prevent future bugs so looking into other e2b suggested options! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know what file I'd need to modify if I wanted to ensure
pip install e2b-code-interpreter
gets run during the poetry install step?