From 6c361de254a7985b7946814ca16bf9f15fb4e6ae Mon Sep 17 00:00:00 2001 From: Ankush Khanna Date: Fri, 4 Oct 2024 18:46:15 +0200 Subject: [PATCH] Adding rbac to sdk and edit auth --- compass_sdk/__init__.py | 12 +++ compass_sdk/compass.py | 23 +++++- compass_sdk/rbac.py | 158 ++++++++++++++++++++++++++++++++++++++++ compass_sdk/types.py | 92 +++++++++++++++++++++++ 4 files changed, 284 insertions(+), 1 deletion(-) create mode 100644 compass_sdk/rbac.py create mode 100644 compass_sdk/types.py diff --git a/compass_sdk/__init__.py b/compass_sdk/__init__.py index d8c571f..ed2e6f7 100644 --- a/compass_sdk/__init__.py +++ b/compass_sdk/__init__.py @@ -371,6 +371,7 @@ class PutDocumentsInput(BaseModel): """ docs: List[Document] + authorized_groups: Optional[List[str]] = None class BatchPutDocumentsInput(BaseModel): @@ -394,3 +395,14 @@ class ProcessFilesParameters(ValidatedModel): class BatchProcessFilesParameters(ProcessFilesParameters): uuid: str file_name_to_doc_ids: Optional[Dict[str, str]] = None + + +class GroupAuthorizationActions(Enum): + ADD = "add" + REMOVE = "remove" + + +class GroupAuthorizationInput(BaseModel): + doc_ids: List[str] + authorized_groups: List[str] + action: GroupAuthorizationActions diff --git a/compass_sdk/compass.py b/compass_sdk/compass.py index 0c4bcdb..902c963 100644 --- a/compass_sdk/compass.py +++ b/compass_sdk/compass.py @@ -21,6 +21,7 @@ CompassDocumentStatus, CompassSdkStage, Document, + GroupAuthorizationInput, LoggerLevel, ParseableDocument, PushDocumentsInput, @@ -102,6 +103,7 @@ def __init__( "add_context": self.session.post, "refresh": self.session.post, "push_documents": self.session.post, + "edit_group_authorization": self.session.post, } self.function_endpoint = { "create_index": "/api/v1/indexes/{index_name}", @@ -115,6 +117,7 @@ def __init__( "add_context": "/api/v1/indexes/{index_name}/documents/add_context/{doc_id}", "refresh": "/api/v1/indexes/{index_name}/refresh", "push_documents": "/api/v2/indexes/{index_name}/documents", + "edit_group_authorization": "api/v1/indexes/{index_name}/group_authorization", } logger.setLevel(logger_level.value) @@ -339,6 +342,7 @@ def insert_docs( errors_sliding_window_size: Optional[int] = 10, skip_first_n_docs: int = 0, num_jobs: Optional[int] = None, + authorized_groups: Optional[List[str]] = None, ) -> Optional[List[CompassDocument]]: """ Insert multiple parsed documents into an index in Compass @@ -351,6 +355,7 @@ def insert_docs( :param sleep_retry_seconds: the number of seconds to wait before retrying an API request :param errors_sliding_window_size: the size of the sliding window to keep track of errors :param skip_first_n_docs: number of docs to skip indexing. Useful when insertion failed after N documents + :param authorized_groups: the groups that are authorized to access the documents. These groups should exist in RBAC. None passed will make the documents public """ def put_request( @@ -361,7 +366,9 @@ def put_request( nonlocal num_succeeded, errors errors.extend(previous_errors) compass_docs: List[CompassDocument] = [compass_doc for compass_doc, _ in request_data] - put_docs_input = PutDocumentsInput(docs=[input_doc for _, input_doc in request_data]) + put_docs_input = PutDocumentsInput( + docs=[input_doc for _, input_doc in request_data], authorized_groups=authorized_groups + ) # It could be that all documents have errors, in which case we should not send a request # to the Compass Server. This is a common case when the parsing of the documents fails. @@ -473,6 +480,20 @@ def search( sleep_retry_seconds=1, ) + def edit_group_authorization(self, *, index_name: str, group_auth_input: GroupAuthorizationInput): + """ + Edit group authorization for an index + :param index_name: the name of the index + :param group_auth_input: the group authorization input + """ + return self._send_request( + function="edit_group_authorization", + index_name=index_name, + data=group_auth_input, + max_retries=DEFAULT_MAX_RETRIES, + sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS, + ) + def _send_request( self, index_name: str, diff --git a/compass_sdk/rbac.py b/compass_sdk/rbac.py new file mode 100644 index 0000000..1976dd3 --- /dev/null +++ b/compass_sdk/rbac.py @@ -0,0 +1,158 @@ +import json +from typing import Dict, List, Type, TypeVar + +import requests +from pydantic import BaseModel +from requests import HTTPError + +from compass_sdk.types import ( + GroupCreateRequest, + GroupCreateResponse, + GroupFetchResponse, + GroupUserDeleteResponse, + PolicyRequest, + RoleCreateRequest, + RoleCreateResponse, + RoleDeleteResponse, + RoleFetchResponse, + RoleMappingDeleteResponse, + RoleMappingRequest, + RoleMappingResponse, + UserCreateRequest, + UserCreateResponse, + UserDeleteResponse, + UserFetchResponse, +) + + +class CompassRootClient: + def __init__(self, compass_url: str, root_user_token: str): + self.base_url = compass_url + "/security/admin/rbac" + self.headers = {"Authorization": f"Bearer {root_user_token}", "Content-Type": "application/json"} + + T = TypeVar("T", bound=BaseModel) + U = TypeVar("U", bound=BaseModel) + Headers = Dict[str, str] + + @staticmethod + def _fetch_entities(url: str, headers: Headers, entity_type: Type[T]) -> List[T]: + response = requests.get(url, headers=headers) + CompassRootClient.raise_for_status(response) + return [entity_type.model_validate(entity) for entity in response.json()] + + @staticmethod + def _create_entities(url: str, headers: Headers, entity_request: List[T], entity_response: Type[U]) -> List[U]: + response = requests.post( + url, + json=[json.loads(entity.model_dump_json()) for entity in entity_request], + headers=headers, + ) + CompassRootClient.raise_for_status(response) + return [entity_response.model_validate(response) for response in response.json()] + + @staticmethod + def _delete_entities(url: str, headers: Headers, names: List[str], entity_response: Type[U]) -> List[U]: + entities = ",".join(names) + response = requests.delete(f"{url}/{entities}", headers=headers) + CompassRootClient.raise_for_status(response) + return [entity_response.model_validate(entity) for entity in response.json()] + + def fetch_users(self) -> List[UserFetchResponse]: + return self._fetch_entities(f"{self.base_url}/v1/users", self.headers, UserFetchResponse) + + def fetch_groups(self) -> List[GroupFetchResponse]: + return self._fetch_entities(f"{self.base_url}/v1/groups", self.headers, GroupFetchResponse) + + def fetch_roles(self) -> List[RoleFetchResponse]: + return self._fetch_entities(f"{self.base_url}/v1/roles", self.headers, RoleFetchResponse) + + def fetch_role_mappings(self) -> List[RoleMappingResponse]: + return self._fetch_entities(f"{self.base_url}/v1/role-mappings", self.headers, RoleMappingResponse) + + def create_users(self, *, users: List[UserCreateRequest]) -> List[UserCreateResponse]: + return self._create_entities( + url=f"{self.base_url}/v1/users", + headers=self.headers, + entity_request=users, + entity_response=UserCreateResponse, + ) + + def create_groups(self, *, groups: List[GroupCreateRequest]) -> List[GroupCreateResponse]: + return self._create_entities( + url=f"{self.base_url}/v1/groups", + headers=self.headers, + entity_request=groups, + entity_response=GroupCreateResponse, + ) + + def create_roles(self, *, roles: List[RoleCreateRequest]) -> List[RoleCreateResponse]: + return self._create_entities( + url=f"{self.base_url}/v1/roles", + headers=self.headers, + entity_request=roles, + entity_response=RoleCreateResponse, + ) + + def create_role_mappings(self, *, role_mappings: List[RoleMappingRequest]) -> List[RoleMappingResponse]: + return self._create_entities( + url=f"{self.base_url}/v1/role-mappings", + headers=self.headers, + entity_request=role_mappings, + entity_response=RoleMappingResponse, + ) + + def delete_users(self, *, user_names: List[str]) -> List[UserDeleteResponse]: + return self._delete_entities(f"{self.base_url}/v1/users", self.headers, user_names, UserDeleteResponse) + + def delete_groups(self, *, group_names: List[str]) -> List[GroupUserDeleteResponse]: + return self._delete_entities(f"{self.base_url}/v1/groups", self.headers, group_names, GroupUserDeleteResponse) + + def delete_roles(self, *, role_ids: List[str]) -> List[RoleDeleteResponse]: + return self._delete_entities(f"{self.base_url}/v1/roles", self.headers, role_ids, RoleDeleteResponse) + + def delete_role_mappings(self, *, role_name: str, group_name: str) -> List[RoleMappingDeleteResponse]: + response = requests.delete( + f"{self.base_url}/v1/role-mappings/role/{role_name}/group/{group_name}", headers=self.headers + ) + self.raise_for_status(response) + return [RoleMappingDeleteResponse.model_validate(role_mapping) for role_mapping in response.json()] + + def delete_user_group(self, *, group_name: str, user_name: str) -> GroupUserDeleteResponse: + response = requests.delete(f"{self.base_url}/v1/group/{group_name}/user/{user_name}", headers=self.headers) + self.raise_for_status(response) + return GroupUserDeleteResponse.model_validate(response.json()) + + def update_role(self, *, role_name: str, policies: List[PolicyRequest]) -> RoleCreateResponse: + response = requests.put( + f"{self.base_url}/v1/roles/{role_name}", + json=[json.loads(policy.model_dump_json()) for policy in policies], + headers=self.headers, + ) + self.raise_for_status(response) + return RoleCreateResponse.model_validate(response.json()) + + @staticmethod + def raise_for_status(response: requests.Response): + """Raises :class:`HTTPError`, if one occurred.""" + + http_error_msg = "" + if isinstance(response.reason, bytes): + # We attempt to decode utf-8 first because some servers + # choose to localize their reason strings. If the string + # isn't utf-8, we fall back to iso-8859-1 for all other + # encodings. (See PR #3538) + try: + reason = response.reason.decode("utf-8") + except UnicodeDecodeError: + reason = response.reason.decode("iso-8859-1") + else: + reason = response.content + + if 400 <= response.status_code < 500: + http_error_msg = f"{response.status_code} Client Error: {reason} for url: {response.url}" + + elif 500 <= response.status_code < 600: + http_error_msg = f"{response.status_code} Server Error: {reason} for url: {response.url}" + + if http_error_msg: + raise HTTPError(http_error_msg, response=response) diff --git a/compass_sdk/types.py b/compass_sdk/types.py new file mode 100644 index 0000000..99a1abd --- /dev/null +++ b/compass_sdk/types.py @@ -0,0 +1,92 @@ +from enum import Enum +from typing import List + +from pydantic import BaseModel + + +class UserFetchResponse(BaseModel): + name: str + + +class UserCreateRequest(BaseModel): + name: str + + +class UserCreateResponse(BaseModel): + name: str + token: str + + +class UserDeleteResponse(BaseModel): + name: str + + +class GroupFetchResponse(BaseModel): + name: str + user_name: str + + +class GroupCreateRequest(BaseModel): + name: str + user_names: List[str] + + +class GroupCreateResponse(BaseModel): + name: str + user_name: str + + +class GroupUserDeleteResponse(BaseModel): + group_name: str + user_name: str + + +class Permission(Enum): + READ = "read" + WRITE = "write" + ADMIN = "admin" + ROOT = "root" + + +class PolicyRequest(BaseModel): + indexes: List[str] + permission: Permission + + +class PolicyResponse(BaseModel): + indexes: List[str] + permission: str + + +class RoleFetchResponse(BaseModel): + name: str + policies: List[PolicyResponse] + + +class RoleCreateRequest(BaseModel): + name: str + policies: List[PolicyRequest] + + +class RoleCreateResponse(BaseModel): + name: str + policies: List[PolicyResponse] + + +class RoleDeleteResponse(BaseModel): + name: str + + +class RoleMappingRequest(BaseModel): + role_name: str + group_name: str + + +class RoleMappingResponse(BaseModel): + role_name: str + group_name: str + + +class RoleMappingDeleteResponse(BaseModel): + role_name: str + group_name: str