Skip to content

Commit

Permalink
Added permissions syncing for slack
Browse files Browse the repository at this point in the history
  • Loading branch information
hagen-danswer committed Sep 30, 2024
1 parent 1cff2b8 commit fe5a667
Show file tree
Hide file tree
Showing 7 changed files with 399 additions and 51 deletions.
131 changes: 83 additions & 48 deletions backend/danswer/connectors/slack/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@

from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.web import SlackResponse

from danswer.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import IdConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import BasicExpertInfo
Expand All @@ -23,9 +22,8 @@
from danswer.connectors.models import Section
from danswer.connectors.slack.utils import expert_info_from_slack_id
from danswer.connectors.slack.utils import get_message_link
from danswer.connectors.slack.utils import make_slack_api_call_logged
from danswer.connectors.slack.utils import make_slack_api_call_paginated
from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.connectors.slack.utils import make_paginated_slack_api_call_w_retries
from danswer.connectors.slack.utils import make_slack_api_call_w_retries
from danswer.connectors.slack.utils import SlackTextCleaner
from danswer.utils.logger import setup_logger

Expand All @@ -38,47 +36,18 @@
# list of messages in a thread
ThreadType = list[MessageType]

basic_retry_wrapper = retry_builder()


def _make_paginated_slack_api_call(
call: Callable[..., SlackResponse], **kwargs: Any
) -> Generator[dict[str, Any], None, None]:
return make_slack_api_call_paginated(
basic_retry_wrapper(
make_slack_api_rate_limited(make_slack_api_call_logged(call))
)
)(**kwargs)


def _make_slack_api_call(
call: Callable[..., SlackResponse], **kwargs: Any
) -> SlackResponse:
return basic_retry_wrapper(
make_slack_api_rate_limited(make_slack_api_call_logged(call))
)(**kwargs)


def get_channel_info(client: WebClient, channel_id: str) -> ChannelType:
"""Get information about a channel. Needed to convert channel ID to channel name"""
return _make_slack_api_call(client.conversations_info, channel=channel_id)[0][
"channel"
]


def _get_channels(
client: WebClient,
exclude_archived: bool,
get_private: bool,
channel_types: list[str],
) -> list[ChannelType]:
channels: list[dict[str, Any]] = []
for result in _make_paginated_slack_api_call(
for result in make_paginated_slack_api_call_w_retries(
client.conversations_list,
exclude_archived=exclude_archived,
# also get private channels the bot is added to
types=["public_channel", "private_channel"]
if get_private
else ["public_channel"],
types=channel_types,
):
channels.extend(result["channels"])

Expand All @@ -88,19 +57,38 @@ def _get_channels(
def get_channels(
client: WebClient,
exclude_archived: bool = True,
get_public: bool = True,
get_private: bool = True,
) -> list[ChannelType]:
"""Get all channels in the workspace"""
channels: list[dict[str, Any]] = []
channel_types = []
if get_public:
channel_types.append("public_channel")
if get_private:
channel_types.append("private_channel")
# try getting private channels as well at first
try:
return _get_channels(
client=client, exclude_archived=exclude_archived, get_private=True
channels = _get_channels(
client=client,
exclude_archived=exclude_archived,
channel_types=channel_types,
)
except SlackApiError as e:
logger.info(f"Unable to fetch private channels due to - {e}")
logger.info("trying again without private channels")
if get_public:
channel_types = ["public_channel"]
else:
logger.warning("No channels to fetch")
return []
channels = _get_channels(
client=client,
exclude_archived=exclude_archived,
channel_types=channel_types,
)

return _get_channels(
client=client, exclude_archived=exclude_archived, get_private=False
)
return channels


def get_channel_messages(
Expand All @@ -112,14 +100,14 @@ def get_channel_messages(
"""Get all messages in a channel"""
# join so that the bot can access messages
if not channel["is_member"]:
_make_slack_api_call(
make_slack_api_call_w_retries(
client.conversations_join,
channel=channel["id"],
is_private=channel["is_private"],
)
logger.info(f"Successfully joined '{channel['name']}'")

for result in _make_paginated_slack_api_call(
for result in make_paginated_slack_api_call_w_retries(
client.conversations_history,
channel=channel["id"],
oldest=oldest,
Expand All @@ -131,7 +119,7 @@ def get_channel_messages(
def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType:
"""Get all messages in a thread"""
threads: list[MessageType] = []
for result in _make_paginated_slack_api_call(
for result in make_paginated_slack_api_call_w_retries(
client.conversations_replies, channel=channel_id, ts=thread_id
):
threads.extend(result["messages"])
Expand Down Expand Up @@ -266,7 +254,7 @@ def filter_channels(
]


def get_all_docs(
def _get_all_docs(
client: WebClient,
workspace: str,
channels: list[str] | None = None,
Expand Down Expand Up @@ -328,7 +316,44 @@ def get_all_docs(
)


class SlackPollConnector(PollConnector):
def _get_all_doc_ids(
client: WebClient,
channels: list[str] | None = None,
channel_name_regex_enabled: bool = False,
msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter,
) -> set[str]:
"""
Get all document ids in the workspace, channel by channel
This is pretty identical to get_all_docs, but it returns a set of ids instead of documents
This makes it an order of magnitude faster than get_all_docs
"""

all_channels = get_channels(client)
filtered_channels = filter_channels(
all_channels, channels, channel_name_regex_enabled
)

all_doc_ids = set()
for channel in filtered_channels:
channel_message_batches = get_channel_messages(
client=client,
channel=channel,
)

for message_batch in channel_message_batches:
for message in message_batch:
if msg_filter_func(message):
continue

# The document id is the channel id and the ts of the first message in the thread
# Since we already have the first message of the thread, we dont have to
# fetch the thread for id retrieval, saving time and API calls
all_doc_ids.add(f"{channel['id']}__{message['ts']}")

return all_doc_ids


class SlackPollConnector(PollConnector, IdConnector):
def __init__(
self,
workspace: str,
Expand All @@ -349,14 +374,24 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None
self.client = WebClient(token=bot_token)
return None

def retrieve_all_source_ids(self) -> set[str]:
if self.client is None:
raise ConnectorMissingCredentialError("Slack")

return _get_all_doc_ids(
client=self.client,
channels=self.channels,
channel_name_regex_enabled=self.channel_regex_enabled,
)

def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
if self.client is None:
raise ConnectorMissingCredentialError("Slack")

documents: list[Document] = []
for document in get_all_docs(
for document in _get_all_docs(
client=self.client,
workspace=self.workspace,
channels=self.channels,
Expand Down
24 changes: 22 additions & 2 deletions backend/danswer/connectors/slack/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
from slack_sdk.errors import SlackApiError
from slack_sdk.web import SlackResponse

from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
from danswer.connectors.models import BasicExpertInfo
from danswer.utils.logger import setup_logger

logger = setup_logger()

basic_retry_wrapper = retry_builder()
# number of messages we request per page when fetching paginated slack messages
_SLACK_LIMIT = 900

Expand All @@ -34,7 +36,7 @@ def get_message_link(
)


def make_slack_api_call_logged(
def _make_slack_api_call_logged(
call: Callable[..., SlackResponse],
) -> Callable[..., SlackResponse]:
@wraps(call)
Expand All @@ -47,7 +49,7 @@ def logged_call(**kwargs: Any) -> SlackResponse:
return logged_call


def make_slack_api_call_paginated(
def _make_slack_api_call_paginated(
call: Callable[..., SlackResponse],
) -> Callable[..., Generator[dict[str, Any], None, None]]:
"""Wraps calls to slack API so that they automatically handle pagination"""
Expand Down Expand Up @@ -116,6 +118,24 @@ def rate_limited_call(**kwargs: Any) -> SlackResponse:
return rate_limited_call


def make_slack_api_call_w_retries(
call: Callable[..., SlackResponse], **kwargs: Any
) -> SlackResponse:
return basic_retry_wrapper(
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
)(**kwargs)


def make_paginated_slack_api_call_w_retries(
call: Callable[..., SlackResponse], **kwargs: Any
) -> Generator[dict[str, Any], None, None]:
return _make_slack_api_call_paginated(
basic_retry_wrapper(
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
)
)(**kwargs)


def expert_info_from_slack_id(
user_id: str | None,
client: WebClient,
Expand Down
1 change: 1 addition & 0 deletions backend/ee/danswer/external_permissions/permission_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
DocumentSource.GOOGLE_DRIVE: None,
# Polling is not supported so we fetch all doc permissions every 5 minutes
DocumentSource.CONFLUENCE: 5 * 60,
DocumentSource.SLACK: 60,
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ee.danswer.external_permissions.confluence.group_sync import confluence_group_sync
from ee.danswer.external_permissions.google_drive.doc_sync import gdrive_doc_sync
from ee.danswer.external_permissions.google_drive.group_sync import gdrive_group_sync

from ee.danswer.external_permissions.slack.doc_sync import slack_doc_sync

# Defining the input/output types for the sync functions
SyncFuncType = Callable[
Expand All @@ -27,6 +27,7 @@
DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, SyncFuncType] = {
DocumentSource.GOOGLE_DRIVE: gdrive_doc_sync,
DocumentSource.CONFLUENCE: confluence_doc_sync,
DocumentSource.SLACK: slack_doc_sync,
}

# These functions update:
Expand Down
Loading

0 comments on commit fe5a667

Please sign in to comment.