diff --git a/conftest.py b/conftest.py new file mode 100644 index 000000000..e3a1cc485 --- /dev/null +++ b/conftest.py @@ -0,0 +1,11 @@ +""" +Pytest fixtures for Enterprise Access tests. +""" +import pytest + +@pytest.fixture(scope='session') +def celery_config(): + return { + 'broker_url': 'amqp://', + 'result_backend': 'redis://' + } diff --git a/enterprise_access/apps/api_client/lms_client.py b/enterprise_access/apps/api_client/lms_client.py index 7df58076f..b0ba2b408 100644 --- a/enterprise_access/apps/api_client/lms_client.py +++ b/enterprise_access/apps/api_client/lms_client.py @@ -5,6 +5,7 @@ import requests from django.conf import settings +from rest_framework import status from enterprise_access.apps.api_client.base_oauth import BaseOAuthClient @@ -146,18 +147,20 @@ def enterprise_contains_learner(self, enterprise_customer_uuid, learner_id): def create_pending_enterprise_users(self, enterprise_customer_uuid, user_emails): """ - Creates a pending enterprise user in the given ``enterprise_customer_uuid`` for each of the - specified ``user_emails`` provided. + Creates pending enterprise users in the given enterprise customer for the provided emails. Args: - enterprise_customer_uuid (UUID): UUID of the enterprise customer in which pending user records are created. + enterprise_customer_uuid (UUID): UUID of the enterprise customer in which pending user record is created. user_emails (list(str)): The emails for which pending enterprise users will be created. Returns: - A ``requests.Response`` object representing the pending-enterprise-learner endpoint response. + A ``requests.Response`` object representing the pending-enterprise-learner endpoint response. HTTP status + codes include: + * 201 CREATED: Any pending enterprise users were created. + * 204 NO CONTENT: No pending enterprise users were created (they ALL existed already). Raises: - ``requests.exceptions.HTTPError`` on any response with an unsuccessful status code. + ``requests.exceptions.HTTPError`` on any endpoint response with an unsuccessful status code. """ data = [ { @@ -169,16 +172,20 @@ def create_pending_enterprise_users(self, enterprise_customer_uuid, user_emails) response = self.client.post(self.pending_enterprise_learner_endpoint, json=data) try: response.raise_for_status() - logger.info( - 'Successfully created %r PendingEnterpriseCustomerUser records for customer %r', - len(data), - enterprise_customer_uuid, - ) + if response.status_code == status.HTTP_201_CREATED: + logger.info( + 'Successfully created PendingEnterpriseCustomerUser records for customer %r', + enterprise_customer_uuid, + ) + else: + logger.info( + 'Found existing PendingEnterpriseCustomerUser records for customer %r', + enterprise_customer_uuid, + ) return response except requests.exceptions.HTTPError as exc: logger.error( - 'Failed to create %r PendingEnterpriseCustomerUser records for customer %r because %r', - len(data), + 'Failed to create PendingEnterpriseCustomerUser records for customer %r because %r', enterprise_customer_uuid, response.text, ) diff --git a/enterprise_access/apps/api_client/tests/test_lms_client.py b/enterprise_access/apps/api_client/tests/test_lms_client.py index 907f2ae88..6fbf5aee1 100644 --- a/enterprise_access/apps/api_client/tests/test_lms_client.py +++ b/enterprise_access/apps/api_client/tests/test_lms_client.py @@ -9,10 +9,18 @@ import requests from django.conf import settings from django.test import TestCase +from rest_framework import status from enterprise_access.apps.api_client.lms_client import LmsApiClient from enterprise_access.apps.api_client.tests.test_utils import MockResponse +TEST_ENTERPRISE_UUID = uuid4() +TEST_USER_EMAILS = [ + 'larry@stooges.com', + 'moe@stooges.com', + 'curly@stooges.com', +] + @ddt.ddt class TestLmsApiClient(TestCase): @@ -194,8 +202,22 @@ def test_enterprise_contains_learner(self, mock_oauth_client, mock_json): ) @ddt.data( - {'mock_response_status': 201, 'mock_response_json': {'detail': 'Good Request'}}, - {'mock_response_status': 400, 'mock_response_json': {'detail': 'Bad Request'}}, + { + 'mock_response_status': status.HTTP_204_NO_CONTENT, + 'mock_response_json': [], + }, + { + 'mock_response_status': status.HTTP_201_CREATED, + 'mock_response_json': [ + {'enterprise_customer': str(TEST_ENTERPRISE_UUID), 'user_email': TEST_USER_EMAILS[0]}, + {'enterprise_customer': str(TEST_ENTERPRISE_UUID), 'user_email': TEST_USER_EMAILS[1]}, + {'enterprise_customer': str(TEST_ENTERPRISE_UUID), 'user_email': TEST_USER_EMAILS[2]}, + ], + }, + { + 'mock_response_status': status.HTTP_400_BAD_REQUEST, + 'mock_response_json': {'detail': 'Bad Request'}, + }, ) @ddt.unpack @mock.patch('enterprise_access.apps.api_client.base_oauth.OAuthAPIClient') @@ -209,27 +231,21 @@ def test_create_pending_enterprise_users(self, mock_oauth_client, mock_response_ mock_response_status, ) - user_emails = [ - 'larry@stooges.com', - 'moe@stooges.com', - 'curly@stooges.com', - ] - mock_enterprise_uuid = str(uuid4()) - client = LmsApiClient() if mock_response_status >= 400: with self.assertRaises(requests.exceptions.HTTPError): - response = client.create_pending_enterprise_users(mock_enterprise_uuid, user_emails) + response = client.create_pending_enterprise_users(str(TEST_ENTERPRISE_UUID), TEST_USER_EMAILS) else: - response = client.create_pending_enterprise_users(mock_enterprise_uuid, user_emails) - assert response.status_code == 201 - assert response.json() == {'detail': 'Good Request'} + response = client.create_pending_enterprise_users(str(TEST_ENTERPRISE_UUID), TEST_USER_EMAILS) + assert response.status_code == mock_response_status + assert response.json() == mock_response_json mock_oauth_client.return_value.post.assert_called_once_with( client.pending_enterprise_learner_endpoint, json=[ - {'enterprise_customer': str(mock_enterprise_uuid), 'user_email': user_email} - for user_email in user_emails + {'enterprise_customer': str(TEST_ENTERPRISE_UUID), 'user_email': TEST_USER_EMAILS[0]}, + {'enterprise_customer': str(TEST_ENTERPRISE_UUID), 'user_email': TEST_USER_EMAILS[1]}, + {'enterprise_customer': str(TEST_ENTERPRISE_UUID), 'user_email': TEST_USER_EMAILS[2]}, ], ) diff --git a/enterprise_access/apps/content_assignments/tasks.py b/enterprise_access/apps/content_assignments/tasks.py new file mode 100644 index 000000000..c8e6dbe37 --- /dev/null +++ b/enterprise_access/apps/content_assignments/tasks.py @@ -0,0 +1,77 @@ +""" +Tasks for content_assignments app. +""" + +import logging + +from celery import shared_task +from django.apps import apps +from django.conf import settings + +from enterprise_access.apps.api_client.lms_client import LmsApiClient +from enterprise_access.tasks import LoggedTaskWithRetry + +from .constants import LearnerContentAssignmentStateChoices + +logger = logging.getLogger(__name__) + + +class CreatePendingEnterpriseLearnerForAssignmentTaskBase(LoggedTaskWithRetry): # pylint: disable=abstract-method + """ + Base class for the create_pending_enterprise_learner_for_assignment task. + + Provides a place to define retry failure handling logic. + """ + + def on_failure(self, exc, task_id, args, kwargs, einfo): + """ + If the task fails for any reason (whether or not retries were involved), set the assignment state to errored. + + Function signature documented at: https://docs.celeryq.dev/en/stable/userguide/tasks.html#on_failure + """ + logger.error(f'"{task_id}" failed: "{exc}"') + learner_content_assignment_uuid = args[0] + learner_content_assignment_model = apps.get_model('content_assignments.LearnerContentAssignment') + + try: + assignment = learner_content_assignment_model.objects.get(uuid=learner_content_assignment_uuid) + assignment.state = LearnerContentAssignmentStateChoices.ERRORED + assignment.save() + if self.request.retries == settings.TASK_MAX_RETRIES: + # The failure resulted from too many retries. This fact would be a useful thing to record in a "reason" + # field on the assignment if one existed. + logger.error( + 'The task failure resulted from exceeding the locally defined max number of retries ' + '(settings.TASK_MAX_RETRIES).' + ) + except assignment.DoesNotExist: + logger.error(f'LearnerContentAssignment not found with UUID: {learner_content_assignment_uuid}') + + +@shared_task(base=CreatePendingEnterpriseLearnerForAssignmentTaskBase) +def create_pending_enterprise_learner_for_assignment_task(learner_content_assignment_uuid): + """ + Create a pending enterprise learner for the email+content associated with the given LearnerContentAssignment. + + Args: + learner_content_assignment_uuid (str): + UUID of the LearnerContentAssignment object from which to obtain the learner email and enterprise customer. + + Raises: + HTTPError if LMS API call fails with an HTTPError. + """ + learner_content_assignment_model = apps.get_model('content_assignments.LearnerContentAssignment') + assignment = learner_content_assignment_model.objects.get(uuid=learner_content_assignment_uuid) + enterprise_customer_uuid = assignment.assignment_configuration.enterprise_customer_uuid + + # Intentionally not logging the learner email (PII). + logger.info(f'Creating a pending enterprise user for enterprise {enterprise_customer_uuid}.') + + lms_client = LmsApiClient() + # Could raise HTTPError and trigger task retry. Intentionally ignoring response since success should just not throw + # an exception. Two possible success statuses are 201 (created) and 200 (found), but there's no reason to + # distinguish them for the purpose of this task. + lms_client.create_pending_enterprise_users(enterprise_customer_uuid, [assignment.learner_email]) + + # TODO: ENT-7596: Save activity history on this assignment to represent that the learner is successfully linked to + # the enterprise. diff --git a/enterprise_access/apps/content_assignments/tests/test_tasks.py b/enterprise_access/apps/content_assignments/tests/test_tasks.py new file mode 100644 index 000000000..d8ad5bb64 --- /dev/null +++ b/enterprise_access/apps/content_assignments/tests/test_tasks.py @@ -0,0 +1,167 @@ +""" +Tests for Enterprise Access content_assignments tasks. +""" + +from unittest import mock +from uuid import uuid4 + +import ddt +from celery import states as celery_states +from django.conf import settings +from requests.exceptions import HTTPError +from rest_framework import status + +from enterprise_access.apps.api_client.tests.test_utils import MockResponse +from enterprise_access.apps.content_assignments.constants import LearnerContentAssignmentStateChoices +from enterprise_access.apps.content_assignments.tasks import create_pending_enterprise_learner_for_assignment_task +from enterprise_access.apps.content_assignments.tests.factories import ( + AssignmentConfigurationFactory, + LearnerContentAssignmentFactory +) +from enterprise_access.apps.subsidy_access_policy.tests.factories import AssignedLearnerCreditAccessPolicyFactory +from test_utils import APITestWithMocks + +TEST_ENTERPRISE_UUID = uuid4() +TEST_EMAIL = 'foo@bar.com' + + +@ddt.ddt +class TestCreatePendingEnterpriseLearnerForAssignmentTask(APITestWithMocks): + """ + Test create_pending_enterprise_learner_for_assignment_task(). + """ + + @classmethod + def setUpTestData(cls): + super().setUpTestData() + + # Create a pair of AssignmentConfiguration + SubsidyAccessPolicy for the main test customer. + cls.assignment_configuration = AssignmentConfigurationFactory( + enterprise_customer_uuid=TEST_ENTERPRISE_UUID, + ) + cls.assigned_learner_credit_policy = AssignedLearnerCreditAccessPolicyFactory( + display_name='An assigned learner credit policy, for the test customer.', + enterprise_customer_uuid=TEST_ENTERPRISE_UUID, + active=True, + assignment_configuration=cls.assignment_configuration, + spend_limit=1000000, + ) + + def setUp(self): + super().setUp() + + self.assignment = LearnerContentAssignmentFactory( + learner_email=TEST_EMAIL, + assignment_configuration=self.assignment_configuration, + ) + + @ddt.data( + # The LMS API did not find an existing PendingEnterpriseLearner, so it created one. + { + 'mock_lms_response_status': status.HTTP_201_CREATED, + 'mock_lms_response_body': { + 'enterprise_customer': str(TEST_ENTERPRISE_UUID), + 'user_email': TEST_EMAIL, + }, + }, + # The LMS API found an existing PendingEnterpriseLearner. + { + 'mock_lms_response_status': status.HTTP_204_NO_CONTENT, + 'mock_lms_response_body': None, + }, + ) + @ddt.unpack + @mock.patch('enterprise_access.apps.api_client.base_oauth.OAuthAPIClient') + def test_happy_path(self, mock_oauth_client, mock_lms_response_status, mock_lms_response_body): + """ + 2xx response form the LMS API should cause the task to run successfully. + """ + mock_oauth_client.return_value.post.return_value = MockResponse( + mock_lms_response_body, + mock_lms_response_status, + ) + + task_result = create_pending_enterprise_learner_for_assignment_task.delay(self.assignment.uuid) + + # Celery thinks the task succeeded. + assert task_result.state == celery_states.SUCCESS + + # The LMS/enterprise API was called once only, and with the correct request body. + assert len(mock_oauth_client.return_value.post.call_args_list) == 1 + assert mock_oauth_client.return_value.post.call_args.kwargs['json'] == [{ + 'enterprise_customer': str(self.assignment.assignment_configuration.enterprise_customer_uuid), + 'user_email': self.assignment.learner_email, + }] + + @ddt.data( + # 503 is a prototypical "please retry this endpoint" status. + status.HTTP_503_SERVICE_UNAVAILABLE, + # 400 should really not trigger retry, but it does. We should improve LoggedTaskWithRetry to make it not retry! + status.HTTP_400_BAD_REQUEST, + ) + @mock.patch('enterprise_access.apps.api_client.base_oauth.OAuthAPIClient') + def test_max_retries(self, response_status_that_triggers_retry, mock_oauth_client): + """ + On repeated error responses from the LMS/enterprise API, the celery worker should retry the task until the + maximum number of retries configured, then set the assignment state to ERRORED. + """ + mock_oauth_client.return_value.post.return_value = MockResponse( + { + 'enterprise_customer': str(TEST_ENTERPRISE_UUID), + 'user_email': TEST_EMAIL, + }, + response_status_that_triggers_retry, + ) + + task_result = create_pending_enterprise_learner_for_assignment_task.delay(self.assignment.uuid) + + # Celery thinks the task failed. + assert task_result.state == celery_states.FAILURE + + # The overall task result is just the HTTPError bubbled up from the API response. + assert isinstance(task_result.result, HTTPError) + assert task_result.result.response.status_code == response_status_that_triggers_retry + + # The LMS/enterprise API was called once plus the max number of retries, all with the correct request body. + assert len(mock_oauth_client.return_value.post.call_args_list) == 1 + settings.TASK_MAX_RETRIES + for call in mock_oauth_client.return_value.post.call_args_list: + assert call.kwargs['json'] == [{ + 'enterprise_customer': str(self.assignment.assignment_configuration.enterprise_customer_uuid), + 'user_email': self.assignment.learner_email, + }] + + self.assignment.refresh_from_db() + assert self.assignment.state == LearnerContentAssignmentStateChoices.ERRORED + + @mock.patch('enterprise_access.apps.api_client.base_oauth.OAuthAPIClient') + def test_last_retry_success(self, mock_oauth_client): + """ + Test a scenario where the API response keeps triggering a retry until the last attempt, then finally responds + successfully. + """ + # Mock multiple consecutive responses, only the last of which was successful. + retry_triggering_responses = [ + MockResponse(None, status.HTTP_503_SERVICE_UNAVAILABLE) + for _ in range(settings.TASK_MAX_RETRIES) + ] + final_success_response = MockResponse( + { + 'enterprise_customer': str(TEST_ENTERPRISE_UUID), + 'user_email': TEST_EMAIL, + }, + status.HTTP_201_CREATED, + ) + mock_oauth_client.return_value.post.side_effect = retry_triggering_responses + [final_success_response] + + task_result = create_pending_enterprise_learner_for_assignment_task.delay(self.assignment.uuid) + + # Celery thinks the task succeeded. + assert task_result.state == celery_states.SUCCESS + + # The LMS/enterprise API was called once plus the max number of retries, all with the correct request body. + assert len(mock_oauth_client.return_value.post.call_args_list) == 1 + settings.TASK_MAX_RETRIES + for call in mock_oauth_client.return_value.post.call_args_list: + assert call.kwargs['json'] == [{ + 'enterprise_customer': str(self.assignment.assignment_configuration.enterprise_customer_uuid), + 'user_email': self.assignment.learner_email, + }]