diff --git a/aleph_alpha_client/aleph_alpha_client.py b/aleph_alpha_client/aleph_alpha_client.py index a039e9b..884dd11 100644 --- a/aleph_alpha_client/aleph_alpha_client.py +++ b/aleph_alpha_client/aleph_alpha_client.py @@ -1,3 +1,4 @@ +import json import warnings from packaging import version @@ -5,6 +6,7 @@ from types import TracebackType from typing import ( Any, + AsyncGenerator, List, Mapping, Optional, @@ -30,7 +32,13 @@ ) from aleph_alpha_client.summarization import SummarizationRequest, SummarizationResponse from aleph_alpha_client.qa import QaRequest, QaResponse -from aleph_alpha_client.completion import CompletionRequest, CompletionResponse +from aleph_alpha_client.completion import ( + CompletionRequest, + CompletionResponse, + CompletionResponseStreamItem, + StreamChunk, + stream_item_from_json, +) from aleph_alpha_client.evaluation import EvaluationRequest, EvaluationResponse from aleph_alpha_client.tokenization import TokenizationRequest, TokenizationResponse from aleph_alpha_client.detokenization import ( @@ -759,6 +767,31 @@ async def _post_request( _raise_for_status(response.status, await response.text()) return await response.json() + async def _post_request_with_streaming( + self, + endpoint: str, + request: AnyRequest, + model: Optional[str] = None, + ) -> AsyncGenerator[Dict[str, Any], None]: + json_body = self._build_json_body(request, model) + json_body["stream"] = "true" + + query_params = self._build_query_parameters() + + async with self.session.post( + self.host + endpoint, json=json_body, params=query_params + ) as response: + if not response.ok: + _raise_for_status(response.status, await response.text()) + + async for stream_item in response.content: + stream_item_as_str = stream_item.decode() + if stream_item_as_str.startswith("data: "): + stream_item_json = stream_item_as_str[len("data: ") :] + yield json.loads(stream_item_json) + else: + raise ValueError("Stream item did not start with 'data: '") + def _build_query_parameters(self) -> Mapping[str, str]: return { # cannot use str() here because we want lowercase true/false in query string @@ -768,7 +801,7 @@ def _build_query_parameters(self) -> Mapping[str, str]: def _build_json_body( self, request: AnyRequest, model: Optional[str] - ) -> Mapping[str, Any]: + ) -> Dict[str, Any]: json_body = dict(request.to_json()) if model is not None: @@ -824,6 +857,43 @@ async def complete( ) return CompletionResponse.from_json(response) + async def complete_with_streaming( + self, + request: CompletionRequest, + model: str, + ) -> AsyncGenerator[CompletionResponseStreamItem, None]: + """Generates streamed completions given a prompt. + + Parameters: + request (CompletionRequest, required): + Parameters for the requested completion. + + model (string, required): + Name of model to use. A model name refers to a model architecture (number of parameters among others). + Always the latest version of model is used. + + Examples: + >>> # create a prompt + >>> prompt = Prompt.from_text("An apple a day, ") + >>> + >>> # create a completion request + >>> request = CompletionRequest( + prompt=prompt, + maximum_tokens=32, + stop_sequences=["###","\\n"], + temperature=0.12 + ) + >>> + >>> # complete the prompt + >>> result = await client.complete_with_streaming(request, model=model_name) + """ + async for stream_item_json in self._post_request_with_streaming( + "complete", + request, + model, + ): + yield stream_item_from_json(stream_item_json) + async def tokenize( self, request: TokenizationRequest, diff --git a/aleph_alpha_client/completion.py b/aleph_alpha_client/completion.py index e49b92b..cd319b1 100644 --- a/aleph_alpha_client/completion.py +++ b/aleph_alpha_client/completion.py @@ -1,5 +1,5 @@ from dataclasses import asdict, dataclass -from typing import Any, Dict, List, Mapping, Optional, Sequence +from typing import Any, Dict, List, Mapping, Optional, Sequence, Union from aleph_alpha_client.prompt import Prompt @@ -301,3 +301,36 @@ def to_json(self) -> Mapping[str, Any]: def _asdict(self) -> Mapping[str, Any]: return asdict(self) + + +CompletionResponseStreamItem = Union["StreamChunk"] + + +def stream_item_from_json(json: Dict[str, Any]) -> CompletionResponseStreamItem: + match json["type"]: + case "stream_chunk": + return StreamChunk.from_json(json) + case _: + raise ValueError(f"Unknown stream item type: {json['type']}") + + +@dataclass(frozen=True) +class StreamChunk: + index: int + log_probs: Optional[Sequence[Mapping[str, Optional[float]]]] + completion: Optional[str] + raw_completion: Optional[str] + completion_tokens: Optional[Sequence[str]] + + @staticmethod + def from_json(json: Dict[str, Any]) -> "StreamChunk": + return StreamChunk( + index=json["index"], + log_probs=json.get("log_probs"), + completion=json.get("completion"), + raw_completion=json.get("raw_completion"), + completion_tokens=json.get("completion_tokens"), + ) + + def to_json(self) -> Mapping[str, Any]: + return asdict(self)