Skip to content

Commit

Permalink
Merge pull request #140 from Aleph-Alpha/fix-qa-with-prompt-document
Browse files Browse the repository at this point in the history
Fix serialization of Prompt-based Documents in QA
  • Loading branch information
ahartel authored Sep 27, 2023
2 parents b6eecce + e26db4f commit 23ade75
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 11 deletions.
5 changes: 5 additions & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changelog

## 3.5.1

- Fix failing serialization of Prompt-based Documents in QA requests.
Documents should also be constructible from actual Prompts and not only from sequences

## 3.5.0

- Deprecation of `qa` and `summarization` methods on `Client` and `AsyncClient`. New methods of processing these tasks will be released before they are removed in the next major version.
Expand Down
13 changes: 8 additions & 5 deletions aleph_alpha_client/document.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import base64
from typing import Any, Dict, Optional, Sequence, Union

from aleph_alpha_client.prompt import Image, PromptItem, Text, _to_json
from aleph_alpha_client.prompt import Image, Prompt, PromptItem, Text, Tokens, _to_json


class Document:
Expand All @@ -12,7 +12,7 @@ class Document:
def __init__(
self,
docx: Optional[str] = None,
prompt: Optional[Sequence[Union[str, Image]]] = None,
prompt: Optional[Sequence[Union[str, Text, Image, Tokens]]] = None,
text: Optional[str] = None,
):
# We use a base_64 representation for docx documents, because we want to embed the file
Expand Down Expand Up @@ -43,11 +43,14 @@ def from_docx_file(cls, path: str):
return cls.from_docx_bytes(docx_bytes)

@classmethod
def from_prompt(cls, prompt: Sequence[Union[str, Image]]):
def from_prompt(cls, prompt: Union[Prompt, Sequence[Union[str, Image]]]):
"""
Pass a prompt that can contain multiple strings and Image prompts and prepare it to be used as a document
"""
return cls(prompt=prompt)
if isinstance(prompt, Prompt):
return cls(prompt=prompt.items)
else:
return cls(prompt=prompt)

@classmethod
def from_text(cls, text: str):
Expand All @@ -65,7 +68,7 @@ def _to_serializable_document(self) -> Dict[str, Any]:
A dict if serialized to JSON is suitable as a document element
"""

def to_prompt_item(item: Union[str, Image]) -> PromptItem:
def to_prompt_item(item: Union[str, Image, Text, Tokens]) -> PromptItem:
# document still uses a plain piece of text for text-prompts
# -> convert to Text-instance
return Text.from_text(item) if isinstance(item, str) else item
Expand Down
2 changes: 1 addition & 1 deletion aleph_alpha_client/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.5.0"
__version__ = "3.5.1"
15 changes: 10 additions & 5 deletions tests/test_qa.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from aleph_alpha_client.aleph_alpha_client import AsyncClient, Client
from aleph_alpha_client.document import Document
from aleph_alpha_client.prompt import Prompt
from aleph_alpha_client.qa import QaRequest

from tests.common import (
Expand All @@ -26,17 +27,18 @@ async def test_can_qa_with_async_client(async_client: AsyncClient):
# Client


def test_qa(sync_client: Client):
def test_qa_from_text(sync_client: Client):
# when posting a QA request with a QaRequest object
request = QaRequest(
query="Who likes pizza?",
documents=[Document.from_prompt(["Andreas likes pizza."])],
documents=[Document.from_text("Andreas likes pizza.")],
)

response = sync_client.qa(request)

# the response should exist and be in the form of a named tuple class
assert len(response.answers) == 1
assert response.answers[0].score > 0.5


def test_qa_no_answer_found(sync_client: Client):
Expand All @@ -52,16 +54,19 @@ def test_qa_no_answer_found(sync_client: Client):
assert len(response.answers) == 0


def test_text(sync_client: Client):
def test_prompt(sync_client: Client):
# when posting an illegal request
request = QaRequest(
query="Who likes pizza?",
documents=[Document.from_text("Andreas likes pizza.")],
documents=[
Document.from_prompt(Prompt.from_text("Andreas likes pizza."))
for _ in range(20)
],
)

# then we expect an exception tue to a bad request response from the API
response = sync_client.qa(request)

# The response should exist in the form of a json dict
assert len(response.answers) == 1
assert len(response.answers) == 20
assert response.answers[0].score > 0.5

0 comments on commit 23ade75

Please sign in to comment.