Skip to content

Commit

Permalink
tool-call: make agent async
Browse files Browse the repository at this point in the history
  • Loading branch information
ochafik committed Sep 28, 2024
1 parent 05bbba9 commit ef2a020
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 87 deletions.
178 changes: 93 additions & 85 deletions examples/agent/run.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "aiohttp",
# "fastapi",
# "openai",
# "pydantic",
# "requests",
# "uvicorn",
# "typer",
# "uvicorn",
# ]
# ///
import json
import openai
import asyncio
import aiohttp
from functools import wraps
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolMessageParam, ChatCompletionUserMessageParam
from pydantic import BaseModel
import requests
import sys
import typer

Check failure on line 20 in examples/agent/run.py

View workflow job for this annotation

GitHub Actions / pyright type-check

Import "typer" could not be resolved (reportMissingImports)
from typing import Annotated, Optional
import urllib.parse


class OpenAPIMethod:
def __init__(self, url, name, descriptor, catalog):
'''
Wraps a remote OpenAPI method as a Python function.
Wraps a remote OpenAPI method as an async Python function.
'''
self.url = url
self.__name__ = name
Expand Down Expand Up @@ -69,7 +70,7 @@ def __init__(self, url, name, descriptor, catalog):
required=[name for name, param in self.parameters.items() if param['required']] + ([self.body['name']] if self.body and self.body['required'] else [])
)

def __call__(self, **kwargs):
async def __call__(self, session: aiohttp.ClientSession, **kwargs):
if self.body:
body = kwargs.pop(self.body['name'], None)
if self.body['required']:
Expand All @@ -86,53 +87,65 @@ def __call__(self, **kwargs):
assert param['in'] == 'query', 'Only query parameters are supported'
query_params[name] = value

params = "&".join(f"{name}={urllib.parse.quote(value)}" for name, value in query_params.items())
params = "&".join(f"{name}={urllib.parse.quote(str(value))}" for name, value in query_params.items() if value is not None)
url = f'{self.url}?{params}'
response = requests.post(url, json=body)
response.raise_for_status()
response_json = response.json()
async with session.post(url, json=body) as response:
response.raise_for_status()
response_json = await response.json()

return response_json

async def discover_tools(tool_endpoints: list[str], verbose: bool = False) -> tuple[dict, list]:
tool_map = {}
tools = []

async with aiohttp.ClientSession() as session:
for url in tool_endpoints:
assert url.startswith('http://') or url.startswith('https://'), f'Tools must be URLs, not local files: {url}'

catalog_url = f'{url}/openapi.json'
async with session.get(catalog_url) as response:
response.raise_for_status()
catalog = await response.json()

for path, descriptor in catalog['paths'].items():
fn = OpenAPIMethod(url=f'{url}{path}', name=path.replace('/', ' ').strip().replace(' ', '_'), descriptor=descriptor, catalog=catalog)
tool_map[fn.__name__] = fn
if verbose:
sys.stderr.write(f'# PARAMS SCHEMA ({fn.__name__}): {json.dumps(fn.parameters_schema, indent=2)}\n')
tools.append(dict(
type="function",
function=dict(
name=fn.__name__,
description=fn.__doc__ or '',
parameters=fn.parameters_schema,
)
)
)

return tool_map, tools

def main(
def typer_async_workaround():
'Adapted from https://github.com/fastapi/typer/issues/950#issuecomment-2351076467'
def decorator(f):
@wraps(f)
def wrapper(*args, **kwargs):
return asyncio.run(f(*args, **kwargs))
return wrapper
return decorator

@typer_async_workaround()
async def main(
goal: Annotated[str, typer.Option()],
api_key: str = '<unset>',
tool_endpoint: Optional[list[str]] = None,
max_iterations: Optional[int] = 10,
verbose: bool = False,
endpoint: str = "http://localhost:8080/v1/",
):
client = AsyncOpenAI(api_key=api_key, base_url=endpoint)

openai.api_key = api_key
openai.base_url = endpoint

tool_map = {}
tools = []

# Discover tools using OpenAPI catalogs at the provided endpoints.
for url in (tool_endpoint or []):
assert url.startswith('http://') or url.startswith('https://'), f'Tools must be URLs, not local files: {url}'

catalog_url = f'{url}/openapi.json'
catalog_response = requests.get(catalog_url)
catalog_response.raise_for_status()
catalog = catalog_response.json()

for path, descriptor in catalog['paths'].items():
fn = OpenAPIMethod(url=f'{url}{path}', name=path.replace('/', ' ').strip().replace(' ', '_'), descriptor=descriptor, catalog=catalog)
tool_map[fn.__name__] = fn
if verbose:
sys.stderr.write(f'# PARAMS SCHEMA ({fn.__name__}): {json.dumps(fn.parameters_schema, indent=2)}\n')
tools.append(dict(
type="function",
function=dict(
name=fn.__name__,
description=fn.__doc__ or '',
parameters=fn.parameters_schema,
)
)
)
tool_map, tools = await discover_tools(tool_endpoint or [], verbose)

sys.stdout.write(f'🛠️ {", ".join(tool_map.keys())}\n')

Expand All @@ -143,51 +156,46 @@ def main(
)
]

i = 0
while (max_iterations is None or i < max_iterations):

response = openai.chat.completions.create(
model="gpt-4o",
messages=messages,
tools=tools,
)

if verbose:
sys.stderr.write(f'# RESPONSE: {response}\n')

assert len(response.choices) == 1
choice = response.choices[0]

content = choice.message.content
if choice.finish_reason == "tool_calls":
messages.append(choice.message) # type: ignore
assert choice.message.tool_calls
for tool_call in choice.message.tool_calls:
if content:
print(f'💭 {content}')

args = json.loads(tool_call.function.arguments)
pretty_call = f'{tool_call.function.name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})'
sys.stdout.write(f'⚙️ {pretty_call}')
sys.stdout.flush()
tool_result = tool_map[tool_call.function.name](**args)
sys.stdout.write(f" → {tool_result}\n")
messages.append(ChatCompletionToolMessageParam(
tool_call_id=tool_call.id,
role="tool",
# name=tool_call.function.name,
content=json.dumps(tool_result),
# content=f'{pretty_call} = {tool_result}',
))
else:
assert content
print(content)
return

i += 1
async with aiohttp.ClientSession() as session:
for i in range(max_iterations or sys.maxsize):
response = await client.chat.completions.create(
model="gpt-4o",
messages=messages,
tools=tools,
)

if max_iterations is not None:
raise Exception(f"Failed to get a valid response after {max_iterations} tool calls")
if verbose:
sys.stderr.write(f'# RESPONSE: {response}\n')

assert len(response.choices) == 1
choice = response.choices[0]

content = choice.message.content
if choice.finish_reason == "tool_calls":
messages.append(choice.message) # type: ignore
assert choice.message.tool_calls
for tool_call in choice.message.tool_calls:
if content:
print(f'💭 {content}')

args = json.loads(tool_call.function.arguments)
pretty_call = f'{tool_call.function.name}({", ".join(f"{k}={v.model_dump_json() if isinstance(v, BaseModel) else json.dumps(v)}" for k, v in args.items())})'
sys.stdout.write(f'⚙️ {pretty_call}')
sys.stdout.flush()
tool_result = await tool_map[tool_call.function.name](session, **args)
sys.stdout.write(f" → {tool_result}\n")
messages.append(ChatCompletionToolMessageParam(
tool_call_id=tool_call.id,
role="tool",
content=json.dumps(tool_result),
))
else:
assert content
print(content)
return

if max_iterations is not None:
raise Exception(f"Failed to get a valid response after {max_iterations} tool calls")

if __name__ == '__main__':
typer.run(main)
2 changes: 1 addition & 1 deletion examples/agent/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def python(code: str) -> str:
Returns:
str: The output of the executed code.
"""
from IPython import InteractiveShell
from IPython.core.interactiveshell import InteractiveShell

Check failure on line 92 in examples/agent/tools.py

View workflow job for this annotation

GitHub Actions / pyright type-check

Import "IPython.core.interactiveshell" could not be resolved (reportMissingImports)
from io import StringIO
import sys

Expand Down
3 changes: 2 additions & 1 deletion requirements/requirements-agent.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
aiohttp
fastapi
ipython
openai
pydantic
requests
typer
uvicorn

0 comments on commit ef2a020

Please sign in to comment.