Skip to content

Commit

Permalink
feat: Add a proxy for cohere-embed-english-v3 [DCH-126] (#5)
Browse files Browse the repository at this point in the history
* feat: Add a proxy for cohere-embed-english-v3

* feat: Add a Docker release

* feaet: Add CI testing

* fix: Switch directory in testing

* fix: Attempt to fix directory management

* fix: Attempt to fix directory in CI

* feat: Add mypy/ruff dependencies

* fix: Formatting issues

* fix formatting

* fix: ruff and env variable fixes
  • Loading branch information
ReinderVosDeWael authored Nov 7, 2024
1 parent 3679fae commit db66052
Show file tree
Hide file tree
Showing 20 changed files with 1,410 additions and 4 deletions.
65 changes: 65 additions & 0 deletions .github/workflows/proxy_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
name: Python Tests

on:
push:
branches:
- main
pull_request:

jobs:
unit:
env:
PROXY_KEY: fake
AWS_REGION: fake
AWS_ACCESS_KEY: fake
AWS_SECRET_ACCESS_KEY: fake

runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v3
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version-file: "./proxy/.python-version"
- name: Install the project
run: uv sync --all-extras --dev --directory proxy
- name: Run tests
id: run-tests
run: >
uv run --directory proxy pytest . \
--junitxml=pytest.xml \
--cov-report=term-missing:skip-covered \
--cov-report=xml:coverage.xml \
--cov=src tests \
--log-level=DEBUG \
--verbose
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
verbose: true

ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Change Directory
run: cd proxy
- uses: chartboost/ruff-action@v1

mypy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v3
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version-file: "./proxy/.python-version"
- name: Install the project
run: uv sync --all-extras --dev --directory proxy
- name: Run mypy
run: uv run --directory proxy mypy .
37 changes: 37 additions & 0 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,40 @@ jobs:
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}

push_docbot_proxy:
name: Push Docker image to Docker Hub
runs-on: ubuntu-latest
steps:
- name: Check out the repo
uses: actions/checkout@v4
with:
lfs: true

- name: Set up QEMU
uses: docker/setup-qemu-action@v3

- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

- name: Log in to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}

- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/[email protected]
with:
images: cmidair/doc-bot-proxy

- name: Build and push Docker image
uses: docker/[email protected]
with:
platforms: linux/amd64,linux/arm64
context: ./proxy
file: ./proxy/Dockerfile
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}

4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@
.venv
.pytest_cache
.open-webui
local
local
*.pyc
__pycache__
22 changes: 19 additions & 3 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,32 @@ services:
target: /app/config.yaml
- action: rebuild
path: .env
- action: rebuild
path: docker-compose.yaml

open_webui:
image: ghcr.io/open-webui/open-webui
image: ghcr.io/open-webui/open-webui:dev
ports:
- 3000:8080
env_file:
- .env
volumes:
- ./.open-webui:/app/backend/data

ai_proxy:
build:
context: proxy
dockerfile: Dockerfile
ports:
- 8080:8080
env_file:
- .env
develop:
watch:
- action: rebuild
path: ./proxy/app
target: /app/app
- action: rebuild
path: ./proxy/uv.lock




1 change: 1 addition & 0 deletions proxy/.python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.12
12 changes: 12 additions & 0 deletions proxy/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim

WORKDIR /app

EXPOSE 8080

COPY pyproject.toml uv.lock .python-version ./
RUN uv sync --locked --no-cache

COPY ./app ./app

CMD ["uv", "run", "fastapi", "run", "app/main.py", "--port", "8080"]
1 change: 1 addition & 0 deletions proxy/app/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Proxy server for AI models with the OpenAI interface."""
1 change: 1 addition & 0 deletions proxy/app/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Core configurations and functionality for the server."""
22 changes: 22 additions & 0 deletions proxy/app/core/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Authentication functionality for the server."""

import fastapi
from fastapi import status
from fastapi.security import api_key

from app.core import config

settings = config.get_settings()


async def check_api_key(
api_key: str = fastapi.Depends(api_key.APIKeyHeader(name="Authorization")),
) -> None:
"""Checks whether the API key is provided in the Authorizaion header.
Args:
api_key: API key, provided in the header.
"""
if api_key == "Bearer " + settings.PROXY_KEY.get_secret_value():
return
raise fastapi.HTTPException(status.HTTP_401_UNAUTHORIZED, "Unauthorized.")
45 changes: 45 additions & 0 deletions proxy/app/core/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Configurations for the application."""

import functools
import logging

import pydantic
import pydantic_settings


class Settings(pydantic_settings.BaseSettings):
"""App settings."""

PROXY_KEY: pydantic.SecretStr = pydantic.Field(...)

AWS_REGION: str = pydantic.Field(...)
AWS_ACCESS_KEY: pydantic.SecretStr = pydantic.Field(...)
AWS_SECRET_ACCESS_KEY: pydantic.SecretStr = pydantic.Field(...)

LOGGER_VERBOSITY: int = logging.INFO


@functools.lru_cache
def get_settings() -> Settings:
"""Gets the app settings."""
return Settings()


def get_logger() -> logging.Logger:
"""Gets the logger."""
logger = logging.getLogger("ai-proxy")
if logger.hasHandlers():
return logger

logger.setLevel(get_settings().LOGGER_VERBOSITY)
logger.propagate = False

formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)s - %(funcName)s - %(message)s", # noqa: E501
)

handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)

return logger
23 changes: 23 additions & 0 deletions proxy/app/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Entrypoint for the server."""

import fastapi
from fastapi import status

from app.core import auth, config
from app.routers.embeddings import views as embeddings_views

logger = config.get_logger()

app = fastapi.FastAPI(
title="Model Proxy",
description=(
"A proxy for AI models not well supported by LiteLLM. The goal is to "
"convert these to LiteLLM if/when their support improves."
),
dependencies=[fastapi.Depends(auth.check_api_key)],
responses={status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized."}},
)

version_router = fastapi.APIRouter(prefix="/v1")
version_router.include_router(embeddings_views.router)
app.include_router(version_router)
1 change: 1 addition & 0 deletions proxy/app/routers/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Embeddings router."""
122 changes: 122 additions & 0 deletions proxy/app/routers/embeddings/controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""Controller for the embeddings endpoints."""

import json
from collections.abc import Iterable

import boto3
import fastapi
import pydantic
from fastapi import status

from app.core import config
from app.routers.embeddings import schemas

settings = config.get_settings()
logger = config.get_logger()


def post_embedding(
payload: schemas.PostEmbeddingRequest,
) -> schemas.PostEmbeddingResponse:
"""Gets the embedding of a string.
Args:
payload: The request body.
Returns:
The embedding response.
"""
if payload.provider == "aws":
logger.debug("Running Azure Embedding.")
return _run_aws_embedding(payload)
raise fastapi.HTTPException(
status.HTTP_400_BAD_REQUEST,
detail="Unknown model provider.",
)


@pydantic.dataclasses.dataclass
class CohereEmbeddingResponse:
"""Dataclass for the response of Cohere's embedding models."""

embeddings: list[list[float]]
id: str
response_type: str
texts: list[str]


def _run_aws_embedding(
payload: schemas.PostEmbeddingRequest,
) -> schemas.PostEmbeddingResponse:
"""Runs an embedding on AWS.
Args:
payload: The payload as provided to the POST embedding endpoint.
Returns:
The embedding response.
"""
if isinstance(payload.input, str):
payload.input = [payload.input]

responses = []
n_chunks_per_request = 64
for index in range(0, len(payload.input), n_chunks_per_request):
inputs = payload.input[
index : min(index + n_chunks_per_request, len(payload.input))
]
responses.append(
_get_cohere_response(inputs, payload.model_name),
)

position = 0
embedding_data = []
for response in responses:
for index, text in enumerate(response.texts):
embedding_data.append(
schemas.EmbeddingData(
index=position,
embedding=response.embeddings[index],
),
)
position += len(text)

return schemas.PostEmbeddingResponse(
data=embedding_data,
model=payload.model,
)


def _get_cohere_response(inputs: Iterable[str], model: str) -> CohereEmbeddingResponse:
"""Gets the AWS response for Cohere models.
Args:
inputs: List of strings to embed.
model: The model to use for embedding.
Returns:
The embedding response.
"""
body = json.dumps(
{
"texts": inputs,
"input_type": "search_document",
},
)

bedrock = boto3.client(
service_name="bedrock-runtime",
region_name=settings.AWS_REGION,
aws_access_key_id=settings.AWS_ACCESS_KEY.get_secret_value(),
aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY.get_secret_value(),
)

response = bedrock.invoke_model(
body=body,
modelId=model,
accept="application/json",
contentType="application/json",
)

response_body = json.loads(response.get("body").read())
return CohereEmbeddingResponse(**response_body)
Loading

0 comments on commit db66052

Please sign in to comment.