From db972b8fe4a9f80bbfc6011506cd8cbb87252398 Mon Sep 17 00:00:00 2001 From: Julian Wiesler Date: Thu, 11 Jan 2024 10:49:50 +0100 Subject: [PATCH] Remove some explicit json formattings from test code --- aleph_alpha_client/embedding.py | 6 ++++++ tests/test_embed.py | 8 ++++---- tests/test_error_handling.py | 9 ++------- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/aleph_alpha_client/embedding.py b/aleph_alpha_client/embedding.py index fa9a870..d565bd7 100644 --- a/aleph_alpha_client/embedding.py +++ b/aleph_alpha_client/embedding.py @@ -328,6 +328,12 @@ def from_json(json: Dict[str, Any]) -> "BatchSemanticEmbeddingResponse": num_tokens_prompt_total=json["num_tokens_prompt_total"], ) + def to_json(self) -> Mapping[str, Any]: + return { + **asdict(self), + "embeddings": [embedding for embedding in self.embeddings], + } + @staticmethod def _from_model_version_and_embeddings( model_version: str, diff --git a/tests/test_embed.py b/tests/test_embed.py index 5b44834..dc4751c 100644 --- a/tests/test_embed.py +++ b/tests/test_embed.py @@ -5,12 +5,12 @@ import pytest from pytest_httpserver import HTTPServer -from aleph_alpha_client import EmbeddingRequest, TokenizationRequest +from aleph_alpha_client import EmbeddingRequest from aleph_alpha_client.aleph_alpha_client import AsyncClient, Client from aleph_alpha_client.embedding import ( BatchSemanticEmbeddingRequest, SemanticEmbeddingRequest, - SemanticRepresentation, + SemanticRepresentation, BatchSemanticEmbeddingResponse, ) from aleph_alpha_client.prompt import Prompt from tests.common import ( @@ -127,7 +127,7 @@ async def test_modelname_gets_passed_along_for_async_client(httpserver: HTTPServ } httpserver.expect_ordered_request( "/batch_semantic_embed", method="POST", data=json.dumps(expected_body) - ).respond_with_json({"model_version": "1", "embeddings": []}) + ).respond_with_json(BatchSemanticEmbeddingResponse(model_version="1", embeddings=[], num_tokens_prompt_total=1).to_json()) async_client = AsyncClient(token="", host=httpserver.url_for(""), total_retries=1) await async_client.batch_semantic_embed(request, model=model_name) @@ -226,6 +226,6 @@ def test_modelname_gets_passed_along_for_sync_client(httpserver: HTTPServer): expected_body = {**request.to_json(), "model": model_name} httpserver.expect_ordered_request( "/batch_semantic_embed", method="POST", data=json.dumps(expected_body) - ).respond_with_json({"model_version": "1", "embeddings": []}) + ).respond_with_json(BatchSemanticEmbeddingResponse(model_version="1", embeddings=[], num_tokens_prompt_total=1).to_json()) sync_client = Client(token="", host=httpserver.url_for(""), total_retries=1) sync_client.batch_semantic_embed(request, model=model_name) diff --git a/tests/test_error_handling.py b/tests/test_error_handling.py index 783fb72..4c38c61 100644 --- a/tests/test_error_handling.py +++ b/tests/test_error_handling.py @@ -9,7 +9,7 @@ Client, _raise_for_status, ) -from aleph_alpha_client.completion import CompletionRequest +from aleph_alpha_client.completion import CompletionRequest, CompletionResponse from aleph_alpha_client.prompt import Prompt import pytest from pytest_httpserver import HTTPServer @@ -111,12 +111,7 @@ def expect_retryable_error( def expect_valid_completion(httpserver: HTTPServer) -> None: httpserver.expect_ordered_request("/complete").respond_with_json( - { - "model_version": "1", - "completions": [], - "num_tokens_prompt_total": 0, - "num_tokens_generated": 0, - } + CompletionResponse(model_version="1", completions=[], num_tokens_prompt_total=0, num_tokens_generated=0).to_json() )