From da38f1efea8da3afa0179a182039d88c1819fb59 Mon Sep 17 00:00:00 2001 From: Scot Mountenay Date: Tue, 24 Sep 2024 10:55:38 -0600 Subject: [PATCH 1/5] [Confluence] Add support for OAuth 2.0 --- confluence/README.md | 64 ++++++++++++++++-- confluence/provider/app.py | 14 +++- confluence/provider/client.py | 114 ++++++++++++++++++++++++-------- confluence/provider/provider.py | 4 +- 4 files changed, 159 insertions(+), 37 deletions(-) diff --git a/confluence/README.md b/confluence/README.md index 785b996dd..3ce2c1885 100644 --- a/confluence/README.md +++ b/confluence/README.md @@ -5,13 +5,18 @@ This package is a utility for connecting Cohere to Confluence, featuring a simpl ## Limitations The Confluence connector will search within the space defined in your `.env`, and performs a case-insensitive full-text -search against all text fields Confluence indexes by default. +search against all text fields Confluence indexes by default. Note: The search uses Confluence's advanced search language called [CQL](https://developer.atlassian.com/cloud/confluence/advanced-searching-using-cql/). If you wish to customize this connector's search experience, please refer to the above linked documentation for more details. ## Configuration -This connector requires the following environment variables: +There are two authentication methods available with this connector. You can either set it up using the service auth +method, or with OAuth. + +### Service Auth + +When using the service auth method, you must set the following env vars: ``` CONFLUENCE_USER: User email address @@ -20,8 +25,59 @@ CONFLUENCE_PRODUCT_URL: URL to your Confluence instance, including https:// sche CONFLUENCE_SPACE_NAME: Name of a space within your Confluence wiki ``` -The API token can be generated by logging into Confluence and going -to the [API tokens page](https://id.atlassian.com/manage-profile/security/api-tokens). +The API token can be generated by logging into Confluence and going to the [API tokens page](https://id.atlassian.com/manage-profile/security/api-tokens). + +### OAuth + +When using OAuth for authentication, the connector does not require any additional environment variables. Instead, +the OAuth flow should occur outside of the Connector and Cohere's API will forward the user's access token to this +connector through the `Authorization` header. + +To use OAuth, you must first create an OAuth 2.0 app in Confluence. To do this, go to the +Atlassian [Developer Console](https://developer.atlassian.com/console/myapps/), and use the option to create a new +OAuth 2.0 integration. + +You must configure in the developer console the OAuth scopes that are allowed to be requested by the client. There are +two options in Confluence, classic scopes or granular scopes. Use the granular scopes option, and ensure that the +following are enabled: + +* read:content:confluence +* read:content-details:confluence +* read:page:confluence +* read:custom-content:confluence + +You must also configure the authorization settings. Go to the Authorization page, and configure the app to use the +authorization type OAuth 2.0 (3L0). On the configuration page for the authorization page, enter the callback URL as: + +https://api.cohere.com/v1/connectors/oauth/token + +Go to the settings option for the app, enter the app name and description under general settings, and then take +note of the OAuth client id and secret from this page. + +Once your Confluence OAuth credentials are ready, you can register the connector in Cohere's API with the following +configuration: + +```bash +curl -X POST \ + 'https://api.cohere.ai/v1/connectors' \ + --header 'Accept: */*' \ + --header 'Authorization: Bearer {COHERE-API-KEY}' \ + --header 'Content-Type: application/json' \ + --data-raw '{ + "name": "Confluence", + "url": "{YOUR_CONNECTOR-URL}", + "oauth": { + "client_id": "{CONFLUENCE-OAUTH-CLIENT-ID}", + "client_secret": "{CONFLUENCE-OAUTH-CLIENT-SECRET}", + "authorize_url": "https://auth.atlassian.com/authorize?audience=api.atlassian.com&response_type=code&prompt=consent", + "token_url": "https://auth.atlassian.com/oauth/token", + "scope": "read:content:confluence read:content-details:confluence read:page:confluence read:custom-content:confluence" + } +}' +``` + +With OAuth the connector will be able to search any Confluence pages that the user has access to. + ### Optional Configuration ``` diff --git a/confluence/provider/app.py b/confluence/provider/app.py index 05a4bb660..f00049cde 100644 --- a/confluence/provider/app.py +++ b/confluence/provider/app.py @@ -1,19 +1,20 @@ import logging from connexion.exceptions import Unauthorized -from flask import abort -from flask import current_app as app +from flask import abort, request, current_app as app from . import UpstreamProviderError, provider logger = logging.getLogger(__name__) +AUTHORIZATION_HEADER = "Authorization" +BEARER_PREFIX = "Bearer " def search(body): logger.debug(f'Search request: {body["query"]}') try: - data = provider.search(body["query"]) + data = provider.search(body["query"], get_access_token()) logger.info(f"Found {len(data)} results") except UpstreamProviderError as error: logger.error(f"Upstream search error: {error.message}") @@ -22,6 +23,13 @@ def search(body): return {"results": data}, 200, {"X-Connector-Id": app.config.get("APP_ID")} +def get_access_token() -> str | None: + authorization_header = request.headers.get(AUTHORIZATION_HEADER, "") + if authorization_header.startswith(BEARER_PREFIX): + return authorization_header.removeprefix(BEARER_PREFIX) + return None + + def apikey_auth(token): api_key = str(app.config.get("CONNECTOR_API_KEY", "")) if api_key != "" and token != api_key: diff --git a/confluence/provider/client.py b/confluence/provider/client.py index f0a924627..4e52c92c3 100644 --- a/confluence/provider/client.py +++ b/confluence/provider/client.py @@ -28,10 +28,19 @@ class ConfluenceClient: # Cache size limit to reduce memory over time CACHE_LIMIT_BYTES = 20 * 1024 * 1024 # 20 MB to bytes - def __init__(self, url, user, api_token, search_limit=10): - self.base_url = url - self.user = user - self.api_token = api_token + # Cache for token to organization cloud id mappings + org_ids = {} + + def __init__( + self, + service_base_url=None, + service_user=None, + service_api_token=None, + search_limit=10, + ): + self.service_base_url = service_base_url + self.service_user = service_user + self.service_api_token = service_api_token self.search_limit = search_limit # Manually cache because functools.lru_cache does not support async methods self.cache = OrderedDict() @@ -71,24 +80,30 @@ def _close_session_and_loop(self): self.loop.stop() self.loop.close() - async def _gather(self, pages): - tasks = [self._get_page(page["id"]) for page in pages if self.PAGE_TYPE in page] + async def _gather(self, pages, access_token=None): + tasks = [ + self._get_page(page["id"], access_token) + for page in pages + if self.PAGE_TYPE in page + ] return await asyncio.gather(*tasks) - async def _get_page(self, page_id): + async def _get_page(self, page_id, access_token=None): # Check cache if page_id in self.cache: return self._cache_get(page_id) - get_page_by_id_url = f"{self.base_url}/wiki/api/v2/pages/{page_id}" - credentials = f"{self.user}:{self.api_token}" - credentials_encoded = base64.b64encode(credentials.encode()).decode("ascii") + base_url = self._get_base_url(access_token) + get_page_by_id_url = f"{base_url}/wiki/api/v2/pages/{page_id}" + + headers = {} + self._add_auth_header(headers, access_token) params = {"body-format": self.PAGE_BODY_FORMAT} async with self.session.get( get_page_by_id_url, - headers={"Authorization": f"Basic {credentials_encoded}"}, + headers=headers, params=params, ) as response: if not response.ok: @@ -97,7 +112,8 @@ async def _get_page(self, page_id): content = await response.json() - page_url = f"{self.base_url}/wiki{content['_links']['webui']}" + base_url = self._get_base_url(access_token) + page_url = f"{base_url}/wiki{content['_links']['webui']}" serialized_page = { "title": content["title"], @@ -109,8 +125,9 @@ async def _get_page(self, page_id): self._cache_put(page_id, serialized_page) return self._cache_get(page_id) - def search_pages(self, query): - search_url = f"{self.base_url}/wiki/rest/api/content/search" + def search_pages(self, query, access_token=None): + base_url = self._get_base_url(access_token) + search_url = f"{base_url}/wiki/rest/api/content/search" # Substitutes any sequence of non-alphanumeric or whitespace characters with a whitespace formatted_query = re.sub("\W+", " ", query) @@ -120,9 +137,12 @@ def search_pages(self, query): "limit": self.search_limit, } + headers = {} + self._add_auth_header(headers, access_token) + response = requests.get( search_url, - auth=(self.user, self.api_token), + headers=headers, params=params, ) @@ -133,30 +153,68 @@ def search_pages(self, query): return response.json().get("results", []) - def fetch_pages(self, pages): + def fetch_pages(self, pages, access_token): self._start_session() - results = self.loop.run_until_complete(self._gather(pages)) + results = self.loop.run_until_complete(self._gather(pages, access_token)) self._close_session_and_loop() return results - def search(self, query): - pages = self.search_pages(query) + def search(self, query, access_token=None): + pages = self.search_pages(query, access_token) + + return [ + page for page in self.fetch_pages(pages, access_token) if page is not None + ] + + def _add_auth_header(self, headers, access_token): + if access_token: + headers["Authorization"] = f"Bearer {access_token}" + else: + credentials = f"{self.service_user}:{self.service_api_token}" + credentials_encoded = base64.b64encode(credentials.encode()).decode("ascii") + headers["Authorization"] = f"Basic {credentials_encoded}" + + def _get_base_url(self, access_token=None): + if not access_token: + return self.service_base_url + + if access_token in self.org_ids: + return ( + f"https://api.atlassian.com/ex/confluence/{self.org_ids[access_token]}" + ) + + headers = {} + self._add_auth_header(headers, access_token) + + response = requests.get( + "https://api.atlassian.com/oauth/token/accessible-resources", + headers=headers, + ) + + if response.status_code != 200: + logger.error("Error determining Confluence base URL") + return + + accessible_resources = response.json() + + if not accessible_resources: + logger.error("No resources available to user") + return + + org_id = accessible_resources[0]["id"] + self.org_ids[access_token] = org_id - return [page for page in self.fetch_pages(pages) if page is not None] + return f"https://api.atlassian.com/ex/confluence/{org_id}" def get_client(): global client if client is None: - assert ( - url := app.config.get("PRODUCT_URL") - ), "CONFLUENCE_PRODUCT_URL must be set" - assert (user := app.config.get("USER")), "CONFLUENCE_USER must be set" - assert ( - api_token := app.config.get("API_TOKEN") - ), "CONFLUENCE_API_TOKEN must be set" + product_url = app.config.get("PRODUCT_URL") + user = app.config.get("USER") + api_token = app.config.get("API_TOKEN") search_limit = app.config.get("SEARCH_LIMIT", 10) - client = ConfluenceClient(url, user, api_token, search_limit) + client = ConfluenceClient(product_url, user, api_token, search_limit) return client diff --git a/confluence/provider/provider.py b/confluence/provider/provider.py index 8026aab26..0669426aa 100644 --- a/confluence/provider/provider.py +++ b/confluence/provider/provider.py @@ -5,7 +5,7 @@ logger = logging.getLogger(__name__) -def search(query): +def search(query, access_token): client = get_client() - pages = client.search(query) + pages = client.search(query, access_token) return pages From 012cb427a8a1a15e2f9d0bc190370ac4c7e02ec6 Mon Sep 17 00:00:00 2001 From: Scot Mountenay Date: Tue, 24 Sep 2024 12:36:02 -0600 Subject: [PATCH 2/5] [Confluence] Add type hints to Confluence OAuth changes --- confluence/poetry.lock | 85 ++++++++++++++++++++++++++++++----- confluence/provider/client.py | 27 +++++------ confluence/pyproject.toml | 2 + 3 files changed, 86 insertions(+), 28 deletions(-) diff --git a/confluence/poetry.lock b/confluence/poetry.lock index 17e8cdf76..f475899bc 100644 --- a/confluence/poetry.lock +++ b/confluence/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiohttp" @@ -626,16 +626,6 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -767,6 +757,52 @@ files = [ {file = "multidict-6.0.5.tar.gz", hash = "sha256:f7e301075edaf50500f0b341543c41194d8df3ae5caf4702f2095f3ca73dd8da"}, ] +[[package]] +name = "mypy" +version = "1.11.2" +description = "Optional static typing for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "mypy-1.11.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d42a6dd818ffce7be66cce644f1dff482f1d97c53ca70908dff0b9ddc120b77a"}, + {file = "mypy-1.11.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:801780c56d1cdb896eacd5619a83e427ce436d86a3bdf9112527f24a66618fef"}, + {file = "mypy-1.11.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:41ea707d036a5307ac674ea172875f40c9d55c5394f888b168033177fce47383"}, + {file = "mypy-1.11.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6e658bd2d20565ea86da7d91331b0eed6d2eee22dc031579e6297f3e12c758c8"}, + {file = "mypy-1.11.2-cp310-cp310-win_amd64.whl", hash = "sha256:478db5f5036817fe45adb7332d927daa62417159d49783041338921dcf646fc7"}, + {file = "mypy-1.11.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:75746e06d5fa1e91bfd5432448d00d34593b52e7e91a187d981d08d1f33d4385"}, + {file = "mypy-1.11.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a976775ab2256aadc6add633d44f100a2517d2388906ec4f13231fafbb0eccca"}, + {file = "mypy-1.11.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:cd953f221ac1379050a8a646585a29574488974f79d8082cedef62744f0a0104"}, + {file = "mypy-1.11.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:57555a7715c0a34421013144a33d280e73c08df70f3a18a552938587ce9274f4"}, + {file = "mypy-1.11.2-cp311-cp311-win_amd64.whl", hash = "sha256:36383a4fcbad95f2657642a07ba22ff797de26277158f1cc7bd234821468b1b6"}, + {file = "mypy-1.11.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e8960dbbbf36906c5c0b7f4fbf2f0c7ffb20f4898e6a879fcf56a41a08b0d318"}, + {file = "mypy-1.11.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:06d26c277962f3fb50e13044674aa10553981ae514288cb7d0a738f495550b36"}, + {file = "mypy-1.11.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6e7184632d89d677973a14d00ae4d03214c8bc301ceefcdaf5c474866814c987"}, + {file = "mypy-1.11.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3a66169b92452f72117e2da3a576087025449018afc2d8e9bfe5ffab865709ca"}, + {file = "mypy-1.11.2-cp312-cp312-win_amd64.whl", hash = "sha256:969ea3ef09617aff826885a22ece0ddef69d95852cdad2f60c8bb06bf1f71f70"}, + {file = "mypy-1.11.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:37c7fa6121c1cdfcaac97ce3d3b5588e847aa79b580c1e922bb5d5d2902df19b"}, + {file = "mypy-1.11.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4a8a53bc3ffbd161b5b2a4fff2f0f1e23a33b0168f1c0778ec70e1a3d66deb86"}, + {file = "mypy-1.11.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2ff93107f01968ed834f4256bc1fc4475e2fecf6c661260066a985b52741ddce"}, + {file = "mypy-1.11.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:edb91dded4df17eae4537668b23f0ff6baf3707683734b6a818d5b9d0c0c31a1"}, + {file = "mypy-1.11.2-cp38-cp38-win_amd64.whl", hash = "sha256:ee23de8530d99b6db0573c4ef4bd8f39a2a6f9b60655bf7a1357e585a3486f2b"}, + {file = "mypy-1.11.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:801ca29f43d5acce85f8e999b1e431fb479cb02d0e11deb7d2abb56bdaf24fd6"}, + {file = "mypy-1.11.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:af8d155170fcf87a2afb55b35dc1a0ac21df4431e7d96717621962e4b9192e70"}, + {file = "mypy-1.11.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f7821776e5c4286b6a13138cc935e2e9b6fde05e081bdebf5cdb2bb97c9df81d"}, + {file = "mypy-1.11.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:539c570477a96a4e6fb718b8d5c3e0c0eba1f485df13f86d2970c91f0673148d"}, + {file = "mypy-1.11.2-cp39-cp39-win_amd64.whl", hash = "sha256:3f14cd3d386ac4d05c5a39a51b84387403dadbd936e17cb35882134d4f8f0d24"}, + {file = "mypy-1.11.2-py3-none-any.whl", hash = "sha256:b499bc07dbdcd3de92b0a8b29fdf592c111276f6a12fe29c30f6c417dd546d12"}, + {file = "mypy-1.11.2.tar.gz", hash = "sha256:7f9993ad3e0ffdc95c2a14b66dee63729f021968bff8ad911867579c65d13a79"}, +] + +[package.dependencies] +mypy-extensions = ">=1.0.0" +typing-extensions = ">=4.6.0" + +[package.extras] +dmypy = ["psutil (>=4.0)"] +install-types = ["pip"] +mypyc = ["setuptools (>=50)"] +reports = ["lxml"] + [[package]] name = "mypy-extensions" version = "1.0.0" @@ -1079,6 +1115,31 @@ files = [ [package.dependencies] Jinja2 = ">=2.0" +[[package]] +name = "types-requests" +version = "2.32.0.20240914" +description = "Typing stubs for requests" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-requests-2.32.0.20240914.tar.gz", hash = "sha256:2850e178db3919d9bf809e434eef65ba49d0e7e33ac92d588f4a5e295fffd405"}, + {file = "types_requests-2.32.0.20240914-py3-none-any.whl", hash = "sha256:59c2f673eb55f32a99b2894faf6020e1a9f4a402ad0f192bfee0b64469054310"}, +] + +[package.dependencies] +urllib3 = ">=2" + +[[package]] +name = "typing-extensions" +version = "4.12.2" +description = "Backported and Experimental Type Hints for Python 3.8+" +optional = false +python-versions = ">=3.8" +files = [ + {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, + {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, +] + [[package]] name = "urllib3" version = "2.0.7" @@ -1303,4 +1364,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "b317122939a7e1fe670176098987cdfbc976fa71344b49f1074245c81208c1dd" +content-hash = "e56a194069245c5ef2fffaacb53899377d9ca125ff203ac81eddb63a346f4405" diff --git a/confluence/provider/client.py b/confluence/provider/client.py index 4e52c92c3..8f39b6c15 100644 --- a/confluence/provider/client.py +++ b/confluence/provider/client.py @@ -29,7 +29,7 @@ class ConfluenceClient: CACHE_LIMIT_BYTES = 20 * 1024 * 1024 # 20 MB to bytes # Cache for token to organization cloud id mappings - org_ids = {} + org_ids: dict[str, str] = {} def __init__( self, @@ -96,14 +96,11 @@ async def _get_page(self, page_id, access_token=None): base_url = self._get_base_url(access_token) get_page_by_id_url = f"{base_url}/wiki/api/v2/pages/{page_id}" - - headers = {} - self._add_auth_header(headers, access_token) params = {"body-format": self.PAGE_BODY_FORMAT} async with self.session.get( get_page_by_id_url, - headers=headers, + headers=self._get_headers(access_token), params=params, ) as response: if not response.ok: @@ -137,12 +134,9 @@ def search_pages(self, query, access_token=None): "limit": self.search_limit, } - headers = {} - self._add_auth_header(headers, access_token) - response = requests.get( search_url, - headers=headers, + headers=self._get_headers(access_token), params=params, ) @@ -153,7 +147,7 @@ def search_pages(self, query, access_token=None): return response.json().get("results", []) - def fetch_pages(self, pages, access_token): + def fetch_pages(self, pages, access_token: str | None = None): self._start_session() results = self.loop.run_until_complete(self._gather(pages, access_token)) self._close_session_and_loop() @@ -167,7 +161,9 @@ def search(self, query, access_token=None): page for page in self.fetch_pages(pages, access_token) if page is not None ] - def _add_auth_header(self, headers, access_token): + def _get_headers(self, access_token: str | None = None) -> dict[str, str]: + headers = {} + if access_token: headers["Authorization"] = f"Bearer {access_token}" else: @@ -175,7 +171,9 @@ def _add_auth_header(self, headers, access_token): credentials_encoded = base64.b64encode(credentials.encode()).decode("ascii") headers["Authorization"] = f"Basic {credentials_encoded}" - def _get_base_url(self, access_token=None): + return headers + + def _get_base_url(self, access_token: str | None = None): if not access_token: return self.service_base_url @@ -184,12 +182,9 @@ def _get_base_url(self, access_token=None): f"https://api.atlassian.com/ex/confluence/{self.org_ids[access_token]}" ) - headers = {} - self._add_auth_header(headers, access_token) - response = requests.get( "https://api.atlassian.com/oauth/token/accessible-resources", - headers=headers, + headers=self._get_headers(access_token), ) if response.status_code != 200: diff --git a/confluence/pyproject.toml b/confluence/pyproject.toml index 241510bd1..57101f64f 100644 --- a/confluence/pyproject.toml +++ b/confluence/pyproject.toml @@ -15,6 +15,8 @@ gunicorn = "^22.0.0" asyncio = "^3.4.3" black = "^24.3.0" aiohttp = "^3.9.4" +mypy = "^1.11.2" +types-requests = "^2.32.0.20240914" From b7d37ea02d3465f065d111138ca3eda42fe77bb9 Mon Sep 17 00:00:00 2001 From: Scot Mountenay Date: Tue, 24 Sep 2024 13:59:33 -0600 Subject: [PATCH 3/5] [Confluence] Fix to ensure service auth continues working after oauth support added --- confluence/provider/app.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/confluence/provider/app.py b/confluence/provider/app.py index f00049cde..e98efde0a 100644 --- a/confluence/provider/app.py +++ b/confluence/provider/app.py @@ -12,9 +12,13 @@ def search(body): logger.debug(f'Search request: {body["query"]}') + access_token = get_access_token() + + if access_token == app.config.get("CONNECTOR_API_KEY", None): + access_token = None try: - data = provider.search(body["query"], get_access_token()) + data = provider.search(body["query"], access_token) logger.info(f"Found {len(data)} results") except UpstreamProviderError as error: logger.error(f"Upstream search error: {error.message}") From 45b45991762cc578e4c303e8c5f2d72a36cfb5fe Mon Sep 17 00:00:00 2001 From: Scot Mountenay Date: Wed, 25 Sep 2024 12:35:37 -0600 Subject: [PATCH 4/5] [Confluence] Add requirement to request offline_access scope to documentation --- confluence/README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/confluence/README.md b/confluence/README.md index 3ce2c1885..e825c5c77 100644 --- a/confluence/README.md +++ b/confluence/README.md @@ -46,6 +46,10 @@ following are enabled: * read:page:confluence * read:custom-content:confluence +The `offline_access` scope must also be requested for refresh tokens to work. This scope does not appear in the +list of scopes in the Atlassian OAuth permissions page, but it must be included in the scopes added to the connector +configuration in Cohere dashboard. + You must also configure the authorization settings. Go to the Authorization page, and configure the app to use the authorization type OAuth 2.0 (3L0). On the configuration page for the authorization page, enter the callback URL as: @@ -71,7 +75,7 @@ curl -X POST \ "client_secret": "{CONFLUENCE-OAUTH-CLIENT-SECRET}", "authorize_url": "https://auth.atlassian.com/authorize?audience=api.atlassian.com&response_type=code&prompt=consent", "token_url": "https://auth.atlassian.com/oauth/token", - "scope": "read:content:confluence read:content-details:confluence read:page:confluence read:custom-content:confluence" + "scope": "read:content:confluence read:content-details:confluence read:page:confluence read:custom-content:confluence offline_access" } }' ``` From 78ff9a515a48f97f67b78df6749714d8732f5209 Mon Sep 17 00:00:00 2001 From: Scot Mountenay Date: Wed, 25 Sep 2024 17:46:36 -0600 Subject: [PATCH 5/5] [Confluence] Separte client into separate oauth and service auth clients --- confluence/README.md | 10 +++- confluence/provider/app.py | 9 +++- confluence/provider/client.py | 89 ++++++++++++++++++++++++----------- 3 files changed, 78 insertions(+), 30 deletions(-) diff --git a/confluence/README.md b/confluence/README.md index e825c5c77..1eff2922c 100644 --- a/confluence/README.md +++ b/confluence/README.md @@ -19,6 +19,7 @@ method, or with OAuth. When using the service auth method, you must set the following env vars: ``` +CONFLUENCE_AUTH_METHOD: Set to "service_auth" CONFLUENCE_USER: User email address CONFLUENCE_API_TOKEN: API token CONFLUENCE_PRODUCT_URL: URL to your Confluence instance, including https:// schema @@ -30,7 +31,7 @@ The API token can be generated by logging into Confluence and going to the [API ### OAuth When using OAuth for authentication, the connector does not require any additional environment variables. Instead, -the OAuth flow should occur outside of the Connector and Cohere's API will forward the user's access token to this +the OAuth flow should occur outside the Connector and Cohere's API will forward the user's access token to this connector through the `Authorization` header. To use OAuth, you must first create an OAuth 2.0 app in Confluence. To do this, go to the @@ -97,6 +98,13 @@ CONFLUENCE_CONNECTOR_API_KEY This variable can be used to set an API key for the connector. +``` +CONFLUENCE_AUTH_METHOD +``` + +This variable is used to configure the connector to use service auth or OAuth authentication. The valid +values are `service_auth` and `oauth`. The default is to run the connector in OAuth mode. + These variables can optionally be put into a `.env` file for development. A `.env-template` file is provided with all the environment variables that are used by this demo. diff --git a/confluence/provider/app.py b/confluence/provider/app.py index e98efde0a..b8f0c91f8 100644 --- a/confluence/provider/app.py +++ b/confluence/provider/app.py @@ -14,7 +14,14 @@ def search(body): logger.debug(f'Search request: {body["query"]}') access_token = get_access_token() - if access_token == app.config.get("CONNECTOR_API_KEY", None): + auth_method = app.config.get("AUTH_METHOD") + connector_api_key = app.config.get("CONNECTOR_API_KEY", None) + + if auth_method == "service_auth" and access_token and not connector_api_key: + logger.error("Connector not configured to use API keys") + raise Unauthorized() + + if access_token == connector_api_key: access_token = None try: diff --git a/confluence/provider/client.py b/confluence/provider/client.py index 8f39b6c15..e53254869 100644 --- a/confluence/provider/client.py +++ b/confluence/provider/client.py @@ -17,7 +17,7 @@ client = None -class ConfluenceClient: +class BaseConfluenceClient: # Page consts PAGE_TYPE = "type" PAGE_BODY_FORMAT = "storage" @@ -28,19 +28,7 @@ class ConfluenceClient: # Cache size limit to reduce memory over time CACHE_LIMIT_BYTES = 20 * 1024 * 1024 # 20 MB to bytes - # Cache for token to organization cloud id mappings - org_ids: dict[str, str] = {} - - def __init__( - self, - service_base_url=None, - service_user=None, - service_api_token=None, - search_limit=10, - ): - self.service_base_url = service_base_url - self.service_user = service_user - self.service_api_token = service_api_token + def __init__(self, search_limit=10): self.search_limit = search_limit # Manually cache because functools.lru_cache does not support async methods self.cache = OrderedDict() @@ -162,20 +150,40 @@ def search(self, query, access_token=None): ] def _get_headers(self, access_token: str | None = None) -> dict[str, str]: - headers = {} + raise NotImplementedError() + + def _get_base_url(self, access_token: str | None = None): + raise NotImplementedError() - if access_token: - headers["Authorization"] = f"Bearer {access_token}" - else: - credentials = f"{self.service_user}:{self.service_api_token}" - credentials_encoded = base64.b64encode(credentials.encode()).decode("ascii") - headers["Authorization"] = f"Basic {credentials_encoded}" - return headers +class ServiceAuthConfluenceClient(BaseConfluenceClient): + def __init__(self, product_url, user, api_token, search_limit): + self.product_url = product_url + self.user = user + self.api_token = api_token + super().__init__(search_limit=search_limit) + + def _get_base_url(self, access_token: str | None = None): + return self.product_url + + def _get_headers(self, access_token: str | None = None) -> dict[str, str]: + credentials = f"{self.user}:{self.api_token}" + credentials_encoded = base64.b64encode(credentials.encode()).decode("ascii") + + return { + "Authorization": f"Basic {credentials_encoded}", + } + + +class OAuthConfluenceClient(BaseConfluenceClient): + # Cache for token to organization cloud id mappings + org_ids: dict[str, str] = {} def _get_base_url(self, access_token: str | None = None): if not access_token: - return self.service_base_url + raise AssertionError( + "Access token required to construct Confluence cloud URLs" + ) if access_token in self.org_ids: return ( @@ -202,14 +210,39 @@ def _get_base_url(self, access_token: str | None = None): return f"https://api.atlassian.com/ex/confluence/{org_id}" + def _get_headers(self, access_token: str | None = None) -> dict[str, str]: + return { + "Authorization": f"Bearer {access_token}", + } + def get_client(): global client + if client is None: - product_url = app.config.get("PRODUCT_URL") - user = app.config.get("USER") - api_token = app.config.get("API_TOKEN") - search_limit = app.config.get("SEARCH_LIMIT", 10) - client = ConfluenceClient(product_url, user, api_token, search_limit) + auth_method = app.config.get("AUTH_METHOD", "oauth") + assert auth_method in [ + "oauth", + "service_auth", + ], 'CONFLUENCE_AUTH_METHOD must be "oauth" or "service_auth"' + + try: + search_limit = int(app.config.get("SEARCH_LIMIT", 10)) + except ValueError: + raise ValueError("SEARCH_LIMIT must be an integer") + + if auth_method == "oauth": + client = OAuthConfluenceClient() + elif auth_method == "service_auth": + assert ( + product_url := app.config.get("PRODUCT_URL") + ), "CONFLUENCE_PRODUCT_URL must be set" + assert (user := app.config.get("USER")), "CONFLUENCE_USER must be set" + assert ( + api_token := app.config.get("API_TOKEN") + ), "CONFLUENCE_API_TOKEN must be set" + client = ServiceAuthConfluenceClient( + product_url, user, api_token, search_limit + ) return client