-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add chat endpoint to sync and async client
while this commit supports streaming for the chat endpoint, it does only offer a simplified version and does not forward all parameters
- Loading branch information
Showing
4 changed files
with
276 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
from dataclasses import dataclass, asdict | ||
from typing import List, Optional, Mapping, Any, Dict | ||
from enum import Enum | ||
|
||
|
||
class Role(str, Enum): | ||
"""A role used for a message in a chat.""" | ||
User = "user" | ||
Assistant = "assistant" | ||
System = "system" | ||
|
||
|
||
@dataclass(frozen=True) | ||
class Message: | ||
""" | ||
Describes a message in a chat. | ||
Parameters: | ||
role (Role, required): | ||
The role of the message. | ||
content (str, required): | ||
The content of the message. | ||
""" | ||
role: Role | ||
content: str | ||
|
||
def to_json(self) -> Mapping[str, Any]: | ||
return asdict(self) | ||
|
||
@staticmethod | ||
def from_json(json: Dict[str, Any]) -> "Message": | ||
return Message( | ||
role=Role(json["role"]), | ||
content=json["content"], | ||
) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class ChatRequest: | ||
""" | ||
Describes a chat request. | ||
Only supports a subset of the parameters of `CompletionRequest` for simplicity. | ||
See `CompletionRequest` for documentation on the parameters. | ||
""" | ||
model: str | ||
messages: List[Message] | ||
maximum_tokens: Optional[int] = None | ||
temperature: float = 0.0 | ||
top_k: int = 0 | ||
top_p: float = 0.0 | ||
|
||
def to_json(self) -> Mapping[str, Any]: | ||
payload = {k: v for k, v in asdict(self).items() if v is not None} | ||
payload["messages"] = [message.to_json() for message in self.messages] | ||
return payload | ||
|
||
|
||
@dataclass(frozen=True) | ||
class ChatResponse: | ||
""" | ||
A simplified version of the chat response. | ||
As the `ChatRequest` does not support the `n` parameter (allowing for multiple return values), | ||
the `ChatResponse` assumes there to be only one choice. | ||
""" | ||
finish_reason: str | ||
message: Message | ||
|
||
@staticmethod | ||
def from_json(json: Dict[str, Any]) -> "ChatResponse": | ||
first_choice = json["choices"][0] | ||
return ChatResponse( | ||
finish_reason=first_choice["finish_reason"], | ||
message=Message.from_json(first_choice["message"]), | ||
) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class ChatStreamChunk: | ||
""" | ||
A streamed chat completion chunk. | ||
Parameters: | ||
content (str, required): | ||
The content of the current chat completion. Will be empty for the first chunk of every completion stream and non-empty for the remaining chunks. | ||
role (Role, optional): | ||
The role of the current chat completion. Will be assistant for the first chunk of every completion stream and missing for the remaining chunks. | ||
""" | ||
content: str | ||
role: Optional[Role] | ||
|
||
@staticmethod | ||
def from_json(json: Dict[str, Any]) -> Optional["ChatStreamChunk"]: | ||
""" | ||
Returns a ChatStreamChunk if the chunk contains a message, otherwise None. | ||
""" | ||
if not (delta := json["choices"][0]["delta"]): | ||
return None | ||
|
||
return ChatStreamChunk( | ||
content=delta["content"], | ||
role=Role(delta.get("role")) if delta.get("role") else None, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import pytest | ||
|
||
from aleph_alpha_client import AsyncClient, Client | ||
from aleph_alpha_client.chat import ChatRequest, Message, Role | ||
from tests.common import async_client, sync_client, model_name, chat_model_name | ||
|
||
|
||
@pytest.mark.system_test | ||
async def test_can_not_chat_with_all_models(async_client: AsyncClient, model_name: str): | ||
request = ChatRequest( | ||
messages=[Message(role=Role.User, content="Hello, how are you?")], | ||
model=model_name, | ||
) | ||
|
||
with pytest.raises(ValueError): | ||
await async_client.chat(request, model=model_name) | ||
|
||
|
||
@pytest.mark.system_test | ||
def test_can_chat_with_chat_model(sync_client: Client, chat_model_name: str): | ||
request = ChatRequest( | ||
messages=[Message(role=Role.User, content="Hello, how are you?")], | ||
model=chat_model_name, | ||
) | ||
|
||
response = sync_client.chat(request, model=chat_model_name) | ||
assert response.message.role == Role.Assistant | ||
assert response.message.content is not None | ||
|
||
|
||
@pytest.mark.system_test | ||
async def test_can_chat_with_async_client(async_client: AsyncClient, chat_model_name: str): | ||
system_msg = Message(role=Role.System, content="You are a helpful assistant.") | ||
user_msg = Message(role=Role.User, content="Hello, how are you?") | ||
request = ChatRequest( | ||
messages=[system_msg, user_msg], | ||
model=chat_model_name, | ||
) | ||
|
||
response = await async_client.chat(request, model=chat_model_name) | ||
assert response.message.role == Role.Assistant | ||
assert response.message.content is not None | ||
|
||
|
||
@pytest.mark.system_test | ||
async def test_can_chat_with_streaming_support(async_client: AsyncClient, chat_model_name: str): | ||
request = ChatRequest( | ||
messages=[Message(role=Role.User, content="Hello, how are you?")], | ||
model=chat_model_name, | ||
) | ||
|
||
stream_items = [ | ||
stream_item async for stream_item in async_client.chat_with_streaming(request, model=chat_model_name) | ||
] | ||
|
||
assert len(stream_items) >= 3 | ||
assert stream_items[0].role is not None | ||
assert all(item.content is not None for item in stream_items[1:]) |