Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: supported batching for Titan text and image embeddings #193

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions aidial_adapter_bedrock/dial_api/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from aidial_sdk.embeddings import Usage
from pydantic import BaseModel

from aidial_adapter_bedrock.embedding.encoding import vector_to_base64


class ModelObject(BaseModel):
object: Literal["model"] = "model"
Expand All @@ -16,13 +18,31 @@ class ModelsResponse(BaseModel):
data: List[ModelObject]


def _encode_vector(
encoding_format: Literal["float", "base64"], vector: List[float]
) -> List[float] | str:
return vector_to_base64(vector) if encoding_format == "base64" else vector


def make_embeddings_response(
model: str, vectors: List[List[float] | str], usage: Usage
model: str,
encoding_format: Literal["float", "base64"],
vectors: List[List[float]],
prompt_tokens: int,
) -> EmbeddingsResponse:

embeddings = [_encode_vector(encoding_format, v) for v in vectors]

data: List[Embedding] = [
Embedding(index=index, embedding=embedding)
for index, embedding in enumerate(vectors)
for index, embedding in enumerate(embeddings)
]

return EmbeddingsResponse(model=model, data=data, usage=usage)
return EmbeddingsResponse(
model=model,
data=data,
usage=Usage(
prompt_tokens=prompt_tokens,
total_tokens=prompt_tokens,
),
)
39 changes: 17 additions & 22 deletions aidial_adapter_bedrock/embedding/amazon/titan_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
https://github.com/aws-samples/amazon-bedrock-samples/blob/5752afb78e7fab49cfd42d38bb09d40756bf0ea0/multimodal/Titan/titan-multimodal-embeddings/rag/1_multimodal_rag.ipynb
"""

from typing import AsyncIterator, List, Self
import asyncio
from typing import AsyncIterator, List, Self, Tuple

from aidial_sdk.chat_completion import Attachment
from aidial_sdk.embeddings import Response as EmbeddingsResponse
from aidial_sdk.embeddings import Usage
from aidial_sdk.embeddings.request import EmbeddingsRequest
from pydantic import BaseModel

Expand All @@ -30,7 +30,6 @@
from aidial_adapter_bedrock.embedding.embeddings_adapter import (
EmbeddingsAdapter,
)
from aidial_adapter_bedrock.embedding.encoding import vector_to_base64
from aidial_adapter_bedrock.embedding.validation import (
validate_embeddings_request,
)
Expand Down Expand Up @@ -155,31 +154,27 @@ async def embeddings(
supports_dimensions=True,
)

vectors: List[List[float] | str] = []
token_count = 0

# NOTE: Amazon Titan doesn't support batched inputs
# TODO: create multiple tasks
async for sub_request in get_requests(self.storage, request):
async def compute_embeddings(
req: AmazonRequest,
) -> Tuple[List[float], int]:
embedding, text_tokens = await call_embedding_model(
self.client,
self.model,
create_titan_request(sub_request, request.dimensions),
)

image_tokens = sub_request.get_image_tokens()

vector = (
vector_to_base64(embedding)
if request.encoding_format == "base64"
else embedding
create_titan_request(req, request.dimensions),
)
image_tokens = req.get_image_tokens()
return embedding, text_tokens + image_tokens

vectors.append(vector)
token_count += text_tokens + image_tokens
# NOTE: Amazon Titan doesn't support batched inputs
tasks = [
asyncio.create_task(compute_embeddings(req))
async for req in get_requests(self.storage, request)
]
results = await asyncio.gather(*tasks)

return make_embeddings_response(
model=self.model,
vectors=vectors,
usage=Usage(prompt_tokens=token_count, total_tokens=token_count),
encoding_format=request.encoding_format,
vectors=[r[0] for r in results],
prompt_tokens=sum(r[1] for r in results),
)
37 changes: 16 additions & 21 deletions aidial_adapter_bedrock/embedding/amazon/titan_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
https://github.com/aws-samples/amazon-bedrock-samples/blob/5752afb78e7fab49cfd42d38bb09d40756bf0ea0/multimodal/Titan/embeddings/v2/Titan-V2-Embeddings.ipynb
"""

from typing import AsyncIterator, List, Self
import asyncio
from typing import AsyncIterator, List, Self, Tuple

from aidial_sdk.embeddings import Response as EmbeddingsResponse
from aidial_sdk.embeddings import Usage
from aidial_sdk.embeddings.request import EmbeddingsRequest

from aidial_adapter_bedrock.bedrock import Bedrock
Expand All @@ -23,7 +23,6 @@
from aidial_adapter_bedrock.embedding.embeddings_adapter import (
EmbeddingsAdapter,
)
from aidial_adapter_bedrock.embedding.encoding import vector_to_base64
from aidial_adapter_bedrock.embedding.validation import (
validate_embeddings_request,
)
Expand Down Expand Up @@ -74,27 +73,23 @@ async def embeddings(
supports_dimensions=self.supports_dimensions,
)

vectors: List[List[float] | str] = []
token_count = 0

# NOTE: Amazon Titan doesn't support batched inputs
async for text_input in get_text_inputs(request):
sub_request = create_titan_request(text_input, request.dimensions)
embedding, tokens = await call_embedding_model(
self.client, self.model, sub_request
async def compute_embeddings(req: str) -> Tuple[List[float], int]:
return await call_embedding_model(
self.client,
self.model,
create_titan_request(req, request.dimensions),
)

vector = (
vector_to_base64(embedding)
if request.encoding_format == "base64"
else embedding
)

vectors.append(vector)
token_count += tokens
# NOTE: Amazon Titan doesn't support batched inputs
tasks = [
asyncio.create_task(compute_embeddings(req))
async for req in get_text_inputs(request)
]
results = await asyncio.gather(*tasks)

return make_embeddings_response(
model=self.model,
vectors=vectors,
usage=Usage(prompt_tokens=token_count, total_tokens=token_count),
encoding_format=request.encoding_format,
vectors=[r[0] for r in results],
prompt_tokens=sum(r[1] for r in results),
)
16 changes: 3 additions & 13 deletions aidial_adapter_bedrock/embedding/cohere/embed_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from typing import AsyncIterator, List, Self

from aidial_sdk.embeddings import Response as EmbeddingsResponse
from aidial_sdk.embeddings import Usage
from aidial_sdk.embeddings.request import EmbeddingsRequest

from aidial_adapter_bedrock.bedrock import Bedrock
Expand All @@ -24,7 +23,6 @@
from aidial_adapter_bedrock.embedding.embeddings_adapter import (
EmbeddingsAdapter,
)
from aidial_adapter_bedrock.embedding.encoding import vector_to_base64
from aidial_adapter_bedrock.embedding.validation import (
validate_embeddings_request,
)
Expand Down Expand Up @@ -92,17 +90,9 @@ async def embeddings(
self.client, self.model, embedding_request
)

vectors: List[List[float] | str] = [
(
vector_to_base64(embedding)
if request.encoding_format == "base64"
else embedding
)
for embedding in embeddings
]

return make_embeddings_response(
model=self.model,
vectors=vectors,
usage=Usage(prompt_tokens=tokens, total_tokens=tokens),
encoding_format=request.encoding_format,
vectors=embeddings,
prompt_tokens=tokens,
)
3 changes: 2 additions & 1 deletion aidial_adapter_bedrock/llm/model/llama/v3.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from typing import Awaitable, Callable

from aidial_adapter_bedrock.dial_api.request import ModelParameters
from aidial_adapter_bedrock.llm.converse.adapter import (
Expand Down Expand Up @@ -26,7 +27,7 @@ def is_stream(self, params: ModelParameters) -> bool:

def input_tokenizer_factory(
deployment: ConverseDeployment, params: ConverseRequestWrapper
):
) -> Callable[[ConverseMessages], Awaitable[int]]:
tool_tokens = default_tokenize_string(json.dumps(params.toolConfig))
system_tokens = default_tokenize_string(json.dumps(params.system))

Expand Down
Loading