Skip to content

Commit

Permalink
[Confluence] Add support for OAuth 2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
scottmx81 committed Sep 24, 2024
1 parent 6100dc2 commit da38f1e
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 37 deletions.
64 changes: 60 additions & 4 deletions confluence/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,18 @@ This package is a utility for connecting Cohere to Confluence, featuring a simpl
## Limitations

The Confluence connector will search within the space defined in your `.env`, and performs a case-insensitive full-text
search against all text fields Confluence indexes by default.
search against all text fields Confluence indexes by default.

Note: The search uses Confluence's advanced search language called [CQL](https://developer.atlassian.com/cloud/confluence/advanced-searching-using-cql/). If you wish to customize this connector's search experience, please refer to the above linked documentation for more details.

## Configuration

This connector requires the following environment variables:
There are two authentication methods available with this connector. You can either set it up using the service auth
method, or with OAuth.

### Service Auth

When using the service auth method, you must set the following env vars:

```
CONFLUENCE_USER: User email address
Expand All @@ -20,8 +25,59 @@ CONFLUENCE_PRODUCT_URL: URL to your Confluence instance, including https:// sche
CONFLUENCE_SPACE_NAME: Name of a space within your Confluence wiki
```

The API token can be generated by logging into Confluence and going
to the [API tokens page](https://id.atlassian.com/manage-profile/security/api-tokens).
The API token can be generated by logging into Confluence and going to the [API tokens page](https://id.atlassian.com/manage-profile/security/api-tokens).

### OAuth

When using OAuth for authentication, the connector does not require any additional environment variables. Instead,
the OAuth flow should occur outside of the Connector and Cohere's API will forward the user's access token to this
connector through the `Authorization` header.

To use OAuth, you must first create an OAuth 2.0 app in Confluence. To do this, go to the
Atlassian [Developer Console](https://developer.atlassian.com/console/myapps/), and use the option to create a new
OAuth 2.0 integration.

You must configure in the developer console the OAuth scopes that are allowed to be requested by the client. There are
two options in Confluence, classic scopes or granular scopes. Use the granular scopes option, and ensure that the
following are enabled:

* read:content:confluence
* read:content-details:confluence
* read:page:confluence
* read:custom-content:confluence

You must also configure the authorization settings. Go to the Authorization page, and configure the app to use the
authorization type OAuth 2.0 (3L0). On the configuration page for the authorization page, enter the callback URL as:

https://api.cohere.com/v1/connectors/oauth/token

Go to the settings option for the app, enter the app name and description under general settings, and then take
note of the OAuth client id and secret from this page.

Once your Confluence OAuth credentials are ready, you can register the connector in Cohere's API with the following
configuration:

```bash
curl -X POST \
'https://api.cohere.ai/v1/connectors' \
--header 'Accept: */*' \
--header 'Authorization: Bearer {COHERE-API-KEY}' \
--header 'Content-Type: application/json' \
--data-raw '{
"name": "Confluence",
"url": "{YOUR_CONNECTOR-URL}",
"oauth": {
"client_id": "{CONFLUENCE-OAUTH-CLIENT-ID}",
"client_secret": "{CONFLUENCE-OAUTH-CLIENT-SECRET}",
"authorize_url": "https://auth.atlassian.com/authorize?audience=api.atlassian.com&response_type=code&prompt=consent",
"token_url": "https://auth.atlassian.com/oauth/token",
"scope": "read:content:confluence read:content-details:confluence read:page:confluence read:custom-content:confluence"
}
}'
```

With OAuth the connector will be able to search any Confluence pages that the user has access to.

### Optional Configuration

```
Expand Down
14 changes: 11 additions & 3 deletions confluence/provider/app.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import logging

from connexion.exceptions import Unauthorized
from flask import abort
from flask import current_app as app
from flask import abort, request, current_app as app

from . import UpstreamProviderError, provider

logger = logging.getLogger(__name__)
AUTHORIZATION_HEADER = "Authorization"
BEARER_PREFIX = "Bearer "


def search(body):
logger.debug(f'Search request: {body["query"]}')

try:
data = provider.search(body["query"])
data = provider.search(body["query"], get_access_token())
logger.info(f"Found {len(data)} results")
except UpstreamProviderError as error:
logger.error(f"Upstream search error: {error.message}")
Expand All @@ -22,6 +23,13 @@ def search(body):
return {"results": data}, 200, {"X-Connector-Id": app.config.get("APP_ID")}


def get_access_token() -> str | None:
authorization_header = request.headers.get(AUTHORIZATION_HEADER, "")
if authorization_header.startswith(BEARER_PREFIX):
return authorization_header.removeprefix(BEARER_PREFIX)
return None


def apikey_auth(token):
api_key = str(app.config.get("CONNECTOR_API_KEY", ""))
if api_key != "" and token != api_key:
Expand Down
114 changes: 86 additions & 28 deletions confluence/provider/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,19 @@ class ConfluenceClient:
# Cache size limit to reduce memory over time
CACHE_LIMIT_BYTES = 20 * 1024 * 1024 # 20 MB to bytes

def __init__(self, url, user, api_token, search_limit=10):
self.base_url = url
self.user = user
self.api_token = api_token
# Cache for token to organization cloud id mappings
org_ids = {}

def __init__(
self,
service_base_url=None,
service_user=None,
service_api_token=None,
search_limit=10,
):
self.service_base_url = service_base_url
self.service_user = service_user
self.service_api_token = service_api_token
self.search_limit = search_limit
# Manually cache because functools.lru_cache does not support async methods
self.cache = OrderedDict()
Expand Down Expand Up @@ -71,24 +80,30 @@ def _close_session_and_loop(self):
self.loop.stop()
self.loop.close()

async def _gather(self, pages):
tasks = [self._get_page(page["id"]) for page in pages if self.PAGE_TYPE in page]
async def _gather(self, pages, access_token=None):
tasks = [
self._get_page(page["id"], access_token)
for page in pages
if self.PAGE_TYPE in page
]

return await asyncio.gather(*tasks)

async def _get_page(self, page_id):
async def _get_page(self, page_id, access_token=None):
# Check cache
if page_id in self.cache:
return self._cache_get(page_id)

get_page_by_id_url = f"{self.base_url}/wiki/api/v2/pages/{page_id}"
credentials = f"{self.user}:{self.api_token}"
credentials_encoded = base64.b64encode(credentials.encode()).decode("ascii")
base_url = self._get_base_url(access_token)
get_page_by_id_url = f"{base_url}/wiki/api/v2/pages/{page_id}"

headers = {}
self._add_auth_header(headers, access_token)
params = {"body-format": self.PAGE_BODY_FORMAT}

async with self.session.get(
get_page_by_id_url,
headers={"Authorization": f"Basic {credentials_encoded}"},
headers=headers,
params=params,
) as response:
if not response.ok:
Expand All @@ -97,7 +112,8 @@ async def _get_page(self, page_id):

content = await response.json()

page_url = f"{self.base_url}/wiki{content['_links']['webui']}"
base_url = self._get_base_url(access_token)
page_url = f"{base_url}/wiki{content['_links']['webui']}"

serialized_page = {
"title": content["title"],
Expand All @@ -109,8 +125,9 @@ async def _get_page(self, page_id):
self._cache_put(page_id, serialized_page)
return self._cache_get(page_id)

def search_pages(self, query):
search_url = f"{self.base_url}/wiki/rest/api/content/search"
def search_pages(self, query, access_token=None):
base_url = self._get_base_url(access_token)
search_url = f"{base_url}/wiki/rest/api/content/search"

# Substitutes any sequence of non-alphanumeric or whitespace characters with a whitespace
formatted_query = re.sub("\W+", " ", query)
Expand All @@ -120,9 +137,12 @@ def search_pages(self, query):
"limit": self.search_limit,
}

headers = {}
self._add_auth_header(headers, access_token)

response = requests.get(
search_url,
auth=(self.user, self.api_token),
headers=headers,
params=params,
)

Expand All @@ -133,30 +153,68 @@ def search_pages(self, query):

return response.json().get("results", [])

def fetch_pages(self, pages):
def fetch_pages(self, pages, access_token):
self._start_session()
results = self.loop.run_until_complete(self._gather(pages))
results = self.loop.run_until_complete(self._gather(pages, access_token))
self._close_session_and_loop()

return results

def search(self, query):
pages = self.search_pages(query)
def search(self, query, access_token=None):
pages = self.search_pages(query, access_token)

return [
page for page in self.fetch_pages(pages, access_token) if page is not None
]

def _add_auth_header(self, headers, access_token):
if access_token:
headers["Authorization"] = f"Bearer {access_token}"
else:
credentials = f"{self.service_user}:{self.service_api_token}"
credentials_encoded = base64.b64encode(credentials.encode()).decode("ascii")
headers["Authorization"] = f"Basic {credentials_encoded}"

def _get_base_url(self, access_token=None):
if not access_token:
return self.service_base_url

if access_token in self.org_ids:
return (
f"https://api.atlassian.com/ex/confluence/{self.org_ids[access_token]}"
)

headers = {}
self._add_auth_header(headers, access_token)

response = requests.get(
"https://api.atlassian.com/oauth/token/accessible-resources",
headers=headers,
)

if response.status_code != 200:
logger.error("Error determining Confluence base URL")
return

accessible_resources = response.json()

if not accessible_resources:
logger.error("No resources available to user")
return

org_id = accessible_resources[0]["id"]
self.org_ids[access_token] = org_id

return [page for page in self.fetch_pages(pages) if page is not None]
return f"https://api.atlassian.com/ex/confluence/{org_id}"


def get_client():
global client
if client is None:
assert (
url := app.config.get("PRODUCT_URL")
), "CONFLUENCE_PRODUCT_URL must be set"
assert (user := app.config.get("USER")), "CONFLUENCE_USER must be set"
assert (
api_token := app.config.get("API_TOKEN")
), "CONFLUENCE_API_TOKEN must be set"
product_url = app.config.get("PRODUCT_URL")
user = app.config.get("USER")
api_token = app.config.get("API_TOKEN")
search_limit = app.config.get("SEARCH_LIMIT", 10)
client = ConfluenceClient(url, user, api_token, search_limit)
client = ConfluenceClient(product_url, user, api_token, search_limit)

return client
4 changes: 2 additions & 2 deletions confluence/provider/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
logger = logging.getLogger(__name__)


def search(query):
def search(query, access_token):
client = get_client()
pages = client.search(query)
pages = client.search(query, access_token)
return pages

0 comments on commit da38f1e

Please sign in to comment.