Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding rbac sdk path #17

Merged
merged 1 commit into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions compass_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ class PutDocumentsInput(BaseModel):
"""

docs: List[Document]
authorized_groups: Optional[List[str]] = None


class BatchPutDocumentsInput(BaseModel):
Expand All @@ -394,3 +395,14 @@ class ProcessFilesParameters(ValidatedModel):
class BatchProcessFilesParameters(ProcessFilesParameters):
uuid: str
file_name_to_doc_ids: Optional[Dict[str, str]] = None


class GroupAuthorizationActions(Enum):
ADD = "add"
REMOVE = "remove"


class GroupAuthorizationInput(BaseModel):
doc_ids: List[str]
authorized_groups: List[str]
action: GroupAuthorizationActions
23 changes: 22 additions & 1 deletion compass_sdk/compass.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
CompassDocumentStatus,
CompassSdkStage,
Document,
GroupAuthorizationInput,
LoggerLevel,
ParseableDocument,
PushDocumentsInput,
Expand Down Expand Up @@ -102,6 +103,7 @@ def __init__(
"add_context": self.session.post,
"refresh": self.session.post,
"push_documents": self.session.post,
"edit_group_authorization": self.session.post,
}
self.function_endpoint = {
"create_index": "/api/v1/indexes/{index_name}",
Expand All @@ -115,6 +117,7 @@ def __init__(
"add_context": "/api/v1/indexes/{index_name}/documents/add_context/{doc_id}",
"refresh": "/api/v1/indexes/{index_name}/refresh",
"push_documents": "/api/v2/indexes/{index_name}/documents",
"edit_group_authorization": "api/v1/indexes/{index_name}/group_authorization",
}
logger.setLevel(logger_level.value)

Expand Down Expand Up @@ -339,6 +342,7 @@ def insert_docs(
errors_sliding_window_size: Optional[int] = 10,
skip_first_n_docs: int = 0,
num_jobs: Optional[int] = None,
authorized_groups: Optional[List[str]] = None,
) -> Optional[List[CompassDocument]]:
"""
Insert multiple parsed documents into an index in Compass
Expand All @@ -351,6 +355,7 @@ def insert_docs(
:param sleep_retry_seconds: the number of seconds to wait before retrying an API request
:param errors_sliding_window_size: the size of the sliding window to keep track of errors
:param skip_first_n_docs: number of docs to skip indexing. Useful when insertion failed after N documents
:param authorized_groups: the groups that are authorized to access the documents. These groups should exist in RBAC. None passed will make the documents public
"""

def put_request(
Expand All @@ -361,7 +366,9 @@ def put_request(
nonlocal num_succeeded, errors
errors.extend(previous_errors)
compass_docs: List[CompassDocument] = [compass_doc for compass_doc, _ in request_data]
put_docs_input = PutDocumentsInput(docs=[input_doc for _, input_doc in request_data])
put_docs_input = PutDocumentsInput(
docs=[input_doc for _, input_doc in request_data], authorized_groups=authorized_groups
)

# It could be that all documents have errors, in which case we should not send a request
# to the Compass Server. This is a common case when the parsing of the documents fails.
Expand Down Expand Up @@ -473,6 +480,20 @@ def search(
sleep_retry_seconds=1,
)

def edit_group_authorization(self, *, index_name: str, group_auth_input: GroupAuthorizationInput):
ankush-cohere marked this conversation as resolved.
Show resolved Hide resolved
"""
Edit group authorization for an index
:param index_name: the name of the index
:param group_auth_input: the group authorization input
"""
return self._send_request(
function="edit_group_authorization",
index_name=index_name,
data=group_auth_input,
max_retries=DEFAULT_MAX_RETRIES,
sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS,
)

def _send_request(
self,
index_name: str,
Expand Down
158 changes: 158 additions & 0 deletions compass_sdk/rbac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import json
from typing import Dict, List, Type, TypeVar

import requests
from pydantic import BaseModel
from requests import HTTPError

from compass_sdk.types import (
GroupCreateRequest,
GroupCreateResponse,
GroupFetchResponse,
GroupUserDeleteResponse,
PolicyRequest,
RoleCreateRequest,
RoleCreateResponse,
RoleDeleteResponse,
RoleFetchResponse,
RoleMappingDeleteResponse,
RoleMappingRequest,
RoleMappingResponse,
UserCreateRequest,
UserCreateResponse,
UserDeleteResponse,
UserFetchResponse,
)


class CompassRootClient:
def __init__(self, compass_url: str, root_user_token: str):
self.base_url = compass_url + "/security/admin/rbac"
self.headers = {"Authorization": f"Bearer {root_user_token}", "Content-Type": "application/json"}

T = TypeVar("T", bound=BaseModel)
U = TypeVar("U", bound=BaseModel)
Headers = Dict[str, str]

@staticmethod
def _fetch_entities(url: str, headers: Headers, entity_type: Type[T]) -> List[T]:
response = requests.get(url, headers=headers)
CompassRootClient.raise_for_status(response)
return [entity_type.model_validate(entity) for entity in response.json()]

@staticmethod
def _create_entities(url: str, headers: Headers, entity_request: List[T], entity_response: Type[U]) -> List[U]:
response = requests.post(
url,
json=[json.loads(entity.model_dump_json()) for entity in entity_request],
headers=headers,
)
CompassRootClient.raise_for_status(response)
return [entity_response.model_validate(response) for response in response.json()]

@staticmethod
def _delete_entities(url: str, headers: Headers, names: List[str], entity_response: Type[U]) -> List[U]:
entities = ",".join(names)
response = requests.delete(f"{url}/{entities}", headers=headers)
CompassRootClient.raise_for_status(response)
return [entity_response.model_validate(entity) for entity in response.json()]

def fetch_users(self) -> List[UserFetchResponse]:
return self._fetch_entities(f"{self.base_url}/v1/users", self.headers, UserFetchResponse)

def fetch_groups(self) -> List[GroupFetchResponse]:
return self._fetch_entities(f"{self.base_url}/v1/groups", self.headers, GroupFetchResponse)

def fetch_roles(self) -> List[RoleFetchResponse]:
return self._fetch_entities(f"{self.base_url}/v1/roles", self.headers, RoleFetchResponse)

def fetch_role_mappings(self) -> List[RoleMappingResponse]:
return self._fetch_entities(f"{self.base_url}/v1/role-mappings", self.headers, RoleMappingResponse)

def create_users(self, *, users: List[UserCreateRequest]) -> List[UserCreateResponse]:
return self._create_entities(
url=f"{self.base_url}/v1/users",
headers=self.headers,
entity_request=users,
entity_response=UserCreateResponse,
)

def create_groups(self, *, groups: List[GroupCreateRequest]) -> List[GroupCreateResponse]:
return self._create_entities(
url=f"{self.base_url}/v1/groups",
headers=self.headers,
entity_request=groups,
entity_response=GroupCreateResponse,
)

def create_roles(self, *, roles: List[RoleCreateRequest]) -> List[RoleCreateResponse]:
return self._create_entities(
url=f"{self.base_url}/v1/roles",
headers=self.headers,
entity_request=roles,
entity_response=RoleCreateResponse,
)

def create_role_mappings(self, *, role_mappings: List[RoleMappingRequest]) -> List[RoleMappingResponse]:
return self._create_entities(
url=f"{self.base_url}/v1/role-mappings",
headers=self.headers,
entity_request=role_mappings,
entity_response=RoleMappingResponse,
)

def delete_users(self, *, user_names: List[str]) -> List[UserDeleteResponse]:
return self._delete_entities(f"{self.base_url}/v1/users", self.headers, user_names, UserDeleteResponse)

def delete_groups(self, *, group_names: List[str]) -> List[GroupUserDeleteResponse]:
return self._delete_entities(f"{self.base_url}/v1/groups", self.headers, group_names, GroupUserDeleteResponse)

def delete_roles(self, *, role_ids: List[str]) -> List[RoleDeleteResponse]:
return self._delete_entities(f"{self.base_url}/v1/roles", self.headers, role_ids, RoleDeleteResponse)

def delete_role_mappings(self, *, role_name: str, group_name: str) -> List[RoleMappingDeleteResponse]:
response = requests.delete(
f"{self.base_url}/v1/role-mappings/role/{role_name}/group/{group_name}", headers=self.headers
)
self.raise_for_status(response)
return [RoleMappingDeleteResponse.model_validate(role_mapping) for role_mapping in response.json()]

def delete_user_group(self, *, group_name: str, user_name: str) -> GroupUserDeleteResponse:
response = requests.delete(f"{self.base_url}/v1/group/{group_name}/user/{user_name}", headers=self.headers)
self.raise_for_status(response)
return GroupUserDeleteResponse.model_validate(response.json())

def update_role(self, *, role_name: str, policies: List[PolicyRequest]) -> RoleCreateResponse:
response = requests.put(
f"{self.base_url}/v1/roles/{role_name}",
json=[json.loads(policy.model_dump_json()) for policy in policies],
headers=self.headers,
)
self.raise_for_status(response)
return RoleCreateResponse.model_validate(response.json())

@staticmethod
def raise_for_status(response: requests.Response):
"""Raises :class:`HTTPError`, if one occurred."""

http_error_msg = ""
if isinstance(response.reason, bytes):
# We attempt to decode utf-8 first because some servers
# choose to localize their reason strings. If the string
# isn't utf-8, we fall back to iso-8859-1 for all other
# encodings. (See PR #3538)
try:
reason = response.reason.decode("utf-8")
except UnicodeDecodeError:
reason = response.reason.decode("iso-8859-1")
else:
reason = response.content

if 400 <= response.status_code < 500:
http_error_msg = f"{response.status_code} Client Error: {reason} for url: {response.url}"

elif 500 <= response.status_code < 600:
http_error_msg = f"{response.status_code} Server Error: {reason} for url: {response.url}"

if http_error_msg:
raise HTTPError(http_error_msg, response=response)
92 changes: 92 additions & 0 deletions compass_sdk/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from enum import Enum
from typing import List

from pydantic import BaseModel


class UserFetchResponse(BaseModel):
name: str


class UserCreateRequest(BaseModel):
name: str


class UserCreateResponse(BaseModel):
name: str
token: str


class UserDeleteResponse(BaseModel):
name: str


class GroupFetchResponse(BaseModel):
name: str
user_name: str


class GroupCreateRequest(BaseModel):
name: str
user_names: List[str]


class GroupCreateResponse(BaseModel):
name: str
user_name: str


class GroupUserDeleteResponse(BaseModel):
group_name: str
user_name: str


class Permission(Enum):
READ = "read"
WRITE = "write"
ADMIN = "admin"
ROOT = "root"


class PolicyRequest(BaseModel):
indexes: List[str]
permission: Permission


class PolicyResponse(BaseModel):
indexes: List[str]
permission: str
ankush-cohere marked this conversation as resolved.
Show resolved Hide resolved


class RoleFetchResponse(BaseModel):
name: str
policies: List[PolicyResponse]


class RoleCreateRequest(BaseModel):
name: str
policies: List[PolicyRequest]


class RoleCreateResponse(BaseModel):
name: str
policies: List[PolicyResponse]


class RoleDeleteResponse(BaseModel):
name: str


class RoleMappingRequest(BaseModel):
role_name: str
group_name: str


class RoleMappingResponse(BaseModel):
role_name: str
group_name: str


class RoleMappingDeleteResponse(BaseModel):
role_name: str
group_name: str
Loading