Skip to content

Commit

Permalink
Add push_documents method to Compass client for asynchronous document…
Browse files Browse the repository at this point in the history
… processing (#15)

Add push_documents method to Compass client for asynchronous document processing

Introduced the `push_documents` method, allowing the Compass client to push a list of documents for asynchronous processing. Documents will be available for search once the processing is completed.

Co-authored-by: Artemiy Ryabinkov <[email protected]>
  • Loading branch information
javier-cohere and artemiyatcohere authored Sep 23, 2024
1 parent 3960096 commit 21c1690
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 5 deletions.
26 changes: 23 additions & 3 deletions compass_sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import logging
import uuid
from enum import Enum
from os import getenv
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Annotated, Any, Dict, List, Optional, Union

from pydantic import BaseModel
from typing_extensions import TypedDict
from pydantic import BaseModel, Field, PositiveInt, StringConstraints

from compass_sdk.constants import (
COHERE_API_ENV_VAR,
Expand Down Expand Up @@ -324,6 +324,26 @@ class Document(BaseModel):
index_fields: List[str] = []


class ParseableDocument(BaseModel):
"""
A document to be sent to Compass in bytes format for parsing on the Compass side
"""

id: uuid.UUID
filename: Annotated[
str,
StringConstraints(min_length=1),
] # Ensures the filename is a non-empty string
content_type: str
content_length_bytes: PositiveInt # File size must be a non-negative integer
bytes: str # Base64-encoded file contents
context: Dict[str, Any] = Field(default_factory=dict)


class PushDocumentsInput(BaseModel):
documents: List[ParseableDocument]


class SearchFilter(BaseModel):
class FilterType(str, Enum):
EQ = "$eq"
Expand Down
62 changes: 60 additions & 2 deletions compass_sdk/compass.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import base64
import os
import threading
import uuid
from collections import deque
from dataclasses import dataclass
from statistics import mean
from typing import Dict, Iterator, List, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

import requests
from joblib import Parallel, delayed
Expand All @@ -20,12 +22,15 @@
CompassSdkStage,
Document,
LoggerLevel,
ParseableDocument,
PushDocumentsInput,
PutDocumentsInput,
SearchFilter,
SearchInput,
logger,
)
from compass_sdk.constants import (
DEFAULT_MAX_ACCEPTED_FILE_SIZE_BYTES,
DEFAULT_MAX_CHUNKS_PER_REQUEST,
DEFAULT_MAX_ERROR_RATE,
DEFAULT_MAX_RETRIES,
Expand Down Expand Up @@ -94,6 +99,7 @@ def __init__(
"search_documents": self.session.post,
"add_context": self.session.post,
"refresh": self.session.post,
"push_documents": self.session.post,
}
self.function_endpoint = {
"create_index": "/api/v1/indexes/{index_name}",
Expand All @@ -106,6 +112,7 @@ def __init__(
"search_documents": "/api/v1/indexes/{index_name}/documents/search",
"add_context": "/api/v1/indexes/{index_name}/documents/add_context/{doc_id}",
"refresh": "/api/v1/indexes/{index_name}/refresh",
"push_documents": "/api/v2/indexes/{index_name}/documents",
}
logger.setLevel(logger_level.value)

Expand Down Expand Up @@ -267,6 +274,57 @@ def batch_status(self, *, uuid: str):
else:
raise Exception(f"Failed to get batch status: {resp.status_code} {resp.text}")

def push_document(
self,
*,
index_name: str,
filename: str,
filebytes: bytes,
content_type: str,
document_id: uuid.UUID,
context: Dict[str, Any] = {},
max_retries: int = DEFAULT_MAX_RETRIES,
sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS,
) -> Optional[str]:
"""
Parse and insert a document into an index in Compass
:param index_name: the name of the index
:param filename: the filename of the document
:param filebytes: the bytes of the document
:param content_type: the content type of the document
:param document_id: the id of the document (optional)
:param context: represents an additional information about the document
:param max_retries: the maximum number of times to retry a request if it fails
:param sleep_retry_seconds: the number of seconds to wait before retrying an API request
:return: an error message if the request failed, otherwise None
"""
if len(filebytes) > DEFAULT_MAX_ACCEPTED_FILE_SIZE_BYTES:
err = f"File too large, supported file size is {DEFAULT_MAX_ACCEPTED_FILE_SIZE_BYTES / 1000_000} mb"
logger.error(err)
return err

b64 = base64.b64encode(filebytes).decode("utf-8")
doc = ParseableDocument(
id=document_id,
filename=filename,
bytes=b64,
content_type=content_type,
content_length_bytes=len(filebytes),
context=context,
)

result = self._send_request(
function="push_documents",
index_name=index_name,
data=PushDocumentsInput(documents=[doc]),
max_retries=max_retries,
sleep_retry_seconds=sleep_retry_seconds,
)

if result.error:
return result.error
return None

def insert_docs(
self,
*,
Expand Down Expand Up @@ -442,7 +500,7 @@ def _send_request_with_retry():
try:
if data:
if isinstance(data, BaseModel):
data_dict = data.model_dump()
data_dict = data.model_dump(mode="json")
elif isinstance(data, Dict):
data_dict = data

Expand Down
2 changes: 2 additions & 0 deletions compass_sdk/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,5 @@
"date",
"authors",
]

UUID_NAMESPACE = "00000000-0000-0000-0000-000000000000"
9 changes: 9 additions & 0 deletions compass_sdk/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import base64
import glob
import os
import uuid
from concurrent import futures
from concurrent.futures import Executor
from typing import Callable, Iterable, Iterator, List, Optional, TypeVar
Expand All @@ -8,6 +10,7 @@
from fsspec import AbstractFileSystem

from compass_sdk import CompassDocument, CompassDocumentMetadata, CompassSdkStage
from compass_sdk.constants import UUID_NAMESPACE

T = TypeVar("T")
U = TypeVar("U")
Expand Down Expand Up @@ -85,3 +88,9 @@ def scan_folder(folder_path: str, allowed_extensions: Optional[List[str]] = None
scanned_files = fs.glob(pattern, recursive=recursive)
all_files.extend([f"{path_prepend}{f}" for f in scanned_files])
return all_files


def generate_doc_id_from_bytes(filebytes: bytes) -> uuid.UUID:
b64_string = base64.b64encode(filebytes).decode("utf-8")
namespace = uuid.UUID(UUID_NAMESPACE)
return uuid.uuid5(namespace, b64_string)

0 comments on commit 21c1690

Please sign in to comment.