Skip to content

Commit

Permalink
Add num_tokens_generated to completion response
Browse files Browse the repository at this point in the history
  • Loading branch information
martinreinhardt01 committed Dec 5, 2023
1 parent 6e7571a commit 9bd7cc9
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 2 deletions.
12 changes: 11 additions & 1 deletion aleph_alpha_client/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,14 @@ class CompletionResponse:
Model name and version (if any) of the used model for inference.
completions:
List of completions; may contain only one entry if no more are requested (see parameter n).
num_tokens_prompt_total:
Number of tokens combined across all completion tasks.
In particular, if you set best_of or n to a number larger than 1 then we report the
combined prompt token count for all best_of or n tasks.
num_tokens_generated:
Number of tokens combined across all completion tasks.
If multiple completions are returned or best_of is set to a value greater than 1 then
this value contains the combined generated token count.
optimized_prompt:
Describes prompt after optimizations. This field is only returned if the flag
`disable_optimizations` flag is not set and the prompt has actually changed.
Expand All @@ -263,6 +271,7 @@ class CompletionResponse:
model_version: str
completions: Sequence[CompletionResult]
num_tokens_prompt_total: int
num_tokens_generated: int
optimized_prompt: Optional[Prompt] = None

@staticmethod
Expand All @@ -274,10 +283,11 @@ def from_json(json: Dict[str, Any]) -> "CompletionResponse":
completions=[
CompletionResult.from_json(item) for item in json["completions"]
],
num_tokens_prompt_total=json["num_tokens_prompt_total"],
num_tokens_generated=json["num_tokens_generated"],
optimized_prompt=Prompt.from_json(optimized_prompt_json)
if optimized_prompt_json
else None,
num_tokens_prompt_total=json["num_tokens_prompt_total"],
)

def to_json(self) -> Mapping[str, Any]:
Expand Down
42 changes: 41 additions & 1 deletion tests/test_complete.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_complete_with_echo(sync_client: Client, model_name: str, prompt_image:
assert len(completion_result.log_probs) > 0

@pytest.mark.system_test
def test_num_tokes_prompt_total_with_best_of(sync_client: Client, model_name: str):
def test_num_tokens_prompt_total_with_best_of(sync_client: Client, model_name: str):
tokens = [49222, 2998] # Hello world
best_of = 2
request = CompletionRequest(
Expand All @@ -140,3 +140,43 @@ def test_num_tokes_prompt_total_with_best_of(sync_client: Client, model_name: st

response = sync_client.complete(request, model=model_name)
assert response.num_tokens_prompt_total == len(tokens) * best_of

"""
curl https://api.aleph-alpha.com/complete -X POST -H "Authorization: Bearer $AA_API_TOKEN" -H "Content-Type: application/json"
-d '{ "model": "luminous-base", "prompt": [{ "type": "text", "data": "Hello world"}], "maximum_tokens": 1, "n": 2, "tokens": true }'
{"completions":
[
{
"completion":"!",
"raw_completion":"!",
"completion_tokens":["!"],
"finish_reason":"maximum_tokens"
},
{
"completion":"!",
"raw_completion":"!",
"completion_tokens":["!"],
"finish_reason":"maximum_tokens"
}
],
"model_version":"2022-04",
"num_tokens_prompt_total":4,
"num_tokens_generated":2}
"""

@pytest.mark.system_test
def test_num_tokens_generated_with_best_of(sync_client: Client, model_name: str):
hello_world = [49222, 2998] # Hello world
best_of = 2
request = CompletionRequest(
prompt = Prompt.from_tokens(hello_world),
best_of = best_of,
maximum_tokens = 1,
tokens = True,
)

response = sync_client.complete(request, model=model_name)
completion_result = response.completions[0]
number_tokens_completion = len(completion_result.completion_tokens)

assert response.num_tokens_generated == best_of * number_tokens_completion

0 comments on commit 9bd7cc9

Please sign in to comment.