From 6d33a0108f7e3bd09027d429044e70ba811dc478 Mon Sep 17 00:00:00 2001 From: Rafid Date: Tue, 17 Dec 2024 19:34:51 +0000 Subject: [PATCH] Introduce more Ruff rules to improve code quality --- cohere/compass/__init__.py | 18 +- cohere/compass/clients/compass.py | 280 ++++++++++++++++++--------- cohere/compass/clients/parser.py | 112 +++++------ cohere/compass/clients/rbac.py | 156 ++++++++++++--- cohere/compass/constants.py | 2 +- cohere/compass/exceptions.py | 22 ++- cohere/compass/models/__init__.py | 12 +- cohere/compass/models/config.py | 82 +++++--- cohere/compass/models/datasources.py | 18 +- cohere/compass/models/documents.py | 108 ++++++----- cohere/compass/models/rbac.py | 51 ++++- cohere/compass/models/search.py | 45 +++-- cohere/compass/utils.py | 56 ++++-- pyproject.toml | 27 ++- 14 files changed, 691 insertions(+), 298 deletions(-) diff --git a/cohere/compass/__init__.py b/cohere/compass/__init__.py index 323dc6b..9f1de0e 100644 --- a/cohere/compass/__init__.py +++ b/cohere/compass/__init__.py @@ -1,6 +1,6 @@ # Python imports from enum import Enum -from typing import List, Optional +from typing import Optional # 3rd party imports from pydantic import BaseModel @@ -12,10 +12,12 @@ ValidatedModel, ) -__version__ = "0.8.0" +__version__ = "0.10.2" class ProcessFileParameters(ValidatedModel): + """Model for use with the process_file parser API.""" + parser_config: ParserConfig metadata_config: MetadataConfig doc_id: Optional[str] = None @@ -23,17 +25,23 @@ class ProcessFileParameters(ValidatedModel): class ProcessFilesParameters(ValidatedModel): - doc_ids: Optional[List[str]] = None + """Model for use with the process_files parser API.""" + + doc_ids: Optional[list[str]] = None parser_config: ParserConfig metadata_config: MetadataConfig class GroupAuthorizationActions(str, Enum): + """Enum for use with the edit_group_authorization API to specify the edit type.""" + ADD = "add" REMOVE = "remove" class GroupAuthorizationInput(BaseModel): - document_ids: List[str] - authorized_groups: List[str] + """Model for use with the edit_group_authorization API.""" + + document_ids: list[str] + authorized_groups: list[str] action: GroupAuthorizationActions diff --git a/cohere/compass/clients/compass.py b/cohere/compass/clients/compass.py index 3577cc8..b0891e7 100644 --- a/cohere/compass/clients/compass.py +++ b/cohere/compass/clients/compass.py @@ -1,13 +1,16 @@ # Python imports -from collections import deque -from dataclasses import dataclass -from statistics import mean -from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union import base64 import logging import os import threading import uuid +from collections import deque +from collections.abc import Iterator +from dataclasses import dataclass +from statistics import mean +from typing import Any, Literal, Optional, Union + +import requests # 3rd party imports # TODO find stubs for joblib and remove "type: ignore" @@ -21,7 +24,6 @@ stop_after_attempt, wait_fixed, ) -import requests # Local imports from cohere.compass import ( @@ -48,19 +50,27 @@ DataSource, Document, DocumentStatus, - PaginatedList, ParseableDocument, - PushDocumentsInput, PutDocumentsInput, SearchChunksResponse, SearchDocumentsResponse, SearchFilter, SearchInput, + UploadDocumentsInput, ) +from cohere.compass.models.datasources import PaginatedList @dataclass class RetryResult: + """ + A class to represent the result of a retryable operation. + + The class contains the following fields: + - result: The result of the operation if successful, otherwise None. + - error (Optional[str]): The error message if the operation failed, otherwise None. + """ + result: Optional[dict[str, Any]] = None error: Optional[str] = None @@ -69,6 +79,8 @@ class RetryResult: class CompassClient: + """A compass client to interact with the Compass API.""" + def __init__( self, *, @@ -79,10 +91,13 @@ def __init__( http_session: Optional[requests.Session] = None, ): """ - A compass client to interact with the Compass API - :param index_url: the url of the Compass instance - :param username: the username for the Compass instance - :param password: the password for the Compass instance + Initialize the Compass client. + + :param index_url: The base URL for the index API. + :param username (optional): The username for authentication. + :param password (optional): The password for authentication. + :param bearer_token (optional): The bearer token for authentication. + :param http_session (optional): An optional HTTP session to use for requests. """ self.index_url = index_url self.username = username or os.getenv("COHERE_COMPASS_USERNAME") @@ -120,24 +135,25 @@ def __init__( "put_documents": "/api/v1/indexes/{index_name}/documents", "search_documents": "/api/v1/indexes/{index_name}/documents/_search", "search_chunks": "/api/v1/indexes/{index_name}/documents/_search_chunks", - "add_attributes": "/api/v1/indexes/{index_name}/documents/{document_id}/_add_attributes", + "add_attributes": "/api/v1/indexes/{index_name}/documents/{document_id}/_add_attributes", # noqa: E501 "refresh": "/api/v1/indexes/{index_name}/_refresh", "upload_documents": "/api/v1/indexes/{index_name}/documents/_upload", - "edit_group_authorization": "/api/v1/indexes/{index_name}/group_authorization", + "edit_group_authorization": "/api/v1/indexes/{index_name}/group_authorization", # noqa: E501 # Data Sources APIs "create_datasource": "/api/v1/datasources", "list_datasources": "/api/v1/datasources", "delete_datasources": "/api/v1/datasources/{datasource_id}", "get_datasource": "/api/v1/datasources/{datasource_id}", "sync_datasource": "/api/v1/datasources/{datasource_id}/_sync", - "list_datasources_objects_states": "/api/v1/datasources/{datasource_id}/documents?skip={skip}&limit={limit}", + "list_datasources_objects_states": "/api/v1/datasources/{datasource_id}/documents?skip={skip}&limit={limit}", # noqa: E501 } def create_index(self, *, index_name: str): """ - Create an index in Compass + Create an index in Compass. + :param index_name: the name of the index - :return: the response from the Compass API + :returns: the response from the Compass API """ return self._send_request( api_name="create_index", @@ -148,9 +164,10 @@ def create_index(self, *, index_name: str): def refresh(self, *, index_name: str): """ - Refresh index + Refresh index. + :param index_name: the name of the index - :return: the response from the Compass API + :returns: the response from the Compass API """ return self._send_request( api_name="refresh", @@ -161,9 +178,10 @@ def refresh(self, *, index_name: str): def delete_index(self, *, index_name: str): """ - Delete an index from Compass + Delete an index from Compass. + :param index_name: the name of the index - :return: the response from the Compass API + :returns: the response from the Compass API """ return self._send_request( api_name="delete_index", @@ -174,10 +192,12 @@ def delete_index(self, *, index_name: str): def delete_document(self, *, index_name: str, document_id: str): """ - Delete a document from Compass + Delete a document from Compass. + :param index_name: the name of the index - :document_id: the id of the document - :return: the response from the Compass API + :param document_id: the id of the document + + :returns: the response from the Compass API """ return self._send_request( api_name="delete_document", @@ -189,10 +209,12 @@ def delete_document(self, *, index_name: str, document_id: str): def get_document(self, *, index_name: str, document_id: str): """ - Get a document from Compass + Get a document from Compass. + :param index_name: the name of the index - :document_id: the id of the document - :return: the response from the Compass API + :param document_id: the id of the document + + :returns: the response from the Compass API """ return self._send_request( api_name="get_document", @@ -204,8 +226,9 @@ def get_document(self, *, index_name: str, document_id: str): def list_indexes(self): """ - List all indexes in Compass - :return: the response from the Compass API + List all indexes in Compass. + + :returns: the response from the Compass API """ return self._send_request( api_name="list_indexes", @@ -224,15 +247,16 @@ def add_attributes( sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS, ) -> Optional[RetryResult]: """ - Update the content field of an existing document with additional context + Update the content field of an existing document with additional context. :param index_name: the name of the index :param document_id: the document to modify - :param context: A dictionary of key:value pairs to insert into the content field of a document + :param context: A dictionary of key-value pairs to insert into the content field + of a document :param max_retries: the maximum number of times to retry a doc insertion - :param sleep_retry_seconds: number of seconds to go to sleep before retrying a doc insertion + :param sleep_retry_seconds: number of seconds to go to sleep before retrying a + doc insertion """ - return self._send_request( api_name="add_attributes", document_id=document_id, @@ -249,15 +273,16 @@ def insert_doc( doc: CompassDocument, max_retries: int = DEFAULT_MAX_RETRIES, sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS, - authorized_groups: Optional[List[str]] = None, + authorized_groups: Optional[list[str]] = None, merge_groups_on_conflict: bool = False, - ) -> Optional[List[Dict[str, str]]]: + ) -> Optional[list[dict[str, str]]]: """ - Insert a parsed document into an index in Compass + Insert a parsed document into an index in Compass. + :param index_name: the name of the index :param doc: the parsed compass document :param max_retries: the maximum number of times to retry a doc insertion - :param sleep_retry_seconds: number of seconds to go to sleep before retrying a doc insertion + :param sleep_retry_seconds: interval between the document insertion retries. """ return self.insert_docs( index_name=index_name, @@ -276,12 +301,13 @@ def upload_document( filebytes: bytes, content_type: str, document_id: uuid.UUID, - attributes: Dict[str, Any] = {}, + attributes: dict[str, Any] = {}, max_retries: int = DEFAULT_MAX_RETRIES, sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS, - ) -> Optional[Union[str, Dict[str, Any]]]: + ) -> Optional[Union[str, dict[str, Any]]]: """ - Parse and insert a document into an index in Compass + Parse and insert a document into an index in Compass. + :param index_name: the name of the index :param filename: the filename of the document :param filebytes: the bytes of the document @@ -289,11 +315,13 @@ def upload_document( :param document_id: the id of the document (optional) :param context: represents an additional information about the document :param max_retries: the maximum number of times to retry a request if it fails - :param sleep_retry_seconds: the number of seconds to wait before retrying an API request - :return: an error message if the request failed, otherwise None + :param sleep_retry_seconds: interval between API request retries + + :returns: an error message if the request failed, otherwise None """ if len(filebytes) > DEFAULT_MAX_ACCEPTED_FILE_SIZE_BYTES: - err = f"File too large, supported file size is {DEFAULT_MAX_ACCEPTED_FILE_SIZE_BYTES / 1000_000} mb" + max_file_size_mb = DEFAULT_MAX_ACCEPTED_FILE_SIZE_BYTES / 1000_000 + err = f"File too large, supported file size is {max_file_size_mb} mb" logger.error(err) return err @@ -309,7 +337,7 @@ def upload_document( result = self._send_request( api_name="upload_documents", - data=PushDocumentsInput(documents=[doc]), + data=UploadDocumentsInput(documents=[doc]), max_retries=max_retries, sleep_retry_seconds=sleep_retry_seconds, index_name=index_name, @@ -331,32 +359,40 @@ def insert_docs( errors_sliding_window_size: Optional[int] = 10, skip_first_n_docs: int = 0, num_jobs: Optional[int] = None, - authorized_groups: Optional[List[str]] = None, + authorized_groups: Optional[list[str]] = None, merge_groups_on_conflict: bool = False, - ) -> Optional[List[Dict[str, str]]]: + ) -> Optional[list[dict[str, str]]]: """ - Insert multiple parsed documents into an index in Compass + Insert multiple parsed documents into an index in Compass. + :param index_name: the name of the index :param docs: the parsed documents - :param max_chunks_per_request: the maximum number of chunks to send in a single API request + :param max_chunks_per_request: the maximum number of chunks to send in a single + API request :param num_jobs: the number of parallel jobs to use :param max_error_rate: the maximum error rate allowed :param max_retries: the maximum number of times to retry a request if it fails - :param sleep_retry_seconds: the number of seconds to wait before retrying an API request - :param errors_sliding_window_size: the size of the sliding window to keep track of errors - :param skip_first_n_docs: number of docs to skip indexing. Useful when insertion failed after N documents - :param authorized_groups: the groups that are authorized to access the documents. These groups should exist in RBAC. None passed will make the documents public - :param merge_groups_on_conflict: when doc level security enable, allow upserting documents with static groups + :param sleep_retry_seconds: the number of seconds to wait before retrying an API + request + :param errors_sliding_window_size: the size of the sliding window to keep track + of errors + :param skip_first_n_docs: number of docs to skip indexing. Useful when insertion + failed after N documents + :param authorized_groups: the groups that are authorized to access the + documents. These groups should exist in RBAC. None passed will make the + documents public + :param merge_groups_on_conflict: when doc level security enable, allow upserting + documents with static groups """ def put_request( - request_data: list[Tuple[CompassDocument, Document]], + request_data: list[tuple[CompassDocument, Document]], previous_errors: list[dict[str, str]], num_doc: int, ) -> None: nonlocal num_succeeded, errors errors.extend(previous_errors) - compass_docs: List[CompassDocument] = [ + compass_docs: list[CompassDocument] = [ compass_doc for compass_doc, _ in request_data ] put_docs_input = PutDocumentsInput( @@ -365,9 +401,10 @@ def put_request( merge_groups_on_conflict=merge_groups_on_conflict, ) - # It could be that all documents have errors, in which case we should not send a request - # to the Compass Server. This is a common case when the parsing of the documents fails. - # In this case, only errors will appear in the insertion_docs response + # It could be that all documents have errors, in which case we should not + # send a request to the Compass Server. This is a common case when the + # parsing of the documents fails. In this case, only errors will appear in + # the insertion_docs response if not request_data: return @@ -381,21 +418,18 @@ def put_request( if results.error: for doc in compass_docs: + filename = doc.metadata.filename + error = results.error doc.errors.append( - { - CompassSdkStage.Indexing: f"{doc.metadata.filename}: {results.error}" - } - ) - errors.append( - { - doc.metadata.document_id: f"{doc.metadata.filename}: {results.error}" - } + {CompassSdkStage.Indexing: f"{filename}: {error}"} ) + errors.append({doc.metadata.document_id: f"{filename}: {error}"}) else: num_succeeded += len(compass_docs) - # Keep track of the results of the last N API calls to calculate the error rate - # If the error rate is higher than the threshold, stop the insertion process + # Keep track of the results of the last N API calls to calculate the error + # rate If the error rate is higher than the threshold, stop the insertion + # process error_window.append(results.error) error_rate = ( mean([1 if x else 0 for x in error_window]) @@ -404,8 +438,9 @@ def put_request( ) if error_rate > max_error_rate: raise CompassMaxErrorRateExceeded( - f"[Thread {threading.get_native_id()}]{error_rate * 100}% of insertions failed " - f"in the last {errors_sliding_window_size} API calls. Stopping the insertion process." + f"[Thread {threading.get_native_id()}] {error_rate * 100}% of " + f"insertions failed in the last {errors_sliding_window_size} API " + "calls. Stopping the insertion process." ) error_window: deque[Optional[str]] = deque( @@ -436,7 +471,14 @@ def create_datasource( datasource: CreateDataSource, max_retries: int = DEFAULT_MAX_RETRIES, sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS, - ): + ) -> Union[DataSource, str]: + """ + Create a new datasource in Compass. + + :param datasource: the datasource to create + :param max_retries: the maximum number of times to retry the request + :param sleep_retry_seconds: the number of seconds to sleep between retries + """ result = self._send_request( api_name="create_datasource", max_retries=max_retries, @@ -454,6 +496,12 @@ def list_datasources( max_retries: int = DEFAULT_MAX_RETRIES, sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS, ) -> Union[PaginatedList[DataSource], str]: + """ + List all datasources in Compass. + + :param max_retries: the maximum number of times to retry the request + :param sleep_retry_seconds: the number of seconds to sleep between retries + """ result = self._send_request( api_name="list_datasources", max_retries=max_retries, @@ -471,6 +519,13 @@ def get_datasource( max_retries: int = DEFAULT_MAX_RETRIES, sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS, ): + """ + Get a datasource in Compass. + + :param datasource_id: the id of the datasource + :param max_retries: the maximum number of times to retry the request + :param sleep_retry_seconds: the number of seconds to sleep between retries + """ result = self._send_request( api_name="get_datasource", datasource_id=datasource_id, @@ -489,6 +544,13 @@ def delete_datasource( max_retries: int = DEFAULT_MAX_RETRIES, sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS, ): + """ + Delete a datasource in Compass. + + :param datasource_id: the id of the datasource + :param max_retries: the maximum number of times to retry the request + :param sleep_retry_seconds: the number of seconds to sleep between retries + """ result = self._send_request( api_name="delete_datasources", datasource_id=datasource_id, @@ -507,6 +569,13 @@ def sync_datasource( max_retries: int = DEFAULT_MAX_RETRIES, sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS, ): + """ + Sync a datasource in Compass. + + :param datasource_id: the id of the datasource + :param max_retries: the maximum number of times to retry the request + :param sleep_retry_seconds: the number of seconds to sleep between retries + """ result = self._send_request( api_name="sync_datasource", datasource_id=datasource_id, @@ -527,6 +596,15 @@ def list_datasources_objects_states( max_retries: int = DEFAULT_MAX_RETRIES, sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS, ) -> Union[PaginatedList[DocumentStatus], str]: + """ + List all objects states in a datasource in Compass. + + :param datasource_id: the id of the datasource + :param skip: the number of objects to skip + :param limit: the number of objects to return + :param max_retries: the maximum number of times to retry the request + :param sleep_retry_seconds: the number of seconds to sleep between retries + """ result = self._send_request( api_name="list_datasources_objects_states", datasource_id=datasource_id, @@ -546,12 +624,13 @@ def _get_request_blocks( max_chunks_per_request: int, ): """ - Create request blocks to send to the Compass API + Create request blocks to send to the Compass API. + :param docs: the documents to send - :param max_chunks_per_request: the maximum number of chunks to send in a single API request - :return: an iterator over the request blocks + :param max_chunks_per_request: the maximum number of chunks to send in a single + API request + :returns: an iterator over the request blocks """ - request_block: list[tuple[CompassDocument, Document]] = [] errors: list[dict[str, str]] = [] num_chunks = 0 @@ -561,7 +640,9 @@ def _get_request_blocks( f"Document {doc.metadata.document_id} has errors: {doc.errors}" ) for error in doc.errors: - errors.append({doc.metadata.document_id: list(error.values())[0]}) + errors.append( + {doc.metadata.document_id: next(iter(error.values()))} + ) else: num_chunks += ( len(doc.chunks) @@ -597,14 +678,8 @@ def _search( index_name: str, query: str, top_k: int = 10, - filters: Optional[List[SearchFilter]] = None, + filters: Optional[list[SearchFilter]] = None, ): - """ - Search your Compass index - :param index_name: the name of the index - :param query: query to search for - :param top_k: number of documents to return - """ return self._send_request( api_name=api_name, index_name=index_name, @@ -619,8 +694,18 @@ def search_documents( index_name: str, query: str, top_k: int = 10, - filters: Optional[List[SearchFilter]] = None, + filters: Optional[list[SearchFilter]] = None, ) -> SearchDocumentsResponse: + """ + Search documents in an index. + + :param index_name: the name of the index + :param query: the search query + :param top_k: the number of documents to return + :param filters: the search filters to apply + + :returns: the search results + """ result = self._search( api_name="search_documents", index_name=index_name, @@ -637,8 +722,18 @@ def search_chunks( index_name: str, query: str, top_k: int = 10, - filters: Optional[List[SearchFilter]] = None, + filters: Optional[list[SearchFilter]] = None, ) -> SearchChunksResponse: + """ + Search chunks in an index. + + :param index_name: the name of the index + :param query: the search query + :param top_k: the number of chunks to return + :param filters: the search filters to apply + + :returns: the search results + """ result = self._search( api_name="search_chunks", index_name=index_name, @@ -653,7 +748,8 @@ def edit_group_authorization( self, *, index_name: str, group_auth_input: GroupAuthorizationInput ): """ - Edit group authorization for an index + Edit group authorization for an index. + :param index_name: the name of the index :param group_auth_input: the group authorization input """ @@ -670,17 +766,18 @@ def _send_request( api_name: str, max_retries: int, sleep_retry_seconds: int, - data: Optional[Union[Dict[str, Any], BaseModel]] = None, + data: Optional[Union[dict[str, Any], BaseModel]] = None, **url_params: str, ) -> RetryResult: """ - Send a request to the Compass API + Send a request to the Compass API. + :param function: the function to call :param index_name: the name of the index :param max_retries: the number of times to retry the request :param sleep_retry_seconds: the number of seconds to sleep between retries :param data: the data to send - :return: An error message if the request failed, otherwise None + :returns: An error message if the request failed, otherwise None. """ @retry( @@ -733,7 +830,8 @@ def _send_request_with_retry(): else: error = str(e) + " " + e.response.text logger.error( - f"Failed to send request to {api_name} {target_path}: {type(e)} {error}. Going to sleep for " + f"Failed to send request to {api_name} {target_path}: " + f"{type(e)} {error}. Going to sleep for " f"{sleep_retry_seconds} seconds and retrying." ) raise e @@ -741,8 +839,8 @@ def _send_request_with_retry(): except Exception as e: error = str(e) logger.error( - f"Failed to send request to {api_name} {target_path}: {type(e)} {error}. Going to sleep for " - f"{sleep_retry_seconds} seconds and retrying." + f"Failed to send request to {api_name} {target_path}: {type(e)} " + f"{error}. Sleeping for {sleep_retry_seconds} before retrying..." ) raise e diff --git a/cohere/compass/clients/parser.py b/cohere/compass/clients/parser.py index 0ddd4c3..9fa3bc9 100644 --- a/cohere/compass/clients/parser.py +++ b/cohere/compass/clients/parser.py @@ -1,9 +1,10 @@ # Python imports -from concurrent.futures import ThreadPoolExecutor -from typing import Any, Callable, Dict, Iterable, List, Optional, Union import json import logging import os +from collections.abc import Iterable +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Callable, Optional, Union # 3rd party imports import requests @@ -20,7 +21,7 @@ ) from cohere.compass.utils import imap_queued, open_document, scan_folder -Fn_or_Dict = Union[Dict[str, Any], Callable[[CompassDocument], Dict[str, Any]]] +Fn_or_Dict = Union[dict[str, Any], Callable[[CompassDocument], dict[str, Any]]] logger = logging.getLogger(__name__) @@ -28,20 +29,20 @@ class CompassParserClient: """ - Client to interact with the CompassParser API. It allows to process files using the - parser and metadata configurations specified in the parameters. The client is - stateful, that is, it can be initialized with parser and metadata configurations - that will be used for all subsequent files processed by the client. Also, - independently of the default configurations, the client allows to pass specific - configurations for each file when calling the process_file or process_files methods. - The client is responsible for opening the files and sending them to the - CompassParser API for processing. The resulting documents are returned as - CompassDocument objects. + Client to interact with the CompassParser API. + + It allows to process files using the parser and metadata configurations specified in + the parameters. The client is stateful, that is, it can be initialized with parser + and metadata configurations that will be used for all subsequent files processed by + the client. Also, independently of the default configurations, the client allows to + pass specific configurations for each file when calling the process_file or + process_files methods. The client is responsible for opening the files and sending + them to the CompassParser API for processing. The resulting documents are returned + as CompassDocument objects. :param parser_url: URL of the CompassParser API :param parser_config: Default parser configuration to use when processing files :param metadata_config: Default metadata configuration to use when processing files - """ def __init__( @@ -55,12 +56,13 @@ def __init__( num_workers: int = 4, ): """ - Initializes the CompassParserClient with the specified parser_url, - parser_config, and metadata_config. The parser_config and metadata_config are - optional, and if not provided, the default configurations will be used. If the - parser/metadata configs are provided, they will be used for all subsequent files - processed by the client unless specific configs are passed when calling the - process_file or process_files methods. + Initialize the CompassParserClient. + + The parser_config and metadata_config are optional, and if not provided, the + default configurations will be used. If the parser/metadata configs are + provided, they will be used for all subsequent files processed by the client + unless specific configs are passed when calling the process_file or + process_files methods. :param parser_url: the URL of the CompassParser API :param parser_config: the parser configuration to use when processing files if @@ -89,18 +91,19 @@ def process_folder( self, *, folder_path: str, - allowed_extensions: Optional[List[str]] = None, + allowed_extensions: Optional[list[str]] = None, recursive: bool = False, parser_config: Optional[ParserConfig] = None, metadata_config: Optional[MetadataConfig] = None, custom_context: Optional[Fn_or_Dict] = None, ): """ - Processes all the files in the specified folder using the default parser and - metadata configurations passed when creating the client. The method iterates - over all the files in the folder and processes them using the process_file - method. The resulting documents are returned as a list of CompassDocument - objects. + Process all the files in the specified folder. + + The files are processed using the default parser and metadata configurations + passed when creating the client. The method iterates over all the files in the + folder and processes them using the process_file method. The resulting documents + are returned as a list of CompassDocument objects. :param folder_path: the folder to process :param allowed_extensions: the list of allowed extensions to process @@ -115,7 +118,7 @@ def process_folder( be filterable but not semantically searchable. Can either be a dictionary or a callable that takes a CompassDocument and returns a dictionary. - :return: the list of processed documents + :returns: the list of processed documents """ filenames = scan_folder( folder_path=folder_path, @@ -132,15 +135,14 @@ def process_folder( def process_files( self, *, - filenames: List[str], - file_ids: Optional[List[str]] = None, + filenames: list[str], + file_ids: Optional[list[str]] = None, parser_config: Optional[ParserConfig] = None, metadata_config: Optional[MetadataConfig] = None, custom_context: Optional[Fn_or_Dict] = None, ) -> Iterable[CompassDocument]: """ - Processes a list of files provided as filenames, using the specified parser and - metadata configurations. + Process a list of files. If the parser/metadata configs are not provided, then the default configs passed by parameter when creating the client will be used. This makes the @@ -151,7 +153,7 @@ def process_files( All the documents passed as filenames and opened to obtain their bytes. Then, they are packed into a ProcessFilesParameters object that contains a list of ProcessFileParameters, each contain a file, its id, and the parser/metadata - config + config. :param filenames: List of filenames to process :param file_ids: List of ids for the files @@ -162,10 +164,10 @@ def process_files( be filterable but not semantically searchable. Can either be a dictionary or a callable that takes a CompassDocument and returns a dictionary. - :return: List of processed documents + :returns: List of processed documents """ - def process_file(i: int) -> List[CompassDocument]: + def process_file(i: int) -> list[CompassDocument]: return self.process_file( filename=filenames[i], file_id=file_ids[i] if file_ids else None, @@ -185,7 +187,7 @@ def process_file(i: int) -> List[CompassDocument]: @staticmethod def _get_metadata( doc: CompassDocument, custom_context: Optional[Fn_or_Dict] = None - ) -> Dict[str, Any]: + ) -> dict[str, Any]: if custom_context is None: return {} elif callable(custom_context): @@ -202,35 +204,39 @@ def process_file( parser_config: Optional[ParserConfig] = None, metadata_config: Optional[MetadataConfig] = None, custom_context: Optional[Fn_or_Dict] = None, - ) -> List[CompassDocument]: + ) -> list[CompassDocument]: """ - Takes in a file, its id, and the parser/metadata config. If the config is None, - then it uses the default configs passed by parameter when creating the client. - This makes the CompassParserClient stateful for convenience, that is, one can - pass in the parser/metadata config only once when creating the - CompassParserClient, and process files without having to pass the config every - time - - :param filename: Filename to process - :param file_id: Id for the file - :param content_type: Content type of the file - :param parser_config: ParserConfig object with the config to use for parsing the file + Process a file. + + The method takes in a file, its id, and the parser/metadata config. If the + config is None, then it uses the default configs passed by parameter when + creating the client. This makes the CompassParserClient stateful for + convenience, that is, one can pass in the parser/metadata config only once when + creating the CompassParserClient, and process files without having to pass the + config every time. + + :param filename: Filename to process. + :param file_id: Id for the file. + :param content_type: Content type of the file. + :param parser_config: ParserConfig object with the config to use for parsing the + file. :param metadata_config: MetadataConfig object with the config to use for - extracting metadata for each document + extracting metadata for each document. :param custom_context: Additional data to add to compass document. Fields will be filterable but not semantically searchable. Can either be a dictionary or a callable that takes a CompassDocument and returns a dictionary. - :return: List of resulting documents + :returns: List of resulting documents """ doc = open_document(filename) if doc.errors: logger.error(f"Error opening document: {doc.errors}") return [] if len(doc.filebytes) > DEFAULT_MAX_ACCEPTED_FILE_SIZE_BYTES: + max_size_mb = DEFAULT_MAX_ACCEPTED_FILE_SIZE_BYTES / 1000_000 logger.error( - f"File too large, supported file size is {DEFAULT_MAX_ACCEPTED_FILE_SIZE_BYTES / 1000_000} " - f"mb, filename {doc.metadata.filename}" + f"File too large, supported file size is {max_size_mb} mb, " + f"filename {doc.metadata.filename}" ) return [] @@ -270,11 +276,7 @@ def process_file( return docs @staticmethod - def _adapt_doc_id_compass_doc(doc: Dict[Any, Any]) -> CompassDocument: - """ - Adapt the doc_id to document_id - """ - + def _adapt_doc_id_compass_doc(doc: dict[Any, Any]) -> CompassDocument: metadata = doc["metadata"] if "document_id" not in metadata: metadata["document_id"] = metadata.pop("doc_id") diff --git a/cohere/compass/clients/rbac.py b/cohere/compass/clients/rbac.py index 1d7cfad..f09905b 100644 --- a/cohere/compass/clients/rbac.py +++ b/cohere/compass/clients/rbac.py @@ -1,5 +1,5 @@ import json -from typing import Dict, List, Type, TypeVar +from typing import TypeVar import requests from pydantic import BaseModel @@ -27,7 +27,15 @@ class CompassRootClient: + """Client for interacting with Compass RBAC API as a root user.""" + def __init__(self, compass_url: str, root_user_token: str): + """ + Initialize a new CompassRootClient. + + :param compass_url: URL of the Compass instance. + :param root_user_token: Root user token for Compass instance. + """ self.base_url = compass_url + "/api/security/admin/rbac" self.headers = { "Authorization": f"Bearer {root_user_token}", @@ -36,18 +44,18 @@ def __init__(self, compass_url: str, root_user_token: str): T = TypeVar("T", bound=BaseModel) U = TypeVar("U", bound=BaseModel) - Headers = Dict[str, str] + Headers = dict[str, str] @staticmethod - def _fetch_entities(url: str, headers: Headers, entity_type: Type[T]) -> List[T]: + def _fetch_entities(url: str, headers: Headers, entity_type: type[T]) -> list[T]: response = requests.get(url, headers=headers) CompassRootClient.raise_for_status(response) return [entity_type.model_validate(entity) for entity in response.json()] @staticmethod def _create_entities( - url: str, headers: Headers, entity_request: List[T], entity_response: Type[U] - ) -> List[U]: + url: str, headers: Headers, entity_request: list[T], entity_response: type[U] + ) -> list[U]: response = requests.post( url, json=[json.loads(entity.model_dump_json()) for entity in entity_request], @@ -60,36 +68,63 @@ def _create_entities( @staticmethod def _delete_entities( - url: str, headers: Headers, names: List[str], entity_response: Type[U] - ) -> List[U]: + url: str, headers: Headers, names: list[str], entity_response: type[U] + ) -> list[U]: entities = ",".join(names) response = requests.delete(f"{url}/{entities}", headers=headers) CompassRootClient.raise_for_status(response) return [entity_response.model_validate(entity) for entity in response.json()] - def fetch_users(self) -> List[UserFetchResponse]: + def fetch_users(self) -> list[UserFetchResponse]: + """ + Fetch all users from Compass. + + :returns: A list containing the users. + """ return self._fetch_entities( f"{self.base_url}/v1/users", self.headers, UserFetchResponse ) - def fetch_groups(self) -> List[GroupFetchResponse]: + def fetch_groups(self) -> list[GroupFetchResponse]: + """ + Fetch all groups from Compass. + + :returns: A list containing the groups. + """ return self._fetch_entities( f"{self.base_url}/v1/groups", self.headers, GroupFetchResponse ) - def fetch_roles(self) -> List[RoleFetchResponse]: + def fetch_roles(self) -> list[RoleFetchResponse]: + """ + Fetch all roles from Compass. + + :returns: A list containing the roles. + """ return self._fetch_entities( f"{self.base_url}/v1/roles", self.headers, RoleFetchResponse ) - def fetch_role_mappings(self) -> List[RoleMappingResponse]: + def fetch_role_mappings(self) -> list[RoleMappingResponse]: + """ + Fetch all role mappings from Compass. + + :returns: A list containing the role mappings. + """ return self._fetch_entities( f"{self.base_url}/v1/role-mappings", self.headers, RoleMappingResponse ) def create_users( - self, *, users: List[UserCreateRequest] - ) -> List[UserCreateResponse]: + self, *, users: list[UserCreateRequest] + ) -> list[UserCreateResponse]: + """ + Create new users in Compass. + + :param users: List of users to be created. + + :returns: A list containing the created users. + """ return self._create_entities( url=f"{self.base_url}/v1/users", headers=self.headers, @@ -98,8 +133,15 @@ def create_users( ) def create_groups( - self, *, groups: List[GroupCreateRequest] - ) -> List[GroupCreateResponse]: + self, *, groups: list[GroupCreateRequest] + ) -> list[GroupCreateResponse]: + """ + Create new groups in Compass. + + :param groups: List of groups to be created. + + :returns: A list containing the created groups. + """ return self._create_entities( url=f"{self.base_url}/v1/groups", headers=self.headers, @@ -108,8 +150,15 @@ def create_groups( ) def create_roles( - self, *, roles: List[RoleCreateRequest] - ) -> List[RoleCreateResponse]: + self, *, roles: list[RoleCreateRequest] + ) -> list[RoleCreateResponse]: + """ + Create new roles in Compass. + + :param roles: List of roles to be created. + + :returns: A list containing the created roles. + """ return self._create_entities( url=f"{self.base_url}/v1/roles", headers=self.headers, @@ -118,8 +167,15 @@ def create_roles( ) def create_role_mappings( - self, *, role_mappings: List[RoleMappingRequest] - ) -> List[RoleMappingResponse]: + self, *, role_mappings: list[RoleMappingRequest] + ) -> list[RoleMappingResponse]: + """ + Create new role mappings in Compass. + + :param role_mappings: List of role mappings to be created. + + :returns: A list containing the created role mappings. + """ return self._create_entities( url=f"{self.base_url}/v1/role-mappings", headers=self.headers, @@ -127,24 +183,53 @@ def create_role_mappings( entity_response=RoleMappingResponse, ) - def delete_users(self, *, user_names: List[str]) -> List[UserDeleteResponse]: + def delete_users(self, *, user_names: list[str]) -> list[UserDeleteResponse]: + """ + Delete users from Compass. + + :param user_names: List of user names to be deleted. + + :returns: A list containing the deleted users. + """ return self._delete_entities( f"{self.base_url}/v1/users", self.headers, user_names, UserDeleteResponse ) - def delete_groups(self, *, group_names: List[str]) -> List[GroupDeleteResponse]: + def delete_groups(self, *, group_names: list[str]) -> list[GroupDeleteResponse]: + """ + Delete groups from Compass. + + :param group_names: List of group names to be deleted. + + :returns: A list containing the deleted groups. + """ return self._delete_entities( f"{self.base_url}/v1/groups", self.headers, group_names, GroupDeleteResponse ) - def delete_roles(self, *, role_ids: List[str]) -> List[RoleDeleteResponse]: + def delete_roles(self, *, role_ids: list[str]) -> list[RoleDeleteResponse]: + """ + Delete roles from Compass. + + :param role_ids: List of role IDs to be deleted. + + :returns: A list containing the deleted roles. + """ return self._delete_entities( f"{self.base_url}/v1/roles", self.headers, role_ids, RoleDeleteResponse ) def delete_role_mappings( self, *, role_name: str, group_name: str - ) -> List[RoleMappingDeleteResponse]: + ) -> list[RoleMappingDeleteResponse]: + """ + Delete role mappings from Compass. + + :param role_name: Name of the role. + :param group_name: Name of the group. + + :returns: A list containing the deleted role mappings. + """ response = requests.delete( f"{self.base_url}/v1/role-mappings/role/{role_name}/group/{group_name}", headers=self.headers, @@ -158,6 +243,14 @@ def delete_role_mappings( def delete_user_group( self, *, group_name: str, user_name: str ) -> GroupUserDeleteResponse: + """ + Remove a user from a group. + + :param group_name: Name of the group. + :param user_name: Name of the user. + + :returns: Response containing the group name and user name. + """ response = requests.delete( f"{self.base_url}/v1/group/{group_name}/user/{user_name}", headers=self.headers, @@ -166,8 +259,16 @@ def delete_user_group( return GroupUserDeleteResponse.model_validate(response.json()) def update_role( - self, *, role_name: str, policies: List[PolicyRequest] + self, *, role_name: str, policies: list[PolicyRequest] ) -> RoleCreateResponse: + """ + Update the policies of a role. + + :param role_name: Name of the role. + :param policies: List of policies to be updated. + + :returns: Response containing the updated role and its new policies. + """ response = requests.put( f"{self.base_url}/v1/roles/{role_name}", json=[json.loads(policy.model_dump_json()) for policy in policies], @@ -178,8 +279,13 @@ def update_role( @staticmethod def raise_for_status(response: requests.Response): - """Raises :class:`HTTPError`, if one occurred.""" + """ + Raise an exception if the response status code is not in the 200 range. + + :param response: Response object from the request. + :raises HTTPError: If the response status code is not in the 200 range. + """ http_error_msg = "" if isinstance(response.reason, bytes): # We attempt to decode utf-8 first because some servers diff --git a/cohere/compass/constants.py b/cohere/compass/constants.py index 9de2cf7..fe1a106 100644 --- a/cohere/compass/constants.py +++ b/cohere/compass/constants.py @@ -26,7 +26,7 @@ Do not write the ```json (...) ``` tag. The output should be a valid JSON. If you cannot find the information, write "" for the corresponding field. Answer: - """ + """ # noqa: E501, W291 METADATA_HEURISTICS_ATTRIBUTES = [ "title", "name", diff --git a/cohere/compass/exceptions.py b/cohere/compass/exceptions.py index 351a73f..c074f1b 100644 --- a/cohere/compass/exceptions.py +++ b/cohere/compass/exceptions.py @@ -1,7 +1,11 @@ class CompassClientError(Exception): """Exception raised for all 4xx client errors in the Compass client.""" - def __init__(self, message: str = "Client error occurred.", code: int = 400): + def __init__( # noqa: D107 + self, + message: str = "Client error occurred.", + code: int = 400, + ): self.message = message self.code = code super().__init__(self.message) @@ -10,7 +14,7 @@ def __init__(self, message: str = "Client error occurred.", code: int = 400): class CompassAuthError(CompassClientError): """Exception raised for authentication errors in the Compass client.""" - def __init__( + def __init__( # noqa: D107 self, message: str = ( "CompassAuthError - check your bearer token or username and password." @@ -21,12 +25,18 @@ def __init__( class CompassMaxErrorRateExceeded(Exception): - """Exception raised when the error rate exceeds the maximum allowed error rate in - the Compass client.""" + """ + Exception raised if the error rate during document insertion exceeds the max. + + When the user calls the insert_docs() method, an optional max_error_rate parameter + can be passed to the method. If the error rate during the insertion process exceeds + this max_error_rate, the insertion process will be stopped and this exception will + be raised. + """ - def __init__( + def __init__( # noqa: D107 self, - message: str = "The maximum error rate was exceeded. Stopping the insertion process.", + message: str = "The maximum error rate was exceeded. Stopping the insertion.", ): self.message = message super().__init__(self.message) diff --git a/cohere/compass/models/__init__.py b/cohere/compass/models/__init__.py index 9ec4f09..453e574 100644 --- a/cohere/compass/models/__init__.py +++ b/cohere/compass/models/__init__.py @@ -5,15 +5,25 @@ class ValidatedModel(BaseModel): - class Config: + """A subclass of BaseModel providing additional validation during initialization.""" + + class Config: # noqa: D106 arbitrary_types_allowed = True use_enum_values = True @classmethod def attribute_in_model(cls, attr_name: str): + """Check if a given attribute name is present in the model fields.""" return attr_name in cls.model_fields def __init__(self, **data: dict[str, Any]): + """ + Initialize the model with the given data. + + :param data: A dictionary of attribute names and their values. + + :raises ValueError: If an attribute name in the data is not valid for the model. + """ for name, _value in data.items(): if not self.attribute_in_model(name): raise ValueError( diff --git a/cohere/compass/models/config.py b/cohere/compass/models/config.py index 4dbd427..747d593 100644 --- a/cohere/compass/models/config.py +++ b/cohere/compass/models/config.py @@ -1,8 +1,8 @@ # Python imports +import math from enum import Enum from os import getenv -from typing import Any, List, Optional -import math +from typing import Any, Optional # 3rd party imports from pydantic import BaseModel, ConfigDict @@ -24,6 +24,8 @@ class DocumentFormat(str, Enum): + """Enum for specifying the output format of the parsed document.""" + Markdown = "markdown" Text = "text" @@ -33,6 +35,8 @@ def _missing_(cls, value: Any): class PDFParsingStrategy(str, Enum): + """Enum for specifying the parsing strategy for PDF files.""" + QuickText = "QuickText" ImageToMarkdown = "ImageToMarkdown" @@ -42,6 +46,8 @@ def _missing_(cls, value: Any): class PresentationParsingStrategy(str, Enum): + """Enum for specifying the parsing strategy for presentation files.""" + Unstructured = "Unstructured" ImageToMarkdown = "ImageToMarkdown" @@ -51,6 +57,8 @@ def _missing_(cls, value: Any): class ParsingStrategy(str, Enum): + """Enum for specifying the parsing strategy to use.""" + Fast = "fast" Hi_Res = "hi_res" @@ -60,10 +68,13 @@ def _missing_(cls, value: Any): class ParsingModel(str, Enum): - Marker = "marker" # Default model, it is actually a combination of models used by the Marker PDF parser - YoloX_Quantized = ( - "yolox_quantized" # Only PDF parsing working option from Unstructured - ) + """Enum for specifying the parsing model to use.""" + + # Default model, which is actually a combination of models used by the "Marker" PDF + # parser + Marker = "marker" + # Only PDF parsing working option from Unstructured + YoloX_Quantized = "yolox_quantized" @classmethod def _missing_(cls, value: Any): @@ -72,22 +83,26 @@ def _missing_(cls, value: Any): class ParserConfig(BaseModel): """ - CompassParser configuration. Important parameters: + A model class for specifying parsing configuration. + + Important parameters: + :param parsing_strategy: the parsing strategy to use: - 'auto' (default): automatically determine the best strategy - 'fast': leverage traditional NLP extraction techniques to quickly pull all the - text elements. “Fast” strategy is not good for image based file types. - - 'hi_res': identifies the layout of the document using detectron2. The advantage of “hi_res” - is that it uses the document layout to gain additional information about document elements. - We recommend using this strategy if your use case is highly sensitive to correct - classifications for document elements. - - 'ocr_only': leverage Optical Character Recognition to extract text from the image based files. + text elements. “Fast” strategy is not good for image based file types. + - 'hi_res': identifies the layout of the document using detectron2. The + advantage of “hi_res” is that it uses the document layout to gain additional + information about document elements. We recommend using this strategy if your + use case is highly sensitive to correct classifications for document elements. + - 'ocr_only': leverage Optical Character Recognition to extract text from the + image based files. :param parsing_model: the parsing model to use. One of: - - yolox_quantized (default): single-stage object detection model, quantized. Runs faster than YoloX - See https://unstructured-io.github.io/unstructured/best_practices/models.html for more details. - We have temporarily removed the option to use other models because - of ongoing stability issues. - + - yolox_quantized (default): single-stage object detection model, quantized. + Runs faster than YoloX. See + https://unstructured-io.github.io/unstructured/best_practices/models.html for + more details. We have temporarily removed the option to use other models + because of ongoing stability issues. """ model_config = ConfigDict( @@ -99,9 +114,9 @@ class ParserConfig(BaseModel): parse_tables: bool = True parse_images: bool = True parsed_images_output_dir: Optional[str] = None - allowed_image_types: Optional[List[str]] = None + allowed_image_types: Optional[list[str]] = None min_chars_per_element: int = DEFAULT_MIN_CHARS_PER_ELEMENT - skip_infer_table_types: List[str] = SKIP_INFER_TABLE_TYPES + skip_infer_table_types: list[str] = SKIP_INFER_TABLE_TYPES parsing_strategy: ParsingStrategy = ParsingStrategy.Fast parsing_model: ParsingModel = ParsingModel.YoloX_Quantized @@ -128,6 +143,8 @@ class ParserConfig(BaseModel): class MetadataStrategy(str, Enum): + """Enum for specifying the strategy for metadata detection.""" + No_Metadata = "no_metadata" Naive_Title = "naive_title" KeywordSearch = "keyword_search" @@ -142,20 +159,27 @@ def _missing_(cls, value: Any): class MetadataConfig(ValidatedModel): """ - Configuration class for metadata detection. + A model class for specifying configuration related to document metadata detection. + :param metadata_strategy: the metadata detection strategy to use. One of: - No_Metadata: no metadata is inferred - Heuristics: metadata is inferred using heuristics - Bart: metadata is inferred using the BART summarization model - Command_R: metadata is inferred using the Command-R summarization model :param cohere_api_key: the Cohere API key to use for metadata detection - :param commandr_model_name: the name of the Command-R model to use for metadata detection + :param commandr_model_name: the name of the Command-R model to use for metadata + detection :param commandr_prompt: the prompt to use for the Command-R model - :param commandr_extractable_attributes: the extractable attributes for the Command-R model - :param commandr_max_tokens: the maximum number of tokens to use for the Command-R model - :param keyword_search_attributes: the attributes to search for in the document when using keyword search - :param keyword_search_separator: the separator to use for nested attributes when using keyword search - :param ignore_errors: if set to True, metadata detection errors will not be raised or stop the parsing process + :param commandr_extractable_attributes: the extractable attributes for the Command-R + model + :param commandr_max_tokens: the maximum number of tokens to use for the Command-R + model + :param keyword_search_attributes: the attributes to search for in the document when + using keyword search + :param keyword_search_separator: the separator to use for nested attributes when + using keyword search + :param ignore_errors: if set to True, metadata detection errors will not be raised + or stop the parsing process """ @@ -164,7 +188,7 @@ class MetadataConfig(ValidatedModel): commandr_model_name: str = "command-r" commandr_prompt: str = DEFAULT_COMMANDR_PROMPT commandr_max_tokens: int = 500 - commandr_extractable_attributes: List[str] = DEFAULT_COMMANDR_EXTRACTABLE_ATTRIBUTES - keyword_search_attributes: List[str] = METADATA_HEURISTICS_ATTRIBUTES + commandr_extractable_attributes: list[str] = DEFAULT_COMMANDR_EXTRACTABLE_ATTRIBUTES + keyword_search_attributes: list[str] = METADATA_HEURISTICS_ATTRIBUTES keyword_search_separator: str = "." ignore_errors: bool = True diff --git a/cohere/compass/models/datasources.py b/cohere/compass/models/datasources.py index 0b91693..72c64f6 100644 --- a/cohere/compass/models/datasources.py +++ b/cohere/compass/models/datasources.py @@ -11,16 +11,22 @@ class PaginatedList(pydantic.BaseModel, typing.Generic[T]): - value: typing.List[T] + """Model class for a paginated list of items.""" + + value: list[T] skip: typing.Optional[int] limit: typing.Optional[int] class OneDriveConfig(pydantic.BaseModel): + """Model class for OneDrive configuration.""" + type: typing.Literal["msft_onedrive"] class AzureBlobStorageConfig(pydantic.BaseModel): + """Model class for Azure Blob Storage configuration.""" + type: typing.Literal["msft_azure_blob_storage"] connection_string: str container_name: str @@ -34,25 +40,31 @@ class AzureBlobStorageConfig(pydantic.BaseModel): class DataSource(pydantic.BaseModel): + """Model class for a data source.""" + id: typing.Optional[pydantic.UUID4] = None name: str description: typing.Optional[str] = None config: DatasourceConfig - destinations: typing.List[str] + destinations: list[str] enabled: bool = True created_at: typing.Optional[datetime.datetime] = None updated_at: typing.Optional[datetime.datetime] = None class CreateDataSource(pydantic.BaseModel): + """Model class for the create_datasource API.""" + datasource: DataSource state_key: typing.Optional[str] = None class DocumentStatus(pydantic.BaseModel): + """Model class for the response of the list_datasources_objects_states API.""" + document_id: str source_id: typing.Optional[str] state: str - destinations: typing.List[str] + destinations: list[str] created_at: datetime.datetime updated_at: typing.Optional[datetime.datetime] = None diff --git a/cohere/compass/models/documents.py b/cohere/compass/models/documents.py index ecccc73..e9e7ace 100644 --- a/cohere/compass/models/documents.py +++ b/cohere/compass/models/documents.py @@ -1,8 +1,8 @@ # Python imports +import uuid from dataclasses import field from enum import Enum -from typing import Annotated, Any, Dict, List, Optional -import uuid +from typing import Annotated, Any, Optional # 3rd party imports from pydantic import BaseModel, Field, PositiveInt, StringConstraints @@ -12,9 +12,7 @@ class CompassDocumentMetadata(ValidatedModel): - """ - Compass document metadata - """ + """Compass document metadata.""" document_id: str = "" filename: str = "" @@ -23,7 +21,9 @@ class CompassDocumentMetadata(ValidatedModel): class AssetType(str, Enum): - def __str__(self) -> str: + """Enum specifying the different types of assets.""" + + def __str__(self) -> str: # noqa: D105 return self.value # A page that has been rendered as an image @@ -35,29 +35,37 @@ def __str__(self) -> str: class CompassDocumentChunkAsset(BaseModel): + """An asset associated with a Compass document chunk.""" + asset_type: AssetType content_type: str asset_data: str class CompassDocumentChunk(BaseModel): + """A chunk of a Compass document.""" + chunk_id: str sort_id: str document_id: str parent_document_id: str - content: Dict[str, Any] - origin: Optional[Dict[str, Any]] = None + content: dict[str, Any] + origin: Optional[dict[str, Any]] = None assets: Optional[list[CompassDocumentChunkAsset]] = None path: Optional[str] = "" def parent_doc_is_split(self): + """ + Check if the parent document is split. + + :returns: True if the document ID is different from the parent document ID, + indicating that the parent document is split; False otherwise. + """ return self.document_id != self.parent_document_id class CompassDocumentStatus(str, Enum): - """ - Compass document status - """ + """Compass document status.""" Success = "success" ParsingErrors = "parsing-errors" @@ -66,9 +74,7 @@ class CompassDocumentStatus(str, Enum): class CompassSdkStage(str, Enum): - """ - Compass SDK stages - """ + """Compass SDK stages.""" Parsing = "parsing" Metadata = "metadata" @@ -78,40 +84,47 @@ class CompassSdkStage(str, Enum): class CompassDocument(ValidatedModel): """ - A Compass document contains all the information required to process a document and - insert it into the index. It includes: + A model class for a Compass document. + + The model contains all the information required to process a document and insert it + into the index. It includes: + - metadata: the document metadata (e.g., filename, title, authors, date) - content: the document content in string format - - elements: the document's Unstructured elements (e.g., tables, images, text). Used - for chunking - - chunks: the document's chunks (e.g., paragraphs, tables, images). Used for indexing + - elements: the document's Unstructured elements (e.g., tables, images, text). + - chunks: the document's chunks (e.g., paragraphs, tables, images). - index_fields: the fields to be indexed. Used by the indexer """ filebytes: bytes = b"" metadata: CompassDocumentMetadata = CompassDocumentMetadata() - content: Dict[str, str] = field(default_factory=dict) + content: dict[str, str] = field(default_factory=dict) content_type: Optional[str] = None - elements: List[Any] = field(default_factory=list) - chunks: List[CompassDocumentChunk] = field(default_factory=list) - index_fields: List[str] = field(default_factory=list) - errors: List[Dict[CompassSdkStage, str]] = field(default_factory=list) + elements: list[Any] = field(default_factory=list) + chunks: list[CompassDocumentChunk] = field(default_factory=list) + index_fields: list[str] = field(default_factory=list) + errors: list[dict[CompassSdkStage, str]] = field(default_factory=list) ignore_metadata_errors: bool = True markdown: Optional[str] = None def has_data(self) -> bool: + """Check if the document has any data.""" return len(self.filebytes) > 0 def has_markdown(self) -> bool: + """Check if the document has a markdown representation.""" return self.markdown is not None def has_filename(self) -> bool: + """Check if the document has a filename.""" return len(self.metadata.filename) > 0 def has_metadata(self) -> bool: + """Check if the document has metadata.""" return len(self.metadata.meta) > 0 def has_parsing_errors(self) -> bool: + """Check if the document has parsing errors.""" return any( stage == CompassSdkStage.Parsing for error in self.errors @@ -119,6 +132,7 @@ def has_parsing_errors(self) -> bool: ) def has_metadata_errors(self) -> bool: + """Check if the document has metadata errors.""" return any( stage == CompassSdkStage.Metadata for error in self.errors @@ -126,6 +140,7 @@ def has_metadata_errors(self) -> bool: ) def has_indexing_errors(self) -> bool: + """Check if the document has indexing errors.""" return any( stage == CompassSdkStage.Indexing for error in self.errors @@ -134,6 +149,7 @@ def has_indexing_errors(self) -> bool: @property def status(self) -> CompassDocumentStatus: + """Get the document status.""" if self.has_parsing_errors(): return CompassDocumentStatus.ParsingErrors @@ -147,40 +163,40 @@ def status(self) -> CompassDocumentStatus: class DocumentChunkAsset(BaseModel): + """Model class for an asset associated with a document chunk.""" + asset_type: AssetType content_type: str asset_data: str class Chunk(BaseModel): + """Model class for a chunk of a document.""" + chunk_id: str sort_id: int parent_document_id: str path: str = "" - content: Dict[str, Any] - origin: Optional[Dict[str, Any]] = None - assets: Optional[List[DocumentChunkAsset]] = None - asset_ids: Optional[List[str]] = None + content: dict[str, Any] + origin: Optional[dict[str, Any]] = None + assets: Optional[list[DocumentChunkAsset]] = None + asset_ids: Optional[list[str]] = None class Document(BaseModel): - """ - A document that can be indexed in Compass (i.e., a list of indexable chunks) - """ + """Model class for a document.""" document_id: str path: str parent_document_id: str - content: Dict[str, Any] - chunks: List[Chunk] - index_fields: Optional[List[str]] = None - authorized_groups: Optional[List[str]] = None + content: dict[str, Any] + chunks: list[Chunk] + index_fields: Optional[list[str]] = None + authorized_groups: Optional[list[str]] = None class ParseableDocument(BaseModel): - """ - A document to be sent to Compass in bytes format for parsing on the Compass side - """ + """A document to be sent to Compass for parsing.""" id: uuid.UUID filename: Annotated[ @@ -189,18 +205,18 @@ class ParseableDocument(BaseModel): content_type: str content_length_bytes: PositiveInt # File size must be a non-negative integer content_encoded_bytes: str # Base64-encoded file contents - attributes: Dict[str, Any] = Field(default_factory=dict) + attributes: dict[str, Any] = Field(default_factory=dict) + +class UploadDocumentsInput(BaseModel): + """A model for the input of a call to upload_documents API.""" -class PushDocumentsInput(BaseModel): - documents: List[ParseableDocument] + documents: list[ParseableDocument] class PutDocumentsInput(BaseModel): - """ - A Compass request to put a list of Document - """ + """A model for the input of a call to put_documents API.""" - documents: List[Document] - authorized_groups: Optional[List[str]] = None + documents: list[Document] + authorized_groups: Optional[list[str]] = None merge_groups_on_conflict: bool = False diff --git a/cohere/compass/models/rbac.py b/cohere/compass/models/rbac.py index 4375f06..eb08e43 100644 --- a/cohere/compass/models/rbac.py +++ b/cohere/compass/models/rbac.py @@ -1,96 +1,133 @@ # Python imports from enum import Enum -from typing import List # 3rd party imports from pydantic import BaseModel class UserFetchResponse(BaseModel): + """Response model for fetching user details.""" + name: str class UserCreateRequest(BaseModel): + """Request model for creating a new user.""" + name: str class UserCreateResponse(BaseModel): + """Response model for creating a new user.""" + name: str token: str class UserDeleteResponse(BaseModel): + """Response model for deleting a user.""" + name: str class GroupFetchResponse(BaseModel): + """Response model for fetching group details.""" + name: str user_name: str class GroupCreateRequest(BaseModel): + """Request model for creating a new group.""" + name: str - user_names: List[str] + user_names: list[str] class GroupCreateResponse(BaseModel): + """Response model for creating a new group.""" + name: str user_name: str class GroupDeleteResponse(BaseModel): + """Response model for deleting a group.""" + name: str class GroupUserDeleteResponse(BaseModel): + """Response model for removing a user from a group.""" + group_name: str user_name: str class Permission(Enum): + """Enumeration for user permissions.""" + READ = "read" WRITE = "write" class PolicyRequest(BaseModel): - indexes: List[str] + """Request model for creating a policy.""" + + indexes: list[str] permission: Permission class PolicyResponse(BaseModel): - indexes: List[str] + """Response model for retrieving a policy.""" + + indexes: list[str] permission: str class RoleFetchResponse(BaseModel): + """Response model for fetching role details.""" + name: str - policies: List[PolicyResponse] + policies: list[PolicyResponse] class RoleCreateRequest(BaseModel): + """Request model for creating a new role.""" + name: str - policies: List[PolicyRequest] + policies: list[PolicyRequest] class RoleCreateResponse(BaseModel): + """Response model for creating a new role.""" + name: str - policies: List[PolicyResponse] + policies: list[PolicyResponse] class RoleDeleteResponse(BaseModel): + """Response model for deleting a role.""" + name: str class RoleMappingRequest(BaseModel): + """Request model for mapping a role to a group.""" + role_name: str group_name: str class RoleMappingResponse(BaseModel): + """Response model for retrieving role-to-group mapping details.""" + role_name: str group_name: str class RoleMappingDeleteResponse(BaseModel): + """Response model for deleting a role-to-group mapping.""" + role_name: str group_name: str diff --git a/cohere/compass/models/search.py b/cohere/compass/models/search.py index 8a3af34..9e77730 100644 --- a/cohere/compass/models/search.py +++ b/cohere/compass/models/search.py @@ -1,55 +1,72 @@ # Python imports from enum import Enum -from typing import Any, Dict, List, Optional -from cohere.compass.models.documents import AssetType +from typing import Any, Optional # 3rd party imports from pydantic import BaseModel +from cohere.compass.models.documents import AssetType + class AssetInfo(BaseModel): + """Information about an asset.""" + asset_type: AssetType content_type: str presigned_url: str class RetrievedChunk(BaseModel): + """Chunk of a document retrieved from search.""" + chunk_id: str sort_id: int parent_document_id: str - content: Dict[str, Any] - origin: Optional[Dict[str, Any]] = None + content: dict[str, Any] + origin: Optional[dict[str, Any]] = None assets_info: Optional[list[AssetInfo]] = None score: float class RetrievedDocument(BaseModel): + """Document retrieved from search.""" + document_id: str path: str parent_document_id: str - content: Dict[str, Any] - index_fields: Optional[List[str]] = None - authorized_groups: Optional[List[str]] = None - chunks: List[RetrievedChunk] + content: dict[str, Any] + index_fields: Optional[list[str]] = None + authorized_groups: Optional[list[str]] = None + chunks: list[RetrievedChunk] score: float class RetrievedChunkExtended(RetrievedChunk): + """Additional information about a chunk retrieved from search.""" + document_id: str path: str - index_fields: Optional[List[str]] = None + index_fields: Optional[list[str]] = None class SearchDocumentsResponse(BaseModel): - hits: List[RetrievedDocument] + """Response object for search_documents API.""" + + hits: list[RetrievedDocument] class SearchChunksResponse(BaseModel): - hits: List[RetrievedChunkExtended] + """Response object for search_chunks API.""" + + hits: list[RetrievedChunkExtended] class SearchFilter(BaseModel): + """Filter to apply on search results.""" + class FilterType(str, Enum): + """Types of filters supported.""" + EQ = "$eq" LT_EQ = "$lte" GT_EQ = "$gte" @@ -61,10 +78,8 @@ class FilterType(str, Enum): class SearchInput(BaseModel): - """ - Search query input - """ + """Input to search APIs.""" query: str top_k: int - filters: Optional[List[SearchFilter]] = None + filters: Optional[list[SearchFilter]] = None diff --git a/cohere/compass/utils.py b/cohere/compass/utils.py index 9d8d9d0..371b258 100644 --- a/cohere/compass/utils.py +++ b/cohere/compass/utils.py @@ -1,14 +1,18 @@ +# Python imports import base64 import glob import os import uuid +from collections.abc import Iterable, Iterator from concurrent import futures from concurrent.futures import Executor -from typing import Callable, Iterable, Iterator, List, Optional, TypeVar +from typing import Callable, Optional, TypeVar +# 3rd party imports import fsspec # type: ignore from fsspec import AbstractFileSystem # type: ignore +# Local imports from cohere.compass.constants import UUID_NAMESPACE from cohere.compass.models import ( CompassDocument, @@ -23,6 +27,16 @@ def imap_queued( executor: Executor, f: Callable[[T], U], it: Iterable[T], max_queued: int ) -> Iterator[U]: + """ + Similar to Python's `map`, but uses an executor to parallelize the calls. + + :param executor: the executor to use. + :param f: the function to call. + :param it: the iterable to map over. + :param max_queued: the maximum number of futures to keep in flight. + + :returns: an iterator over the results. + """ assert max_queued >= 1 futures_set: set[futures.Future[U]] = set() @@ -41,9 +55,11 @@ def imap_queued( def get_fs(document_path: str) -> AbstractFileSystem: """ - Get the filesystem object for the given document path + Get an fsspec's filesystem object for the given document path. + :param document_path: the path to the document - :return: the filesystem object + + :returns: the filesystem object. """ if document_path.find("://") >= 0: file_system = document_path.split("://")[0] @@ -55,9 +71,11 @@ def get_fs(document_path: str) -> AbstractFileSystem: def open_document(document_path: str) -> CompassDocument: """ - Opens a document regardless of the file system (local, GCS, S3, etc.) and returns a file-like object + Open the document at the given path and return a CompassDocument object. + :param document_path: the path to the document - :return: a file-like object + + :returns: a CompassDocument object. """ doc = CompassDocument(metadata=CompassDocumentMetadata(filename=document_path)) try: @@ -75,15 +93,19 @@ def open_document(document_path: str) -> CompassDocument: def scan_folder( folder_path: str, - allowed_extensions: Optional[List[str]] = None, + allowed_extensions: Optional[list[str]] = None, recursive: bool = False, -) -> List[str]: +) -> list[str]: """ - Scans a folder for files with the given extensions - :param folder_path: the path to the folder - :param allowed_extensions: the allowed extensions - :param recursive: whether to scan the folder recursively or to only scan the top level - :return: a list of file paths + Scan a folder for files with the given extensions. + + :param folder_path: the path of the folder to scan. + :param allowed_extensions: the extensions to look for. If None, all files will be + considered. + :param recursive: whether to scan the folder recursively or to stick to the top + level. + + :returns: A list of file paths. """ fs = get_fs(folder_path) all_files: list[str] = [] @@ -107,6 +129,16 @@ def scan_folder( def generate_doc_id_from_bytes(filebytes: bytes) -> uuid.UUID: + """ + Generate a UUID based on the provided file bytes. + + This function encodes the given file bytes into a base64 string and then generates a + UUID using the uuid5 method with a predefined namespace. + + :param filebytes: The bytes of the file to generate the UUID from. + + :returns: The generated UUID based on the file bytes. + """ b64_string = base64.b64encode(filebytes).decode("utf-8") namespace = uuid.UUID(UUID_NAMESPACE) return uuid.uuid5(namespace, b64_string) diff --git a/pyproject.toml b/pyproject.toml index 7800788..26b5158 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,10 @@ [tool.poetry] name = "compass-sdk" -version = "0.10.1" +version = "0.10.2" authors = [] description = "Compass SDK" readme = "README.md" -packages = [{include = "cohere"}] +packages = [{ include = "cohere" }] [tool.poetry.dependencies] fsspec = ">=2024.9.0" @@ -34,3 +34,26 @@ build-backend = "poetry.core.masonry.api" [tool.ruff] line-length = 88 +target-version = "py39" + +[tool.ruff.lint] +extend-select = [ + "C90", # mccabe (for code complexity) + "D", # pydocstyle + "E", # pycodestyle errors + "W", # pycodestyle warnings + "I", # isort + "Q", # flakes8-quotes + "RUF", # Ruff-specific + "UP", # pyupgrade +] +ignore = [ + "D100", # ignore missing docstring in module + "D104", # ignore missing docstring in public package + "D212", # ignore multi-line docstring summary should start at the first line +] +isort = { known-first-party = ["cohere.compass"] } +mccabe = { max-complexity = 15 } + +[tool.ruff.lint.per-file-ignores] +"tests/**/*.py" = ["D"]