Skip to content

Commit

Permalink
Add streaming support
Browse files Browse the repository at this point in the history
  • Loading branch information
WoytenAA committed Sep 30, 2024
1 parent 8eecd0b commit a4b3466
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 3 deletions.
74 changes: 72 additions & 2 deletions aleph_alpha_client/aleph_alpha_client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import json
import warnings

from packaging import version
from tokenizers import Tokenizer # type: ignore
from types import TracebackType
from typing import (
Any,
AsyncGenerator,
List,
Mapping,
Optional,
Expand All @@ -30,7 +32,13 @@
)
from aleph_alpha_client.summarization import SummarizationRequest, SummarizationResponse
from aleph_alpha_client.qa import QaRequest, QaResponse
from aleph_alpha_client.completion import CompletionRequest, CompletionResponse
from aleph_alpha_client.completion import (
CompletionRequest,
CompletionResponse,
CompletionResponseStreamItem,
StreamChunk,
stream_item_from_json,
)
from aleph_alpha_client.evaluation import EvaluationRequest, EvaluationResponse
from aleph_alpha_client.tokenization import TokenizationRequest, TokenizationResponse
from aleph_alpha_client.detokenization import (
Expand Down Expand Up @@ -759,6 +767,31 @@ async def _post_request(
_raise_for_status(response.status, await response.text())
return await response.json()

async def _post_request_with_streaming(
self,
endpoint: str,
request: AnyRequest,
model: Optional[str] = None,
) -> AsyncGenerator[Dict[str, Any], None]:
json_body = self._build_json_body(request, model)
json_body["stream"] = "true"

query_params = self._build_query_parameters()

async with self.session.post(
self.host + endpoint, json=json_body, params=query_params
) as response:
if not response.ok:
_raise_for_status(response.status, await response.text())

async for stream_item in response.content:
stream_item_as_str = stream_item.decode()
if stream_item_as_str.startswith("data: "):
stream_item_json = stream_item_as_str[len("data: ") :]
yield json.loads(stream_item_json)
else:
raise ValueError("Stream item did not start with 'data: '")

def _build_query_parameters(self) -> Mapping[str, str]:
return {
# cannot use str() here because we want lowercase true/false in query string
Expand All @@ -768,7 +801,7 @@ def _build_query_parameters(self) -> Mapping[str, str]:

def _build_json_body(
self, request: AnyRequest, model: Optional[str]
) -> Mapping[str, Any]:
) -> Dict[str, Any]:
json_body = dict(request.to_json())

if model is not None:
Expand Down Expand Up @@ -824,6 +857,43 @@ async def complete(
)
return CompletionResponse.from_json(response)

async def complete_with_streaming(
self,
request: CompletionRequest,
model: str,
) -> AsyncGenerator[CompletionResponseStreamItem, None]:
"""Generates streamed completions given a prompt.
Parameters:
request (CompletionRequest, required):
Parameters for the requested completion.
model (string, required):
Name of model to use. A model name refers to a model architecture (number of parameters among others).
Always the latest version of model is used.
Examples:
>>> # create a prompt
>>> prompt = Prompt.from_text("An apple a day, ")
>>>
>>> # create a completion request
>>> request = CompletionRequest(
prompt=prompt,
maximum_tokens=32,
stop_sequences=["###","\\n"],
temperature=0.12
)
>>>
>>> # complete the prompt
>>> result = await client.complete_with_streaming(request, model=model_name)
"""
async for stream_item_json in self._post_request_with_streaming(
"complete",
request,
model,
):
yield stream_item_from_json(stream_item_json)

async def tokenize(
self,
request: TokenizationRequest,
Expand Down
35 changes: 34 additions & 1 deletion aleph_alpha_client/completion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Mapping, Optional, Sequence
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union

from aleph_alpha_client.prompt import Prompt

Expand Down Expand Up @@ -301,3 +301,36 @@ def to_json(self) -> Mapping[str, Any]:

def _asdict(self) -> Mapping[str, Any]:
return asdict(self)


CompletionResponseStreamItem = Union["StreamChunk"]


def stream_item_from_json(json: Dict[str, Any]) -> CompletionResponseStreamItem:
match json["type"]:
case "stream_chunk":
return StreamChunk.from_json(json)
case _:
raise ValueError(f"Unknown stream item type: {json['type']}")


@dataclass(frozen=True)
class StreamChunk:
index: int
log_probs: Optional[Sequence[Mapping[str, Optional[float]]]]
completion: Optional[str]
raw_completion: Optional[str]
completion_tokens: Optional[Sequence[str]]

@staticmethod
def from_json(json: Dict[str, Any]) -> "StreamChunk":
return StreamChunk(
index=json["index"],
log_probs=json.get("log_probs"),
completion=json.get("completion"),
raw_completion=json.get("raw_completion"),
completion_tokens=json.get("completion_tokens"),
)

def to_json(self) -> Mapping[str, Any]:
return asdict(self)

0 comments on commit a4b3466

Please sign in to comment.