From e6ffbc26325c687e01f59a593048b93833793761 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Tue, 5 Mar 2024 22:11:24 -0800 Subject: [PATCH] feat: add `GET` REST API route for listing tools (#1100) --- memgpt/metadata.py | 23 +++++++++++++++- memgpt/models/pydantic_models.py | 10 ++++++- memgpt/server/rest_api/server.py | 2 ++ memgpt/server/rest_api/tools/__init__.py | 0 memgpt/server/rest_api/tools/index.py | 35 ++++++++++++++++++++++++ 5 files changed, 68 insertions(+), 2 deletions(-) create mode 100644 memgpt/server/rest_api/tools/__init__.py create mode 100644 memgpt/server/rest_api/tools/index.py diff --git a/memgpt/metadata.py b/memgpt/metadata.py index 36946be54f..0bc1cf615d 100644 --- a/memgpt/metadata.py +++ b/memgpt/metadata.py @@ -1,6 +1,7 @@ """ Metadata store for user/agent/data_source information""" import os +import inspect as python_inspect import uuid import secrets from typing import Optional, List @@ -9,8 +10,9 @@ from memgpt.utils import get_local_time, enforce_types from memgpt.data_types import AgentState, Source, User, LLMConfig, EmbeddingConfig, Token, Preset from memgpt.config import MemGPTConfig +from memgpt.functions.functions import load_all_function_sets -from memgpt.models.pydantic_models import PersonaModel, HumanModel +from memgpt.models.pydantic_models import PersonaModel, HumanModel, ToolModel from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text, JSON, BLOB, BINARY, ARRAY, Boolean from sqlalchemy import func @@ -517,6 +519,25 @@ def list_presets(self, user_id: uuid.UUID) -> List[Preset]: results = session.query(PresetModel).filter(PresetModel.user_id == user_id).all() return [r.to_record() for r in results] + @enforce_types + def list_tools(self, user_id: uuid.UUID) -> List[ToolModel]: + with self.session_maker() as session: + available_functions = load_all_function_sets() + print(available_functions) + results = [ + ToolModel( + name=k, + json_schema=v["json_schema"], + source_type="python", + source_code=python_inspect.getsource(v["python_function"]), + ) + for k, v in available_functions.items() + ] + print(results) + return results + # results = session.query(PresetModel).filter(PresetModel.user_id == user_id).all() + # return [r.to_record() for r in results] + @enforce_types def list_agents(self, user_id: uuid.UUID) -> List[AgentState]: with self.session_maker() as session: diff --git a/memgpt/models/pydantic_models.py b/memgpt/models/pydantic_models.py index eb8fac97af..32f9317d89 100644 --- a/memgpt/models/pydantic_models.py +++ b/memgpt/models/pydantic_models.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Dict +from typing import List, Optional, Dict, Literal from pydantic import BaseModel, Field, Json import uuid from datetime import datetime @@ -36,6 +36,14 @@ class PresetModel(BaseModel): functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.") +class ToolModel(BaseModel): + # TODO move into database + name: str = Field(..., description="The name of the function.") + json_schema: dict = Field(..., description="The JSON schema of the function.") + source_type: Optional[Literal["python"]] = Field(None, description="The type of the source code.") + source_code: Optional[str] = Field(..., description="The source code of the function.") + + class AgentStateModel(BaseModel): id: uuid.UUID = Field(..., description="The unique identifier of the agent.") name: str = Field(..., description="The name of the agent.") diff --git a/memgpt/server/rest_api/server.py b/memgpt/server/rest_api/server.py index 4fdfab8c1a..6afb738f22 100644 --- a/memgpt/server/rest_api/server.py +++ b/memgpt/server/rest_api/server.py @@ -20,6 +20,7 @@ from memgpt.server.rest_api.openai_assistants.assistants import setup_openai_assistant_router from memgpt.server.rest_api.personas.index import setup_personas_index_router from memgpt.server.rest_api.static_files import mount_static_files +from memgpt.server.rest_api.tools.index import setup_tools_index_router from memgpt.server.server import SyncServer """ @@ -92,6 +93,7 @@ def verify_password(credentials: HTTPAuthorizationCredentials = Depends(security app.include_router(setup_humans_index_router(server, interface, password), prefix=API_PREFIX) app.include_router(setup_personas_index_router(server, interface, password), prefix=API_PREFIX) app.include_router(setup_models_index_router(server, interface, password), prefix=API_PREFIX) +app.include_router(setup_tools_index_router(server, interface, password), prefix=API_PREFIX) # /api/config endpoints app.include_router(setup_config_index_router(server, interface, password), prefix=API_PREFIX) diff --git a/memgpt/server/rest_api/tools/__init__.py b/memgpt/server/rest_api/tools/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/memgpt/server/rest_api/tools/index.py b/memgpt/server/rest_api/tools/index.py new file mode 100644 index 0000000000..8137ae61ae --- /dev/null +++ b/memgpt/server/rest_api/tools/index.py @@ -0,0 +1,35 @@ +import uuid +from functools import partial +from typing import List + +from fastapi import APIRouter, Depends, Body +from pydantic import BaseModel, Field + +from memgpt.models.pydantic_models import ToolModel +from memgpt.server.rest_api.auth_token import get_current_user +from memgpt.server.rest_api.interface import QueuingInterface +from memgpt.server.server import SyncServer + +router = APIRouter() + + +class ListToolsResponse(BaseModel): + tools: List[ToolModel] = Field(..., description="List of tools (functions).") + + +def setup_tools_index_router(server: SyncServer, interface: QueuingInterface, password: str): + get_current_user_with_server = partial(partial(get_current_user, server), password) + + @router.get("/tools", tags=["tools"], response_model=ListToolsResponse) + async def list_tools( + user_id: uuid.UUID = Depends(get_current_user_with_server), + ): + """ + Get a list of all tools available to agents created by a user + """ + # Clear the interface + interface.clear() + tools = server.ms.list_tools(user_id=user_id) + return ListToolsResponse(tools=tools) + + return router