Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move modules under "cohere.compass" + More refactoring #56

Merged
merged 4 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
name: Formatting
name: Pre-commit Checks

on:
pull_request: {}
workflow_dispatch: {}

jobs:
build:
name: Formatting
name: Run pre-commit checks
runs-on: ubuntu-latest
strategy:
matrix:
python-version:
- 3.9
- "3.9"
- "3.10"
- "3.11"
- "3.12"
- "3.13"

steps:
- uses: actions/checkout@v4
Expand All @@ -21,10 +25,18 @@ jobs:
with:
python-version: ${{ matrix.python-version }}

- name: Upgrade pip & install requirements
- name: Install poetry
run: |
pip install poetry

- name: Install dependencies
run: |
poetry install

- name: Install pre-commit
run: |
pip install pre-commit

- name: Formatting
- name: Run pre-commit
run: |
pre-commit run --all-files
20 changes: 11 additions & 9 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ jobs:
strategy:
matrix:
python-version:
- 3.11
- "3.9"
- "3.10"
- "3.11"
- "3.12"
- "3.13"

steps:
- uses: actions/checkout@v4
Expand All @@ -47,17 +51,15 @@ jobs:
cache-dependency-path: |
poetry.lock

- name: Install dependencies (tests)
- name: Install poetry
run: |
pip install pytest pytest-asyncio pytest-mock requests-mock
pip install poetry

- name: Install dependencies
working-directory: .
- name: Install dependencies
run: |
pip install -e .
poetry install

- name: Run tests
- name: Run tests
working-directory: .
run: |
echo $COHERE_API_KEY
pytest -sv tests/test_compass_client.py
poetry run pytest -sv
35 changes: 0 additions & 35 deletions .github/workflows/typecheck.yml

This file was deleted.

4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ repos:
args: [--fix]
# Run the formatter.
- id: ruff-format
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.390
hooks:
- id: pyright
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Cohere Compass SDK

[![Checked with pyright](https://microsoft.github.io/pyright/img/pyright_badge.svg)](https://microsoft.github.io/pyright/)

The Compass SDK is a Python library that allows you to parse documents and insert them
into a Compass index.

Expand Down
4 changes: 2 additions & 2 deletions compass_sdk/__init__.py → cohere/compass/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from pydantic import BaseModel

# Local imports
from compass_sdk.models import (
from cohere.compass.models import (
MetadataConfig,
ParserConfig,
ValidatedModel,
)

__version__ = "0.7.0"
__version__ = "0.8.0"


class ProcessFileParameters(ValidatedModel):
Expand Down
3 changes: 3 additions & 0 deletions cohere/compass/clients/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from cohere.compass.clients.compass import * # noqa: F403
from cohere.compass.clients.parser import * # noqa: F403
from cohere.compass.clients.rbac import * # noqa: F403
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import uuid

# 3rd party imports
from joblib import Parallel, delayed
# TODO find stubs for joblib and remove "type: ignore"
from joblib import Parallel, delayed # type: ignore
from pydantic import BaseModel
from requests.exceptions import InvalidSchema
from tenacity import (
Expand All @@ -23,22 +24,22 @@
import requests

# Local imports
from compass_sdk import (
from cohere.compass import (
GroupAuthorizationInput,
)
from compass_sdk.constants import (
from cohere.compass.constants import (
DEFAULT_MAX_ACCEPTED_FILE_SIZE_BYTES,
DEFAULT_MAX_CHUNKS_PER_REQUEST,
DEFAULT_MAX_ERROR_RATE,
DEFAULT_MAX_RETRIES,
DEFAULT_SLEEP_RETRY_SECONDS,
)
from compass_sdk.exceptions import (
from cohere.compass.exceptions import (
CompassAuthError,
CompassClientError,
CompassMaxErrorRateExceeded,
)
from compass_sdk.models import (
from cohere.compass.models import (
Chunk,
CompassDocument,
CompassDocumentStatus,
Expand All @@ -60,7 +61,7 @@

@dataclass
class RetryResult:
result: Optional[dict] = None
result: Optional[dict[str, Any]] = None
error: Optional[str] = None


Expand All @@ -75,9 +76,9 @@ def __init__(self, timeout: int):
self._timeout = timeout
super().__init__()

def request(self, method, url, **kwargs):
def request(self, *args: Any, **kwargs: Any):
kwargs.setdefault("timeout", self._timeout)
return super().request(method, url, **kwargs)
return super().request(*args, **kwargs)


class CompassClient:
Expand Down Expand Up @@ -231,7 +232,7 @@ def add_context(
*,
index_name: str,
doc_id: str,
context: Dict,
context: dict[str, Any],
max_retries: int = DEFAULT_MAX_RETRIES,
sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS,
) -> Optional[RetryResult]:
Expand Down Expand Up @@ -291,7 +292,7 @@ def upload_document(
context: Dict[str, Any] = {},
max_retries: int = DEFAULT_MAX_RETRIES,
sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS,
) -> Optional[str | Dict]:
) -> Optional[Union[str, Dict[str, Any]]]:
"""
Parse and insert a document into an index in Compass
:param index_name: the name of the index
Expand Down Expand Up @@ -362,8 +363,8 @@ def insert_docs(
"""

def put_request(
request_data: List[Tuple[CompassDocument, Document]],
previous_errors: List[CompassDocument],
request_data: list[Tuple[CompassDocument, Document]],
previous_errors: list[dict[str, str]],
num_doc: int,
) -> None:
nonlocal num_succeeded, errors
Expand Down Expand Up @@ -420,11 +421,11 @@ def put_request(
f"in the last {errors_sliding_window_size} API calls. Stopping the insertion process."
)

error_window = deque(
error_window: deque[Optional[str]] = deque(
maxlen=errors_sliding_window_size
) # Keep track of the results of the last N API calls
num_succeeded = 0
errors = []
errors: list[dict[str, str]] = []
requests_iter = self._get_request_blocks(docs, max_chunks_per_request)

try:
Expand Down Expand Up @@ -556,17 +557,18 @@ def list_datasources_objects_states(
def _get_request_blocks(
docs: Iterator[CompassDocument],
max_chunks_per_request: int,
) -> Iterator:
):
"""
Create request blocks to send to the Compass API
:param docs: the documents to send
:param max_chunks_per_request: the maximum number of chunks to send in a single API request
:return: an iterator over the request blocks
"""

request_block, errors = [], []
request_block: list[tuple[CompassDocument, Document]] = []
errors: list[dict[str, str]] = []
num_chunks = 0
for num_doc, doc in enumerate(docs, 1):
for _, doc in enumerate(docs, 1):
if doc.status != CompassDocumentStatus.Success:
logger.error(f"Document {doc.metadata.doc_id} has errors: {doc.errors}")
for error in doc.errors:
Expand Down Expand Up @@ -679,7 +681,7 @@ def _send_request(
api_name: str,
max_retries: int,
sleep_retry_seconds: int,
data: Optional[Union[Dict, BaseModel]] = None,
data: Optional[Union[Dict[str, Any], BaseModel]] = None,
**url_params: str,
) -> RetryResult:
"""
Expand Down Expand Up @@ -710,11 +712,13 @@ def _send_request_with_retry():
if data:
if isinstance(data, BaseModel):
data_dict = data.model_dump(mode="json")
elif isinstance(data, Dict):
else:
data_dict = data

headers = None
auth = (self.username, self.password)
auth = None
if self.username and self.password:
auth = (self.username, self.password)
if self.bearer_token:
headers = {"Authorization": f"Bearer {self.bearer_token}"}
auth = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@
import requests

# Local imports
from compass_sdk import (
from cohere.compass import (
ProcessFileParameters,
)
from compass_sdk.constants import DEFAULT_MAX_ACCEPTED_FILE_SIZE_BYTES
from compass_sdk.models import (
from cohere.compass.constants import DEFAULT_MAX_ACCEPTED_FILE_SIZE_BYTES
from cohere.compass.models import (
CompassDocument,
MetadataConfig,
ParserConfig,
)
from compass_sdk.utils import imap_queued, open_document, scan_folder
from cohere.compass.utils import imap_queued, open_document, scan_folder

Fn_or_Dict = Union[Dict[str, Any], Callable[[CompassDocument], Dict[str, Any]]]

Expand Down Expand Up @@ -254,7 +254,7 @@ def process_file(
)

if res.ok:
docs = []
docs: list[CompassDocument] = []
for doc in res.json()["docs"]:
if not doc.get("errors", []):
compass_doc = CompassDocument(**doc)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydantic import BaseModel
from requests import HTTPError

from compass_sdk.models import (
from cohere.compass.models import (
GroupCreateRequest,
GroupCreateResponse,
GroupDeleteResponse,
Expand Down
File renamed without changes.
6 changes: 3 additions & 3 deletions compass_sdk/exceptions.py → cohere/compass/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
class CompassClientError(Exception):
"""Exception raised for all 4xx client errors in the Compass client."""

def __init__(self, message="Client error occurred."):
def __init__(self, message: str = "Client error occurred."):
self.message = message
super().__init__(self.message)

Expand All @@ -11,7 +11,7 @@ class CompassAuthError(CompassClientError):

def __init__(
self,
message=(
message: str = (
"CompassAuthError - check your bearer token or username and password."
),
):
Expand All @@ -25,7 +25,7 @@ class CompassMaxErrorRateExceeded(Exception):

def __init__(
self,
message="The maximum error rate was exceeded. Stopping the insertion process.",
message: str = "The maximum error rate was exceeded. Stopping the insertion process.",
):
self.message = message
super().__init__(self.message)
29 changes: 29 additions & 0 deletions cohere/compass/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Any

# import models into model package
from pydantic import BaseModel


class ValidatedModel(BaseModel):
class Config:
arbitrary_types_allowed = True
use_enum_values = True

@classmethod
def attribute_in_model(cls, attr_name: str):
return attr_name in cls.model_fields

def __init__(self, **data: dict[str, Any]):
for name, _value in data.items():
if not self.attribute_in_model(name):
raise ValueError(
f"{name} is not a valid attribute for {self.__class__.__name__}"
)
super().__init__(**data)


from cohere.compass.models.config import * # noqa: E402, F403
from cohere.compass.models.datasources import * # noqa: E402, F403
from cohere.compass.models.documents import * # noqa: E402, F403
from cohere.compass.models.rbac import * # noqa: E402, F403
from cohere.compass.models.search import * # noqa: E402, F403
Loading
Loading