From 5cab79364af06f3219594e15c73b76adf83facb6 Mon Sep 17 00:00:00 2001 From: Reinder Vos de Wael Date: Thu, 14 Nov 2024 14:45:57 -0500 Subject: [PATCH] feat: Add Azure OpenAI embedding (#7) * feat: Add Azure OpenAI embedding * fix: set environment variables in testing * fix: set azure embedding endpoint to fake url --- .github/workflows/proxy_test.yaml | 4 ++ proxy/app/core/config.py | 13 ++-- proxy/app/routers/embeddings/controller.py | 41 +++++++++++- proxy/app/routers/embeddings/schemas.py | 2 +- proxy/pyproject.toml | 1 + proxy/uv.lock | 74 ++++++++++++++++++++++ 6 files changed, 129 insertions(+), 6 deletions(-) diff --git a/.github/workflows/proxy_test.yaml b/.github/workflows/proxy_test.yaml index 283d773..b1b2c07 100644 --- a/.github/workflows/proxy_test.yaml +++ b/.github/workflows/proxy_test.yaml @@ -13,6 +13,10 @@ jobs: AWS_REGION: fake AWS_ACCESS_KEY: fake AWS_SECRET_ACCESS_KEY: fake + AZURE_EMBEDDING_ENDPOINT: https://fake.com + AZURE_EMBEDDING_DEPLOYMENT: fake + AZURE_EMBEDDING_API_KEY: fake + AZURE_EMBEDDING_API_VERSION: fake runs-on: ubuntu-latest steps: diff --git a/proxy/app/core/config.py b/proxy/app/core/config.py index ced2a17..9e9c68f 100644 --- a/proxy/app/core/config.py +++ b/proxy/app/core/config.py @@ -10,11 +10,16 @@ class Settings(pydantic_settings.BaseSettings): """App settings.""" - PROXY_KEY: pydantic.SecretStr = pydantic.Field(...) + PROXY_KEY: pydantic.SecretStr - AWS_REGION: str = pydantic.Field(...) - AWS_ACCESS_KEY: pydantic.SecretStr = pydantic.Field(...) - AWS_SECRET_ACCESS_KEY: pydantic.SecretStr = pydantic.Field(...) + AWS_REGION: str + AWS_ACCESS_KEY: pydantic.SecretStr + AWS_SECRET_ACCESS_KEY: pydantic.SecretStr + + AZURE_EMBEDDING_ENDPOINT: pydantic.HttpUrl + AZURE_EMBEDDING_DEPLOYMENT: str + AZURE_EMBEDDING_API_KEY: pydantic.SecretStr + AZURE_EMBEDDING_API_VERSION: str LOGGER_VERBOSITY: int = logging.DEBUG diff --git a/proxy/app/routers/embeddings/controller.py b/proxy/app/routers/embeddings/controller.py index 5ffb508..89c10a6 100644 --- a/proxy/app/routers/embeddings/controller.py +++ b/proxy/app/routers/embeddings/controller.py @@ -5,6 +5,7 @@ import boto3 import fastapi +import openai import pydantic from fastapi import status @@ -27,8 +28,11 @@ def post_embedding( The embedding response. """ if payload.provider == "aws": - logger.debug("Running Azure Embedding.") + logger.debug("Running AWS Embedding.") return _run_aws_embedding(payload) + if payload.provider == "azure": + logger.debug("Running Azure Embedding.") + return _run_azure_embedding(payload) raise fastapi.HTTPException( status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Unknown provider.", @@ -120,3 +124,38 @@ def _get_cohere_response(inputs: Iterable[str], model: str) -> CohereEmbeddingRe response_body = json.loads(response.get("body").read()) return CohereEmbeddingResponse(**response_body) + + +def _run_azure_embedding( + payload: schemas.PostEmbeddingRequest, +) -> schemas.PostEmbeddingResponse: + """Gets the Azure response for an OpenAI embedding model. + + Args: + payload: The request sent by the user. + + Returns: + The embedding response. + """ + client = openai.AzureOpenAI( + azure_endpoint=str(settings.AZURE_EMBEDDING_ENDPOINT), + azure_deployment=settings.AZURE_EMBEDDING_DEPLOYMENT, + api_key=settings.AZURE_EMBEDDING_API_KEY.get_secret_value(), + api_version=settings.AZURE_EMBEDDING_API_VERSION, + ) + + openai_response = client.embeddings.create( + input=payload.input, + model=payload.model_name, + ) + + return schemas.PostEmbeddingResponse( + model=payload.model, + data=[ + schemas.EmbeddingData( + index=embedding.index, + embedding=embedding.embedding, + ) + for embedding in openai_response.data + ], + ) diff --git a/proxy/app/routers/embeddings/schemas.py b/proxy/app/routers/embeddings/schemas.py index 0ef5593..21950b9 100644 --- a/proxy/app/routers/embeddings/schemas.py +++ b/proxy/app/routers/embeddings/schemas.py @@ -4,7 +4,7 @@ import pydantic -EMBEDDING_MODEL = Literal["aws/cohere.embed-english-v3"] +EMBEDDING_MODEL = Literal["aws/cohere.embed-english-v3", "azure/text-embedding-3-large"] class PostEmbeddingRequest(pydantic.BaseModel): diff --git a/proxy/pyproject.toml b/proxy/pyproject.toml index 66323dd..9e965fb 100644 --- a/proxy/pyproject.toml +++ b/proxy/pyproject.toml @@ -10,6 +10,7 @@ dependencies = [ "pydantic-settings>=2.6.1", "boto3>=1.35.54", "pytest-cov>=6.0.0", + "openai>=1.54.4", ] [tool.uv] diff --git a/proxy/uv.lock b/proxy/uv.lock index e6f8278..ce0d449 100644 --- a/proxy/uv.lock +++ b/proxy/uv.lock @@ -34,6 +34,7 @@ source = { virtual = "." } dependencies = [ { name = "boto3" }, { name = "fastapi", extra = ["standard"] }, + { name = "openai" }, { name = "pydantic" }, { name = "pydantic-settings" }, { name = "pytest-cov" }, @@ -51,6 +52,7 @@ dev = [ requires-dist = [ { name = "boto3", specifier = ">=1.35.54" }, { name = "fastapi", extras = ["standard"], specifier = ">=0.115.4" }, + { name = "openai", specifier = ">=1.54.4" }, { name = "pydantic", specifier = ">=2.9.2" }, { name = "pydantic-settings", specifier = ">=2.6.1" }, { name = "pytest-cov", specifier = ">=6.0.0" }, @@ -160,6 +162,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7e/77/03fc2979d1538884d921c2013075917fc927f41cd8526909852fe4494112/coverage-7.6.4-cp313-cp313t-win_amd64.whl", hash = "sha256:f3ddf056d3ebcf6ce47bdaf56142af51bb7fad09e4af310241e9db7a3a8022e1", size = 211502 }, ] +[[package]] +name = "distro" +version = "1.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/fc/f8/98eea607f65de6527f8a2e8885fc8015d3e6f5775df186e443e0964a11c3/distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed", size = 60722 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277 }, +] + [[package]] name = "dnspython" version = "2.7.0" @@ -314,6 +325,38 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/31/80/3a54838c3fb461f6fec263ebf3a3a41771bd05190238de3486aae8540c36/jinja2-3.1.4-py3-none-any.whl", hash = "sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d", size = 133271 }, ] +[[package]] +name = "jiter" +version = "0.7.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/46/e5/50ff23c9bba2722d2f0f55ba51e57f7cbab9a4be758e6b9b263ef51e6024/jiter-0.7.1.tar.gz", hash = "sha256:448cf4f74f7363c34cdef26214da527e8eeffd88ba06d0b80b485ad0667baf5d", size = 162334 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/10/b3/de89eae8f57dc0ee5f6e3aa1ffcdee0364ef9ef85be81006fd17d7710ffa/jiter-0.7.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:ad36a1155cbd92e7a084a568f7dc6023497df781adf2390c345dd77a120905ca", size = 291900 }, + { url = "https://files.pythonhosted.org/packages/c0/ff/0d804eff4751fceeabc6311d4b07e956daa06fa58f05931887dc7454466b/jiter-0.7.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7ba52e6aaed2dc5c81a3d9b5e4ab95b039c4592c66ac973879ba57c3506492bb", size = 304390 }, + { url = "https://files.pythonhosted.org/packages/e8/26/c258bef532d113a7ac26242893fc9760040a4846dec731098b7f5ac3fca7/jiter-0.7.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2b7de0b6f6728b678540c7927587e23f715284596724be203af952418acb8a2d", size = 328710 }, + { url = "https://files.pythonhosted.org/packages/71/92/644dc215cbb9816112e28f3b43a8c8e769f083434a05fc3afd269c444f51/jiter-0.7.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9463b62bd53c2fb85529c700c6a3beb2ee54fde8bef714b150601616dcb184a6", size = 347569 }, + { url = "https://files.pythonhosted.org/packages/c6/02/795a3535262c54595bd97e375cc03b443717febb37723a7f9c077049825b/jiter-0.7.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:627164ec01d28af56e1f549da84caf0fe06da3880ebc7b7ee1ca15df106ae172", size = 373641 }, + { url = "https://files.pythonhosted.org/packages/7d/35/c7e9a06a49116e3618954f6c8a26816a7959c0f9e5617b0073e4145c5d6d/jiter-0.7.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:25d0e5bf64e368b0aa9e0a559c3ab2f9b67e35fe7269e8a0d81f48bbd10e8963", size = 388828 }, + { url = "https://files.pythonhosted.org/packages/fb/05/894144e4cbc1b9d46756db512268a90f84fc1d8bd28f1a17e0fef5aaf5c5/jiter-0.7.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c244261306f08f8008b3087059601997016549cb8bb23cf4317a4827f07b7d74", size = 325511 }, + { url = "https://files.pythonhosted.org/packages/19/d3/e6674ac34de53787504e4fb309084f824df321f24113121d94bf53808be3/jiter-0.7.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7ded4e4b75b68b843b7cea5cd7c55f738c20e1394c68c2cb10adb655526c5f1b", size = 365940 }, + { url = "https://files.pythonhosted.org/packages/e9/ca/c773f0ce186090cc69a2c97b8dab3dad14ae9988a657a20d879458a8407e/jiter-0.7.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:80dae4f1889b9d09e5f4de6b58c490d9c8ce7730e35e0b8643ab62b1538f095c", size = 515430 }, + { url = "https://files.pythonhosted.org/packages/16/5f/c98f6e6362fbc7c87ad384ba8506983fca9bb55ea0af7efcb23e7dd22817/jiter-0.7.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5970cf8ec943b51bce7f4b98d2e1ed3ada170c2a789e2db3cb484486591a176a", size = 497389 }, + { url = "https://files.pythonhosted.org/packages/30/60/f60e12469afc9096bac3df0fda53de707ed5105d84322a0d1bc4ad03ee3e/jiter-0.7.1-cp312-none-win32.whl", hash = "sha256:701d90220d6ecb3125d46853c8ca8a5bc158de8c49af60fd706475a49fee157e", size = 198546 }, + { url = "https://files.pythonhosted.org/packages/01/d2/d8ec257544f7991384a46fccee6abdc5065cfede26354bb2c86251858a92/jiter-0.7.1-cp312-none-win_amd64.whl", hash = "sha256:7824c3ecf9ecf3321c37f4e4d4411aad49c666ee5bc2a937071bdd80917e4533", size = 202792 }, + { url = "https://files.pythonhosted.org/packages/b5/cf/00a93a9968fc21b9ecfcabb130a8c822138594ac4a00b7bff9cbb38daa7f/jiter-0.7.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:097676a37778ba3c80cb53f34abd6943ceb0848263c21bf423ae98b090f6c6ba", size = 291039 }, + { url = "https://files.pythonhosted.org/packages/22/9a/0eb3eddffeca703f6adaaf117ba93ac3336fb323206259a86c2993cec9ad/jiter-0.7.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3298af506d4271257c0a8f48668b0f47048d69351675dd8500f22420d4eec378", size = 302468 }, + { url = "https://files.pythonhosted.org/packages/b1/95/b4da75e93752edfd6dd0df8f7723a6575e8a8bdce2e82f4458eb5564936a/jiter-0.7.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:12fd88cfe6067e2199964839c19bd2b422ca3fd792949b8f44bb8a4e7d21946a", size = 328401 }, + { url = "https://files.pythonhosted.org/packages/28/af/7fa53804a2e7e309ce66822c9484fd7d4f8ef452be3937aab8a93a82c54b/jiter-0.7.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:dacca921efcd21939123c8ea8883a54b9fa7f6545c8019ffcf4f762985b6d0c8", size = 347237 }, + { url = "https://files.pythonhosted.org/packages/30/0c/0b89bd3dce7d330d8ee878b0a95899b73e30cb55d2b2c41998276350d4a0/jiter-0.7.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de3674a5fe1f6713a746d25ad9c32cd32fadc824e64b9d6159b3b34fd9134143", size = 373558 }, + { url = "https://files.pythonhosted.org/packages/24/96/c75633b99d57dd8b8457f88f51201805c93b314e369fba69829d726bc2a5/jiter-0.7.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:65df9dbae6d67e0788a05b4bad5706ad40f6f911e0137eb416b9eead6ba6f044", size = 388251 }, + { url = "https://files.pythonhosted.org/packages/64/39/369e6ff198003f55acfcdb58169c774473082d3303cddcd24334af534c4e/jiter-0.7.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ba9a358d59a0a55cccaa4957e6ae10b1a25ffdabda863c0343c51817610501d", size = 325020 }, + { url = "https://files.pythonhosted.org/packages/80/26/0c386fa233a78997db5fa7b362e6f35a37d2656d09e521b0600f29933992/jiter-0.7.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:576eb0f0c6207e9ede2b11ec01d9c2182973986514f9c60bc3b3b5d5798c8f50", size = 365211 }, + { url = "https://files.pythonhosted.org/packages/21/4e/bfebe799924a39f181874b5e9041b792ee67768a8b160814e016a7c9a40d/jiter-0.7.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:e550e29cdf3577d2c970a18f3959e6b8646fd60ef1b0507e5947dc73703b5627", size = 514904 }, + { url = "https://files.pythonhosted.org/packages/a7/81/b3c72c6691acd29cf707df1a0b300e6726385b3c1ced8dc20424c4452699/jiter-0.7.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:81d968dbf3ce0db2e0e4dec6b0a0d5d94f846ee84caf779b07cab49f5325ae43", size = 497102 }, + { url = "https://files.pythonhosted.org/packages/1e/c3/766f9ec97df0441597878c7949da2b241a12a381c3affa7ca761734c8c74/jiter-0.7.1-cp313-none-win32.whl", hash = "sha256:f892e547e6e79a1506eb571a676cf2f480a4533675f834e9ae98de84f9b941ac", size = 198119 }, + { url = "https://files.pythonhosted.org/packages/76/01/cbc0136784a3ffefb5ca5326f8167780c5c3de0c81b6b81b773a973c571e/jiter-0.7.1-cp313-none-win_amd64.whl", hash = "sha256:0302f0940b1455b2a7fb0409b8d5b31183db70d2b07fd177906d83bf941385d1", size = 199236 }, +] + [[package]] name = "jmespath" version = "1.0.1" @@ -414,6 +457,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/e2/5d3f6ada4297caebe1a2add3b126fe800c96f56dbe5d1988a2cbe0b267aa/mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d", size = 4695 }, ] +[[package]] +name = "openai" +version = "1.54.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "distro" }, + { name = "httpx" }, + { name = "jiter" }, + { name = "pydantic" }, + { name = "sniffio" }, + { name = "tqdm" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7e/95/83845be5ddd46ce0a35fd602a3366ec2d7fd6b2be6fb760ca553e2488ea1/openai-1.54.4.tar.gz", hash = "sha256:50f3656e45401c54e973fa05dc29f3f0b0d19348d685b2f7ddb4d92bf7b1b6bf", size = 314159 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/d8/3e4cf8a5f544bef575d3502fedd81a15e317f591022de940647bdd0cc017/openai-1.54.4-py3-none-any.whl", hash = "sha256:0d95cef99346bf9b6d7fbf57faf61a673924c3e34fa8af84c9ffe04660673a7e", size = 389581 }, +] + [[package]] name = "packaging" version = "24.1" @@ -688,6 +750,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/43/f185bfd0ca1d213beb4293bed51d92254df23d8ceaf6c0e17146d508a776/starlette-0.41.2-py3-none-any.whl", hash = "sha256:fbc189474b4731cf30fcef52f18a8d070e3f3b46c6a04c97579e85e6ffca942d", size = 73259 }, ] +[[package]] +name = "tqdm" +version = "4.67.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "platform_system == 'Windows'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e8/4f/0153c21dc5779a49a0598c445b1978126b1344bab9ee71e53e44877e14e0/tqdm-4.67.0.tar.gz", hash = "sha256:fe5a6f95e6fe0b9755e9469b77b9c3cf850048224ecaa8293d7d2d31f97d869a", size = 169739 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2b/78/57043611a16c655c8350b4c01b8d6abfb38cc2acb475238b62c2146186d7/tqdm-4.67.0-py3-none-any.whl", hash = "sha256:0cd8af9d56911acab92182e88d763100d4788bdf421d251616040cc4d44863be", size = 78590 }, +] + [[package]] name = "typer" version = "0.12.5"