Skip to content

Commit

Permalink
Remove some explicit json formattings from test code
Browse files Browse the repository at this point in the history
  • Loading branch information
WieslerTNG committed Jan 11, 2024
1 parent 90447e8 commit db972b8
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 11 deletions.
6 changes: 6 additions & 0 deletions aleph_alpha_client/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,12 @@ def from_json(json: Dict[str, Any]) -> "BatchSemanticEmbeddingResponse":
num_tokens_prompt_total=json["num_tokens_prompt_total"],
)

def to_json(self) -> Mapping[str, Any]:
return {
**asdict(self),
"embeddings": [embedding for embedding in self.embeddings],
}

@staticmethod
def _from_model_version_and_embeddings(
model_version: str,
Expand Down
8 changes: 4 additions & 4 deletions tests/test_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import pytest
from pytest_httpserver import HTTPServer

from aleph_alpha_client import EmbeddingRequest, TokenizationRequest
from aleph_alpha_client import EmbeddingRequest
from aleph_alpha_client.aleph_alpha_client import AsyncClient, Client
from aleph_alpha_client.embedding import (
BatchSemanticEmbeddingRequest,
SemanticEmbeddingRequest,
SemanticRepresentation,
SemanticRepresentation, BatchSemanticEmbeddingResponse,
)
from aleph_alpha_client.prompt import Prompt
from tests.common import (
Expand Down Expand Up @@ -127,7 +127,7 @@ async def test_modelname_gets_passed_along_for_async_client(httpserver: HTTPServ
}
httpserver.expect_ordered_request(
"/batch_semantic_embed", method="POST", data=json.dumps(expected_body)
).respond_with_json({"model_version": "1", "embeddings": []})
).respond_with_json(BatchSemanticEmbeddingResponse(model_version="1", embeddings=[], num_tokens_prompt_total=1).to_json())
async_client = AsyncClient(token="", host=httpserver.url_for(""), total_retries=1)
await async_client.batch_semantic_embed(request, model=model_name)

Expand Down Expand Up @@ -226,6 +226,6 @@ def test_modelname_gets_passed_along_for_sync_client(httpserver: HTTPServer):
expected_body = {**request.to_json(), "model": model_name}
httpserver.expect_ordered_request(
"/batch_semantic_embed", method="POST", data=json.dumps(expected_body)
).respond_with_json({"model_version": "1", "embeddings": []})
).respond_with_json(BatchSemanticEmbeddingResponse(model_version="1", embeddings=[], num_tokens_prompt_total=1).to_json())
sync_client = Client(token="", host=httpserver.url_for(""), total_retries=1)
sync_client.batch_semantic_embed(request, model=model_name)
9 changes: 2 additions & 7 deletions tests/test_error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
Client,
_raise_for_status,
)
from aleph_alpha_client.completion import CompletionRequest
from aleph_alpha_client.completion import CompletionRequest, CompletionResponse
from aleph_alpha_client.prompt import Prompt
import pytest
from pytest_httpserver import HTTPServer
Expand Down Expand Up @@ -111,12 +111,7 @@ def expect_retryable_error(

def expect_valid_completion(httpserver: HTTPServer) -> None:
httpserver.expect_ordered_request("/complete").respond_with_json(
{
"model_version": "1",
"completions": [],
"num_tokens_prompt_total": 0,
"num_tokens_generated": 0,
}
CompletionResponse(model_version="1", completions=[], num_tokens_prompt_total=0, num_tokens_generated=0).to_json()
)


Expand Down

0 comments on commit db972b8

Please sign in to comment.