From 44bb63282627c593b73cf9109bedec15ec18d125 Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Mon, 23 Sep 2024 13:44:37 +0000 Subject: [PATCH] Add push_documents method to Compass client for asynchronous document processing Introduced the `push_documents` method, allowing the Compass client to push a list of documents for asynchronous processing. Documents will be available for search once the processing is completed. --- compass_sdk/__init__.py | 26 +++++++++++++++-- compass_sdk/compass.py | 62 ++++++++++++++++++++++++++++++++++++++-- compass_sdk/constants.py | 2 ++ compass_sdk/utils.py | 9 ++++++ 4 files changed, 94 insertions(+), 5 deletions(-) diff --git a/compass_sdk/__init__.py b/compass_sdk/__init__.py index d25d664..d8c571f 100644 --- a/compass_sdk/__init__.py +++ b/compass_sdk/__init__.py @@ -1,10 +1,10 @@ import logging +import uuid from enum import Enum from os import getenv -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Annotated, Any, Dict, List, Optional, Union -from pydantic import BaseModel -from typing_extensions import TypedDict +from pydantic import BaseModel, Field, PositiveInt, StringConstraints from compass_sdk.constants import ( COHERE_API_ENV_VAR, @@ -324,6 +324,26 @@ class Document(BaseModel): index_fields: List[str] = [] +class ParseableDocument(BaseModel): + """ + A document to be sent to Compass in bytes format for parsing on the Compass side + """ + + id: uuid.UUID + filename: Annotated[ + str, + StringConstraints(min_length=1), + ] # Ensures the filename is a non-empty string + content_type: str + content_length_bytes: PositiveInt # File size must be a non-negative integer + bytes: str # Base64-encoded file contents + context: Dict[str, Any] = Field(default_factory=dict) + + +class PushDocumentsInput(BaseModel): + documents: List[ParseableDocument] + + class SearchFilter(BaseModel): class FilterType(str, Enum): EQ = "$eq" diff --git a/compass_sdk/compass.py b/compass_sdk/compass.py index b91e583..23c5307 100644 --- a/compass_sdk/compass.py +++ b/compass_sdk/compass.py @@ -1,9 +1,11 @@ +import base64 import os import threading +import uuid from collections import deque from dataclasses import dataclass from statistics import mean -from typing import Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union import requests from joblib import Parallel, delayed @@ -20,12 +22,15 @@ CompassSdkStage, Document, LoggerLevel, + ParseableDocument, + PushDocumentsInput, PutDocumentsInput, SearchFilter, SearchInput, logger, ) from compass_sdk.constants import ( + DEFAULT_MAX_ACCEPTED_FILE_SIZE_BYTES, DEFAULT_MAX_CHUNKS_PER_REQUEST, DEFAULT_MAX_ERROR_RATE, DEFAULT_MAX_RETRIES, @@ -94,6 +99,7 @@ def __init__( "search_documents": self.session.post, "add_context": self.session.post, "refresh": self.session.post, + "push_documents": self.session.post, } self.function_endpoint = { "create_index": "/api/v1/indexes/{index_name}", @@ -106,6 +112,7 @@ def __init__( "search_documents": "/api/v1/indexes/{index_name}/documents/search", "add_context": "/api/v1/indexes/{index_name}/documents/add_context/{doc_id}", "refresh": "/api/v1/indexes/{index_name}/refresh", + "push_documents": "/api/v2/indexes/{index_name}/documents", } logger.setLevel(logger_level.value) @@ -267,6 +274,57 @@ def batch_status(self, *, uuid: str): else: raise Exception(f"Failed to get batch status: {resp.status_code} {resp.text}") + def push_document( + self, + *, + index_name: str, + filename: str, + filebytes: bytes, + content_type: str, + document_id: uuid.UUID, + context: Dict[str, Any] = {}, + max_retries: int = DEFAULT_MAX_RETRIES, + sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS, + ) -> Optional[str]: + """ + Parse and insert a document into an index in Compass + :param index_name: the name of the index + :param filename: the filename of the document + :param filebytes: the bytes of the document + :param content_type: the content type of the document + :param document_id: the id of the document (optional) + :param context: represents an additional information about the document + :param max_retries: the maximum number of times to retry a request if it fails + :param sleep_retry_seconds: the number of seconds to wait before retrying an API request + :return: an error message if the request failed, otherwise None + """ + if len(filebytes) > DEFAULT_MAX_ACCEPTED_FILE_SIZE_BYTES: + err = f"File too large, supported file size is {DEFAULT_MAX_ACCEPTED_FILE_SIZE_BYTES / 1000_000} mb" + logger.error(err) + return err + + b64 = base64.b64encode(filebytes).decode("utf-8") + doc = ParseableDocument( + id=document_id, + filename=filename, + bytes=b64, + content_type=content_type, + content_length_bytes=len(filebytes), + context=context, + ) + + result = self._send_request( + function="push_documents", + index_name=index_name, + data=PushDocumentsInput(documents=[doc]), + max_retries=max_retries, + sleep_retry_seconds=sleep_retry_seconds, + ) + + if result.error: + return result.error + return None + def insert_docs( self, *, @@ -442,7 +500,7 @@ def _send_request_with_retry(): try: if data: if isinstance(data, BaseModel): - data_dict = data.model_dump() + data_dict = data.model_dump(mode="json") elif isinstance(data, Dict): data_dict = data diff --git a/compass_sdk/constants.py b/compass_sdk/constants.py index 848e677..c683bfd 100644 --- a/compass_sdk/constants.py +++ b/compass_sdk/constants.py @@ -33,3 +33,5 @@ "date", "authors", ] + +UUID_NAMESPACE = "00000000-0000-0000-0000-000000000000" diff --git a/compass_sdk/utils.py b/compass_sdk/utils.py index 3010f6a..3d5e6ce 100644 --- a/compass_sdk/utils.py +++ b/compass_sdk/utils.py @@ -1,5 +1,7 @@ +import base64 import glob import os +import uuid from concurrent import futures from concurrent.futures import Executor from typing import Callable, Iterable, Iterator, List, Optional, TypeVar @@ -8,6 +10,7 @@ from fsspec import AbstractFileSystem from compass_sdk import CompassDocument, CompassDocumentMetadata, CompassSdkStage +from compass_sdk.constants import UUID_NAMESPACE T = TypeVar("T") U = TypeVar("U") @@ -85,3 +88,9 @@ def scan_folder(folder_path: str, allowed_extensions: Optional[List[str]] = None scanned_files = fs.glob(pattern, recursive=recursive) all_files.extend([f"{path_prepend}{f}" for f in scanned_files]) return all_files + + +def generate_doc_id_from_bytes(filebytes: bytes) -> uuid.UUID: + b64_string = base64.b64encode(filebytes).decode("utf-8") + namespace = uuid.UUID(UUID_NAMESPACE) + return uuid.uuid5(namespace, b64_string)