Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add push_documents method to Compass client for asynchronous document processing #15

Merged
merged 1 commit into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions compass_sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import logging
import uuid
from enum import Enum
from os import getenv
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Annotated, Any, Dict, List, Optional, Union

from pydantic import BaseModel
from typing_extensions import TypedDict
from pydantic import BaseModel, Field, PositiveInt, StringConstraints

from compass_sdk.constants import (
COHERE_API_ENV_VAR,
Expand Down Expand Up @@ -324,6 +324,26 @@ class Document(BaseModel):
index_fields: List[str] = []


class ParseableDocument(BaseModel):
"""
A document to be sent to Compass in bytes format for parsing on the Compass side
"""

id: uuid.UUID
filename: Annotated[
str,
StringConstraints(min_length=1),
] # Ensures the filename is a non-empty string
content_type: str
content_length_bytes: PositiveInt # File size must be a non-negative integer
bytes: str # Base64-encoded file contents
context: Dict[str, Any] = Field(default_factory=dict)


class PushDocumentsInput(BaseModel):
documents: List[ParseableDocument]


class SearchFilter(BaseModel):
class FilterType(str, Enum):
EQ = "$eq"
Expand Down
62 changes: 60 additions & 2 deletions compass_sdk/compass.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import base64
import os
import threading
import uuid
from collections import deque
from dataclasses import dataclass
from statistics import mean
from typing import Dict, Iterator, List, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

import requests
from joblib import Parallel, delayed
Expand All @@ -20,12 +22,15 @@
CompassSdkStage,
Document,
LoggerLevel,
ParseableDocument,
PushDocumentsInput,
PutDocumentsInput,
SearchFilter,
SearchInput,
logger,
)
from compass_sdk.constants import (
DEFAULT_MAX_ACCEPTED_FILE_SIZE_BYTES,
DEFAULT_MAX_CHUNKS_PER_REQUEST,
DEFAULT_MAX_ERROR_RATE,
DEFAULT_MAX_RETRIES,
Expand Down Expand Up @@ -94,6 +99,7 @@ def __init__(
"search_documents": self.session.post,
"add_context": self.session.post,
"refresh": self.session.post,
"push_documents": self.session.post,
}
self.function_endpoint = {
"create_index": "/api/v1/indexes/{index_name}",
Expand All @@ -106,6 +112,7 @@ def __init__(
"search_documents": "/api/v1/indexes/{index_name}/documents/search",
"add_context": "/api/v1/indexes/{index_name}/documents/add_context/{doc_id}",
"refresh": "/api/v1/indexes/{index_name}/refresh",
"push_documents": "/api/v2/indexes/{index_name}/documents",
}
logger.setLevel(logger_level.value)

Expand Down Expand Up @@ -267,6 +274,57 @@ def batch_status(self, *, uuid: str):
else:
raise Exception(f"Failed to get batch status: {resp.status_code} {resp.text}")

def push_document(
self,
*,
index_name: str,
filename: str,
filebytes: bytes,
content_type: str,
document_id: uuid.UUID,
context: Dict[str, Any] = {},
max_retries: int = DEFAULT_MAX_RETRIES,
sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS,
) -> Optional[str]:
"""
Parse and insert a document into an index in Compass
:param index_name: the name of the index
:param filename: the filename of the document
:param filebytes: the bytes of the document
:param content_type: the content type of the document
:param document_id: the id of the document (optional)
:param context: represents an additional information about the document
:param max_retries: the maximum number of times to retry a request if it fails
artemiyatcohere marked this conversation as resolved.
Show resolved Hide resolved
:param sleep_retry_seconds: the number of seconds to wait before retrying an API request
:return: an error message if the request failed, otherwise None
"""
if len(filebytes) > DEFAULT_MAX_ACCEPTED_FILE_SIZE_BYTES:
err = f"File too large, supported file size is {DEFAULT_MAX_ACCEPTED_FILE_SIZE_BYTES / 1000_000} mb"
logger.error(err)
return err

b64 = base64.b64encode(filebytes).decode("utf-8")
doc = ParseableDocument(
id=document_id,
filename=filename,
bytes=b64,
content_type=content_type,
content_length_bytes=len(filebytes),
context=context,
)

result = self._send_request(
function="push_documents",
index_name=index_name,
data=PushDocumentsInput(documents=[doc]),
max_retries=max_retries,
sleep_retry_seconds=sleep_retry_seconds,
)

if result.error:
return result.error
return None

def insert_docs(
self,
*,
Expand Down Expand Up @@ -442,7 +500,7 @@ def _send_request_with_retry():
try:
if data:
if isinstance(data, BaseModel):
data_dict = data.model_dump()
data_dict = data.model_dump(mode="json")
elif isinstance(data, Dict):
data_dict = data

Expand Down
2 changes: 2 additions & 0 deletions compass_sdk/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,5 @@
"date",
"authors",
]

UUID_NAMESPACE = "00000000-0000-0000-0000-000000000000"
benrules3 marked this conversation as resolved.
Show resolved Hide resolved
9 changes: 9 additions & 0 deletions compass_sdk/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import base64
import glob
import os
import uuid
from concurrent import futures
from concurrent.futures import Executor
from typing import Callable, Iterable, Iterator, List, Optional, TypeVar
Expand All @@ -8,6 +10,7 @@
from fsspec import AbstractFileSystem

from compass_sdk import CompassDocument, CompassDocumentMetadata, CompassSdkStage
from compass_sdk.constants import UUID_NAMESPACE

T = TypeVar("T")
U = TypeVar("U")
Expand Down Expand Up @@ -85,3 +88,9 @@ def scan_folder(folder_path: str, allowed_extensions: Optional[List[str]] = None
scanned_files = fs.glob(pattern, recursive=recursive)
all_files.extend([f"{path_prepend}{f}" for f in scanned_files])
return all_files


def generate_doc_id_from_bytes(filebytes: bytes) -> uuid.UUID:
b64_string = base64.b64encode(filebytes).decode("utf-8")
namespace = uuid.UUID(UUID_NAMESPACE)
return uuid.uuid5(namespace, b64_string)
Loading