diff --git a/aleph_alpha_client/document.py b/aleph_alpha_client/document.py index 2afe14a..434f64b 100644 --- a/aleph_alpha_client/document.py +++ b/aleph_alpha_client/document.py @@ -1,7 +1,7 @@ import base64 from typing import Any, Dict, Optional, Sequence, Union -from aleph_alpha_client.prompt import Image, Prompt, PromptItem, Text, _to_json +from aleph_alpha_client.prompt import Image, Prompt, PromptItem, Text, Tokens, _to_json class Document: @@ -12,7 +12,7 @@ class Document: def __init__( self, docx: Optional[str] = None, - prompt: Optional[Prompt] = None, + prompt: Optional[Sequence[Union[str, Text, Image, Tokens]]] = None, text: Optional[str] = None, ): # We use a base_64 representation for docx documents, because we want to embed the file @@ -43,11 +43,14 @@ def from_docx_file(cls, path: str): return cls.from_docx_bytes(docx_bytes) @classmethod - def from_prompt(cls, prompt: Prompt): + def from_prompt(cls, prompt: Union[Prompt, Sequence[Union[str, Image]]]): """ Pass a prompt that can contain multiple strings and Image prompts and prepare it to be used as a document """ - return cls(prompt=prompt) + if isinstance(prompt, Prompt): + return cls(prompt=prompt.items) + else: + return cls(prompt=prompt) @classmethod def from_text(cls, text: str): @@ -65,6 +68,11 @@ def _to_serializable_document(self) -> Dict[str, Any]: A dict if serialized to JSON is suitable as a document element """ + def to_prompt_item(item: Union[str, Image, Text, Tokens]) -> PromptItem: + # document still uses a plain piece of text for text-prompts + # -> convert to Text-instance + return Text.from_text(item) if isinstance(item, str) else item + if self.docx is not None: # Serialize docx to Document JSON format return { @@ -72,7 +80,9 @@ def _to_serializable_document(self) -> Dict[str, Any]: } elif self.prompt is not None: # Serialize prompt to Document JSON format - prompt_data = [_to_json(prompt_item) for prompt_item in self.prompt.items] + prompt_data = [ + _to_json(to_prompt_item(prompt_item)) for prompt_item in self.prompt + ] return {"prompt": prompt_data} elif self.text is not None: return { diff --git a/tests/test_summarize.py b/tests/test_summarize.py index eda9317..30ca19c 100644 --- a/tests/test_summarize.py +++ b/tests/test_summarize.py @@ -29,7 +29,7 @@ async def test_can_summarize_with_async_client(async_client: AsyncClient): def test_summarize(sync_client: Client): # when posting a Summarization request request = SummarizationRequest( - document=Document.from_prompt(Prompt.from_text("Andreas likes pizza.")), + document=Document.from_prompt(["Andreas likes pizza."]), ) response = sync_client.summarize(request)