Skip to content

Commit

Permalink
fix: URL-encode document names in document index client requests (#847)
Browse files Browse the repository at this point in the history
Prior to this commit, we did not URL-encode document names when querying
the document index. This caused problems when names have special
characters, i.e., '/' or '?'. We do not need to URL-encode other parts
of the URL (e.g., namespace or collection) as these are already
constrained to URL-safe characters. In addition to this fix, this commit
includes the following changes:
- update `DocumentPath.from_slash_separated_str` to correctly delimit
  paths whose document name contains a forward slash. Ideally we'd
  remove both this and `DocumentPath.to_slash_separated_str` in the
  future.
- parametrize tests to include special characters and increase coverage
- migrate to new document index collection to avoid CI conflicts
  • Loading branch information
Michael-JB authored May 21, 2024
1 parent a5ab8e5 commit 3223dd8
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 23 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

### Fixes
- `ExpandChunks`-task is now fast even for very large documents
- The document index client now correctly URL-encodes document names in its queries.

### Deprecations
...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from http import HTTPStatus
from json import dumps
from typing import Annotated, Any, Literal, Mapping, Optional, Sequence
from urllib.parse import quote

import requests
from pydantic import BaseModel, Field
Expand Down Expand Up @@ -100,6 +101,9 @@ class DocumentPath(BaseModel, frozen=True):
collection_path: CollectionPath
document_name: str

def encoded_document_name(self) -> str:
return quote(self.document_name, safe="")

@classmethod
def from_json(cls, document_path_json: Mapping[str, str]) -> "DocumentPath":
return cls(
Expand All @@ -115,7 +119,7 @@ def to_slash_separated_str(self) -> str:

@classmethod
def from_slash_separated_str(cls, path: str) -> "DocumentPath":
split = path.split("/")
split = path.split("/", 2)
assert len(split) == 3
return cls(
collection_path=CollectionPath(
Expand Down Expand Up @@ -492,7 +496,7 @@ def add_document(
Currently only supports text.
"""

url = f"{self._base_document_index_url}/collections/{document_path.collection_path.namespace}/{document_path.collection_path.collection}/docs/{document_path.document_name}"
url = f"{self._base_document_index_url}/collections/{document_path.collection_path.namespace}/{document_path.collection_path.collection}/docs/{document_path.encoded_document_name()}"
response = requests.put(
url, data=dumps(contents._to_modalities_json()), headers=self.headers
)
Expand All @@ -505,7 +509,7 @@ def delete_document(self, document_path: DocumentPath) -> None:
document_path: Consists of `collection_path` and name of document to be deleted.
"""

url = f"{self._base_document_index_url}/collections/{document_path.collection_path.namespace}/{document_path.collection_path.collection}/docs/{document_path.document_name}"
url = f"{self._base_document_index_url}/collections/{document_path.collection_path.namespace}/{document_path.collection_path.collection}/docs/{document_path.encoded_document_name()}"
response = requests.delete(url, headers=self.headers)
self._raise_for_status(response)

Expand All @@ -519,7 +523,7 @@ def document(self, document_path: DocumentPath) -> DocumentContents:
Content of the retrieved document.
"""

url = f"{self._base_document_index_url}/collections/{document_path.collection_path.namespace}/{document_path.collection_path.collection}/docs/{document_path.document_name}"
url = f"{self._base_document_index_url}/collections/{document_path.collection_path.namespace}/{document_path.collection_path.collection}/docs/{document_path.encoded_document_name()}"
response = requests.get(url, headers=self.headers)
self._raise_for_status(response)
return DocumentContents._from_modalities_json(response.json())
Expand Down
75 changes: 56 additions & 19 deletions tests/connectors/document_index/test_document_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,8 @@ def aleph_alpha_namespace() -> str:

@fixture
def collection_path(aleph_alpha_namespace: str) -> CollectionPath:
return CollectionPath(namespace=aleph_alpha_namespace, collection="ci-collection")


@fixture
def document_path(
document_index: DocumentIndexClient, collection_path: CollectionPath
) -> DocumentPath:
document_index.create_collection(collection_path)
return DocumentPath(
collection_path=collection_path, document_name="Example Document"
return CollectionPath(
namespace=aleph_alpha_namespace, collection="intelligence-layer-sdk-ci"
)


Expand Down Expand Up @@ -89,12 +81,29 @@ def test_document_index_creates_collection(


@pytest.mark.internal
@pytest.mark.parametrize(
"document_name",
[
"Example Document",
"!@#$%^&*()-_+={}[]\\|;:'\"<>,.?/~`",
],
)
def test_document_index_adds_document(
document_index: DocumentIndexClient,
document_path: DocumentPath,
collection_path: CollectionPath,
document_contents: DocumentContents,
document_name: str,
) -> None:
document_path = DocumentPath(
collection_path=collection_path,
document_name=document_name,
)
document_index.add_document(document_path, document_contents)

assert any(
d.document_path == document_path
for d in document_index.documents(collection_path)
)
assert document_contents == document_index.document(document_path)


Expand All @@ -115,11 +124,20 @@ def test_document_index_searches_asymmetrically(


@pytest.mark.internal
@pytest.mark.parametrize(
"document_name",
[
"Document to be deleted",
"Document to be deleted !@#$%^&*()-_+={}[]\\|;:'\"<>,.?/~`",
],
)
def test_document_index_deletes_document(
document_index: DocumentIndexClient, collection_path: CollectionPath
document_index: DocumentIndexClient,
collection_path: CollectionPath,
document_name: str,
) -> None:
document_path = DocumentPath(
collection_path=collection_path, document_name="Document to be deleted"
collection_path=collection_path, document_name=document_name
)
document_contents = DocumentContents.from_text("Some text...")

Expand All @@ -145,11 +163,30 @@ def test_document_index_raises_on_getting_non_existing_document(
)


def test_document_path_from_string() -> None:
abc = DocumentPath.from_slash_separated_str("a/b/c")
assert abc == DocumentPath(
collection_path=CollectionPath(namespace="a", collection="b"), document_name="c"
)
@pytest.mark.parametrize(
"slash_separated_str,expected_document_path",
[
(
"a/b/c",
DocumentPath(
collection_path=CollectionPath(namespace="a", collection="b"),
document_name="c",
),
),
(
"a/b/c/d",
DocumentPath(
collection_path=CollectionPath(namespace="a", collection="b"),
document_name="c/d",
),
),
],
)
def test_document_path_from_string(
slash_separated_str: str, expected_document_path: DocumentPath
) -> None:
actual_document_path = DocumentPath.from_slash_separated_str(slash_separated_str)
assert actual_document_path == expected_document_path
with raises(AssertionError):
DocumentPath.from_slash_separated_str("a/c")

Expand All @@ -159,7 +196,7 @@ def test_document_list_all_documents(
) -> None:
filter_result = document_index.documents(collection_path)

assert len(filter_result) == 2
assert len(filter_result) == 3


def test_document_list_max_n_documents(
Expand Down

0 comments on commit 3223dd8

Please sign in to comment.