-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add a proxy for cohere-embed-english-v3 [DCH-126] (#5)
* feat: Add a proxy for cohere-embed-english-v3 * feat: Add a Docker release * feaet: Add CI testing * fix: Switch directory in testing * fix: Attempt to fix directory management * fix: Attempt to fix directory in CI * feat: Add mypy/ruff dependencies * fix: Formatting issues * fix formatting * fix: ruff and env variable fixes
- Loading branch information
1 parent
3679fae
commit db66052
Showing
20 changed files
with
1,410 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
name: Python Tests | ||
|
||
on: | ||
push: | ||
branches: | ||
- main | ||
pull_request: | ||
|
||
jobs: | ||
unit: | ||
env: | ||
PROXY_KEY: fake | ||
AWS_REGION: fake | ||
AWS_ACCESS_KEY: fake | ||
AWS_SECRET_ACCESS_KEY: fake | ||
|
||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- name: Install uv | ||
uses: astral-sh/setup-uv@v3 | ||
- name: Set up Python | ||
uses: actions/setup-python@v5 | ||
with: | ||
python-version-file: "./proxy/.python-version" | ||
- name: Install the project | ||
run: uv sync --all-extras --dev --directory proxy | ||
- name: Run tests | ||
id: run-tests | ||
run: > | ||
uv run --directory proxy pytest . \ | ||
--junitxml=pytest.xml \ | ||
--cov-report=term-missing:skip-covered \ | ||
--cov-report=xml:coverage.xml \ | ||
--cov=src tests \ | ||
--log-level=DEBUG \ | ||
--verbose | ||
- name: Upload coverage to Codecov | ||
uses: codecov/codecov-action@v4 | ||
with: | ||
token: ${{ secrets.CODECOV_TOKEN }} | ||
verbose: true | ||
|
||
ruff: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- name: Change Directory | ||
run: cd proxy | ||
- uses: chartboost/ruff-action@v1 | ||
|
||
mypy: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- name: Install uv | ||
uses: astral-sh/setup-uv@v3 | ||
- name: Set up Python | ||
uses: actions/setup-python@v5 | ||
with: | ||
python-version-file: "./proxy/.python-version" | ||
- name: Install the project | ||
run: uv sync --all-extras --dev --directory proxy | ||
- name: Run mypy | ||
run: uv run --directory proxy mypy . |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,3 +45,40 @@ jobs: | |
tags: ${{ steps.meta.outputs.tags }} | ||
labels: ${{ steps.meta.outputs.labels }} | ||
|
||
push_docbot_proxy: | ||
name: Push Docker image to Docker Hub | ||
runs-on: ubuntu-latest | ||
steps: | ||
- name: Check out the repo | ||
uses: actions/checkout@v4 | ||
with: | ||
lfs: true | ||
|
||
- name: Set up QEMU | ||
uses: docker/setup-qemu-action@v3 | ||
|
||
- name: Set up Docker Buildx | ||
uses: docker/setup-buildx-action@v3 | ||
|
||
- name: Log in to Docker Hub | ||
uses: docker/login-action@v3 | ||
with: | ||
username: ${{ secrets.DOCKER_USERNAME }} | ||
password: ${{ secrets.DOCKER_PASSWORD }} | ||
|
||
- name: Extract metadata (tags, labels) for Docker | ||
id: meta | ||
uses: docker/[email protected] | ||
with: | ||
images: cmidair/doc-bot-proxy | ||
|
||
- name: Build and push Docker image | ||
uses: docker/[email protected] | ||
with: | ||
platforms: linux/amd64,linux/arm64 | ||
context: ./proxy | ||
file: ./proxy/Dockerfile | ||
push: true | ||
tags: ${{ steps.meta.outputs.tags }} | ||
labels: ${{ steps.meta.outputs.labels }} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,4 +2,6 @@ | |
.venv | ||
.pytest_cache | ||
.open-webui | ||
local | ||
local | ||
*.pyc | ||
__pycache__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
3.12 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim | ||
|
||
WORKDIR /app | ||
|
||
EXPOSE 8080 | ||
|
||
COPY pyproject.toml uv.lock .python-version ./ | ||
RUN uv sync --locked --no-cache | ||
|
||
COPY ./app ./app | ||
|
||
CMD ["uv", "run", "fastapi", "run", "app/main.py", "--port", "8080"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Proxy server for AI models with the OpenAI interface.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Core configurations and functionality for the server.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
"""Authentication functionality for the server.""" | ||
|
||
import fastapi | ||
from fastapi import status | ||
from fastapi.security import api_key | ||
|
||
from app.core import config | ||
|
||
settings = config.get_settings() | ||
|
||
|
||
async def check_api_key( | ||
api_key: str = fastapi.Depends(api_key.APIKeyHeader(name="Authorization")), | ||
) -> None: | ||
"""Checks whether the API key is provided in the Authorizaion header. | ||
Args: | ||
api_key: API key, provided in the header. | ||
""" | ||
if api_key == "Bearer " + settings.PROXY_KEY.get_secret_value(): | ||
return | ||
raise fastapi.HTTPException(status.HTTP_401_UNAUTHORIZED, "Unauthorized.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
"""Configurations for the application.""" | ||
|
||
import functools | ||
import logging | ||
|
||
import pydantic | ||
import pydantic_settings | ||
|
||
|
||
class Settings(pydantic_settings.BaseSettings): | ||
"""App settings.""" | ||
|
||
PROXY_KEY: pydantic.SecretStr = pydantic.Field(...) | ||
|
||
AWS_REGION: str = pydantic.Field(...) | ||
AWS_ACCESS_KEY: pydantic.SecretStr = pydantic.Field(...) | ||
AWS_SECRET_ACCESS_KEY: pydantic.SecretStr = pydantic.Field(...) | ||
|
||
LOGGER_VERBOSITY: int = logging.INFO | ||
|
||
|
||
@functools.lru_cache | ||
def get_settings() -> Settings: | ||
"""Gets the app settings.""" | ||
return Settings() | ||
|
||
|
||
def get_logger() -> logging.Logger: | ||
"""Gets the logger.""" | ||
logger = logging.getLogger("ai-proxy") | ||
if logger.hasHandlers(): | ||
return logger | ||
|
||
logger.setLevel(get_settings().LOGGER_VERBOSITY) | ||
logger.propagate = False | ||
|
||
formatter = logging.Formatter( | ||
"%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)s - %(funcName)s - %(message)s", # noqa: E501 | ||
) | ||
|
||
handler = logging.StreamHandler() | ||
handler.setFormatter(formatter) | ||
logger.addHandler(handler) | ||
|
||
return logger |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
"""Entrypoint for the server.""" | ||
|
||
import fastapi | ||
from fastapi import status | ||
|
||
from app.core import auth, config | ||
from app.routers.embeddings import views as embeddings_views | ||
|
||
logger = config.get_logger() | ||
|
||
app = fastapi.FastAPI( | ||
title="Model Proxy", | ||
description=( | ||
"A proxy for AI models not well supported by LiteLLM. The goal is to " | ||
"convert these to LiteLLM if/when their support improves." | ||
), | ||
dependencies=[fastapi.Depends(auth.check_api_key)], | ||
responses={status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized."}}, | ||
) | ||
|
||
version_router = fastapi.APIRouter(prefix="/v1") | ||
version_router.include_router(embeddings_views.router) | ||
app.include_router(version_router) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Embeddings router.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
"""Controller for the embeddings endpoints.""" | ||
|
||
import json | ||
from collections.abc import Iterable | ||
|
||
import boto3 | ||
import fastapi | ||
import pydantic | ||
from fastapi import status | ||
|
||
from app.core import config | ||
from app.routers.embeddings import schemas | ||
|
||
settings = config.get_settings() | ||
logger = config.get_logger() | ||
|
||
|
||
def post_embedding( | ||
payload: schemas.PostEmbeddingRequest, | ||
) -> schemas.PostEmbeddingResponse: | ||
"""Gets the embedding of a string. | ||
Args: | ||
payload: The request body. | ||
Returns: | ||
The embedding response. | ||
""" | ||
if payload.provider == "aws": | ||
logger.debug("Running Azure Embedding.") | ||
return _run_aws_embedding(payload) | ||
raise fastapi.HTTPException( | ||
status.HTTP_400_BAD_REQUEST, | ||
detail="Unknown model provider.", | ||
) | ||
|
||
|
||
@pydantic.dataclasses.dataclass | ||
class CohereEmbeddingResponse: | ||
"""Dataclass for the response of Cohere's embedding models.""" | ||
|
||
embeddings: list[list[float]] | ||
id: str | ||
response_type: str | ||
texts: list[str] | ||
|
||
|
||
def _run_aws_embedding( | ||
payload: schemas.PostEmbeddingRequest, | ||
) -> schemas.PostEmbeddingResponse: | ||
"""Runs an embedding on AWS. | ||
Args: | ||
payload: The payload as provided to the POST embedding endpoint. | ||
Returns: | ||
The embedding response. | ||
""" | ||
if isinstance(payload.input, str): | ||
payload.input = [payload.input] | ||
|
||
responses = [] | ||
n_chunks_per_request = 64 | ||
for index in range(0, len(payload.input), n_chunks_per_request): | ||
inputs = payload.input[ | ||
index : min(index + n_chunks_per_request, len(payload.input)) | ||
] | ||
responses.append( | ||
_get_cohere_response(inputs, payload.model_name), | ||
) | ||
|
||
position = 0 | ||
embedding_data = [] | ||
for response in responses: | ||
for index, text in enumerate(response.texts): | ||
embedding_data.append( | ||
schemas.EmbeddingData( | ||
index=position, | ||
embedding=response.embeddings[index], | ||
), | ||
) | ||
position += len(text) | ||
|
||
return schemas.PostEmbeddingResponse( | ||
data=embedding_data, | ||
model=payload.model, | ||
) | ||
|
||
|
||
def _get_cohere_response(inputs: Iterable[str], model: str) -> CohereEmbeddingResponse: | ||
"""Gets the AWS response for Cohere models. | ||
Args: | ||
inputs: List of strings to embed. | ||
model: The model to use for embedding. | ||
Returns: | ||
The embedding response. | ||
""" | ||
body = json.dumps( | ||
{ | ||
"texts": inputs, | ||
"input_type": "search_document", | ||
}, | ||
) | ||
|
||
bedrock = boto3.client( | ||
service_name="bedrock-runtime", | ||
region_name=settings.AWS_REGION, | ||
aws_access_key_id=settings.AWS_ACCESS_KEY.get_secret_value(), | ||
aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY.get_secret_value(), | ||
) | ||
|
||
response = bedrock.invoke_model( | ||
body=body, | ||
modelId=model, | ||
accept="application/json", | ||
contentType="application/json", | ||
) | ||
|
||
response_body = json.loads(response.get("body").read()) | ||
return CohereEmbeddingResponse(**response_body) |
Oops, something went wrong.