Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add num_prompt_tokens_total to the embed responses #157

Merged
merged 8 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,5 @@ dmypy.json
cython_debug/

# IDEs
.vscode/
.vscode/
.idea/
13 changes: 11 additions & 2 deletions Changelog.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
# Changelog

## 6.0.0

- Added `num_tokens_prompt_total` to the types below.
This is a breaking change since `num_tokens_prompt_total` is mandatory.
- `EmbeddingResponse`
- `SemanticEmbeddingResponse`
- `BatchSemanticEmbeddingResponse`
- HTTP API version 1.15.0 or higher is required.

## 5.0.0

- Added `num_tokens_prompt_total` and `num_tokens_generated` to `CompletionResponse`. This is a
breaking change as these were introduced as mandatory parameters rather than optional ones.
HTTP API version 1.14.0 or higher is required.
- HTTP API version 1.14.0 or higher is required.

## 4.1.0

- Added `verify_ssl` flag so you can disable SSL checking for your sessions.
- Added `verify_ssl` flag, so you can disable SSL checking for your sessions.

## 4.0.0

Expand Down
36 changes: 32 additions & 4 deletions aleph_alpha_client/aleph_alpha_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import warnings

from packaging import version
from tokenizers import Tokenizer # type: ignore
from types import TracebackType
from typing import (
Expand Down Expand Up @@ -48,6 +50,7 @@
SemanticEmbeddingRequest,
SemanticEmbeddingResponse,
)
from aleph_alpha_client.version import MIN_API_VERSION

POOLING_OPTIONS = ["mean", "max", "last_token", "abs_max"]
RETRY_STATUS_CODES = frozenset({408, 429, 500, 502, 503, 504})
Expand Down Expand Up @@ -80,6 +83,16 @@ def _raise_for_status(status_code: int, text: str):
raise RuntimeError(status_code, text)


def _check_api_version(version_str: str):
api_ver = version.parse(MIN_API_VERSION)
ver = version.parse(version_str)
valid = api_ver.major == ver.major and api_ver <= ver
if not valid:
raise RuntimeError(
f"The aleph alpha client requires at least api version {api_ver}, found version {ver}"
)


AnyRequest = Union[
CompletionRequest,
EmbeddingRequest,
Expand Down Expand Up @@ -179,6 +192,10 @@ def __init__(
self.session.mount("https://", adapter)
self.session.mount("http://", adapter)

def validate_version(self) -> None:
"""Gets version of the AlephAlpha HTTP API."""
_check_api_version(self.get_version())

def get_version(self) -> str:
"""Gets version of the AlephAlpha HTTP API."""
return self._get_request("version").text
Expand Down Expand Up @@ -434,6 +451,7 @@ def batch_semantic_embed(

responses: List[EmbeddingVector] = []
model_version = ""
num_tokens_prompt_total = 0
# The API currently only supports batch semantic embedding requests with up to 100
# prompts per batch. As a convenience for users, this function chunks larger requests.
for batch_request in _generate_semantic_embedding_batches(request):
Expand All @@ -445,9 +463,12 @@ def batch_semantic_embed(
response = BatchSemanticEmbeddingResponse.from_json(raw_response)
model_version = response.model_version
responses.extend(response.embeddings)
num_tokens_prompt_total += response.num_tokens_prompt_total

return BatchSemanticEmbeddingResponse._from_model_version_and_embeddings(
model_version, responses
return BatchSemanticEmbeddingResponse(
model_version=model_version,
embeddings=responses,
num_tokens_prompt_total=num_tokens_prompt_total,
)

def evaluate(
Expand Down Expand Up @@ -683,6 +704,9 @@ async def __aexit__(
):
await self.session.__aexit__(exc_type=exc_type, exc_val=exc_val, exc_tb=exc_tb)

async def validate_version(self) -> None:
_check_api_version(await self.get_version())

async def get_version(self) -> str:
"""Gets version of the AlephAlpha HTTP API."""
return await self._get_request_text("version")
Expand Down Expand Up @@ -971,13 +995,17 @@ async def batch_semantic_embed(
_generate_semantic_embedding_batches(request, batch_size),
progress_bar,
)
num_tokens_prompt_total = 0
for result in results:
resp = BatchSemanticEmbeddingResponse.from_json(result)
model_version = resp.model_version
responses.extend(resp.embeddings)
num_tokens_prompt_total += resp.num_tokens_prompt_total

return BatchSemanticEmbeddingResponse._from_model_version_and_embeddings(
model_version, responses
return BatchSemanticEmbeddingResponse(
model_version=model_version,
embeddings=responses,
num_tokens_prompt_total=num_tokens_prompt_total,
)

async def evaluate(
Expand Down
23 changes: 20 additions & 3 deletions aleph_alpha_client/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def _asdict(self) -> Mapping[str, Any]:
@dataclass(frozen=True)
class EmbeddingResponse:
model_version: str
num_tokens_prompt_total: int
embeddings: Optional[Dict[Tuple[str, str], List[float]]]
tokens: Optional[List[str]]
message: Optional[str] = None
Expand All @@ -103,6 +104,7 @@ def from_json(json: Dict[str, Any]) -> "EmbeddingResponse":
},
tokens=json.get("tokens"),
message=json.get("message"),
num_tokens_prompt_total=json["num_tokens_prompt_total"],
)


Expand Down Expand Up @@ -289,6 +291,7 @@ class SemanticEmbeddingResponse:

model_version: str
embedding: EmbeddingVector
num_tokens_prompt_total: int
message: Optional[str] = None

@staticmethod
Expand All @@ -297,6 +300,7 @@ def from_json(json: Dict[str, Any]) -> "SemanticEmbeddingResponse":
model_version=json["model_version"],
embedding=json["embedding"],
message=json.get("message"),
num_tokens_prompt_total=json["num_tokens_prompt_total"],
)


Expand All @@ -314,17 +318,30 @@ class BatchSemanticEmbeddingResponse:

model_version: str
embeddings: Sequence[EmbeddingVector]
num_tokens_prompt_total: int

@staticmethod
def from_json(json: Dict[str, Any]) -> "BatchSemanticEmbeddingResponse":
return BatchSemanticEmbeddingResponse(
model_version=json["model_version"], embeddings=json["embeddings"]
model_version=json["model_version"],
embeddings=json["embeddings"],
num_tokens_prompt_total=json["num_tokens_prompt_total"],
)

def to_json(self) -> Mapping[str, Any]:
return {
**asdict(self),
"embeddings": [embedding for embedding in self.embeddings],
}

@staticmethod
def _from_model_version_and_embeddings(
model_version: str, embeddings: Sequence[EmbeddingVector]
model_version: str,
embeddings: Sequence[EmbeddingVector],
num_tokens_prompt_total: int,
) -> "BatchSemanticEmbeddingResponse":
return BatchSemanticEmbeddingResponse(
model_version=model_version, embeddings=embeddings
model_version=model_version,
embeddings=embeddings,
num_tokens_prompt_total=num_tokens_prompt_total,
)
1 change: 1 addition & 0 deletions aleph_alpha_client/version.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
__version__ = "5.0.0"
MIN_API_VERSION = "1.15.0"
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def version():
"Pillow >= 9.2.0",
"tqdm >= v4.62.0",
"python-liquid >= 1.9.4",
"packaging >= 23.2"
],
tests_require=tests_require,
extras_require={
Expand Down
28 changes: 28 additions & 0 deletions tests/test_clients.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from pytest_httpserver import HTTPServer
import os
import pytest

from aleph_alpha_client.version import MIN_API_VERSION
from aleph_alpha_client.aleph_alpha_client import AsyncClient, Client
from aleph_alpha_client.completion import (
CompletionRequest,
Expand All @@ -11,6 +13,32 @@
from tests.common import model_name, sync_client, async_client


def test_api_version_mismatch_client(httpserver: HTTPServer):
httpserver.expect_request("/version").respond_with_data("0.0.0")

with pytest.raises(RuntimeError):
Client(host=httpserver.url_for(""), token="AA_TOKEN").validate_version()


async def test_api_version_mismatch_async_client(httpserver: HTTPServer):
httpserver.expect_request("/version").respond_with_data("0.0.0")

with pytest.raises(RuntimeError):
async with AsyncClient(host=httpserver.url_for(""), token="AA_TOKEN") as client:
await client.validate_version()


def test_api_version_correct_client(httpserver: HTTPServer):
httpserver.expect_request("/version").respond_with_data(MIN_API_VERSION)
Client(host=httpserver.url_for(""), token="AA_TOKEN").validate_version()


async def test_api_version_correct_async_client(httpserver: HTTPServer):
httpserver.expect_request("/version").respond_with_data(MIN_API_VERSION)
async with AsyncClient(host=httpserver.url_for(""), token="AA_TOKEN") as client:
await client.validate_version()


@pytest.mark.system_test
async def test_can_use_async_client_without_context_manager(model_name: str):
request = CompletionRequest(
Expand Down
27 changes: 18 additions & 9 deletions tests/test_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from aleph_alpha_client.embedding import (
BatchSemanticEmbeddingRequest,
SemanticEmbeddingRequest,
SemanticRepresentation,
SemanticRepresentation, BatchSemanticEmbeddingResponse,
)
from aleph_alpha_client.prompt import Prompt
from tests.common import (
Expand All @@ -24,7 +24,7 @@

@pytest.mark.system_test
async def test_can_embed_with_async_client(async_client: AsyncClient, model_name: str):
request = request = EmbeddingRequest(
request = EmbeddingRequest(
prompt=Prompt.from_text("abc"), layers=[-1], pooling=["mean"], tokens=True
)

Expand All @@ -34,6 +34,7 @@ async def test_can_embed_with_async_client(async_client: AsyncClient, model_name
request.pooling
) * len(request.layers)
assert response.tokens is not None
assert response.num_tokens_prompt_total >= 1


@pytest.mark.system_test
Expand All @@ -50,6 +51,7 @@ async def test_can_semantic_embed_with_async_client(
assert response.model_version is not None
assert response.embedding
assert len(response.embedding) == 128
assert response.num_tokens_prompt_total >= 1


@pytest.mark.parametrize("num_prompts", [1, 100, 101])
Expand All @@ -58,18 +60,22 @@ async def test_batch_embed_semantic_with_async_client(
async_client: AsyncClient, sync_client: Client, num_prompts: int, batch_size: int
):
words = ["car", "elephant", "kitchen sink", "rubber", "sun"]
r = random.Random(4082)
prompts = list([Prompt.from_text(words[r.randint(0, 4)]) for i in range(num_prompts)])

request = BatchSemanticEmbeddingRequest(
prompts=[
Prompt.from_text(words[random.randint(0, 4)]) for i in range(num_prompts)
],
prompts=prompts,
representation=SemanticRepresentation.Symmetric,
compress_to_size=128,
)

result = await async_client.batch_semantic_embed(
request=request, num_concurrent_requests=10, batch_size=batch_size
request=request, num_concurrent_requests=10, batch_size=batch_size, model="luminous-base"
)

# We have no control over the exact tokenizer used in the backend, so we cannot know the exact
# number of tokens. We do know, however, that there must be at least one token per prompt.
assert result.num_tokens_prompt_total >= num_prompts

assert len(result.embeddings) == num_prompts
# To make sure that the ordering of responses is preserved,
# we compare the returned embeddings with those of the sync_client's
Expand Down Expand Up @@ -121,7 +127,7 @@ async def test_modelname_gets_passed_along_for_async_client(httpserver: HTTPServ
}
httpserver.expect_ordered_request(
"/batch_semantic_embed", method="POST", data=json.dumps(expected_body)
).respond_with_json({"model_version": "1", "embeddings": []})
).respond_with_json(BatchSemanticEmbeddingResponse(model_version="1", embeddings=[], num_tokens_prompt_total=1).to_json())
async_client = AsyncClient(token="", host=httpserver.url_for(""), total_retries=1)
await async_client.batch_semantic_embed(request, model=model_name)

Expand All @@ -142,6 +148,7 @@ def test_embed(sync_client: Client, model_name: str):
request.layers
)
assert result.tokens is None
assert result.num_tokens_prompt_total >= 1


@pytest.mark.system_test
Expand Down Expand Up @@ -178,6 +185,7 @@ def test_embed_with_tokens(sync_client: Client, model_name: str):
request.layers
)
assert result.tokens is not None
assert result.num_tokens_prompt_total >= 1


@pytest.mark.system_test
Expand All @@ -193,6 +201,7 @@ def test_embed_semantic(sync_client: Client):
assert result.model_version is not None
assert result.embedding
assert len(result.embedding) == 128
assert result.num_tokens_prompt_total >= 1


@pytest.mark.parametrize("num_prompts", [1, 100, 101, 200, 1000])
Expand All @@ -217,6 +226,6 @@ def test_modelname_gets_passed_along_for_sync_client(httpserver: HTTPServer):
expected_body = {**request.to_json(), "model": model_name}
httpserver.expect_ordered_request(
"/batch_semantic_embed", method="POST", data=json.dumps(expected_body)
).respond_with_json({"model_version": "1", "embeddings": []})
).respond_with_json(BatchSemanticEmbeddingResponse(model_version="1", embeddings=[], num_tokens_prompt_total=1).to_json())
sync_client = Client(token="", host=httpserver.url_for(""), total_retries=1)
sync_client.batch_semantic_embed(request, model=model_name)
9 changes: 2 additions & 7 deletions tests/test_error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
Client,
_raise_for_status,
)
from aleph_alpha_client.completion import CompletionRequest
from aleph_alpha_client.completion import CompletionRequest, CompletionResponse
from aleph_alpha_client.prompt import Prompt
import pytest
from pytest_httpserver import HTTPServer
Expand Down Expand Up @@ -111,12 +111,7 @@ def expect_retryable_error(

def expect_valid_completion(httpserver: HTTPServer) -> None:
httpserver.expect_ordered_request("/complete").respond_with_json(
{
"model_version": "1",
"completions": [],
"num_tokens_prompt_total": 0,
"num_tokens_generated": 0,
}
CompletionResponse(model_version="1", completions=[], num_tokens_prompt_total=0, num_tokens_generated=0).to_json()
)


Expand Down