Skip to content

Commit

Permalink
feat: add GET REST API route for listing tools (#1100)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpacker authored Mar 6, 2024
1 parent 9280568 commit e6ffbc2
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 2 deletions.
23 changes: 22 additions & 1 deletion memgpt/metadata.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
""" Metadata store for user/agent/data_source information"""

import os
import inspect as python_inspect
import uuid
import secrets
from typing import Optional, List
Expand All @@ -9,8 +10,9 @@
from memgpt.utils import get_local_time, enforce_types
from memgpt.data_types import AgentState, Source, User, LLMConfig, EmbeddingConfig, Token, Preset
from memgpt.config import MemGPTConfig
from memgpt.functions.functions import load_all_function_sets

from memgpt.models.pydantic_models import PersonaModel, HumanModel
from memgpt.models.pydantic_models import PersonaModel, HumanModel, ToolModel

from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text, JSON, BLOB, BINARY, ARRAY, Boolean
from sqlalchemy import func
Expand Down Expand Up @@ -517,6 +519,25 @@ def list_presets(self, user_id: uuid.UUID) -> List[Preset]:
results = session.query(PresetModel).filter(PresetModel.user_id == user_id).all()
return [r.to_record() for r in results]

@enforce_types
def list_tools(self, user_id: uuid.UUID) -> List[ToolModel]:
with self.session_maker() as session:
available_functions = load_all_function_sets()
print(available_functions)
results = [
ToolModel(
name=k,
json_schema=v["json_schema"],
source_type="python",
source_code=python_inspect.getsource(v["python_function"]),
)
for k, v in available_functions.items()
]
print(results)
return results
# results = session.query(PresetModel).filter(PresetModel.user_id == user_id).all()
# return [r.to_record() for r in results]

@enforce_types
def list_agents(self, user_id: uuid.UUID) -> List[AgentState]:
with self.session_maker() as session:
Expand Down
10 changes: 9 additions & 1 deletion memgpt/models/pydantic_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Dict
from typing import List, Optional, Dict, Literal
from pydantic import BaseModel, Field, Json
import uuid
from datetime import datetime
Expand Down Expand Up @@ -36,6 +36,14 @@ class PresetModel(BaseModel):
functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.")


class ToolModel(BaseModel):
# TODO move into database
name: str = Field(..., description="The name of the function.")
json_schema: dict = Field(..., description="The JSON schema of the function.")
source_type: Optional[Literal["python"]] = Field(None, description="The type of the source code.")
source_code: Optional[str] = Field(..., description="The source code of the function.")


class AgentStateModel(BaseModel):
id: uuid.UUID = Field(..., description="The unique identifier of the agent.")
name: str = Field(..., description="The name of the agent.")
Expand Down
2 changes: 2 additions & 0 deletions memgpt/server/rest_api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from memgpt.server.rest_api.openai_assistants.assistants import setup_openai_assistant_router
from memgpt.server.rest_api.personas.index import setup_personas_index_router
from memgpt.server.rest_api.static_files import mount_static_files
from memgpt.server.rest_api.tools.index import setup_tools_index_router
from memgpt.server.server import SyncServer

"""
Expand Down Expand Up @@ -92,6 +93,7 @@ def verify_password(credentials: HTTPAuthorizationCredentials = Depends(security
app.include_router(setup_humans_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_personas_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_models_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_tools_index_router(server, interface, password), prefix=API_PREFIX)

# /api/config endpoints
app.include_router(setup_config_index_router(server, interface, password), prefix=API_PREFIX)
Expand Down
Empty file.
35 changes: 35 additions & 0 deletions memgpt/server/rest_api/tools/index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import uuid
from functools import partial
from typing import List

from fastapi import APIRouter, Depends, Body
from pydantic import BaseModel, Field

from memgpt.models.pydantic_models import ToolModel
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer

router = APIRouter()


class ListToolsResponse(BaseModel):
tools: List[ToolModel] = Field(..., description="List of tools (functions).")


def setup_tools_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)

@router.get("/tools", tags=["tools"], response_model=ListToolsResponse)
async def list_tools(
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Get a list of all tools available to agents created by a user
"""
# Clear the interface
interface.clear()
tools = server.ms.list_tools(user_id=user_id)
return ListToolsResponse(tools=tools)

return router

0 comments on commit e6ffbc2

Please sign in to comment.