Skip to content

Commit

Permalink
allow both types of prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
ahartel committed Sep 27, 2023
1 parent e37027f commit 1d00dfe
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
20 changes: 15 additions & 5 deletions aleph_alpha_client/document.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import base64
from typing import Any, Dict, Optional, Sequence, Union

from aleph_alpha_client.prompt import Image, Prompt, PromptItem, Text, _to_json
from aleph_alpha_client.prompt import Image, Prompt, PromptItem, Text, Tokens, _to_json


class Document:
Expand All @@ -12,7 +12,7 @@ class Document:
def __init__(
self,
docx: Optional[str] = None,
prompt: Optional[Prompt] = None,
prompt: Optional[Sequence[Union[str, Text, Image, Tokens]]] = None,
text: Optional[str] = None,
):
# We use a base_64 representation for docx documents, because we want to embed the file
Expand Down Expand Up @@ -43,11 +43,14 @@ def from_docx_file(cls, path: str):
return cls.from_docx_bytes(docx_bytes)

@classmethod
def from_prompt(cls, prompt: Prompt):
def from_prompt(cls, prompt: Union[Prompt, Sequence[Union[str, Image]]]):
"""
Pass a prompt that can contain multiple strings and Image prompts and prepare it to be used as a document
"""
return cls(prompt=prompt)
if isinstance(prompt, Prompt):
return cls(prompt=prompt.items)
else:
return cls(prompt=prompt)

@classmethod
def from_text(cls, text: str):
Expand All @@ -65,14 +68,21 @@ def _to_serializable_document(self) -> Dict[str, Any]:
A dict if serialized to JSON is suitable as a document element
"""

def to_prompt_item(item: Union[str, Image, Text, Tokens]) -> PromptItem:
# document still uses a plain piece of text for text-prompts
# -> convert to Text-instance
return Text.from_text(item) if isinstance(item, str) else item

if self.docx is not None:
# Serialize docx to Document JSON format
return {
"docx": self.docx,
}
elif self.prompt is not None:
# Serialize prompt to Document JSON format
prompt_data = [_to_json(prompt_item) for prompt_item in self.prompt.items]
prompt_data = [
_to_json(to_prompt_item(prompt_item)) for prompt_item in self.prompt
]
return {"prompt": prompt_data}
elif self.text is not None:
return {
Expand Down
2 changes: 1 addition & 1 deletion tests/test_summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ async def test_can_summarize_with_async_client(async_client: AsyncClient):
def test_summarize(sync_client: Client):
# when posting a Summarization request
request = SummarizationRequest(
document=Document.from_prompt(Prompt.from_text("Andreas likes pizza.")),
document=Document.from_prompt(["Andreas likes pizza."]),
)

response = sync_client.summarize(request)
Expand Down

0 comments on commit 1d00dfe

Please sign in to comment.