Skip to content

Commit

Permalink
Add typing for update_group_authorization API
Browse files Browse the repository at this point in the history
- The API used to return the RetryResult object. Now it returns a strong
  type.
- It used to be called `edit_group_authorization`. I renamed to
  `update_group_authorization` to match the API name.
  • Loading branch information
corafid committed Dec 18, 2024
1 parent 6107663 commit f8e98c3
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 12 deletions.
4 changes: 2 additions & 2 deletions cohere/compass/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ class ProcessFilesParameters(ValidatedModel):


class GroupAuthorizationActions(str, Enum):
"""Enum for use with the edit_group_authorization API to specify the edit type."""
"""Enum for use with the update_group_authorization API to specify the edit type."""

ADD = "add"
REMOVE = "remove"


class GroupAuthorizationInput(BaseModel):
"""Model for use with the edit_group_authorization API."""
"""Model for use with the update_group_authorization API."""

document_ids: list[str]
authorized_groups: list[str]
Expand Down
20 changes: 12 additions & 8 deletions cohere/compass/clients/compass.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from cohere.compass.exceptions import (
CompassAuthError,
CompassClientError,
CompassError,
CompassMaxErrorRateExceeded,
)
from cohere.compass.models import (
Expand All @@ -59,7 +60,7 @@
UploadDocumentsInput,
)
from cohere.compass.models.datasources import PaginatedList
from cohere.compass.models.documents import DocumentAttributes
from cohere.compass.models.documents import DocumentAttributes, PutDocumentsResponse


@dataclass
Expand Down Expand Up @@ -118,7 +119,7 @@ def __init__(
"add_attributes": self.session.post,
"refresh": self.session.post,
"upload_documents": self.session.post,
"edit_group_authorization": self.session.post,
"update_group_authorization": self.session.post,
# Data Sources APIs
"create_datasource": self.session.post,
"list_datasources": self.session.get,
Expand All @@ -139,7 +140,7 @@ def __init__(
"add_attributes": "/api/v1/indexes/{index_name}/documents/{document_id}/_add_attributes", # noqa: E501
"refresh": "/api/v1/indexes/{index_name}/_refresh",
"upload_documents": "/api/v1/indexes/{index_name}/documents/_upload",
"edit_group_authorization": "/api/v1/indexes/{index_name}/group_authorization", # noqa: E501
"update_group_authorization": "/api/v1/indexes/{index_name}/group_authorization", # noqa: E501
# Data Sources APIs
"create_datasource": "/api/v1/datasources",
"list_datasources": "/api/v1/datasources",
Expand All @@ -163,7 +164,7 @@ def create_index(self, *, index_name: str):
index_name=index_name,
)

def refresh(self, *, index_name: str):
def refresh_index(self, *, index_name: str):
"""
Refresh index.
Expand Down Expand Up @@ -744,22 +745,25 @@ def search_chunks(

return SearchChunksResponse.model_validate(result.result)

def edit_group_authorization(
def update_group_authorization(
self, *, index_name: str, group_auth_input: GroupAuthorizationInput
):
) -> PutDocumentsResponse:
"""
Edit group authorization for an index.
:param index_name: the name of the index
:param group_auth_input: the group authorization input
"""
return self._send_request(
api_name="edit_group_authorization",
result = self._send_request(
api_name="update_group_authorization",
index_name=index_name,
data=group_auth_input,
max_retries=DEFAULT_MAX_RETRIES,
sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS,
)
if result.error:
raise CompassError(result.error)
return PutDocumentsResponse.model_validate(result.result)

def _send_request(
self,
Expand Down
8 changes: 7 additions & 1 deletion cohere/compass/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
class CompassClientError(Exception):
class CompassError(Exception):
"""Base class for all exceptions raised by the Compass client."""

pass


class CompassClientError(CompassError):
"""Exception raised for all 4xx client errors in the Compass client."""

def __init__( # noqa: D107
Expand Down
17 changes: 17 additions & 0 deletions cohere/compass/models/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,20 @@ class PutDocumentsInput(BaseModel):
documents: list[Document]
authorized_groups: Optional[list[str]] = None
merge_groups_on_conflict: bool = False


class PutDocumentResult(BaseModel):
"""
A model for the response of put_document.
This model is also used by the put_documents and edit_group_authorization APIs.
"""

document_id: str
error: Optional[str]


class PutDocumentsResponse(BaseModel):
"""A model for the response of put_documents and edit_group_authorization APIs."""

results: list[PutDocumentResult]
2 changes: 1 addition & 1 deletion tests/test_compass_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_get_documents_is_valid(requests_mock: Mocker):

def test_refresh_is_valid(requests_mock: Mocker):
compass = CompassClient(index_url="http://test.com")
compass.refresh(index_name="test_index")
compass.refresh_index(index_name="test_index")
assert requests_mock.request_history[0].method == "POST"
assert (
requests_mock.request_history[0].url
Expand Down

0 comments on commit f8e98c3

Please sign in to comment.