Skip to content

Commit

Permalink
Adding rbac to sdk and edit auth
Browse files Browse the repository at this point in the history
  • Loading branch information
ankush-cohere committed Oct 9, 2024
1 parent 6f26265 commit 6c361de
Show file tree
Hide file tree
Showing 4 changed files with 284 additions and 1 deletion.
12 changes: 12 additions & 0 deletions compass_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ class PutDocumentsInput(BaseModel):
"""

docs: List[Document]
authorized_groups: Optional[List[str]] = None


class BatchPutDocumentsInput(BaseModel):
Expand All @@ -394,3 +395,14 @@ class ProcessFilesParameters(ValidatedModel):
class BatchProcessFilesParameters(ProcessFilesParameters):
uuid: str
file_name_to_doc_ids: Optional[Dict[str, str]] = None


class GroupAuthorizationActions(Enum):
ADD = "add"
REMOVE = "remove"


class GroupAuthorizationInput(BaseModel):
doc_ids: List[str]
authorized_groups: List[str]
action: GroupAuthorizationActions
23 changes: 22 additions & 1 deletion compass_sdk/compass.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
CompassDocumentStatus,
CompassSdkStage,
Document,
GroupAuthorizationInput,
LoggerLevel,
ParseableDocument,
PushDocumentsInput,
Expand Down Expand Up @@ -102,6 +103,7 @@ def __init__(
"add_context": self.session.post,
"refresh": self.session.post,
"push_documents": self.session.post,
"edit_group_authorization": self.session.post,
}
self.function_endpoint = {
"create_index": "/api/v1/indexes/{index_name}",
Expand All @@ -115,6 +117,7 @@ def __init__(
"add_context": "/api/v1/indexes/{index_name}/documents/add_context/{doc_id}",
"refresh": "/api/v1/indexes/{index_name}/refresh",
"push_documents": "/api/v2/indexes/{index_name}/documents",
"edit_group_authorization": "api/v1/indexes/{index_name}/group_authorization",
}
logger.setLevel(logger_level.value)

Expand Down Expand Up @@ -339,6 +342,7 @@ def insert_docs(
errors_sliding_window_size: Optional[int] = 10,
skip_first_n_docs: int = 0,
num_jobs: Optional[int] = None,
authorized_groups: Optional[List[str]] = None,
) -> Optional[List[CompassDocument]]:
"""
Insert multiple parsed documents into an index in Compass
Expand All @@ -351,6 +355,7 @@ def insert_docs(
:param sleep_retry_seconds: the number of seconds to wait before retrying an API request
:param errors_sliding_window_size: the size of the sliding window to keep track of errors
:param skip_first_n_docs: number of docs to skip indexing. Useful when insertion failed after N documents
:param authorized_groups: the groups that are authorized to access the documents. These groups should exist in RBAC. None passed will make the documents public
"""

def put_request(
Expand All @@ -361,7 +366,9 @@ def put_request(
nonlocal num_succeeded, errors
errors.extend(previous_errors)
compass_docs: List[CompassDocument] = [compass_doc for compass_doc, _ in request_data]
put_docs_input = PutDocumentsInput(docs=[input_doc for _, input_doc in request_data])
put_docs_input = PutDocumentsInput(
docs=[input_doc for _, input_doc in request_data], authorized_groups=authorized_groups
)

# It could be that all documents have errors, in which case we should not send a request
# to the Compass Server. This is a common case when the parsing of the documents fails.
Expand Down Expand Up @@ -473,6 +480,20 @@ def search(
sleep_retry_seconds=1,
)

def edit_group_authorization(self, *, index_name: str, group_auth_input: GroupAuthorizationInput):
"""
Edit group authorization for an index
:param index_name: the name of the index
:param group_auth_input: the group authorization input
"""
return self._send_request(
function="edit_group_authorization",
index_name=index_name,
data=group_auth_input,
max_retries=DEFAULT_MAX_RETRIES,
sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS,
)

def _send_request(
self,
index_name: str,
Expand Down
158 changes: 158 additions & 0 deletions compass_sdk/rbac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import json
from typing import Dict, List, Type, TypeVar

import requests
from pydantic import BaseModel
from requests import HTTPError

from compass_sdk.types import (
GroupCreateRequest,
GroupCreateResponse,
GroupFetchResponse,
GroupUserDeleteResponse,
PolicyRequest,
RoleCreateRequest,
RoleCreateResponse,
RoleDeleteResponse,
RoleFetchResponse,
RoleMappingDeleteResponse,
RoleMappingRequest,
RoleMappingResponse,
UserCreateRequest,
UserCreateResponse,
UserDeleteResponse,
UserFetchResponse,
)


class CompassRootClient:
def __init__(self, compass_url: str, root_user_token: str):
self.base_url = compass_url + "/security/admin/rbac"
self.headers = {"Authorization": f"Bearer {root_user_token}", "Content-Type": "application/json"}

T = TypeVar("T", bound=BaseModel)
U = TypeVar("U", bound=BaseModel)
Headers = Dict[str, str]

@staticmethod
def _fetch_entities(url: str, headers: Headers, entity_type: Type[T]) -> List[T]:
response = requests.get(url, headers=headers)
CompassRootClient.raise_for_status(response)
return [entity_type.model_validate(entity) for entity in response.json()]

@staticmethod
def _create_entities(url: str, headers: Headers, entity_request: List[T], entity_response: Type[U]) -> List[U]:
response = requests.post(
url,
json=[json.loads(entity.model_dump_json()) for entity in entity_request],
headers=headers,
)
CompassRootClient.raise_for_status(response)
return [entity_response.model_validate(response) for response in response.json()]

@staticmethod
def _delete_entities(url: str, headers: Headers, names: List[str], entity_response: Type[U]) -> List[U]:
entities = ",".join(names)
response = requests.delete(f"{url}/{entities}", headers=headers)
CompassRootClient.raise_for_status(response)
return [entity_response.model_validate(entity) for entity in response.json()]

def fetch_users(self) -> List[UserFetchResponse]:
return self._fetch_entities(f"{self.base_url}/v1/users", self.headers, UserFetchResponse)

def fetch_groups(self) -> List[GroupFetchResponse]:
return self._fetch_entities(f"{self.base_url}/v1/groups", self.headers, GroupFetchResponse)

def fetch_roles(self) -> List[RoleFetchResponse]:
return self._fetch_entities(f"{self.base_url}/v1/roles", self.headers, RoleFetchResponse)

def fetch_role_mappings(self) -> List[RoleMappingResponse]:
return self._fetch_entities(f"{self.base_url}/v1/role-mappings", self.headers, RoleMappingResponse)

def create_users(self, *, users: List[UserCreateRequest]) -> List[UserCreateResponse]:
return self._create_entities(
url=f"{self.base_url}/v1/users",
headers=self.headers,
entity_request=users,
entity_response=UserCreateResponse,
)

def create_groups(self, *, groups: List[GroupCreateRequest]) -> List[GroupCreateResponse]:
return self._create_entities(
url=f"{self.base_url}/v1/groups",
headers=self.headers,
entity_request=groups,
entity_response=GroupCreateResponse,
)

def create_roles(self, *, roles: List[RoleCreateRequest]) -> List[RoleCreateResponse]:
return self._create_entities(
url=f"{self.base_url}/v1/roles",
headers=self.headers,
entity_request=roles,
entity_response=RoleCreateResponse,
)

def create_role_mappings(self, *, role_mappings: List[RoleMappingRequest]) -> List[RoleMappingResponse]:
return self._create_entities(
url=f"{self.base_url}/v1/role-mappings",
headers=self.headers,
entity_request=role_mappings,
entity_response=RoleMappingResponse,
)

def delete_users(self, *, user_names: List[str]) -> List[UserDeleteResponse]:
return self._delete_entities(f"{self.base_url}/v1/users", self.headers, user_names, UserDeleteResponse)

def delete_groups(self, *, group_names: List[str]) -> List[GroupUserDeleteResponse]:
return self._delete_entities(f"{self.base_url}/v1/groups", self.headers, group_names, GroupUserDeleteResponse)

def delete_roles(self, *, role_ids: List[str]) -> List[RoleDeleteResponse]:
return self._delete_entities(f"{self.base_url}/v1/roles", self.headers, role_ids, RoleDeleteResponse)

def delete_role_mappings(self, *, role_name: str, group_name: str) -> List[RoleMappingDeleteResponse]:
response = requests.delete(
f"{self.base_url}/v1/role-mappings/role/{role_name}/group/{group_name}", headers=self.headers
)
self.raise_for_status(response)
return [RoleMappingDeleteResponse.model_validate(role_mapping) for role_mapping in response.json()]

def delete_user_group(self, *, group_name: str, user_name: str) -> GroupUserDeleteResponse:
response = requests.delete(f"{self.base_url}/v1/group/{group_name}/user/{user_name}", headers=self.headers)
self.raise_for_status(response)
return GroupUserDeleteResponse.model_validate(response.json())

def update_role(self, *, role_name: str, policies: List[PolicyRequest]) -> RoleCreateResponse:
response = requests.put(
f"{self.base_url}/v1/roles/{role_name}",
json=[json.loads(policy.model_dump_json()) for policy in policies],
headers=self.headers,
)
self.raise_for_status(response)
return RoleCreateResponse.model_validate(response.json())

@staticmethod
def raise_for_status(response: requests.Response):
"""Raises :class:`HTTPError`, if one occurred."""

http_error_msg = ""
if isinstance(response.reason, bytes):
# We attempt to decode utf-8 first because some servers
# choose to localize their reason strings. If the string
# isn't utf-8, we fall back to iso-8859-1 for all other
# encodings. (See PR #3538)
try:
reason = response.reason.decode("utf-8")
except UnicodeDecodeError:
reason = response.reason.decode("iso-8859-1")
else:
reason = response.content

if 400 <= response.status_code < 500:
http_error_msg = f"{response.status_code} Client Error: {reason} for url: {response.url}"

elif 500 <= response.status_code < 600:
http_error_msg = f"{response.status_code} Server Error: {reason} for url: {response.url}"

if http_error_msg:
raise HTTPError(http_error_msg, response=response)
92 changes: 92 additions & 0 deletions compass_sdk/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from enum import Enum
from typing import List

from pydantic import BaseModel


class UserFetchResponse(BaseModel):
name: str


class UserCreateRequest(BaseModel):
name: str


class UserCreateResponse(BaseModel):
name: str
token: str


class UserDeleteResponse(BaseModel):
name: str


class GroupFetchResponse(BaseModel):
name: str
user_name: str


class GroupCreateRequest(BaseModel):
name: str
user_names: List[str]


class GroupCreateResponse(BaseModel):
name: str
user_name: str


class GroupUserDeleteResponse(BaseModel):
group_name: str
user_name: str


class Permission(Enum):
READ = "read"
WRITE = "write"
ADMIN = "admin"
ROOT = "root"


class PolicyRequest(BaseModel):
indexes: List[str]
permission: Permission


class PolicyResponse(BaseModel):
indexes: List[str]
permission: str


class RoleFetchResponse(BaseModel):
name: str
policies: List[PolicyResponse]


class RoleCreateRequest(BaseModel):
name: str
policies: List[PolicyRequest]


class RoleCreateResponse(BaseModel):
name: str
policies: List[PolicyResponse]


class RoleDeleteResponse(BaseModel):
name: str


class RoleMappingRequest(BaseModel):
role_name: str
group_name: str


class RoleMappingResponse(BaseModel):
role_name: str
group_name: str


class RoleMappingDeleteResponse(BaseModel):
role_name: str
group_name: str

0 comments on commit 6c361de

Please sign in to comment.