Skip to content

Commit

Permalink
Add RouteLLM
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jul 19, 2024
1 parent e17d88c commit f741e6b
Show file tree
Hide file tree
Showing 3 changed files with 682 additions and 76 deletions.
34 changes: 29 additions & 5 deletions griptapecli/core/skatepark.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,21 @@
)
from .state import RunProcess, State

from routellm.controller import Controller

app = FastAPI()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


state = State()

controller = Controller(
routers=["mf"],
strong_model="gpt-4o",
weak_model="llama3-8b-instruct",
)

DEFAULT_QUEUE_DELAY = "2"

model_to_internal_model_map: dict[str, str] = {
Expand All @@ -66,7 +74,8 @@
"command-text": "cohere.command-text-v14",
"command-light-text": "cohere.command-light-text-v14",
"llama3-8b-instruct": "meta.llama3-8b-instruct-v1:0",
"meta.llama2-13b-chat-v1": "meta.llama2-13b-chat-v1",
"llama3-70b-instruct": "meta.llama3-70b-instruct-v1:0",
"llama2-13b-chat": "meta.llama2-13b-chat-v1",
"llama2-70b-chat": "meta.llama2-70b-chat-v1",
"mistral-7b-instruct": "mistral.mistral-7b-instruct-v0:2",
"mixtral-8x7b-instruct": "mistral.mixtral-8x7b-instruct-v0:1",
Expand Down Expand Up @@ -314,7 +323,7 @@ def list_run_logs(structure_run_id: str):

@app.post("/api/drivers/prompt", response_model=None, status_code=status.HTTP_200_OK)
def prompt_driver(value: PromptDriverRequestModel) -> dict:
driver = _get_prompt_driver_from_model(value.params)
driver = _get_prompt_driver_from_model(value.messages, value.params)

prompt_stack = PromptStack(
messages=[Message.from_dict(message) for message in value.messages]
Expand All @@ -332,7 +341,7 @@ def prompt_driver(value: PromptDriverRequestModel) -> dict:
"/api/drivers/prompt-stream", response_model=None, status_code=status.HTTP_200_OK
)
async def prompt_driver_stream(value: PromptDriverRequestModel) -> StreamingResponse:
driver = _get_prompt_driver_from_model(value.params)
driver = _get_prompt_driver_from_model(value.messages, value.params)

prompt_stack = PromptStack(
messages=[Message.from_dict(message) for message in value.messages]
Expand Down Expand Up @@ -410,9 +419,24 @@ def _set_structure_run_to_running(structure_run: StructureRun) -> StructureRun:
return structure_run


def _get_prompt_driver_from_model(params: dict) -> BasePromptDriver:
def _get_prompt_driver_from_model(
messages: list[dict], params: dict
) -> BasePromptDriver:
model = params["model"]
logger.info("Model %s", model)

if model == "auto":
if messages:
logger.info(messages[-1]["content"][0]["artifact"]["value"])
prompt = messages[-1]["content"][0]["artifact"]["value"]
else:
raise HTTPException(status_code=400, detail="No messages provided")

model = controller.route(
prompt=prompt,
router="mf",
threshold=0.11593,
)
logger.info("Routed model %s", model)

internal_model = model_to_internal_model_map.get(model)

Expand Down
Loading

0 comments on commit f741e6b

Please sign in to comment.