Skip to content

Commit

Permalink
feat: Add Azure OpenAI embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
ReinderVosDeWael committed Nov 14, 2024
1 parent 8280645 commit b61687b
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 6 deletions.
13 changes: 9 additions & 4 deletions proxy/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,16 @@
class Settings(pydantic_settings.BaseSettings):
"""App settings."""

PROXY_KEY: pydantic.SecretStr = pydantic.Field(...)
PROXY_KEY: pydantic.SecretStr

AWS_REGION: str = pydantic.Field(...)
AWS_ACCESS_KEY: pydantic.SecretStr = pydantic.Field(...)
AWS_SECRET_ACCESS_KEY: pydantic.SecretStr = pydantic.Field(...)
AWS_REGION: str
AWS_ACCESS_KEY: pydantic.SecretStr
AWS_SECRET_ACCESS_KEY: pydantic.SecretStr

AZURE_EMBEDDING_ENDPOINT: pydantic.HttpUrl
AZURE_EMBEDDING_DEPLOYMENT: str
AZURE_EMBEDDING_API_KEY: pydantic.SecretStr
AZURE_EMBEDDING_API_VERSION: str

LOGGER_VERBOSITY: int = logging.DEBUG

Expand Down
41 changes: 40 additions & 1 deletion proxy/app/routers/embeddings/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import boto3
import fastapi
import openai
import pydantic
from fastapi import status

Expand All @@ -27,8 +28,11 @@ def post_embedding(
The embedding response.
"""
if payload.provider == "aws":
logger.debug("Running Azure Embedding.")
logger.debug("Running AWS Embedding.")
return _run_aws_embedding(payload)
if payload.provider == "azure":
logger.debug("Running Azure Embedding.")
return _run_azure_embedding(payload)
raise fastapi.HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Unknown provider.",
Expand Down Expand Up @@ -120,3 +124,38 @@ def _get_cohere_response(inputs: Iterable[str], model: str) -> CohereEmbeddingRe

response_body = json.loads(response.get("body").read())
return CohereEmbeddingResponse(**response_body)


def _run_azure_embedding(
payload: schemas.PostEmbeddingRequest,
) -> schemas.PostEmbeddingResponse:
"""Gets the Azure response for an OpenAI embedding model.
Args:
payload: The request sent by the user.
Returns:
The embedding response.
"""
client = openai.AzureOpenAI(
azure_endpoint=str(settings.AZURE_EMBEDDING_ENDPOINT),
azure_deployment=settings.AZURE_EMBEDDING_DEPLOYMENT,
api_key=settings.AZURE_EMBEDDING_API_KEY.get_secret_value(),
api_version=settings.AZURE_EMBEDDING_API_VERSION,
)

openai_response = client.embeddings.create(
input=payload.input,
model=payload.model_name,
)

return schemas.PostEmbeddingResponse(
model=payload.model,
data=[
schemas.EmbeddingData(
index=embedding.index,
embedding=embedding.embedding,
)
for embedding in openai_response.data
],
)
2 changes: 1 addition & 1 deletion proxy/app/routers/embeddings/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pydantic

EMBEDDING_MODEL = Literal["aws/cohere.embed-english-v3"]
EMBEDDING_MODEL = Literal["aws/cohere.embed-english-v3", "azure/text-embedding-3-large"]


class PostEmbeddingRequest(pydantic.BaseModel):
Expand Down
1 change: 1 addition & 0 deletions proxy/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ dependencies = [
"pydantic-settings>=2.6.1",
"boto3>=1.35.54",
"pytest-cov>=6.0.0",
"openai>=1.54.4",
]

[tool.uv]
Expand Down
74 changes: 74 additions & 0 deletions proxy/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit b61687b

Please sign in to comment.