Skip to content

Commit

Permalink
Merge pull request #34754 from dimagi/jc/limit-batch-size-for-entra
Browse files Browse the repository at this point in the history
Limit batch size for entra api
  • Loading branch information
jingcheng16 authored Jun 12, 2024
2 parents adcaf40 + 4be01c5 commit 1941d57
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 51 deletions.
6 changes: 3 additions & 3 deletions corehq/apps/sso/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from corehq.apps.accounting.models import BillingAccount, Subscription
from corehq.apps.sso import certificates
from corehq.apps.sso.exceptions import ServiceProviderCertificateError
from corehq.apps.sso.utils.entra import get_all_members_of_the_idp_from_entra
from corehq.apps.sso.utils.entra import get_all_usernames_of_the_idp_from_entra
from corehq.apps.sso.utils.user_helpers import get_email_domain_from_username
from corehq.util.quickcache import quickcache

Expand Down Expand Up @@ -439,9 +439,9 @@ def get_required_identity_provider(cls, username):
return idp
return None

def get_all_members_of_the_idp(self):
def get_all_usernames_of_the_idp(self):
if self.idp_type == IdentityProviderType.ENTRA_ID:
return get_all_members_of_the_idp_from_entra(self)
return get_all_usernames_of_the_idp_from_entra(self)
else:
raise NotImplementedError("Not implemented")

Expand Down
6 changes: 3 additions & 3 deletions corehq/apps/sso/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def auto_deactivate_removed_sso_users():
idp_type=IdentityProviderType.ENTRA_ID
).all():
try:
idp_users = idp.get_all_members_of_the_idp()
usernames_in_idp = idp.get_all_usernames_of_the_idp()
except EntraVerificationFailed as e:
notify_exception(None, f"Failed to get members of the IdP. {str(e)}")
send_deactivation_skipped_email(idp=idp, failure_code=MSGraphIssue.VERIFICATION_ERROR,
Expand All @@ -152,7 +152,7 @@ def auto_deactivate_removed_sso_users():
continue

# if the Graph Users API returns an empty list of users we will skip auto deactivation
if len(idp_users) == 0:
if len(usernames_in_idp) == 0:
send_deactivation_skipped_email(idp=idp, failure_code=MSGraphIssue.EMPTY_ERROR)
continue

Expand All @@ -167,7 +167,7 @@ def auto_deactivate_removed_sso_users():
authenticated_email_domains = authenticated_domains.values_list('email_domain', flat=True)

for username in usernames_in_account:
if username not in idp_users and username not in exempt_usernames:
if username not in usernames_in_idp and username not in exempt_usernames:
email_domain = get_email_domain_from_username(username)
if email_domain in authenticated_email_domains:
usernames_to_deactivate.append(username)
Expand Down
12 changes: 6 additions & 6 deletions corehq/apps/sso/tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,8 @@ def setUpClass(cls):
email_domain='vaultwax.com',
identity_provider=cls.idp,
)
idp_patcher = patch('corehq.apps.sso.models.IdentityProvider.get_all_members_of_the_idp')
cls.mock_get_all_members_of_the_idp = idp_patcher.start()
idp_patcher = patch('corehq.apps.sso.models.IdentityProvider.get_all_usernames_of_the_idp')
cls.mock_get_all_usernames_of_the_idp = idp_patcher.start()
cls.addClassCleanup(idp_patcher.stop)

def setUp(self):
Expand All @@ -345,7 +345,7 @@ def setUp(self):

def test_user_is_deactivated_if_not_member_of_idp(self):
self.assertTrue(self.web_user_c.is_active)
self.mock_get_all_members_of_the_idp.return_value = [self.web_user_a.username, self.web_user_b.username]
self.mock_get_all_usernames_of_the_idp.return_value = [self.web_user_a.username, self.web_user_b.username]

auto_deactivate_removed_sso_users()

Expand All @@ -359,7 +359,7 @@ def test_sso_exempt_users_are_not_deactivated(self):
username=sso_exempt.username,
email_domain=self.email_domain,
)
self.mock_get_all_members_of_the_idp.return_value = [self.web_user_a.username, self.web_user_b.username]
self.mock_get_all_usernames_of_the_idp.return_value = [self.web_user_a.username, self.web_user_b.username]

auto_deactivate_removed_sso_users()

Expand All @@ -369,7 +369,7 @@ def test_sso_exempt_users_are_not_deactivated(self):

@patch('corehq.apps.sso.tasks.send_html_email_async.delay')
def test_deactivation_skipped_if_entra_return_empty_sso_user(self, mock_send):
self.mock_get_all_members_of_the_idp.return_value = []
self.mock_get_all_usernames_of_the_idp.return_value = []

auto_deactivate_removed_sso_users()

Expand All @@ -384,7 +384,7 @@ def test_deactivation_skipped_if_entra_return_empty_sso_user(self, mock_send):

def test_deactivation_skip_members_of_the_domains_but_not_have_an_email_domain_controlled_by_the_idp(self):
dimagi_user = self._create_web_user('[email protected]')
self.mock_get_all_members_of_the_idp.return_value = [self.web_user_a.username, self.web_user_b.username]
self.mock_get_all_usernames_of_the_idp.return_value = [self.web_user_a.username, self.web_user_b.username]

auto_deactivate_removed_sso_users()

Expand Down
93 changes: 54 additions & 39 deletions corehq/apps/sso/utils/entra.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,32 +18,33 @@ class MSOdataType:


ENDPOINT_BASE_URL = "https://graph.microsoft.com/v1.0"
MS_BATCH_LIMIT = 20


def get_all_members_of_the_idp_from_entra(idp):
def get_all_usernames_of_the_idp_from_entra(idp):
import msal
config = configure_idp(idp)
config = _configure_idp(idp)

# Create a preferably long-lived app instance which maintains a token cache.
app = msal.ConfidentialClientApplication(
config["client_id"], authority=config["authority"],
client_credential=config["secret"],
)

token = get_access_token(app, config)
token = _get_access_token(app, config)

# microsoft.graph.appRoleAssignment's property doesn't have userPrincipalName
user_principal_ids = get_all_user_ids_in_app(token, config["client_id"])
user_principal_ids = _get_all_user_ids_in_app(token, config["client_id"])

if len(user_principal_ids) == 0:
return []

user_principal_names = get_user_principal_names(user_principal_ids, token)
user_principal_names = _get_user_principal_names(user_principal_ids, token)

return user_principal_names


def configure_idp(idp):
def _configure_idp(idp):
authority_base_url = "https://login.microsoftonline.com/"
authority = f"{authority_base_url}{idp.api_host}"

Expand All @@ -55,39 +56,53 @@ def configure_idp(idp):
}


def get_user_principal_names(user_ids, token):
# Prepare batch request
batch_payload = {
"requests": [
{
"id": str(i),
"method": "GET",
"url": f"/users/{principal_id}?$select=userPrincipalName"
} for i, principal_id in enumerate(user_ids)
]
}
# Send batch request
batch_response = requests.post(
f'{ENDPOINT_BASE_URL}/$batch',
headers={'Authorization': 'Bearer ' + token, 'Content-Type': 'application/json'},
data=json.dumps(batch_payload)
)
batch_response.raise_for_status()
batch_result = batch_response.json()

for resp in batch_result['responses']:
if 'body' in resp and 'error' in resp['body']:
raise EntraVerificationFailed(resp['body']['error']['code'], resp['body']['message'])

# Extract userPrincipalName from batch response
user_principal_names = [
resp['body']['userPrincipalName'] for resp in batch_result['responses']
if 'body' in resp and 'userPrincipalName' in resp['body']
]
def _get_user_principal_names(user_ids, token):
# Convert set to list to make it subscriptable
user_ids = list(user_ids)
#JSON batch requests are currently limited to 20 individual requests.
user_id_chunks = [user_ids[i:i + MS_BATCH_LIMIT] for i in range(0, len(user_ids), MS_BATCH_LIMIT)]

user_principal_names = []

for chunk in user_id_chunks:
batch_payload = {
"requests": [
{
"id": str(i),
"method": "GET",
"url": f"/users/{principal_id}?$select=userPrincipalName"
} for i, principal_id in enumerate(chunk)
]
}

# Send batch request
batch_response = requests.post(
f'{ENDPOINT_BASE_URL}/$batch',
headers={'Authorization': 'Bearer ' + token, 'Content-Type': 'application/json'},
data=json.dumps(batch_payload)
)

try:
batch_response.raise_for_status()
except requests.exceptions.HTTPError as e:
# Append the response body to the HTTPError message
error_message = f"{e.response.status_code} {e.response.reason} - {batch_response.text}"
raise requests.exceptions.HTTPError(error_message, response=e.response)

batch_result = batch_response.json()
for resp in batch_result['responses']:
if 'body' in resp and 'error' in resp['body']:
raise EntraVerificationFailed(resp['body']['error']['code'], resp['body']['message'])

# Extract userPrincipalName from batch response
for resp in batch_result['responses']:
if 'body' in resp and 'userPrincipalName' in resp['body']:
user_principal_names.append(resp['body']['userPrincipalName'])

return user_principal_names


def get_access_token(app, config):
def _get_access_token(app, config):
# looks up a token from cache
result = app.acquire_token_silent(config["scope"], account=None)
if not result:
Expand All @@ -98,7 +113,7 @@ def get_access_token(app, config):
return result.get("access_token")


def get_all_user_ids_in_app(token, app_id):
def _get_all_user_ids_in_app(token, app_id):
endpoint = (f"{ENDPOINT_BASE_URL}/servicePrincipals(appId='{app_id}')/"
f"appRoleAssignedTo?$select=principalId, principalType")
# Calling graph using the access token
Expand All @@ -123,7 +138,7 @@ def get_all_user_ids_in_app(token, app_id):
"Please include only Users or Groups as members of this SSO application")

for group_id in group_queue:
members_data = get_group_members(group_id, token)
members_data = _get_group_members(group_id, token)
for member in members_data.get("value", []):
# Only direct user in the group will have access to the application
# Nested group won't have access to the application
Expand All @@ -133,7 +148,7 @@ def get_all_user_ids_in_app(token, app_id):
return user_ids


def get_group_members(group_id, token):
def _get_group_members(group_id, token):
endpoint = f"{ENDPOINT_BASE_URL}/groups/{group_id}/members?$select=id"
headers = {'Authorization': 'Bearer ' + token}
response = requests.get(endpoint, headers=headers)
Expand Down

0 comments on commit 1941d57

Please sign in to comment.