diff --git a/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py b/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py index 4f7451faf76..61ceae4e463 100644 --- a/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py +++ b/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py @@ -29,7 +29,7 @@ from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs from ee.danswer.db.external_perm import ExternalUserGroup from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair -from ee.danswer.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIOD +from ee.danswer.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIODS from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP logger = setup_logger() @@ -66,9 +66,9 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool: if last_ext_group_sync is None: return True - source_sync_period = EXTERNAL_GROUP_SYNC_PERIOD + source_sync_period = EXTERNAL_GROUP_SYNC_PERIODS.get(cc_pair.connector.source) - # If EXTERNAL_GROUP_SYNC_PERIOD is None, we always run the sync. + # If EXTERNAL_GROUP_SYNC_PERIODS is None, we always run the sync. if not source_sync_period: return True diff --git a/backend/danswer/connectors/confluence/connector.py b/backend/danswer/connectors/confluence/connector.py index 8d614c163c7..c9be6676fa7 100644 --- a/backend/danswer/connectors/confluence/connector.py +++ b/backend/danswer/connectors/confluence/connector.py @@ -3,15 +3,13 @@ from typing import Any from urllib.parse import quote -from atlassian import Confluence # type: ignore - from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.configs.constants import DocumentSource +from danswer.connectors.confluence.onyx_confluence import build_confluence_client from danswer.connectors.confluence.onyx_confluence import OnyxConfluence from danswer.connectors.confluence.utils import attachment_to_content -from danswer.connectors.confluence.utils import build_confluence_client from danswer.connectors.confluence.utils import build_confluence_document_id from danswer.connectors.confluence.utils import datetime_from_string from danswer.connectors.confluence.utils import extract_text_from_confluence_html @@ -114,25 +112,10 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None # see https://github.com/atlassian-api/atlassian-python-api/blob/master/atlassian/rest_client.py # for a list of other hidden constructor args self._confluence_client = build_confluence_client( - credentials_json=credentials, + credentials=credentials, is_cloud=self.is_cloud, wiki_base=self.wiki_base, ) - - client_without_retries = Confluence( - api_version="cloud" if self.is_cloud else "latest", - url=self.wiki_base.rstrip("/"), - username=credentials["confluence_username"] if self.is_cloud else None, - password=credentials["confluence_access_token"] if self.is_cloud else None, - token=credentials["confluence_access_token"] if not self.is_cloud else None, - ) - spaces = client_without_retries.get_all_spaces(limit=1) - if not spaces: - raise RuntimeError( - f"No spaces found at {self.wiki_base}! " - "Check your credentials and wiki_base and make sure " - "is_cloud is set correctly." - ) return None def _get_comment_string_for_page_id(self, page_id: str) -> str: diff --git a/backend/danswer/connectors/confluence/onyx_confluence.py b/backend/danswer/connectors/confluence/onyx_confluence.py index 4820429ba0f..739e4aef21d 100644 --- a/backend/danswer/connectors/confluence/onyx_confluence.py +++ b/backend/danswer/connectors/confluence/onyx_confluence.py @@ -232,7 +232,6 @@ def _traverse_and_update(data: dict | list) -> None: def paginated_cql_user_retrieval( self, - cql: str, expand: str | None = None, limit: int | None = None, ) -> Iterator[dict[str, Any]]: @@ -241,10 +240,28 @@ def paginated_cql_user_retrieval( It's a seperate endpoint from the content/search endpoint used only for users. Otherwise it's very similar to the content/search endpoint. """ + cql = "type=user" + url = "rest/api/search/user" if self.cloud else "rest/api/search" expand_string = f"&expand={expand}" if expand else "" - yield from self._paginate_url( - f"rest/api/search/user?cql={cql}{expand_string}", limit - ) + url += f"?cql={cql}{expand_string}" + yield from self._paginate_url(url, limit) + + def paginated_groups_by_user_retrieval( + self, + user: dict[str, Any], + limit: int | None = None, + ) -> Iterator[dict[str, Any]]: + """ + This is not an SQL like query. + It's a confluence specific endpoint that can be used to fetch groups. + """ + user_field = "accountId" if self.cloud else "key" + user_value = user["accountId"] if self.cloud else user["userKey"] + # Server uses userKey (but calls it key during the API call), Cloud uses accountId + user_query = f"{user_field}={quote(user_value)}" + + url = f"rest/api/user/memberof?{user_query}" + yield from self._paginate_url(url, limit) def paginated_groups_retrieval( self, @@ -264,6 +281,55 @@ def paginated_group_members_retrieval( """ This is not an SQL like query. It's a confluence specific endpoint that can be used to fetch the members of a group. + THIS DOESN'T WORK FOR SERVER because it breaks when there is a slash in the group name. + E.g. neither "test/group" nor "test%2Fgroup" works for confluence. """ group_name = quote(group_name) yield from self._paginate_url(f"rest/api/group/{group_name}/member", limit) + + +def _validate_connector_configuration( + credentials: dict[str, Any], + is_cloud: bool, + wiki_base: str, +) -> None: + # test connection with direct client, no retries + confluence_client_without_retries = Confluence( + api_version="cloud" if is_cloud else "latest", + url=wiki_base.rstrip("/"), + username=credentials["confluence_username"] if is_cloud else None, + password=credentials["confluence_access_token"] if is_cloud else None, + token=credentials["confluence_access_token"] if not is_cloud else None, + ) + spaces = confluence_client_without_retries.get_all_spaces(limit=1) + + if not spaces: + raise RuntimeError( + f"No spaces found at {wiki_base}! " + "Check your credentials and wiki_base and make sure " + "is_cloud is set correctly." + ) + + +def build_confluence_client( + credentials: dict[str, Any], + is_cloud: bool, + wiki_base: str, +) -> OnyxConfluence: + _validate_connector_configuration( + credentials=credentials, + is_cloud=is_cloud, + wiki_base=wiki_base, + ) + return OnyxConfluence( + api_version="cloud" if is_cloud else "latest", + # Remove trailing slash from wiki_base if present + url=wiki_base.rstrip("/"), + # passing in username causes issues for Confluence data center + username=credentials["confluence_username"] if is_cloud else None, + password=credentials["confluence_access_token"] if is_cloud else None, + token=credentials["confluence_access_token"] if not is_cloud else None, + backoff_and_retry=True, + max_backoff_retries=10, + max_backoff_seconds=60, + ) diff --git a/backend/danswer/connectors/confluence/utils.py b/backend/danswer/connectors/confluence/utils.py index cb5253f4c14..e6ac0308a3a 100644 --- a/backend/danswer/connectors/confluence/utils.py +++ b/backend/danswer/connectors/confluence/utils.py @@ -269,20 +269,3 @@ def datetime_from_string(datetime_string: str) -> datetime: datetime_object = datetime_object.astimezone(timezone.utc) return datetime_object - - -def build_confluence_client( - credentials_json: dict[str, Any], is_cloud: bool, wiki_base: str -) -> OnyxConfluence: - return OnyxConfluence( - api_version="cloud" if is_cloud else "latest", - # Remove trailing slash from wiki_base if present - url=wiki_base.rstrip("/"), - # passing in username causes issues for Confluence data center - username=credentials_json["confluence_username"] if is_cloud else None, - password=credentials_json["confluence_access_token"] if is_cloud else None, - token=credentials_json["confluence_access_token"] if not is_cloud else None, - backoff_and_retry=True, - max_backoff_retries=10, - max_backoff_seconds=60, - ) diff --git a/backend/ee/danswer/external_permissions/confluence/group_sync.py b/backend/ee/danswer/external_permissions/confluence/group_sync.py index 17140b33f71..f2f53e589b1 100644 --- a/backend/ee/danswer/external_permissions/confluence/group_sync.py +++ b/backend/ee/danswer/external_permissions/confluence/group_sync.py @@ -1,7 +1,5 @@ -from atlassian import Confluence # type: ignore - +from danswer.connectors.confluence.onyx_confluence import build_confluence_client from danswer.connectors.confluence.onyx_confluence import OnyxConfluence -from danswer.connectors.confluence.utils import build_confluence_client from danswer.connectors.confluence.utils import get_user_email_from_username__server from danswer.db.models import ConnectorCredentialPair from danswer.utils.logger import setup_logger @@ -11,22 +9,30 @@ logger = setup_logger() -def _get_group_members_email_paginated( +def _build_group_member_email_map( confluence_client: OnyxConfluence, - group_name: str, -) -> set[str]: - group_member_emails: set[str] = set() - for member in confluence_client.paginated_group_members_retrieval(group_name): - email = member.get("email") +) -> dict[str, set[str]]: + group_member_emails: dict[str, set[str]] = {} + for user_result in confluence_client.paginated_cql_user_retrieval(): + user = user_result["user"] + email = user.get("email") if not email: - user_name = member.get("username") + # This field is only present in Confluence Server + user_name = user.get("username") + # If it is present, try to get the email using a Server-specific method if user_name: email = get_user_email_from_username__server( confluence_client=confluence_client, user_name=user_name, ) - if email: - group_member_emails.add(email) + if not email: + # If we still don't have an email, skip this user + continue + + for group in confluence_client.paginated_groups_by_user_retrieval(user): + # group name uniqueness is enforced by Confluence, so we can use it as a group ID + group_id = group["name"] + group_member_emails.setdefault(group_id, set()).add(email) return group_member_emails @@ -34,45 +40,20 @@ def _get_group_members_email_paginated( def confluence_group_sync( cc_pair: ConnectorCredentialPair, ) -> list[ExternalUserGroup]: - credentials = cc_pair.credential.credential_json - is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False) - wiki_base = cc_pair.connector.connector_specific_config["wiki_base"] - - # test connection with direct client, no retries - confluence_client = Confluence( - api_version="cloud" if is_cloud else "latest", - url=wiki_base.rstrip("/"), - username=credentials["confluence_username"] if is_cloud else None, - password=credentials["confluence_access_token"] if is_cloud else None, - token=credentials["confluence_access_token"] if not is_cloud else None, - ) - spaces = confluence_client.get_all_spaces(limit=1) - if not spaces: - raise RuntimeError(f"No spaces found at {wiki_base}!") - confluence_client = build_confluence_client( - credentials_json=credentials, - is_cloud=is_cloud, - wiki_base=wiki_base, + credentials=cc_pair.credential.credential_json, + is_cloud=cc_pair.connector.connector_specific_config.get("is_cloud", False), + wiki_base=cc_pair.connector.connector_specific_config["wiki_base"], ) - # Get all group names - group_names: list[str] = [] - for group in confluence_client.paginated_groups_retrieval(): - if group_name := group.get("name"): - group_names.append(group_name) - - # For each group name, get all members and create a danswer group + group_member_email_map = _build_group_member_email_map( + confluence_client=confluence_client, + ) danswer_groups: list[ExternalUserGroup] = [] - for group_name in group_names: - group_member_emails = _get_group_members_email_paginated( - confluence_client, group_name - ) - if not group_member_emails: - continue + for group_id, group_member_emails in group_member_email_map.items(): danswer_groups.append( ExternalUserGroup( - id=group_name, + id=group_id, user_emails=list(group_member_emails), ) ) diff --git a/backend/ee/danswer/external_permissions/sync_params.py b/backend/ee/danswer/external_permissions/sync_params.py index fb81ab35035..c00090d748d 100644 --- a/backend/ee/danswer/external_permissions/sync_params.py +++ b/backend/ee/danswer/external_permissions/sync_params.py @@ -55,7 +55,12 @@ DocumentSource.SLACK: 5 * 60, } -EXTERNAL_GROUP_SYNC_PERIOD: int = 30 # 30 seconds +# If nothing is specified here, we run the doc_sync every time the celery beat runs +EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = { + # Polling is not supported so we fetch all group permissions every 60 seconds + DocumentSource.GOOGLE_DRIVE: 60, + DocumentSource.CONFLUENCE: 60, +} def check_if_valid_sync_source(source_type: DocumentSource) -> bool: