Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add chat endpoint to sync and async client #187

Merged
merged 2 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 107 additions & 2 deletions aleph_alpha_client/aleph_alpha_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@
CompletionRequest,
CompletionResponse,
CompletionResponseStreamItem,
StreamChunk,
stream_item_from_json,
)
from aleph_alpha_client.chat import ChatRequest, ChatResponse, ChatStreamChunk, ChatStreamChunk, Usage, stream_chat_item_from_json
from aleph_alpha_client.evaluation import EvaluationRequest, EvaluationResponse
from aleph_alpha_client.tokenization import TokenizationRequest, TokenizationResponse
from aleph_alpha_client.detokenization import (
Expand Down Expand Up @@ -99,6 +99,7 @@ def _check_api_version(version_str: str):

AnyRequest = Union[
CompletionRequest,
ChatRequest,
EmbeddingRequest,
EvaluationRequest,
TokenizationRequest,
Expand Down Expand Up @@ -302,6 +303,34 @@ def complete(
response = self._post_request("complete", request, model)
return CompletionResponse.from_json(response)

def chat(
self,
request: ChatRequest,
model: str,
) -> ChatResponse:
"""Chat with a model.

Parameters:
request (ChatRequest, required):
Parameters for the requested chat.

model (string, required):
Name of model to use. A model name refers to a model architecture (number of parameters among others).
Always the latest version of model is used.

Examples:
>>> # create a chat request
>>> request = ChatRequest(
messages=[Message(role="user", content="Hello, how are you?")],
model=model,
)
>>>
>>> # chat with the model
>>> result = client.chat(request, model=model_name)
"""
response = self._post_request("chat/completions", request, model)
return ChatResponse.from_json(response)

def tokenize(
self,
request: TokenizationRequest,
Expand Down Expand Up @@ -797,7 +826,11 @@ async def _post_request_with_streaming(
f"Stream item did not start with `{self.SSE_DATA_PREFIX}`. Was `{stream_item_as_str}"
)

yield json.loads(stream_item_as_str[len(self.SSE_DATA_PREFIX) :])
payload = stream_item_as_str[len(self.SSE_DATA_PREFIX) :]
if payload == "[DONE]":
continue

yield json.loads(payload)

def _build_query_parameters(self) -> Mapping[str, str]:
return {
Expand Down Expand Up @@ -864,6 +897,38 @@ async def complete(
)
return CompletionResponse.from_json(response)

async def chat(
self,
request: ChatRequest,
model: str,
) -> ChatResponse:
"""Chat with a model.

Parameters:
request (ChatRequest, required):
Parameters for the requested chat.

model (string, required):
Name of model to use. A model name refers to a model architecture (number of parameters among others).
Always the latest version of model is used.

Examples:
>>> # create a chat request
>>> request = ChatRequest(
messages=[Message(role="user", content="Hello, how are you?")],
model=model,
)
>>>
>>> # chat with the model
>>> result = await client.chat(request, model=model_name)
"""
response = await self._post_request(
"chat/completions",
request,
model,
)
return ChatResponse.from_json(response)

async def complete_with_streaming(
self,
request: CompletionRequest,
Expand Down Expand Up @@ -905,6 +970,46 @@ async def complete_with_streaming(
):
yield stream_item_from_json(stream_item_json)

async def chat_with_streaming(
self,
request: ChatRequest,
model: str,
) -> AsyncGenerator[Union[ChatStreamChunk, Usage], None]:
"""Generates streamed chat completions.

The first yielded chunk contains the role, while subsequent chunks only contain the content delta.

Parameters:
request (ChatRequest, required):
Parameters for the requested chat.

model (string, required):
Name of model to use. A model name refers to a model architecture (number of parameters among others).
Always the latest version of model is used.

Examples:
>>> # create a chat request
>>> request = ChatRequest(
messages=[Message(role="user", content="Hello, how are you?")],
model=model,
)
>>>
>>> # chat with the model
>>> result = await client.chat_with_streaming(request, model=model_name)
>>>
>>> # consume the chat stream
>>> async for stream_item in result:
>>> do_something_with(stream_item)
"""
async for stream_item_json in self._post_request_with_streaming(
"chat/completions",
request,
model,
):
chunk = stream_chat_item_from_json(stream_item_json)
if chunk is not None:
yield chunk

async def tokenize(
self,
request: TokenizationRequest,
Expand Down
153 changes: 153 additions & 0 deletions aleph_alpha_client/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
from dataclasses import dataclass, asdict
from typing import List, Optional, Mapping, Any, Dict, Union
from enum import Enum


class Role(str, Enum):
"""A role used for a message in a chat."""
User = "user"
Assistant = "assistant"
System = "system"


@dataclass(frozen=True)
class Message:
"""
Describes a message in a chat.

Parameters:
role (Role, required):
The role of the message.

content (str, required):
The content of the message.
"""
role: Role
content: str

def to_json(self) -> Mapping[str, Any]:
return asdict(self)

@staticmethod
def from_json(json: Dict[str, Any]) -> "Message":
return Message(
role=Role(json["role"]),
content=json["content"],
)


@dataclass(frozen=True)
class StreamOptions:
"""
Additional options to affect the streaming behavior.
"""
# If set, an additional chunk will be streamed before the data: [DONE] message.
# The usage field on this chunk shows the token usage statistics for the entire
# request, and the choices field will always be an empty array.
include_usage: bool


@dataclass(frozen=True)
class ChatRequest:
"""
Describes a chat request.

Only supports a subset of the parameters of `CompletionRequest` for simplicity.
See `CompletionRequest` for documentation on the parameters.
"""
model: str
messages: List[Message]
maximum_tokens: Optional[int] = None
temperature: float = 0.0
top_k: int = 0
top_p: float = 0.0
ahartel marked this conversation as resolved.
Show resolved Hide resolved
stream_options: Optional[StreamOptions] = None

def to_json(self) -> Mapping[str, Any]:
payload = {k: v for k, v in asdict(self).items() if v is not None}
payload["messages"] = [message.to_json() for message in self.messages]
return payload


@dataclass(frozen=True)
class ChatResponse:
"""
A simplified version of the chat response.

As the `ChatRequest` does not support the `n` parameter (allowing for multiple return values),
the `ChatResponse` assumes there to be only one choice.
"""
finish_reason: str
message: Message

@staticmethod
def from_json(json: Dict[str, Any]) -> "ChatResponse":
first_choice = json["choices"][0]
return ChatResponse(
finish_reason=first_choice["finish_reason"],
message=Message.from_json(first_choice["message"]),
)



@dataclass(frozen=True)
class Usage:
"""
Usage statistics for the completion request.

When streaming is enabled, this field will be null by default.
To include an additional usage-only message in the response stream, set stream_options.include_usage to true.
"""
# Number of tokens in the generated completion.
completion_tokens: int

# Number of tokens in the prompt.
prompt_tokens: int

# Total number of tokens used in the request (prompt + completion).
total_tokens: int

@staticmethod
def from_json(json: Dict[str, Any]) -> "Usage":
return Usage(
completion_tokens=json["completion_tokens"],
prompt_tokens=json["prompt_tokens"],
total_tokens=json["total_tokens"]
)



@dataclass(frozen=True)
class ChatStreamChunk:
"""
A streamed chat completion chunk.

Parameters:
content (str, required):
The content of the current chat completion. Will be empty for the first chunk of every completion stream and non-empty for the remaining chunks.

role (Role, optional):
The role of the current chat completion. Will be assistant for the first chunk of every completion stream and missing for the remaining chunks.
"""
content: str
role: Optional[Role]

@staticmethod
def from_json(json: Dict[str, Any]) -> Optional["ChatStreamChunk"]:
"""
Returns a ChatStreamChunk if the chunk contains a message, otherwise None.
"""
if not (delta := json["choices"][0]["delta"]):
return None

return ChatStreamChunk(
content=delta["content"],
role=Role(delta.get("role")) if delta.get("role") else None,
)


def stream_chat_item_from_json(json: Dict[str, Any]) -> Union[Usage, ChatStreamChunk, None]:
if (usage := json.get("usage")) is not None:
return Usage.from_json(usage)

return ChatStreamChunk.from_json(json)
5 changes: 5 additions & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ def model_name() -> str:
return "luminous-base"


@pytest.fixture(scope="session")
def chat_model_name() -> str:
return "llama-3.1-70b-instruct"


@pytest.fixture(scope="session")
def prompt_image() -> Image:
image_source_path = Path(__file__).parent / "dog-and-cat-cover.jpg"
Expand Down
Loading