Skip to content

Commit

Permalink
feat: enable include_usage flag for chat stream endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
moldhouse committed Oct 28, 2024
1 parent 0ec0501 commit e1eab3d
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 8 deletions.
6 changes: 3 additions & 3 deletions aleph_alpha_client/aleph_alpha_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
CompletionResponseStreamItem,
stream_item_from_json,
)
from aleph_alpha_client.chat import ChatRequest, ChatResponse, ChatStreamChunk, ChatStreamChunk
from aleph_alpha_client.chat import ChatRequest, ChatResponse, ChatStreamChunk, ChatStreamChunk, Usage, stream_chat_item_from_json
from aleph_alpha_client.evaluation import EvaluationRequest, EvaluationResponse
from aleph_alpha_client.tokenization import TokenizationRequest, TokenizationResponse
from aleph_alpha_client.detokenization import (
Expand Down Expand Up @@ -974,7 +974,7 @@ async def chat_with_streaming(
self,
request: ChatRequest,
model: str,
) -> AsyncGenerator[ChatStreamChunk, None]:
) -> AsyncGenerator[Union[ChatStreamChunk, Usage], None]:
"""Generates streamed chat completions.
The first yielded chunk contains the role, while subsequent chunks only contain the content delta.
Expand Down Expand Up @@ -1006,7 +1006,7 @@ async def chat_with_streaming(
request,
model,
):
chunk = ChatStreamChunk.from_json(stream_item_json)
chunk = stream_chat_item_from_json(stream_item_json)
if chunk is not None:
yield chunk

Expand Down
51 changes: 49 additions & 2 deletions aleph_alpha_client/chat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, asdict
from typing import List, Optional, Mapping, Any, Dict
from typing import List, Optional, Mapping, Any, Dict, Union
from enum import Enum


Expand Down Expand Up @@ -36,6 +36,17 @@ def from_json(json: Dict[str, Any]) -> "Message":
)


@dataclass(frozen=True)
class StreamOptions:
"""
Additional options to affect the streaming behavior.
"""
# If set, an additional chunk will be streamed before the data: [DONE] message.
# The usage field on this chunk shows the token usage statistics for the entire
# request, and the choices field will always be an empty array.
include_usage: bool


@dataclass(frozen=True)
class ChatRequest:
"""
Expand All @@ -50,6 +61,7 @@ class ChatRequest:
temperature: float = 0.0
top_k: int = 0
top_p: float = 0.0
stream_options: Optional[StreamOptions] = None

def to_json(self) -> Mapping[str, Any]:
payload = {k: v for k, v in asdict(self).items() if v is not None}
Expand Down Expand Up @@ -77,6 +89,34 @@ def from_json(json: Dict[str, Any]) -> "ChatResponse":
)



@dataclass(frozen=True)
class Usage:
"""
Usage statistics for the completion request.
When streaming is enabled, this field will be null by default.
To include an additional usage-only message in the response stream, set stream_options.include_usage to true.
"""
# Number of tokens in the generated completion.
completion_tokens: int

# Number of tokens in the prompt.
prompt_tokens: int

# Total number of tokens used in the request (prompt + completion).
total_tokens: int

@staticmethod
def from_json(json: Dict[str, Any]) -> "Usage":
return Usage(
completion_tokens=json["completion_tokens"],
prompt_tokens=json["prompt_tokens"],
total_tokens=json["total_tokens"]
)



@dataclass(frozen=True)
class ChatStreamChunk:
"""
Expand All @@ -103,4 +143,11 @@ def from_json(json: Dict[str, Any]) -> Optional["ChatStreamChunk"]:
return ChatStreamChunk(
content=delta["content"],
role=Role(delta.get("role")) if delta.get("role") else None,
)
)


def stream_chat_item_from_json(json: Dict[str, Any]) -> Union[Usage, ChatStreamChunk, None]:
if (usage := json.get("usage")) is not None:
return Usage.from_json(usage)

return ChatStreamChunk.from_json(json)
82 changes: 79 additions & 3 deletions tests/test_chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from aleph_alpha_client import AsyncClient, Client
from aleph_alpha_client.chat import ChatRequest, Message, Role
from aleph_alpha_client.chat import ChatRequest, Message, Role, StreamOptions, stream_chat_item_from_json, Usage, ChatStreamChunk
from tests.common import async_client, sync_client, model_name, chat_model_name


Expand Down Expand Up @@ -50,5 +50,81 @@ async def test_can_chat_with_streaming_support(async_client: AsyncClient, chat_m
stream_item async for stream_item in async_client.chat_with_streaming(request, model=chat_model_name)
]

assert stream_items[0].role is not None
assert all(item.content is not None for item in stream_items[1:])
first = stream_items[0]
assert isinstance(first, ChatStreamChunk) and first.role is not None
assert all(isinstance(item, ChatStreamChunk) and item.content is not None for item in stream_items[1:])


async def test_usage_response_is_parsed():
# Given an API response with usage data and no choice
data = {
"choices": [],
"created": 1730133402,
"model": "llama-3.1-70b-instruct",
"system_fingerprint": ".unknown.",
"object": "chat.completion.chunk",
"usage": {
"prompt_tokens": 31,
"completion_tokens": 88,
"total_tokens": 119
}
}

# When parsing it
result = stream_chat_item_from_json(data)

# Then a usage instance is returned
assert isinstance(result, Usage)
assert result.prompt_tokens == 31


def test_chunk_response_is_parsed():
# Given an API response without usage data
data = {
"choices": [
{
"finish_reason": None,
"index": 0,
"delta": {
"content": " way, those clothes you're wearing"
},
"logprobs": None
}
],
"created": 1730133401,
"model": "llama-3.1-70b-instruct",
"system_fingerprint": None,
"object": "chat.completion.chunk",
"usage": None,
}

# When parsing it
result = stream_chat_item_from_json(data)

# Then a ChatStreamChunk instance is returned
assert isinstance(result, ChatStreamChunk)
assert result.content == " way, those clothes you're wearing"



async def test_stream_options(async_client: AsyncClient, chat_model_name: str):
# Given a request with include usage options set
stream_options = StreamOptions(include_usage=True)
request = ChatRequest(
messages=[Message(role=Role.User, content="Hello, how are you?")],
model=chat_model_name,
stream_options=stream_options

)

# When receiving the chunks
stream_items = [
stream_item async for stream_item in async_client.chat_with_streaming(request, model=chat_model_name)
]

# Then the last chunks has information about usage
assert all(isinstance(item, ChatStreamChunk) is None for item in stream_items[:-1])
assert isinstance(stream_items[:-1], Usage)



0 comments on commit e1eab3d

Please sign in to comment.