From 9bd7cc9cf20c43a77a95ca7ce164324911209296 Mon Sep 17 00:00:00 2001 From: Martin Reinhardt Date: Tue, 5 Dec 2023 14:37:44 +0100 Subject: [PATCH] Add num_tokens_generated to completion response --- aleph_alpha_client/completion.py | 12 ++++++++- tests/test_complete.py | 42 +++++++++++++++++++++++++++++++- 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/aleph_alpha_client/completion.py b/aleph_alpha_client/completion.py index f15813b..f72b6c9 100644 --- a/aleph_alpha_client/completion.py +++ b/aleph_alpha_client/completion.py @@ -255,6 +255,14 @@ class CompletionResponse: Model name and version (if any) of the used model for inference. completions: List of completions; may contain only one entry if no more are requested (see parameter n). + num_tokens_prompt_total: + Number of tokens combined across all completion tasks. + In particular, if you set best_of or n to a number larger than 1 then we report the + combined prompt token count for all best_of or n tasks. + num_tokens_generated: + Number of tokens combined across all completion tasks. + If multiple completions are returned or best_of is set to a value greater than 1 then + this value contains the combined generated token count. optimized_prompt: Describes prompt after optimizations. This field is only returned if the flag `disable_optimizations` flag is not set and the prompt has actually changed. @@ -263,6 +271,7 @@ class CompletionResponse: model_version: str completions: Sequence[CompletionResult] num_tokens_prompt_total: int + num_tokens_generated: int optimized_prompt: Optional[Prompt] = None @staticmethod @@ -274,10 +283,11 @@ def from_json(json: Dict[str, Any]) -> "CompletionResponse": completions=[ CompletionResult.from_json(item) for item in json["completions"] ], + num_tokens_prompt_total=json["num_tokens_prompt_total"], + num_tokens_generated=json["num_tokens_generated"], optimized_prompt=Prompt.from_json(optimized_prompt_json) if optimized_prompt_json else None, - num_tokens_prompt_total=json["num_tokens_prompt_total"], ) def to_json(self) -> Mapping[str, Any]: diff --git a/tests/test_complete.py b/tests/test_complete.py index 5d72309..88a0dc3 100644 --- a/tests/test_complete.py +++ b/tests/test_complete.py @@ -129,7 +129,7 @@ def test_complete_with_echo(sync_client: Client, model_name: str, prompt_image: assert len(completion_result.log_probs) > 0 @pytest.mark.system_test -def test_num_tokes_prompt_total_with_best_of(sync_client: Client, model_name: str): +def test_num_tokens_prompt_total_with_best_of(sync_client: Client, model_name: str): tokens = [49222, 2998] # Hello world best_of = 2 request = CompletionRequest( @@ -140,3 +140,43 @@ def test_num_tokes_prompt_total_with_best_of(sync_client: Client, model_name: st response = sync_client.complete(request, model=model_name) assert response.num_tokens_prompt_total == len(tokens) * best_of + +""" +curl https://api.aleph-alpha.com/complete -X POST -H "Authorization: Bearer $AA_API_TOKEN" -H "Content-Type: application/json" + -d '{ "model": "luminous-base", "prompt": [{ "type": "text", "data": "Hello world"}], "maximum_tokens": 1, "n": 2, "tokens": true }' +{"completions": + [ + { + "completion":"!", + "raw_completion":"!", + "completion_tokens":["!"], + "finish_reason":"maximum_tokens" + }, + { + "completion":"!", + "raw_completion":"!", + "completion_tokens":["!"], + "finish_reason":"maximum_tokens" + } + ], + "model_version":"2022-04", + "num_tokens_prompt_total":4, + "num_tokens_generated":2} +""" + +@pytest.mark.system_test +def test_num_tokens_generated_with_best_of(sync_client: Client, model_name: str): + hello_world = [49222, 2998] # Hello world + best_of = 2 + request = CompletionRequest( + prompt = Prompt.from_tokens(hello_world), + best_of = best_of, + maximum_tokens = 1, + tokens = True, + ) + + response = sync_client.complete(request, model=model_name) + completion_result = response.completions[0] + number_tokens_completion = len(completion_result.completion_tokens) + + assert response.num_tokens_generated == best_of * number_tokens_completion \ No newline at end of file