diff --git a/Changelog.md b/Changelog.md index d67c35b..482c670 100644 --- a/Changelog.md +++ b/Changelog.md @@ -1,5 +1,11 @@ # Changelog +## 5.0.0 + +- Added `num_tokens_prompt_total` and `num_tokens_generated` to `CompletionResponse`. This is a + breaking change as these were introduced as mandatory parameters rather than optional ones. + HTTP API version 1.14.0 or higher is required. + ## 4.1.0 - Added `verify_ssl` flag so you can disable SSL checking for your sessions. diff --git a/aleph_alpha_client/completion.py b/aleph_alpha_client/completion.py index ac48523..128d003 100644 --- a/aleph_alpha_client/completion.py +++ b/aleph_alpha_client/completion.py @@ -247,8 +247,31 @@ def _asdict(self) -> Mapping[str, Any]: @dataclass(frozen=True) class CompletionResponse: + """ + Describes a completion response + + Parameters: + model_version: + Model name and version (if any) of the used model for inference. + completions: + List of completions; may contain only one entry if no more are requested (see parameter n). + num_tokens_prompt_total: + Number of tokens combined across all completion tasks. + In particular, if you set best_of or n to a number larger than 1 then we report the + combined prompt token count for all best_of or n tasks. + num_tokens_generated: + Number of tokens combined across all completion tasks. + If multiple completions are returned or best_of is set to a value greater than 1 then + this value contains the combined generated token count. + optimized_prompt: + Describes prompt after optimizations. This field is only returned if the flag + `disable_optimizations` flag is not set and the prompt has actually changed. + """ + model_version: str completions: Sequence[CompletionResult] + num_tokens_prompt_total: int + num_tokens_generated: int optimized_prompt: Optional[Prompt] = None @staticmethod @@ -259,6 +282,8 @@ def from_json(json: Dict[str, Any]) -> "CompletionResponse": completions=[ CompletionResult.from_json(item) for item in json["completions"] ], + num_tokens_prompt_total=json["num_tokens_prompt_total"], + num_tokens_generated=json["num_tokens_generated"], optimized_prompt=Prompt.from_json(optimized_prompt_json) if optimized_prompt_json else None, diff --git a/aleph_alpha_client/version.py b/aleph_alpha_client/version.py index 7039708..ba7be38 100644 --- a/aleph_alpha_client/version.py +++ b/aleph_alpha_client/version.py @@ -1 +1 @@ -__version__ = "4.1.0" +__version__ = "5.0.0" diff --git a/tests/test_clients.py b/tests/test_clients.py index 3327510..10da424 100644 --- a/tests/test_clients.py +++ b/tests/test_clients.py @@ -33,7 +33,9 @@ def test_nice_flag_on_client(httpserver: HTTPServer): ).respond_with_json( CompletionResponse( "model_version", - [CompletionResult(log_probs=[], completion="foo")], + [CompletionResult(log_probs=[], completion="foo", )], + num_tokens_prompt_total=2, + num_tokens_generated=1, ).to_json() ) @@ -47,11 +49,14 @@ async def test_nice_flag_on_async_client(httpserver: HTTPServer): httpserver.expect_request("/version").respond_with_data("OK") httpserver.expect_request( - "/complete", query_string={"nice": "true"} + "/complete", + query_string={"nice": "true"}, ).respond_with_json( CompletionResponse( "model_version", [CompletionResult(log_probs=[], completion="foo")], + num_tokens_prompt_total=2, + num_tokens_generated=1, ).to_json() ) diff --git a/tests/test_complete.py b/tests/test_complete.py index 863a9fe..a248ced 100644 --- a/tests/test_complete.py +++ b/tests/test_complete.py @@ -127,3 +127,34 @@ def test_complete_with_echo(sync_client: Client, model_name: str, prompt_image: assert len(completion_result.completion_tokens) > 0 assert completion_result.log_probs is not None assert len(completion_result.log_probs) > 0 + +@pytest.mark.system_test +def test_num_tokens_prompt_total_with_best_of(sync_client: Client, model_name: str): + tokens = [49222, 2998] # Hello world + best_of = 2 + request = CompletionRequest( + prompt = Prompt.from_tokens(tokens), + best_of = best_of, + maximum_tokens = 1, + ) + + response = sync_client.complete(request, model=model_name) + assert response.num_tokens_prompt_total == len(tokens) * best_of + +@pytest.mark.system_test +def test_num_tokens_generated_with_best_of(sync_client: Client, model_name: str): + hello_world = [49222, 2998] # Hello world + best_of = 2 + request = CompletionRequest( + prompt = Prompt.from_tokens(hello_world), + best_of = best_of, + maximum_tokens = 1, + tokens = True, + ) + + response = sync_client.complete(request, model=model_name) + completion_result = response.completions[0] + assert completion_result.completion_tokens is not None + number_tokens_completion = len(completion_result.completion_tokens) + + assert response.num_tokens_generated == best_of * number_tokens_completion \ No newline at end of file diff --git a/tests/test_error_handling.py b/tests/test_error_handling.py index 834d204..485f27e 100644 --- a/tests/test_error_handling.py +++ b/tests/test_error_handling.py @@ -111,8 +111,7 @@ def expect_retryable_error( def expect_valid_completion(httpserver: HTTPServer) -> None: httpserver.expect_ordered_request("/complete").respond_with_json( - {"model_version": "1", "completions": []} - ) + {"model_version": "1", "completions": [], "num_tokens_prompt_total": 0, "num_tokens_generated": 0}) def expect_valid_version(httpserver: HTTPServer) -> None: