diff --git a/aleph_alpha_client/aleph_alpha_client.py b/aleph_alpha_client/aleph_alpha_client.py index f3f0231..b951b34 100644 --- a/aleph_alpha_client/aleph_alpha_client.py +++ b/aleph_alpha_client/aleph_alpha_client.py @@ -434,6 +434,7 @@ def batch_semantic_embed( responses: List[EmbeddingVector] = [] model_version = "" + num_prompt_tokens_total = 0 # The API currently only supports batch semantic embedding requests with up to 100 # prompts per batch. As a convenience for users, this function chunks larger requests. for batch_request in _generate_semantic_embedding_batches(request): @@ -445,9 +446,10 @@ def batch_semantic_embed( response = BatchSemanticEmbeddingResponse.from_json(raw_response) model_version = response.model_version responses.extend(response.embeddings) + num_prompt_tokens_total += response.num_prompt_tokens_total - return BatchSemanticEmbeddingResponse._from_model_version_and_embeddings( - model_version, responses + return BatchSemanticEmbeddingResponse( + model_version=model_version, embeddings=responses, num_prompt_tokens_total=num_prompt_tokens_total ) def evaluate( @@ -971,13 +973,15 @@ async def batch_semantic_embed( _generate_semantic_embedding_batches(request, batch_size), progress_bar, ) + num_prompt_tokens_total = 0 for result in results: resp = BatchSemanticEmbeddingResponse.from_json(result) model_version = resp.model_version responses.extend(resp.embeddings) + num_prompt_tokens_total += resp.num_prompt_tokens_total - return BatchSemanticEmbeddingResponse._from_model_version_and_embeddings( - model_version, responses + return BatchSemanticEmbeddingResponse( + model_version=model_version, embeddings=responses, num_prompt_tokens_total=num_prompt_tokens_total ) async def evaluate( diff --git a/aleph_alpha_client/embedding.py b/aleph_alpha_client/embedding.py index d0fdb41..a7a9301 100644 --- a/aleph_alpha_client/embedding.py +++ b/aleph_alpha_client/embedding.py @@ -88,6 +88,7 @@ def _asdict(self) -> Mapping[str, Any]: @dataclass(frozen=True) class EmbeddingResponse: model_version: str + num_prompt_tokens_total: int embeddings: Optional[Dict[Tuple[str, str], List[float]]] tokens: Optional[List[str]] message: Optional[str] = None @@ -103,6 +104,7 @@ def from_json(json: Dict[str, Any]) -> "EmbeddingResponse": }, tokens=json.get("tokens"), message=json.get("message"), + num_prompt_tokens_total=json.get("num_prompt_tokens_total", 0) ) @@ -289,6 +291,7 @@ class SemanticEmbeddingResponse: model_version: str embedding: EmbeddingVector + num_prompt_tokens_total: int message: Optional[str] = None @staticmethod @@ -297,6 +300,7 @@ def from_json(json: Dict[str, Any]) -> "SemanticEmbeddingResponse": model_version=json["model_version"], embedding=json["embedding"], message=json.get("message"), + num_prompt_tokens_total=json.get("num_prompt_tokens_total", 0) ) @@ -314,17 +318,18 @@ class BatchSemanticEmbeddingResponse: model_version: str embeddings: Sequence[EmbeddingVector] + num_prompt_tokens_total: int @staticmethod def from_json(json: Dict[str, Any]) -> "BatchSemanticEmbeddingResponse": return BatchSemanticEmbeddingResponse( - model_version=json["model_version"], embeddings=json["embeddings"] + model_version=json["model_version"], embeddings=json["embeddings"], num_prompt_tokens_total=json.get("num_prompt_tokens_total", 0) ) @staticmethod def _from_model_version_and_embeddings( - model_version: str, embeddings: Sequence[EmbeddingVector] + model_version: str, embeddings: Sequence[EmbeddingVector], num_prompt_tokens_total: int ) -> "BatchSemanticEmbeddingResponse": return BatchSemanticEmbeddingResponse( - model_version=model_version, embeddings=embeddings + model_version=model_version, embeddings=embeddings, num_prompt_tokens_total=num_prompt_tokens_total ) diff --git a/tests/test_embed.py b/tests/test_embed.py index fdd7750..4e1523c 100644 --- a/tests/test_embed.py +++ b/tests/test_embed.py @@ -5,7 +5,7 @@ import pytest from pytest_httpserver import HTTPServer -from aleph_alpha_client import EmbeddingRequest +from aleph_alpha_client import EmbeddingRequest, TokenizationRequest from aleph_alpha_client.aleph_alpha_client import AsyncClient, Client from aleph_alpha_client.embedding import ( BatchSemanticEmbeddingRequest, @@ -34,6 +34,7 @@ async def test_can_embed_with_async_client(async_client: AsyncClient, model_name request.pooling ) * len(request.layers) assert response.tokens is not None + assert response.num_prompt_tokens_total == 1 @pytest.mark.system_test @@ -50,6 +51,7 @@ async def test_can_semantic_embed_with_async_client( assert response.model_version is not None assert response.embedding assert len(response.embedding) == 128 + assert response.num_prompt_tokens_total == 1 @pytest.mark.parametrize("num_prompts", [1, 100, 101]) @@ -58,10 +60,11 @@ async def test_batch_embed_semantic_with_async_client( async_client: AsyncClient, sync_client: Client, num_prompts: int, batch_size: int ): words = ["car", "elephant", "kitchen sink", "rubber", "sun"] + prompts = [Prompt.from_text(words[random.randint(0, 4)]) for i in range(num_prompts)] + tokens = [async_client.tokenize(TokenizationRequest(prompt=p.items[0].text, tokens=True, token_ids=False), "luminous-base") for p in prompts] + request = BatchSemanticEmbeddingRequest( - prompts=[ - Prompt.from_text(words[random.randint(0, 4)]) for i in range(num_prompts) - ], + prompts=prompts, representation=SemanticRepresentation.Symmetric, compress_to_size=128, ) @@ -69,6 +72,8 @@ async def test_batch_embed_semantic_with_async_client( result = await async_client.batch_semantic_embed( request=request, num_concurrent_requests=10, batch_size=batch_size ) + num_tokens = sum([len((await t).tokens) for t in tokens]) + assert result.num_prompt_tokens_total == num_tokens assert len(result.embeddings) == num_prompts # To make sure that the ordering of responses is preserved, @@ -142,6 +147,7 @@ def test_embed(sync_client: Client, model_name: str): request.layers ) assert result.tokens is None + assert result.num_prompt_tokens_total == 1 @pytest.mark.system_test @@ -178,6 +184,7 @@ def test_embed_with_tokens(sync_client: Client, model_name: str): request.layers ) assert result.tokens is not None + assert result.num_prompt_tokens_total == 1 @pytest.mark.system_test @@ -193,6 +200,7 @@ def test_embed_semantic(sync_client: Client): assert result.model_version is not None assert result.embedding assert len(result.embedding) == 128 + assert result.num_prompt_tokens_total == 1 @pytest.mark.parametrize("num_prompts", [1, 100, 101, 200, 1000])