diff --git a/aleph_alpha_client/aleph_alpha_client.py b/aleph_alpha_client/aleph_alpha_client.py index 1aaace9..6c6bd9b 100644 --- a/aleph_alpha_client/aleph_alpha_client.py +++ b/aleph_alpha_client/aleph_alpha_client.py @@ -36,9 +36,9 @@ CompletionRequest, CompletionResponse, CompletionResponseStreamItem, - StreamChunk, stream_item_from_json, ) +from aleph_alpha_client.chat import ChatRequest, ChatResponse, ChatStreamChunk, ChatStreamChunk from aleph_alpha_client.evaluation import EvaluationRequest, EvaluationResponse from aleph_alpha_client.tokenization import TokenizationRequest, TokenizationResponse from aleph_alpha_client.detokenization import ( @@ -99,6 +99,7 @@ def _check_api_version(version_str: str): AnyRequest = Union[ CompletionRequest, + ChatRequest, EmbeddingRequest, EvaluationRequest, TokenizationRequest, @@ -302,6 +303,34 @@ def complete( response = self._post_request("complete", request, model) return CompletionResponse.from_json(response) + def chat( + self, + request: ChatRequest, + model: str, + ) -> ChatResponse: + """Chat with a model. + + Parameters: + request (ChatRequest, required): + Parameters for the requested chat. + + model (string, required): + Name of model to use. A model name refers to a model architecture (number of parameters among others). + Always the latest version of model is used. + + Examples: + >>> # create a chat request + >>> request = ChatRequest( + messages=[Message(role="user", content="Hello, how are you?")], + model=model, + ) + >>> + >>> # chat with the model + >>> result = client.chat(request, model=model_name) + """ + response = self._post_request("chat/completions", request, model) + return ChatResponse.from_json(response) + def tokenize( self, request: TokenizationRequest, @@ -797,7 +826,11 @@ async def _post_request_with_streaming( f"Stream item did not start with `{self.SSE_DATA_PREFIX}`. Was `{stream_item_as_str}" ) - yield json.loads(stream_item_as_str[len(self.SSE_DATA_PREFIX) :]) + payload = stream_item_as_str[len(self.SSE_DATA_PREFIX) :] + if payload == "[DONE]": + continue + + yield json.loads(payload) def _build_query_parameters(self) -> Mapping[str, str]: return { @@ -864,6 +897,38 @@ async def complete( ) return CompletionResponse.from_json(response) + async def chat( + self, + request: ChatRequest, + model: str, + ) -> ChatResponse: + """Chat with a model. + + Parameters: + request (ChatRequest, required): + Parameters for the requested chat. + + model (string, required): + Name of model to use. A model name refers to a model architecture (number of parameters among others). + Always the latest version of model is used. + + Examples: + >>> # create a chat request + >>> request = ChatRequest( + messages=[Message(role="user", content="Hello, how are you?")], + model=model, + ) + >>> + >>> # chat with the model + >>> result = await client.chat(request, model=model_name) + """ + response = await self._post_request( + "chat/completions", + request, + model, + ) + return ChatResponse.from_json(response) + async def complete_with_streaming( self, request: CompletionRequest, @@ -905,6 +970,46 @@ async def complete_with_streaming( ): yield stream_item_from_json(stream_item_json) + async def chat_with_streaming( + self, + request: ChatRequest, + model: str, + ) -> AsyncGenerator[ChatStreamChunk, None]: + """Generates streamed chat completions. + + The first yielded chunk contains the role, while subsequent chunks only contain the content delta. + + Parameters: + request (ChatRequest, required): + Parameters for the requested chat. + + model (string, required): + Name of model to use. A model name refers to a model architecture (number of parameters among others). + Always the latest version of model is used. + + Examples: + >>> # create a chat request + >>> request = ChatRequest( + messages=[Message(role="user", content="Hello, how are you?")], + model=model, + ) + >>> + >>> # chat with the model + >>> result = await client.chat_with_streaming(request, model=model_name) + >>> + >>> # consume the chat stream + >>> async for stream_item in result: + >>> do_something_with(stream_item) + """ + async for stream_item_json in self._post_request_with_streaming( + "chat/completions", + request, + model, + ): + chunk = ChatStreamChunk.from_json(stream_item_json) + if chunk is not None: + yield chunk + async def tokenize( self, request: TokenizationRequest, diff --git a/aleph_alpha_client/chat.py b/aleph_alpha_client/chat.py new file mode 100644 index 0000000..8105140 --- /dev/null +++ b/aleph_alpha_client/chat.py @@ -0,0 +1,106 @@ +from dataclasses import dataclass, asdict +from typing import List, Optional, Mapping, Any, Dict +from enum import Enum + + +class Role(str, Enum): + """A role used for a message in a chat.""" + User = "user" + Assistant = "assistant" + System = "system" + + +@dataclass(frozen=True) +class Message: + """ + Describes a message in a chat. + + Parameters: + role (Role, required): + The role of the message. + + content (str, required): + The content of the message. + """ + role: Role + content: str + + def to_json(self) -> Mapping[str, Any]: + return asdict(self) + + @staticmethod + def from_json(json: Dict[str, Any]) -> "Message": + return Message( + role=Role(json["role"]), + content=json["content"], + ) + + +@dataclass(frozen=True) +class ChatRequest: + """ + Describes a chat request. + + Only supports a subset of the parameters of `CompletionRequest` for simplicity. + See `CompletionRequest` for documentation on the parameters. + """ + model: str + messages: List[Message] + maximum_tokens: Optional[int] = None + temperature: float = 0.0 + top_k: int = 0 + top_p: float = 0.0 + + def to_json(self) -> Mapping[str, Any]: + payload = {k: v for k, v in asdict(self).items() if v is not None} + payload["messages"] = [message.to_json() for message in self.messages] + return payload + + +@dataclass(frozen=True) +class ChatResponse: + """ + A simplified version of the chat response. + + As the `ChatRequest` does not support the `n` parameter (allowing for multiple return values), + the `ChatResponse` assumes there to be only one choice. + """ + finish_reason: str + message: Message + + @staticmethod + def from_json(json: Dict[str, Any]) -> "ChatResponse": + first_choice = json["choices"][0] + return ChatResponse( + finish_reason=first_choice["finish_reason"], + message=Message.from_json(first_choice["message"]), + ) + + +@dataclass(frozen=True) +class ChatStreamChunk: + """ + A streamed chat completion chunk. + + Parameters: + content (str, required): + The content of the current chat completion. Will be empty for the first chunk of every completion stream and non-empty for the remaining chunks. + + role (Role, optional): + The role of the current chat completion. Will be assistant for the first chunk of every completion stream and missing for the remaining chunks. + """ + content: str + role: Optional[Role] + + @staticmethod + def from_json(json: Dict[str, Any]) -> Optional["ChatStreamChunk"]: + """ + Returns a ChatStreamChunk if the chunk contains a message, otherwise None. + """ + if not (delta := json["choices"][0]["delta"]): + return None + + return ChatStreamChunk( + content=delta["content"], + role=Role(delta.get("role")) if delta.get("role") else None, + ) \ No newline at end of file diff --git a/tests/common.py b/tests/common.py index 6a30287..9390638 100644 --- a/tests/common.py +++ b/tests/common.py @@ -32,6 +32,11 @@ def model_name() -> str: return "luminous-base" +@pytest.fixture(scope="session") +def chat_model_name() -> str: + return "llama-3.1-70b-instruct" + + @pytest.fixture(scope="session") def prompt_image() -> Image: image_source_path = Path(__file__).parent / "dog-and-cat-cover.jpg" diff --git a/tests/test_chat.py b/tests/test_chat.py new file mode 100644 index 0000000..0bec74a --- /dev/null +++ b/tests/test_chat.py @@ -0,0 +1,58 @@ +import pytest + +from aleph_alpha_client import AsyncClient, Client +from aleph_alpha_client.chat import ChatRequest, Message, Role +from tests.common import async_client, sync_client, model_name, chat_model_name + + +@pytest.mark.system_test +async def test_can_not_chat_with_all_models(async_client: AsyncClient, model_name: str): + request = ChatRequest( + messages=[Message(role=Role.User, content="Hello, how are you?")], + model=model_name, + ) + + with pytest.raises(ValueError): + await async_client.chat(request, model=model_name) + + +@pytest.mark.system_test +def test_can_chat_with_chat_model(sync_client: Client, chat_model_name: str): + request = ChatRequest( + messages=[Message(role=Role.User, content="Hello, how are you?")], + model=chat_model_name, + ) + + response = sync_client.chat(request, model=chat_model_name) + assert response.message.role == Role.Assistant + assert response.message.content is not None + + +@pytest.mark.system_test +async def test_can_chat_with_async_client(async_client: AsyncClient, chat_model_name: str): + system_msg = Message(role=Role.System, content="You are a helpful assistant.") + user_msg = Message(role=Role.User, content="Hello, how are you?") + request = ChatRequest( + messages=[system_msg, user_msg], + model=chat_model_name, + ) + + response = await async_client.chat(request, model=chat_model_name) + assert response.message.role == Role.Assistant + assert response.message.content is not None + + +@pytest.mark.system_test +async def test_can_chat_with_streaming_support(async_client: AsyncClient, chat_model_name: str): + request = ChatRequest( + messages=[Message(role=Role.User, content="Hello, how are you?")], + model=chat_model_name, + ) + + stream_items = [ + stream_item async for stream_item in async_client.chat_with_streaming(request, model=chat_model_name) + ] + + assert len(stream_items) >= 3 + assert stream_items[0].role is not None + assert all(item.content is not None for item in stream_items[1:])