Skip to content

Commit

Permalink
fix: bug with service user id vs lms user id
Browse files Browse the repository at this point in the history
- Fixed bug for successful JWTs where the JWT user id was still
using the service user id, rather than the LMS user id, so comparison
against the LMS user id would fail.
- As part of the bug fix, the custom attribute
``failed_jwt_cookie_user_id`` was renamed to
``jwt_cookie_lms_user_id``, and will be set for all JWT cookies.
Since this is only a breaking change for recently added monitoring,
this won't be versioned as a breaking change.

This is part of:
edx/edx-arch-experiments#429
  • Loading branch information
robrap committed Dec 5, 2023
1 parent d9aebba commit fab0a68
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 56 deletions.
13 changes: 13 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,19 @@ Change Log
Unreleased
----------

[9.0.1] - 2023-12-06
--------------------

Fixed
~~~~~

* Fixed bug for successful JWTs where the JWT user id was still using the service user id, rather than the LMS user id, so comparison against the LMS user id would fail.

Updated
~~~~~~~

* As part of the bug fix, the custom attribute ``failed_jwt_cookie_user_id`` was renamed to ``jwt_cookie_lms_user_id``, and will be set for all JWT cookies. Since this is only a breaking change for recently added monitoring, this won't be versioned as a breaking change.

[9.0.0] - 2023-11-27
--------------------

Expand Down
2 changes: 1 addition & 1 deletion edx_rest_framework_extensions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
""" edx Django REST Framework extensions. """

__version__ = '9.0.0' # pragma: no cover
__version__ = '9.0.1' # pragma: no cover
69 changes: 29 additions & 40 deletions edx_rest_framework_extensions/auth/jwt/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,13 @@ def authenticate(self, request):
# CSRF passed validation with authenticated user

# adds additional monitoring for mismatches; and raises errors in certain cases
self._monitor_or_enforce_successful_jwt_cookie_and_session_user_mismatch(
request, jwt_user_id=user_and_auth[0].id
)
is_mismatch = self._is_jwt_cookie_and_session_user_mismatch(request)
if is_mismatch and get_setting(ENABLE_SET_REQUEST_USER_FOR_JWT_COOKIE):
raise JwtSessionUserMismatchError(
'Failing otherwise successful JWT authentication due to session user mismatch '
'with set request user.'
)

set_custom_attribute('jwt_auth_result', 'success-cookie')
return user_and_auth

Expand All @@ -149,8 +153,8 @@ def authenticate(self, request):
set_custom_attribute('jwt_auth_failed', 'Exception:{}'.format(repr(exception_to_report)))

if is_authenticating_with_jwt_cookie:
# This check also adds monitoring details for all failed JWT cookies
is_user_mismatch = self._is_failed_jwt_cookie_and_session_user_mismatch(request)
# This check also adds monitoring details
is_user_mismatch = self._is_jwt_cookie_and_session_user_mismatch(request)
if is_forgiving_jwt_cookies_enabled:
if is_user_mismatch:
set_custom_attribute('jwt_auth_result', 'user-mismatch-failure')
Expand Down Expand Up @@ -257,64 +261,49 @@ def is_authenticating_with_jwt_cookie(cls, request):
except Exception: # pylint: disable=broad-exception-caught
return False

def _is_failed_jwt_cookie_and_session_user_mismatch(self, request):
def _get_and_monitor_jwt_cookie_lms_user_id(self, request):
"""
Returns True if failed JWT cookie and session user do not match, False otherwise.
Returns the LMS user id from the JWT cookie, or None if not found
Notes:
- Must only be called in the case of a JWT cookie failure.
- Must only be called in the case of a JWT cookie.
- Also provides monitoring details for mismatches.
"""
try:
cookie_token = JSONWebTokenAuthentication.get_token_from_cookies(request.COOKIES)
invalid_decoded_jwt = unsafe_jwt_decode_handler(cookie_token)
jwt_user_id = invalid_decoded_jwt.get('user_id', None)
jwt_user_id_attribute_value = jwt_user_id if jwt_user_id else 'not-found' # pragma: no cover
jwt_lms_user_id = invalid_decoded_jwt.get('user_id', None)
jwt_lms_user_id_attribute_value = jwt_lms_user_id if jwt_lms_user_id else 'not-found' # pragma: no cover
except Exception: # pylint: disable=broad-exception-caught
jwt_user_id = None
jwt_user_id_attribute_value = 'decode-error'
jwt_lms_user_id = None
jwt_lms_user_id_attribute_value = 'decode-error'

# .. custom_attribute_name: failed_jwt_cookie_user_id
# .. custom_attribute_description: The user_id pulled from the failed
# .. custom_attribute_name: jwt_cookie_lms_user_id
# .. custom_attribute_description: The LMS user_id pulled from the
# JWT cookie. If the user_id claim is not found in the JWT, the attribute
# value will be 'not-found'. If the failed JWT simply can't be decoded,
# the attribute value will be 'decode-error'. Note: for successful JWTs,
# the user id will already be available in `enduser.id` or `request_user_id`.
set_custom_attribute('failed_jwt_cookie_user_id', jwt_user_id_attribute_value)
# value will be 'not-found'. If the JWT simply can't be decoded,
# the attribute value will be 'decode-error'. Note that the id will be
# set in the case of expired JWTs, or other failures that can still be
# decoded.
set_custom_attribute('jwt_cookie_lms_user_id', jwt_lms_user_id_attribute_value)

return self._is_jwt_cookie_and_session_user_mismatch(request, jwt_user_id)
return jwt_lms_user_id

def _monitor_or_enforce_successful_jwt_cookie_and_session_user_mismatch(self, request, jwt_user_id):
"""
Provides monitoring and possible enforcement when a successful JWT cookie and session user do not match.
Notes:
- Must only be called in the case of a successful JWT cookie.
- Also provides monitoring details for mismatches.
- In the case where ENABLE_SET_REQUEST_USER_FOR_JWT_COOKIE is being used, we trigger a failure for
what otherwise would have been a successful authentication.
"""
is_mismatch = self._is_jwt_cookie_and_session_user_mismatch(request, jwt_user_id)
if is_mismatch and get_setting(ENABLE_SET_REQUEST_USER_FOR_JWT_COOKIE):
# For failed_jwt_cookie_user_id docs, see custom attribute annotations elsewhere
set_custom_attribute('failed_jwt_cookie_user_id', jwt_user_id)
raise JwtSessionUserMismatchError(
'Failing otherwise successful JWT authentication due to session user mismatch with set request user.'
)

def _is_jwt_cookie_and_session_user_mismatch(self, request, jwt_user_id):
def _is_jwt_cookie_and_session_user_mismatch(self, request):
"""
Returns True if JWT cookie and session user do not match, False otherwise.
Arguments:
request: The request.
jwt_user_id (int): The user_id of the JWT, None if not found.
Other notes:
- If ENABLE_FORGIVING_JWT_COOKIES is toggled off, always return False.
- Also adds monitoring details for mismatches.
- Should only be called for JWT cookies.
"""
# adds early monitoring for the JWT LMS user_id
jwt_lms_user_id = self._get_and_monitor_jwt_cookie_lms_user_id(request)

is_forgiving_jwt_cookies_enabled = get_setting(ENABLE_FORGIVING_JWT_COOKIES)
# This toggle provides a temporary safety valve for rollout.
if not is_forgiving_jwt_cookies_enabled:
Expand Down Expand Up @@ -356,7 +345,7 @@ def _is_jwt_cookie_and_session_user_mismatch(self, request, jwt_user_id):
else:
session_lms_user_id = None

if not session_lms_user_id or session_lms_user_id == jwt_user_id:
if not session_lms_user_id or session_lms_user_id == jwt_lms_user_id:
return False

# .. custom_attribute_name: jwt_auth_mismatch_session_lms_user_id
Expand Down
39 changes: 24 additions & 15 deletions edx_rest_framework_extensions/auth/jwt/tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def test_authenticate_jwt_and_session_mismatch_bad_signature_cookie(self, mock_s
mock_set_custom_attribute.assert_any_call(
'is_forgiving_jwt_cookies_enabled', is_forgiving_jwt_cookies_enabled
)
mock_set_custom_attribute.assert_any_call('failed_jwt_cookie_user_id', jwt_user_id)
mock_set_custom_attribute.assert_any_call('jwt_cookie_lms_user_id', jwt_user_id)
if is_forgiving_jwt_cookies_enabled:
mock_set_custom_attribute.assert_any_call('jwt_auth_result', 'user-mismatch-failure')
mock_set_custom_attribute.assert_any_call('jwt_auth_mismatch_session_lms_user_id', session_lms_user_id)
Expand Down Expand Up @@ -555,7 +555,7 @@ def test_authenticate_jwt_and_session_mismatch_invalid_cookie(self, mock_set_cus
mock_set_custom_attribute.assert_any_call(
'is_forgiving_jwt_cookies_enabled', is_forgiving_jwt_cookies_enabled
)
mock_set_custom_attribute.assert_any_call('failed_jwt_cookie_user_id', 'decode-error')
mock_set_custom_attribute.assert_any_call('jwt_cookie_lms_user_id', 'decode-error')
if is_forgiving_jwt_cookies_enabled:
mock_set_custom_attribute.assert_any_call('jwt_auth_result', 'user-mismatch-failure')
mock_set_custom_attribute.assert_any_call('jwt_auth_mismatch_session_lms_user_id', session_lms_user_id)
Expand Down Expand Up @@ -652,7 +652,8 @@ def test_authenticate_no_lms_user_id_property_and_set_request_user(self, mock_se
- This test is kept with the rest of the JWT vs session user tests.
"""
session_user = factories.UserFactory(id=111)
jwt_user = factories.UserFactory(id=222)
jwt_user_id = 222
jwt_user = factories.UserFactory(id=jwt_user_id)
jwt_header_payload, jwt_signature = self._get_test_jwt_token_payload_and_signature(user=jwt_user)
# Cookie parts will be recombined by JwtAuthCookieMiddleware
self.client.cookies = SimpleCookie({
Expand All @@ -669,9 +670,9 @@ def test_authenticate_no_lms_user_id_property_and_set_request_user(self, mock_se
mock_set_custom_attribute.assert_any_call('is_forgiving_jwt_cookies_enabled', True)
mock_set_custom_attribute.assert_any_call('jwt_auth_result', 'success-cookie')
mock_set_custom_attribute.assert_any_call('jwt_auth_get_lms_user_id_status', 'not-configured')
mock_set_custom_attribute.assert_any_call('jwt_cookie_lms_user_id', jwt_user_id)
set_custom_attribute_keys = [call.args[0] for call in mock_set_custom_attribute.call_args_list]
assert 'jwt_auth_mismatch_session_lms_user_id' not in set_custom_attribute_keys
assert 'failed_jwt_cookie_user_id' not in set_custom_attribute_keys
assert 'jwt_auth_failed' not in set_custom_attribute_keys
mock_logger.error.assert_not_called()

Expand Down Expand Up @@ -700,8 +701,10 @@ def test_authenticate_unknown_user_id_property_and_set_request_user(self, mock_s
- This test is kept with the rest of the JWT vs session user tests.
"""
session_user = factories.UserFactory(id=111)
jwt_user = factories.UserFactory(id=222)
jwt_header_payload, jwt_signature = self._get_test_jwt_token_payload_and_signature(user=jwt_user)
jwt_lms_user_id = 222
jwt_header_payload, jwt_signature = self._get_test_jwt_token_payload_and_signature(
user=session_user, lms_user_id=jwt_lms_user_id
)
# Cookie parts will be recombined by JwtAuthCookieMiddleware
self.client.cookies = SimpleCookie({
jwt_cookie_header_payload_name(): jwt_header_payload,
Expand All @@ -717,9 +720,9 @@ def test_authenticate_unknown_user_id_property_and_set_request_user(self, mock_s
mock_set_custom_attribute.assert_any_call('is_forgiving_jwt_cookies_enabled', True)
mock_set_custom_attribute.assert_any_call('jwt_auth_result', 'success-cookie')
mock_set_custom_attribute.assert_any_call('jwt_auth_get_lms_user_id_status', 'misconfigured')
mock_set_custom_attribute.assert_any_call('jwt_cookie_lms_user_id', jwt_lms_user_id)
set_custom_attribute_keys = [call.args[0] for call in mock_set_custom_attribute.call_args_list]
assert 'jwt_auth_mismatch_session_lms_user_id' not in set_custom_attribute_keys
assert 'failed_jwt_cookie_user_id' not in set_custom_attribute_keys
assert 'jwt_auth_failed' not in set_custom_attribute_keys

# assert for error log for misconfigured VERIFY_LMS_USER_ID_PROPERTY_NAME
Expand Down Expand Up @@ -769,7 +772,7 @@ def test_authenticate_user_id_property_and_set_request_user(self, mock_set_custo
mock_set_custom_attribute.assert_any_call('is_forgiving_jwt_cookies_enabled', True)
mock_set_custom_attribute.assert_any_call('jwt_auth_mismatch_session_lms_user_id', session_lms_user_id)
mock_set_custom_attribute.assert_any_call('jwt_auth_get_lms_user_id_status', 'id-found')
mock_set_custom_attribute.assert_any_call('failed_jwt_cookie_user_id', jwt_user_id)
mock_set_custom_attribute.assert_any_call('jwt_cookie_lms_user_id', jwt_user_id)
mock_set_custom_attribute.assert_any_call('jwt_auth_result', 'user-mismatch-enforced-failure')
mock_set_custom_attribute.assert_any_call('jwt_auth_failed', mock.ANY)

Expand Down Expand Up @@ -803,8 +806,7 @@ def test_authenticate_other_user_property_and_set_request_user(self, mock_set_cu
# In this test, the service's user id matches the JWT LMS user id, which ordinarily would never happen.
# However, for the purpose of this test, we want to ensure that this doesn't prevent the mismatch.
jwt_user_id = session_user_id
jwt_user = session_user
jwt_header_payload, jwt_signature = self._get_test_jwt_token_payload_and_signature(user=jwt_user)
jwt_header_payload, jwt_signature = self._get_test_jwt_token_payload_and_signature(user=session_user)
# Cookie parts will be recombined by JwtAuthCookieMiddleware
self.client.cookies = SimpleCookie({
jwt_cookie_header_payload_name(): jwt_header_payload,
Expand All @@ -821,7 +823,7 @@ def test_authenticate_other_user_property_and_set_request_user(self, mock_set_cu
mock_set_custom_attribute.assert_any_call('is_forgiving_jwt_cookies_enabled', True)
mock_set_custom_attribute.assert_any_call('jwt_auth_mismatch_session_lms_user_id', session_user_lms_id)
mock_set_custom_attribute.assert_any_call('jwt_auth_get_lms_user_id_status', 'id-found')
mock_set_custom_attribute.assert_any_call('failed_jwt_cookie_user_id', jwt_user_id)
mock_set_custom_attribute.assert_any_call('jwt_cookie_lms_user_id', jwt_user_id)
mock_set_custom_attribute.assert_any_call('jwt_auth_result', 'user-mismatch-enforced-failure')
mock_set_custom_attribute.assert_any_call('jwt_auth_failed', mock.ANY)

Expand All @@ -847,7 +849,10 @@ def test_authenticate_jwt_and_no_session_and_set_request_user(self, mock_set_cus
- This test is kept with the rest of the JWT vs session user tests.
"""
test_user = factories.UserFactory()
jwt_header_payload, jwt_signature = self._get_test_jwt_token_payload_and_signature(user=test_user)
jwt_lms_user_id = 222
jwt_header_payload, jwt_signature = self._get_test_jwt_token_payload_and_signature(
user=test_user, lms_user_id=jwt_lms_user_id
)
# Cookie parts will be recombined by JwtAuthCookieMiddleware
self.client.cookies = SimpleCookie({
jwt_cookie_header_payload_name(): jwt_header_payload,
Expand All @@ -860,24 +865,28 @@ def test_authenticate_jwt_and_no_session_and_set_request_user(self, mock_set_cus
# The case where forgiving JWTs is disabled is tested under other tests, including the middleware tests.
mock_set_custom_attribute.assert_any_call('is_forgiving_jwt_cookies_enabled', True)
mock_set_custom_attribute.assert_any_call('skip_jwt_vs_session_check', True)
mock_set_custom_attribute.assert_any_call('jwt_cookie_lms_user_id', jwt_lms_user_id)
set_custom_attribute_keys = [call.args[0] for call in mock_set_custom_attribute.call_args_list]
assert 'jwt_auth_mismatch_session_lms_user_id' not in set_custom_attribute_keys
assert 'jwt_auth_get_lms_user_id_status' not in set_custom_attribute_keys
assert response.status_code == 200

def _get_test_jwt_token(self, user=None, is_valid_signature=True):
def _get_test_jwt_token(self, user=None, is_valid_signature=True, lms_user_id=None):
""" Returns a test jwt token for the provided user """
test_user = factories.UserFactory() if user is None else user
payload = generate_latest_version_payload(test_user)
if lms_user_id:
# In other services, the LMS user id in the JWT would not be the user's id.
payload['user_id'] = lms_user_id
if is_valid_signature:
jwt_token = generate_jwt_token(payload)
else:
jwt_token = generate_jwt_token(payload, signing_key='invalid-key')
return jwt_token

def _get_test_jwt_token_payload_and_signature(self, user=None):
def _get_test_jwt_token_payload_and_signature(self, user=None, lms_user_id=None):
""" Returns a test jwt token split into payload and signature """
jwt_token = self._get_test_jwt_token(user=user)
jwt_token = self._get_test_jwt_token(user=user, lms_user_id=lms_user_id)
jwt_token_parts = jwt_token.split('.')
header_and_payload = '.'.join(jwt_token_parts[0:2])
signature = jwt_token_parts[2]
Expand Down

0 comments on commit fab0a68

Please sign in to comment.