diff --git a/Changelog.md b/Changelog.md index 7aa2889..473cc69 100644 --- a/Changelog.md +++ b/Changelog.md @@ -1,5 +1,10 @@ # Changelog +## 3.5.1 + +- Fix failing serialization of Prompt-based Documents in QA requests. + Documents should also be constructible from actual Prompts and not only from sequences + ## 3.5.0 - Deprecation of `qa` and `summarization` methods on `Client` and `AsyncClient`. New methods of processing these tasks will be released before they are removed in the next major version. diff --git a/aleph_alpha_client/document.py b/aleph_alpha_client/document.py index 7544834..434f64b 100644 --- a/aleph_alpha_client/document.py +++ b/aleph_alpha_client/document.py @@ -1,7 +1,7 @@ import base64 from typing import Any, Dict, Optional, Sequence, Union -from aleph_alpha_client.prompt import Image, PromptItem, Text, _to_json +from aleph_alpha_client.prompt import Image, Prompt, PromptItem, Text, Tokens, _to_json class Document: @@ -12,7 +12,7 @@ class Document: def __init__( self, docx: Optional[str] = None, - prompt: Optional[Sequence[Union[str, Image]]] = None, + prompt: Optional[Sequence[Union[str, Text, Image, Tokens]]] = None, text: Optional[str] = None, ): # We use a base_64 representation for docx documents, because we want to embed the file @@ -43,11 +43,14 @@ def from_docx_file(cls, path: str): return cls.from_docx_bytes(docx_bytes) @classmethod - def from_prompt(cls, prompt: Sequence[Union[str, Image]]): + def from_prompt(cls, prompt: Union[Prompt, Sequence[Union[str, Image]]]): """ Pass a prompt that can contain multiple strings and Image prompts and prepare it to be used as a document """ - return cls(prompt=prompt) + if isinstance(prompt, Prompt): + return cls(prompt=prompt.items) + else: + return cls(prompt=prompt) @classmethod def from_text(cls, text: str): @@ -65,7 +68,7 @@ def _to_serializable_document(self) -> Dict[str, Any]: A dict if serialized to JSON is suitable as a document element """ - def to_prompt_item(item: Union[str, Image]) -> PromptItem: + def to_prompt_item(item: Union[str, Image, Text, Tokens]) -> PromptItem: # document still uses a plain piece of text for text-prompts # -> convert to Text-instance return Text.from_text(item) if isinstance(item, str) else item diff --git a/aleph_alpha_client/version.py b/aleph_alpha_client/version.py index dcbfb52..0c11bab 100644 --- a/aleph_alpha_client/version.py +++ b/aleph_alpha_client/version.py @@ -1 +1 @@ -__version__ = "3.5.0" +__version__ = "3.5.1" diff --git a/tests/test_qa.py b/tests/test_qa.py index 76f151c..e09cb64 100644 --- a/tests/test_qa.py +++ b/tests/test_qa.py @@ -1,6 +1,7 @@ import pytest from aleph_alpha_client.aleph_alpha_client import AsyncClient, Client from aleph_alpha_client.document import Document +from aleph_alpha_client.prompt import Prompt from aleph_alpha_client.qa import QaRequest from tests.common import ( @@ -26,17 +27,18 @@ async def test_can_qa_with_async_client(async_client: AsyncClient): # Client -def test_qa(sync_client: Client): +def test_qa_from_text(sync_client: Client): # when posting a QA request with a QaRequest object request = QaRequest( query="Who likes pizza?", - documents=[Document.from_prompt(["Andreas likes pizza."])], + documents=[Document.from_text("Andreas likes pizza.")], ) response = sync_client.qa(request) # the response should exist and be in the form of a named tuple class assert len(response.answers) == 1 + assert response.answers[0].score > 0.5 def test_qa_no_answer_found(sync_client: Client): @@ -52,16 +54,19 @@ def test_qa_no_answer_found(sync_client: Client): assert len(response.answers) == 0 -def test_text(sync_client: Client): +def test_prompt(sync_client: Client): # when posting an illegal request request = QaRequest( query="Who likes pizza?", - documents=[Document.from_text("Andreas likes pizza.")], + documents=[ + Document.from_prompt(Prompt.from_text("Andreas likes pizza.")) + for _ in range(20) + ], ) # then we expect an exception tue to a bad request response from the API response = sync_client.qa(request) # The response should exist in the form of a json dict - assert len(response.answers) == 1 + assert len(response.answers) == 20 assert response.answers[0].score > 0.5