Skip to content

Commit

Permalink
Fix serialization of Prompt-based Documents in QA
Browse files Browse the repository at this point in the history
  • Loading branch information
ahartel committed Sep 27, 2023
1 parent b6eecce commit 8f09631
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
2 changes: 1 addition & 1 deletion aleph_alpha_client/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def to_prompt_item(item: Union[str, Image]) -> PromptItem:
elif self.prompt is not None:
# Serialize prompt to Document JSON format
prompt_data = [
_to_json(to_prompt_item(prompt_item)) for prompt_item in self.prompt
_to_json(to_prompt_item(prompt_item)) for prompt_item in self.prompt.items
]
return {"prompt": prompt_data}
elif self.text is not None:
Expand Down
11 changes: 8 additions & 3 deletions tests/test_qa.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from aleph_alpha_client.aleph_alpha_client import AsyncClient, Client
from aleph_alpha_client.document import Document
from aleph_alpha_client.prompt import Prompt
from aleph_alpha_client.qa import QaRequest

from tests.common import (
Expand Down Expand Up @@ -37,6 +38,7 @@ def test_qa(sync_client: Client):

# the response should exist and be in the form of a named tuple class
assert len(response.answers) == 1
assert response.answers[0].score > 0.5


def test_qa_no_answer_found(sync_client: Client):
Expand All @@ -52,16 +54,19 @@ def test_qa_no_answer_found(sync_client: Client):
assert len(response.answers) == 0


def test_text(sync_client: Client):
def test_prompt(sync_client: Client):
# when posting an illegal request
request = QaRequest(
query="Who likes pizza?",
documents=[Document.from_text("Andreas likes pizza.")],
documents=[
Document.from_prompt(Prompt.from_text("Andreas likes pizza."))
for _ in range(20)
],
)

# then we expect an exception tue to a bad request response from the API
response = sync_client.qa(request)

# The response should exist in the form of a json dict
assert len(response.answers) == 1
assert len(response.answers) == 20
assert response.answers[0].score > 0.5

0 comments on commit 8f09631

Please sign in to comment.