From da38f1efea8da3afa0179a182039d88c1819fb59 Mon Sep 17 00:00:00 2001 From: Scot Mountenay Date: Tue, 24 Sep 2024 10:55:38 -0600 Subject: [PATCH] [Confluence] Add support for OAuth 2.0 --- confluence/README.md | 64 ++++++++++++++++-- confluence/provider/app.py | 14 +++- confluence/provider/client.py | 114 ++++++++++++++++++++++++-------- confluence/provider/provider.py | 4 +- 4 files changed, 159 insertions(+), 37 deletions(-) diff --git a/confluence/README.md b/confluence/README.md index 785b996dd..3ce2c1885 100644 --- a/confluence/README.md +++ b/confluence/README.md @@ -5,13 +5,18 @@ This package is a utility for connecting Cohere to Confluence, featuring a simpl ## Limitations The Confluence connector will search within the space defined in your `.env`, and performs a case-insensitive full-text -search against all text fields Confluence indexes by default. +search against all text fields Confluence indexes by default. Note: The search uses Confluence's advanced search language called [CQL](https://developer.atlassian.com/cloud/confluence/advanced-searching-using-cql/). If you wish to customize this connector's search experience, please refer to the above linked documentation for more details. ## Configuration -This connector requires the following environment variables: +There are two authentication methods available with this connector. You can either set it up using the service auth +method, or with OAuth. + +### Service Auth + +When using the service auth method, you must set the following env vars: ``` CONFLUENCE_USER: User email address @@ -20,8 +25,59 @@ CONFLUENCE_PRODUCT_URL: URL to your Confluence instance, including https:// sche CONFLUENCE_SPACE_NAME: Name of a space within your Confluence wiki ``` -The API token can be generated by logging into Confluence and going -to the [API tokens page](https://id.atlassian.com/manage-profile/security/api-tokens). +The API token can be generated by logging into Confluence and going to the [API tokens page](https://id.atlassian.com/manage-profile/security/api-tokens). + +### OAuth + +When using OAuth for authentication, the connector does not require any additional environment variables. Instead, +the OAuth flow should occur outside of the Connector and Cohere's API will forward the user's access token to this +connector through the `Authorization` header. + +To use OAuth, you must first create an OAuth 2.0 app in Confluence. To do this, go to the +Atlassian [Developer Console](https://developer.atlassian.com/console/myapps/), and use the option to create a new +OAuth 2.0 integration. + +You must configure in the developer console the OAuth scopes that are allowed to be requested by the client. There are +two options in Confluence, classic scopes or granular scopes. Use the granular scopes option, and ensure that the +following are enabled: + +* read:content:confluence +* read:content-details:confluence +* read:page:confluence +* read:custom-content:confluence + +You must also configure the authorization settings. Go to the Authorization page, and configure the app to use the +authorization type OAuth 2.0 (3L0). On the configuration page for the authorization page, enter the callback URL as: + +https://api.cohere.com/v1/connectors/oauth/token + +Go to the settings option for the app, enter the app name and description under general settings, and then take +note of the OAuth client id and secret from this page. + +Once your Confluence OAuth credentials are ready, you can register the connector in Cohere's API with the following +configuration: + +```bash +curl -X POST \ + 'https://api.cohere.ai/v1/connectors' \ + --header 'Accept: */*' \ + --header 'Authorization: Bearer {COHERE-API-KEY}' \ + --header 'Content-Type: application/json' \ + --data-raw '{ + "name": "Confluence", + "url": "{YOUR_CONNECTOR-URL}", + "oauth": { + "client_id": "{CONFLUENCE-OAUTH-CLIENT-ID}", + "client_secret": "{CONFLUENCE-OAUTH-CLIENT-SECRET}", + "authorize_url": "https://auth.atlassian.com/authorize?audience=api.atlassian.com&response_type=code&prompt=consent", + "token_url": "https://auth.atlassian.com/oauth/token", + "scope": "read:content:confluence read:content-details:confluence read:page:confluence read:custom-content:confluence" + } +}' +``` + +With OAuth the connector will be able to search any Confluence pages that the user has access to. + ### Optional Configuration ``` diff --git a/confluence/provider/app.py b/confluence/provider/app.py index 05a4bb660..f00049cde 100644 --- a/confluence/provider/app.py +++ b/confluence/provider/app.py @@ -1,19 +1,20 @@ import logging from connexion.exceptions import Unauthorized -from flask import abort -from flask import current_app as app +from flask import abort, request, current_app as app from . import UpstreamProviderError, provider logger = logging.getLogger(__name__) +AUTHORIZATION_HEADER = "Authorization" +BEARER_PREFIX = "Bearer " def search(body): logger.debug(f'Search request: {body["query"]}') try: - data = provider.search(body["query"]) + data = provider.search(body["query"], get_access_token()) logger.info(f"Found {len(data)} results") except UpstreamProviderError as error: logger.error(f"Upstream search error: {error.message}") @@ -22,6 +23,13 @@ def search(body): return {"results": data}, 200, {"X-Connector-Id": app.config.get("APP_ID")} +def get_access_token() -> str | None: + authorization_header = request.headers.get(AUTHORIZATION_HEADER, "") + if authorization_header.startswith(BEARER_PREFIX): + return authorization_header.removeprefix(BEARER_PREFIX) + return None + + def apikey_auth(token): api_key = str(app.config.get("CONNECTOR_API_KEY", "")) if api_key != "" and token != api_key: diff --git a/confluence/provider/client.py b/confluence/provider/client.py index f0a924627..4e52c92c3 100644 --- a/confluence/provider/client.py +++ b/confluence/provider/client.py @@ -28,10 +28,19 @@ class ConfluenceClient: # Cache size limit to reduce memory over time CACHE_LIMIT_BYTES = 20 * 1024 * 1024 # 20 MB to bytes - def __init__(self, url, user, api_token, search_limit=10): - self.base_url = url - self.user = user - self.api_token = api_token + # Cache for token to organization cloud id mappings + org_ids = {} + + def __init__( + self, + service_base_url=None, + service_user=None, + service_api_token=None, + search_limit=10, + ): + self.service_base_url = service_base_url + self.service_user = service_user + self.service_api_token = service_api_token self.search_limit = search_limit # Manually cache because functools.lru_cache does not support async methods self.cache = OrderedDict() @@ -71,24 +80,30 @@ def _close_session_and_loop(self): self.loop.stop() self.loop.close() - async def _gather(self, pages): - tasks = [self._get_page(page["id"]) for page in pages if self.PAGE_TYPE in page] + async def _gather(self, pages, access_token=None): + tasks = [ + self._get_page(page["id"], access_token) + for page in pages + if self.PAGE_TYPE in page + ] return await asyncio.gather(*tasks) - async def _get_page(self, page_id): + async def _get_page(self, page_id, access_token=None): # Check cache if page_id in self.cache: return self._cache_get(page_id) - get_page_by_id_url = f"{self.base_url}/wiki/api/v2/pages/{page_id}" - credentials = f"{self.user}:{self.api_token}" - credentials_encoded = base64.b64encode(credentials.encode()).decode("ascii") + base_url = self._get_base_url(access_token) + get_page_by_id_url = f"{base_url}/wiki/api/v2/pages/{page_id}" + + headers = {} + self._add_auth_header(headers, access_token) params = {"body-format": self.PAGE_BODY_FORMAT} async with self.session.get( get_page_by_id_url, - headers={"Authorization": f"Basic {credentials_encoded}"}, + headers=headers, params=params, ) as response: if not response.ok: @@ -97,7 +112,8 @@ async def _get_page(self, page_id): content = await response.json() - page_url = f"{self.base_url}/wiki{content['_links']['webui']}" + base_url = self._get_base_url(access_token) + page_url = f"{base_url}/wiki{content['_links']['webui']}" serialized_page = { "title": content["title"], @@ -109,8 +125,9 @@ async def _get_page(self, page_id): self._cache_put(page_id, serialized_page) return self._cache_get(page_id) - def search_pages(self, query): - search_url = f"{self.base_url}/wiki/rest/api/content/search" + def search_pages(self, query, access_token=None): + base_url = self._get_base_url(access_token) + search_url = f"{base_url}/wiki/rest/api/content/search" # Substitutes any sequence of non-alphanumeric or whitespace characters with a whitespace formatted_query = re.sub("\W+", " ", query) @@ -120,9 +137,12 @@ def search_pages(self, query): "limit": self.search_limit, } + headers = {} + self._add_auth_header(headers, access_token) + response = requests.get( search_url, - auth=(self.user, self.api_token), + headers=headers, params=params, ) @@ -133,30 +153,68 @@ def search_pages(self, query): return response.json().get("results", []) - def fetch_pages(self, pages): + def fetch_pages(self, pages, access_token): self._start_session() - results = self.loop.run_until_complete(self._gather(pages)) + results = self.loop.run_until_complete(self._gather(pages, access_token)) self._close_session_and_loop() return results - def search(self, query): - pages = self.search_pages(query) + def search(self, query, access_token=None): + pages = self.search_pages(query, access_token) + + return [ + page for page in self.fetch_pages(pages, access_token) if page is not None + ] + + def _add_auth_header(self, headers, access_token): + if access_token: + headers["Authorization"] = f"Bearer {access_token}" + else: + credentials = f"{self.service_user}:{self.service_api_token}" + credentials_encoded = base64.b64encode(credentials.encode()).decode("ascii") + headers["Authorization"] = f"Basic {credentials_encoded}" + + def _get_base_url(self, access_token=None): + if not access_token: + return self.service_base_url + + if access_token in self.org_ids: + return ( + f"https://api.atlassian.com/ex/confluence/{self.org_ids[access_token]}" + ) + + headers = {} + self._add_auth_header(headers, access_token) + + response = requests.get( + "https://api.atlassian.com/oauth/token/accessible-resources", + headers=headers, + ) + + if response.status_code != 200: + logger.error("Error determining Confluence base URL") + return + + accessible_resources = response.json() + + if not accessible_resources: + logger.error("No resources available to user") + return + + org_id = accessible_resources[0]["id"] + self.org_ids[access_token] = org_id - return [page for page in self.fetch_pages(pages) if page is not None] + return f"https://api.atlassian.com/ex/confluence/{org_id}" def get_client(): global client if client is None: - assert ( - url := app.config.get("PRODUCT_URL") - ), "CONFLUENCE_PRODUCT_URL must be set" - assert (user := app.config.get("USER")), "CONFLUENCE_USER must be set" - assert ( - api_token := app.config.get("API_TOKEN") - ), "CONFLUENCE_API_TOKEN must be set" + product_url = app.config.get("PRODUCT_URL") + user = app.config.get("USER") + api_token = app.config.get("API_TOKEN") search_limit = app.config.get("SEARCH_LIMIT", 10) - client = ConfluenceClient(url, user, api_token, search_limit) + client = ConfluenceClient(product_url, user, api_token, search_limit) return client diff --git a/confluence/provider/provider.py b/confluence/provider/provider.py index 8026aab26..0669426aa 100644 --- a/confluence/provider/provider.py +++ b/confluence/provider/provider.py @@ -5,7 +5,7 @@ logger = logging.getLogger(__name__) -def search(query): +def search(query, access_token): client = get_client() - pages = client.search(query) + pages = client.search(query, access_token) return pages