-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add oauth to Sharepoint, simplify logic, remove listItems from search
- Loading branch information
1 parent
56a480c
commit 480e088
Showing
5 changed files
with
88 additions
and
143 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,141 +1,126 @@ | ||
from functools import lru_cache | ||
from azure.identity import ClientSecretCredential | ||
from flask import current_app as app | ||
from msgraph.core import GraphClient, APIVersion | ||
from urllib.parse import urlparse | ||
import requests | ||
|
||
from msal import ConfidentialClientApplication | ||
from flask import current_app as app, request | ||
|
||
from . import UpstreamProviderError | ||
from .consts import CACHE_SIZE | ||
|
||
client = None | ||
AUTHORIZATION_HEADER = "Authorization" | ||
BEARER_PREFIX = "Bearer " | ||
|
||
|
||
class SharepointClient: | ||
DEFAULT_SCOPES = ["https://graph.microsoft.com/.default"] | ||
DEFAULT_REGION = "NAM" | ||
SEARCH_ENTITY_TYPES = ["driveItem", "listItem"] | ||
SEARCH_URL = "/search/query" | ||
SEARCH_LIMIT = 3 | ||
BASE_URL = "https://graph.microsoft.com/v1.0" | ||
SEARCH_ENTITY_TYPES = ["driveItem"] | ||
APPLICATION_AUTH = "application" | ||
DELEGATED_AUTH = "user" | ||
|
||
def __init__(self, auth_type, search_limit): | ||
self.access_token = None | ||
self.user = None | ||
self.auth_type = auth_type | ||
self.search_limit = search_limit | ||
|
||
graph_client = None | ||
def get_auth_type(self): | ||
return self.auth_type | ||
|
||
def __init__(self, tenant_id, client_id, client_secret, search_limit=5): | ||
def set_app_access_token(self, tenant_id, client_id, client_secret): | ||
try: | ||
credential = ClientSecretCredential( | ||
tenant_id, | ||
client_id, | ||
client_secret, | ||
credential = ConfidentialClientApplication( | ||
client_id=client_id, | ||
client_credential=client_secret, | ||
authority=f"https://login.microsoftonline.com/{tenant_id}", | ||
) | ||
|
||
self.graph_client = GraphClient( | ||
credential=credential, | ||
token_response = credential.acquire_token_for_client( | ||
scopes=self.DEFAULT_SCOPES, | ||
api_version=APIVersion.beta, | ||
) | ||
if "access_token" not in token_response: | ||
raise UpstreamProviderError( | ||
"Error while retrieving access token from Microsoft Graph API" | ||
) | ||
self.access_token = token_response["access_token"] | ||
except Exception as e: | ||
raise UpstreamProviderError( | ||
f"Error while initializing Sharepoint client: {str(e)}" | ||
f"Error while initializing Teams client: {str(e)}" | ||
) | ||
|
||
self.search_limit = search_limit | ||
def set_user_access_token(self, token): | ||
self.access_token = token | ||
self.headers = {"Authorization": f"Bearer {self.access_token}"} | ||
|
||
@lru_cache(CACHE_SIZE) | ||
def search(self, query): | ||
search_response = self.graph_client.post( | ||
self.SEARCH_URL, | ||
response = requests.post( | ||
f"{self.BASE_URL}/search/query", | ||
headers={"Authorization": f"Bearer {self.access_token}"}, | ||
json={ | ||
"requests": [ | ||
{ | ||
"entityTypes": self.SEARCH_ENTITY_TYPES, | ||
"region": self.DEFAULT_REGION, | ||
"query": { | ||
"queryString": query, | ||
"size": self.SEARCH_LIMIT, | ||
"size": self.search_limit, | ||
}, | ||
"region": self.DEFAULT_REGION, | ||
} | ||
] | ||
}, | ||
) | ||
|
||
if not search_response.ok: | ||
message = ( | ||
search_response.json() | ||
.get("error", {}) | ||
.get("message", "Error calling Microsoft Graph API") | ||
) | ||
raise UpstreamProviderError(message) | ||
|
||
return search_response.json()["value"][0]["hitsContainers"] | ||
|
||
@lru_cache(CACHE_SIZE) | ||
def get_pages(self, site_id): | ||
page_url = f"/sites/{site_id}/pages" | ||
response = self.graph_client.get(page_url) | ||
|
||
if not response.ok: | ||
return [] | ||
|
||
return response.json() | ||
|
||
@lru_cache(CACHE_SIZE) | ||
def fetch_page(self, url): | ||
parsed_url = urlparse(url) | ||
site_id = parsed_url.netloc | ||
pages = self.get_pages(site_id) | ||
|
||
# Find page by path | ||
matching_page = None | ||
for page in pages["value"]: | ||
normalized_page_path = f"/{page['webUrl']}" | ||
if normalized_page_path == parsed_url.path: | ||
matching_page = page | ||
break | ||
|
||
return matching_page | ||
|
||
@lru_cache(CACHE_SIZE) | ||
def get_drive_item(self, parent_drive_id, resource_id): | ||
drive_item_url = f"/drives/{parent_drive_id}/items/{resource_id}/content" | ||
|
||
get_response = self.graph_client.get(drive_item_url) | ||
|
||
# Fail gracefully when retrieving content | ||
if not get_response.ok: | ||
return {} | ||
raise UpstreamProviderError( | ||
f"Error while searching Sharepoint: {response.text}" | ||
) | ||
|
||
return get_response.content | ||
return response.json()["value"][0]["hitsContainers"] | ||
|
||
@lru_cache(CACHE_SIZE) | ||
def get_list_item(self, site_id, page_id): | ||
list_item_url = ( | ||
f"/sites/{site_id}/pages/{page_id}/microsoft.graph.sitePage/webParts" | ||
def get_drive_item_content(self, parent_drive_id, resource_id): | ||
response = requests.get( | ||
f"{self.BASE_URL}/drives/{parent_drive_id}/items/{resource_id}/content", | ||
headers={"Authorization": f"Bearer {self.access_token}"}, | ||
) | ||
get_response = self.graph_client.get(list_item_url) | ||
|
||
# Fail gracefully when retrieving content | ||
if not get_response.ok: | ||
if not response.ok: | ||
return {} | ||
|
||
return get_response.json() | ||
return response.content | ||
|
||
|
||
def get_client(): | ||
global client | ||
if client is not None: | ||
return client | ||
|
||
# Fetch environment variables | ||
assert ( | ||
tenant_id := app.config.get("TENANT_ID") | ||
), "SHAREPOINT_TENANT_ID must be set" | ||
assert ( | ||
client_id := app.config.get("CLIENT_ID") | ||
), "SHAREPOINT_CLIENT_ID must be set" | ||
assert ( | ||
client_secret := app.config.get("CLIENT_SECRET") | ||
), "SHAREPOINT_CLIENT_SECRET must be set" | ||
search_limit = app.config.get("SEARCH_LIMIT", 5) | ||
auth_type := app.config.get("AUTH_TYPE") | ||
), "SHAREPOINT_AUTH_TYPE must be set" | ||
|
||
client = SharepointClient(tenant_id, client_id, client_secret, search_limit) | ||
search_limit = app.config.get("SEARCH_LIMIT", 5) | ||
client = SharepointClient(auth_type, search_limit) | ||
|
||
if auth_type == client.APPLICATION_AUTH: | ||
assert ( | ||
tenant_id := app.config.get("TENANT_ID") | ||
), "SHAREPOINT_TENANT_ID must be set" | ||
assert ( | ||
client_id := app.config.get("CLIENT_ID") | ||
), "SHAREPOINT_CLIENT_ID must be set" | ||
assert ( | ||
client_secret := app.config.get("CLIENT_SECRET") | ||
), "SHAREPOINT_CLIENT_SECRET must be set" | ||
client.set_app_access_token(tenant_id, client_id, client_secret) | ||
elif auth_type == client.DELEGATED_AUTH: | ||
token = get_access_token() | ||
if token is None: | ||
raise UpstreamProviderError("No access token provided in request") | ||
client.set_user_access_token(token) | ||
else: | ||
raise UpstreamProviderError(f"Invalid auth type: {auth_type}") | ||
|
||
return client | ||
|
||
|
||
def get_access_token(): | ||
authorization_header = request.headers.get(AUTHORIZATION_HEADER, "") | ||
if authorization_header.startswith(BEARER_PREFIX): | ||
return authorization_header.removeprefix(BEARER_PREFIX) | ||
return None |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters