Skip to content

Commit

Permalink
Add oauth to Sharepoint, simplify logic, remove listItems from search
Browse files Browse the repository at this point in the history
  • Loading branch information
tianjing-li committed Dec 11, 2023
1 parent 56a480c commit 480e088
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 143 deletions.
171 changes: 78 additions & 93 deletions sharepoint/provider/client.py
Original file line number Diff line number Diff line change
@@ -1,141 +1,126 @@
from functools import lru_cache
from azure.identity import ClientSecretCredential
from flask import current_app as app
from msgraph.core import GraphClient, APIVersion
from urllib.parse import urlparse
import requests

from msal import ConfidentialClientApplication
from flask import current_app as app, request

from . import UpstreamProviderError
from .consts import CACHE_SIZE

client = None
AUTHORIZATION_HEADER = "Authorization"
BEARER_PREFIX = "Bearer "


class SharepointClient:
DEFAULT_SCOPES = ["https://graph.microsoft.com/.default"]
DEFAULT_REGION = "NAM"
SEARCH_ENTITY_TYPES = ["driveItem", "listItem"]
SEARCH_URL = "/search/query"
SEARCH_LIMIT = 3
BASE_URL = "https://graph.microsoft.com/v1.0"
SEARCH_ENTITY_TYPES = ["driveItem"]
APPLICATION_AUTH = "application"
DELEGATED_AUTH = "user"

def __init__(self, auth_type, search_limit):
self.access_token = None
self.user = None
self.auth_type = auth_type
self.search_limit = search_limit

graph_client = None
def get_auth_type(self):
return self.auth_type

def __init__(self, tenant_id, client_id, client_secret, search_limit=5):
def set_app_access_token(self, tenant_id, client_id, client_secret):
try:
credential = ClientSecretCredential(
tenant_id,
client_id,
client_secret,
credential = ConfidentialClientApplication(
client_id=client_id,
client_credential=client_secret,
authority=f"https://login.microsoftonline.com/{tenant_id}",
)

self.graph_client = GraphClient(
credential=credential,
token_response = credential.acquire_token_for_client(
scopes=self.DEFAULT_SCOPES,
api_version=APIVersion.beta,
)
if "access_token" not in token_response:
raise UpstreamProviderError(
"Error while retrieving access token from Microsoft Graph API"
)
self.access_token = token_response["access_token"]
except Exception as e:
raise UpstreamProviderError(
f"Error while initializing Sharepoint client: {str(e)}"
f"Error while initializing Teams client: {str(e)}"
)

self.search_limit = search_limit
def set_user_access_token(self, token):
self.access_token = token
self.headers = {"Authorization": f"Bearer {self.access_token}"}

@lru_cache(CACHE_SIZE)
def search(self, query):
search_response = self.graph_client.post(
self.SEARCH_URL,
response = requests.post(
f"{self.BASE_URL}/search/query",
headers={"Authorization": f"Bearer {self.access_token}"},
json={
"requests": [
{
"entityTypes": self.SEARCH_ENTITY_TYPES,
"region": self.DEFAULT_REGION,
"query": {
"queryString": query,
"size": self.SEARCH_LIMIT,
"size": self.search_limit,
},
"region": self.DEFAULT_REGION,
}
]
},
)

if not search_response.ok:
message = (
search_response.json()
.get("error", {})
.get("message", "Error calling Microsoft Graph API")
)
raise UpstreamProviderError(message)

return search_response.json()["value"][0]["hitsContainers"]

@lru_cache(CACHE_SIZE)
def get_pages(self, site_id):
page_url = f"/sites/{site_id}/pages"
response = self.graph_client.get(page_url)

if not response.ok:
return []

return response.json()

@lru_cache(CACHE_SIZE)
def fetch_page(self, url):
parsed_url = urlparse(url)
site_id = parsed_url.netloc
pages = self.get_pages(site_id)

# Find page by path
matching_page = None
for page in pages["value"]:
normalized_page_path = f"/{page['webUrl']}"
if normalized_page_path == parsed_url.path:
matching_page = page
break

return matching_page

@lru_cache(CACHE_SIZE)
def get_drive_item(self, parent_drive_id, resource_id):
drive_item_url = f"/drives/{parent_drive_id}/items/{resource_id}/content"

get_response = self.graph_client.get(drive_item_url)

# Fail gracefully when retrieving content
if not get_response.ok:
return {}
raise UpstreamProviderError(
f"Error while searching Sharepoint: {response.text}"
)

return get_response.content
return response.json()["value"][0]["hitsContainers"]

@lru_cache(CACHE_SIZE)
def get_list_item(self, site_id, page_id):
list_item_url = (
f"/sites/{site_id}/pages/{page_id}/microsoft.graph.sitePage/webParts"
def get_drive_item_content(self, parent_drive_id, resource_id):
response = requests.get(
f"{self.BASE_URL}/drives/{parent_drive_id}/items/{resource_id}/content",
headers={"Authorization": f"Bearer {self.access_token}"},
)
get_response = self.graph_client.get(list_item_url)

# Fail gracefully when retrieving content
if not get_response.ok:
if not response.ok:
return {}

return get_response.json()
return response.content


def get_client():
global client
if client is not None:
return client

# Fetch environment variables
assert (
tenant_id := app.config.get("TENANT_ID")
), "SHAREPOINT_TENANT_ID must be set"
assert (
client_id := app.config.get("CLIENT_ID")
), "SHAREPOINT_CLIENT_ID must be set"
assert (
client_secret := app.config.get("CLIENT_SECRET")
), "SHAREPOINT_CLIENT_SECRET must be set"
search_limit = app.config.get("SEARCH_LIMIT", 5)
auth_type := app.config.get("AUTH_TYPE")
), "SHAREPOINT_AUTH_TYPE must be set"

client = SharepointClient(tenant_id, client_id, client_secret, search_limit)
search_limit = app.config.get("SEARCH_LIMIT", 5)
client = SharepointClient(auth_type, search_limit)

if auth_type == client.APPLICATION_AUTH:
assert (
tenant_id := app.config.get("TENANT_ID")
), "SHAREPOINT_TENANT_ID must be set"
assert (
client_id := app.config.get("CLIENT_ID")
), "SHAREPOINT_CLIENT_ID must be set"
assert (
client_secret := app.config.get("CLIENT_SECRET")
), "SHAREPOINT_CLIENT_SECRET must be set"
client.set_app_access_token(tenant_id, client_id, client_secret)
elif auth_type == client.DELEGATED_AUTH:
token = get_access_token()
if token is None:
raise UpstreamProviderError("No access token provided in request")
client.set_user_access_token(token)
else:
raise UpstreamProviderError(f"Invalid auth type: {auth_type}")

return client


def get_access_token():
authorization_header = request.headers.get(AUTHORIZATION_HEADER, "")
if authorization_header.startswith(BEARER_PREFIX):
return authorization_header.removeprefix(BEARER_PREFIX)
return None
1 change: 0 additions & 1 deletion sharepoint/provider/consts.py

This file was deleted.

6 changes: 0 additions & 6 deletions sharepoint/provider/enums.py

This file was deleted.

49 changes: 8 additions & 41 deletions sharepoint/provider/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from .client import get_client
from .unstructured import get_unstructured_client
from .enums import MicrosoftDataType


load_dotenv()
Expand All @@ -23,7 +22,7 @@ def search(query):
for hit_container in search_response:
hits.extend(hit_container.get("hits", []))

drive_items, list_items = collect_items(sharepoint_client, hits)
drive_items = collect_items(sharepoint_client, hits)

# Build and request async Unstructured calls
files = [
Expand Down Expand Up @@ -51,48 +50,27 @@ def search(query):
if serialized_drive_item:
results.append(serialized_drive_item)

for page, list_item in list_items:
serialized_list_item = serialize_list_item(page, list_item)

if serialized_list_item is not None:
results.append(serialized_list_item)

return results


def collect_items(sharepoint_client, hits):
# Gather data
drive_items = []
list_items = []
for hit in hits:
if hit["resource"]["@odata.type"] == MicrosoftDataType.DRIVE_ITEM.value:
if hit["resource"]["@odata.type"] == "#microsoft.graph.driveItem":
parent_drive_id = hit["resource"]["parentReference"]["driveId"]
resource_id = hit["resource"]["id"]
drive_item = sharepoint_client.get_drive_item(parent_drive_id, resource_id)
drive_item = sharepoint_client.get_drive_item_content(
parent_drive_id, resource_id
)

if drive_item:
drive_items.append((hit, drive_item))

elif hit["resource"]["@odata.type"] == MicrosoftDataType.LIST_ITEM.value:
if hit.get("resource", {}).get("parentReference", {}).get("siteId") is None:
continue

site_ids = hit["resource"]["parentReference"]["siteId"]
site_id = site_ids.split(",")[0]
page = sharepoint_client.fetch_page(hit["resource"]["webUrl"])

if not page:
continue

list_item = sharepoint_client.get_list_item(site_id, page["id"])

if list_item:
list_items.append((page, list_item))
return drive_items

return drive_items, list_items


def serialize_resource(resource):
def serialize_metadata(resource):
data = {}

# Only return primitive types, Coral cannot parse arrays/sub-dictionaries
Expand Down Expand Up @@ -136,20 +114,9 @@ def serialize_drive_item(hit, item, content):

data = {}
if (resource := hit.get("resource")) is not None:
data = serialize_resource(resource)
data = serialize_metadata(resource)

if text is not None:
data["text"] = text

return data


def serialize_list_item(page, item):
html_text = ""

for html in item["value"]:
html_text += html["innerHtml"]

data = serialize_resource(page)

return {**data, "text": html_text}
4 changes: 2 additions & 2 deletions sharepoint/provider/unstructured.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from collections import OrderedDict
from flask import current_app as app

from .consts import CACHE_SIZE

logger = logging.getLogger(__name__)

CACHE_SIZE = 256

unstructured = None


Expand Down

0 comments on commit 480e088

Please sign in to comment.