From 5c67922722da49a32eea2704a6283202f878a5c0 Mon Sep 17 00:00:00 2001 From: muhammad-ammar Date: Thu, 31 Oct 2024 12:46:56 +0500 Subject: [PATCH] feat: unlink expired licenses --- license_manager/apps/api_client/enterprise.py | 8 + .../apps/subscriptions/constants.py | 2 + .../tests/test_unlink_expired_licenses.py | 140 +++++++++++++++ .../commands/unlink_expired_licenses.py | 160 ++++++++++++++++++ license_manager/settings/base.py | 2 + 5 files changed, 312 insertions(+) create mode 100644 license_manager/apps/subscriptions/management/commands/tests/test_unlink_expired_licenses.py create mode 100644 license_manager/apps/subscriptions/management/commands/unlink_expired_licenses.py diff --git a/license_manager/apps/api_client/enterprise.py b/license_manager/apps/api_client/enterprise.py index 300cb614..d3e6f5df 100644 --- a/license_manager/apps/api_client/enterprise.py +++ b/license_manager/apps/api_client/enterprise.py @@ -20,6 +20,7 @@ class EnterpriseApiClient(BaseOAuthClient): course_enrollments_revoke_endpoint = api_base_url + 'licensed-enterprise-course-enrollment/license_revoke/' bulk_licensed_enrollments_expiration_endpoint = api_base_url \ + 'licensed-enterprise-course-enrollment/bulk_licensed_enrollments_expiration/' + unlink_users_endpoint = api_base_url + 'enterprise-customer/' def get_enterprise_customer_data(self, enterprise_customer_uuid): """ @@ -189,3 +190,10 @@ def bulk_enroll_enterprise_learners(self, enterprise_id, options): """ enrollment_url = '{}{}/enroll_learners_in_courses/'.format(self.enterprise_customer_endpoint, enterprise_id) return self.client.post(enrollment_url, json=options, timeout=settings.BULK_ENROLL_REQUEST_TIMEOUT_SECONDS) + + def bulk_unlink_enterprise_users(self, enterprise_uuid, options): + """ + Calls the Enterprise `unlink_users` API to unlink learners for an enterprise. + """ + enrollment_url = '{}{}/unlink_users/'.format(self.unlink_users_endpoint, enterprise_uuid) + return self.client.post(enrollment_url, json=options, timeout=settings.BULK_UNLINK_REQUEST_TIMEOUT_SECONDS) diff --git a/license_manager/apps/subscriptions/constants.py b/license_manager/apps/subscriptions/constants.py index 8782e9c7..ccaf5cba 100644 --- a/license_manager/apps/subscriptions/constants.py +++ b/license_manager/apps/subscriptions/constants.py @@ -151,3 +151,5 @@ class SegmentEvents: } ENTERPRISE_BRAZE_ALIAS_LABEL = 'Enterprise' # Do Not change this, this is consistent with other uses across edX repos. + +EXPIRED_LICENSE_UNLINKED = 'edx.server.license-manager.expired.license.unlinked' diff --git a/license_manager/apps/subscriptions/management/commands/tests/test_unlink_expired_licenses.py b/license_manager/apps/subscriptions/management/commands/tests/test_unlink_expired_licenses.py new file mode 100644 index 00000000..383b1e07 --- /dev/null +++ b/license_manager/apps/subscriptions/management/commands/tests/test_unlink_expired_licenses.py @@ -0,0 +1,140 @@ +from datetime import timedelta +from unittest import mock + +import pytest +from django.core.management import call_command +from django.test import TestCase +from django.test.utils import override_settings + +from license_manager.apps.subscriptions.constants import ( + ACTIVATED, + ASSIGNED, + EXPIRED_LICENSE_UNLINKED, + REVOKED, + UNASSIGNED, +) +from license_manager.apps.subscriptions.models import LicenseEvent +from license_manager.apps.subscriptions.tests.factories import ( + CustomerAgreementFactory, + LicenseFactory, + SubscriptionPlanFactory, +) +from license_manager.apps.subscriptions.utils import localized_utcnow + + +@pytest.mark.django_db +class UnlinkExpiredLicensesCommandTests(TestCase): + command_name = 'unlink_expired_licenses' + today = localized_utcnow() + customer_uuid = '76b933cb-bf2a-4c1e-bf44-4e8a58cc37ae' + + def _create_expired_plan_with_licenses( + self, + unassigned_licenses_count=1, + assigned_licenses_count=2, + activated_licenses_count=3, + revoked_licenses_count=4, + start_date=today - timedelta(days=7), + expiration_date=today, + expiration_processed=False + ): + """ + Creates a plan with licenses. The plan is expired by default. + """ + customer_agreement = CustomerAgreementFactory(enterprise_customer_uuid=self.customer_uuid) + expired_plan = SubscriptionPlanFactory.create( + customer_agreement=customer_agreement, + start_date=start_date, + expiration_date=expiration_date, + expiration_processed=expiration_processed + ) + + LicenseFactory.create_batch(unassigned_licenses_count, status=UNASSIGNED, subscription_plan=expired_plan) + LicenseFactory.create_batch(assigned_licenses_count, status=ASSIGNED, subscription_plan=expired_plan) + LicenseFactory.create_batch(activated_licenses_count, status=ACTIVATED, subscription_plan=expired_plan) + LicenseFactory.create_batch(revoked_licenses_count, status=REVOKED, subscription_plan=expired_plan) + + return expired_plan + + def _get_allocated_license_uuids(self, subscription_plan): + return [str(license.uuid) for license in subscription_plan.licenses.filter(status__in=[ASSIGNED, ACTIVATED])] + + @override_settings( + CUSTOMERS_WITH_EXPIRED_LICENSES_UNLINKING_ENABLED=['76b933cb-bf2a-4c1e-bf44-4e8a58cc37ae'] + ) + @mock.patch( + 'license_manager.apps.subscriptions.management.commands.unlink_expired_licenses.EnterpriseApiClient', + return_value=mock.MagicMock() + ) + def test_expired_licenses_unlinking(self, mock_enterprise_client): + """ + Verify that expired licenses unlinking working as expected. + """ + today = localized_utcnow() + + # create a plan that is expired but difference between expiration_date and today is less than 90 + self._create_expired_plan_with_licenses() + # create a plan that is expired 90 days ago + plan_expired_90_days_ago = self._create_expired_plan_with_licenses( + start_date=today - timedelta(days=150), + expiration_date=today - timedelta(days=90) + ) + + call_command(self.command_name) + + # verify that correct licenses from desired subscription plan were recorded in database + for license_event in LicenseEvent.objects.all(): + assert license_event.license.subscription_plan.uuid == plan_expired_90_days_ago.uuid + assert license_event.event_name == EXPIRED_LICENSE_UNLINKED + + # verify that call to unlink_users endpoint has correct user emails + mock_client_call_args = mock_enterprise_client().bulk_unlink_enterprise_users.call_args_list[0] + assert mock_client_call_args.args[0] == self.customer_uuid + assert sorted(mock_client_call_args.args[1]['user_emails']) == sorted([ + license.user_email for license in plan_expired_90_days_ago.licenses.filter( + status__in=[ASSIGNED, ACTIVATED] + ) + ]) + + @override_settings( + CUSTOMERS_WITH_EXPIRED_LICENSES_UNLINKING_ENABLED=['76b933cb-bf2a-4c1e-bf44-4e8a58cc37ae'] + ) + @mock.patch( + 'license_manager.apps.subscriptions.management.commands.unlink_expired_licenses.EnterpriseApiClient', + return_value=mock.MagicMock() + ) + def test_expired_licenses_other_active_licenses(self, mock_enterprise_client): + """ + Verify that no unlinking happens when all expired licenses has other active licenses. + """ + assert LicenseEvent.objects.count() == 0 + today = localized_utcnow() + + # create a plan that is expired 90 days ago + plan_expired_90_days_ago = self._create_expired_plan_with_licenses( + start_date=today - timedelta(days=150), + expiration_date=today - timedelta(days=90) + ) + # just another plan + another_plan = self._create_expired_plan_with_licenses( + start_date=today - timedelta(days=150), + expiration_date=today + timedelta(days=10) + ) + + # fetch user emails from the expired plan + user_emails = list(plan_expired_90_days_ago.licenses.filter( + status__in=[ASSIGNED, ACTIVATED] + ).values_list('user_email', flat=True)) + + # assigned the above emails to licenses to create the test scenario where a learner has other active licenses + for license in another_plan.licenses.filter(status__in=[ASSIGNED, ACTIVATED]): + license.user_email = user_emails.pop() + license.save() + + call_command(self.command_name) + + # verify that no records were created in database for LicenseEvent + assert LicenseEvent.objects.count() == 0 + + # verify that no calls have been made to the unlink_users endpoint. + assert mock_enterprise_client().bulk_unlink_enterprise_users.call_count == 0 diff --git a/license_manager/apps/subscriptions/management/commands/unlink_expired_licenses.py b/license_manager/apps/subscriptions/management/commands/unlink_expired_licenses.py new file mode 100644 index 00000000..23ac76b7 --- /dev/null +++ b/license_manager/apps/subscriptions/management/commands/unlink_expired_licenses.py @@ -0,0 +1,160 @@ + +import logging +from datetime import timedelta + +from django.conf import settings +from django.core.management.base import BaseCommand +from django.core.paginator import Paginator +from django.db.models import Exists, OuterRef + +from license_manager.apps.api_client.enterprise import EnterpriseApiClient +from license_manager.apps.subscriptions.constants import ( + ACTIVATED, + ASSIGNED, + EXPIRED_LICENSE_UNLINKED, +) +from license_manager.apps.subscriptions.models import ( + CustomerAgreement, + License, + LicenseEvent, + SubscriptionPlan, +) +from license_manager.apps.subscriptions.utils import localized_utcnow + + +logger = logging.getLogger(__name__) + + +class Command(BaseCommand): + help = ( + 'Unlink expired licenses.' + ) + + def add_arguments(self, parser): + """ + Entry point to add arguments. + """ + parser.add_argument( + '--dry-run', + action='store_true', + dest='dry_run', + default=False, + help='Dry Run, print log messages without firing the segment event.', + ) + + def expired_licenses(self, log_prefix, enterprise_customer_uuid): + """ + Get expired licenses. + """ + now = localized_utcnow() + expired_subscription_plan_uuids = [] + + customer_agreement = CustomerAgreement.objects.get(enterprise_customer_uuid=enterprise_customer_uuid) + + # fetch expired subscription plans where the expiration date is older than 90 days. + expired_subscription_plan_uuids = list(SubscriptionPlan.objects.filter( + customer_agreement=customer_agreement, + expiration_date__lt=now - timedelta(days=90), + ).prefetch_related( + 'licenses' + ).values_list('uuid', flat=True)) + + logger.info('%s Expired plans. UUIDs: %s', log_prefix, expired_subscription_plan_uuids) + + queryset = License.objects.filter( + status__in=[ASSIGNED, ACTIVATED], + renewed_to=None, + subscription_plan__uuid__in=expired_subscription_plan_uuids, + ).select_related( + 'subscription_plan', + ).values('uuid', 'lms_user_id', 'user_email') + + # subquery to check for the existence of `EXPIRED_LICENSE_UNLINKED` + event_exists_subquery = LicenseEvent.objects.filter( + license=OuterRef('pk'), + event_name=EXPIRED_LICENSE_UNLINKED + ).values('pk') + + # exclude previously processed licenses. + queryset = queryset.exclude(Exists(event_exists_subquery)) + + return queryset + + def handle(self, *args, **options): + """ + Unlink expired licenses. + """ + unlink = not options['dry_run'] + + log_prefix = '[UNLINK_EXPIRED_LICENSES]' + if not unlink: + log_prefix = '[DRY RUN]' + + logger.info('%s Command started.', log_prefix) + + enterprise_customer_uuids = settings.CUSTOMERS_WITH_EXPIRED_LICENSES_UNLINKING_ENABLED + for enterprise_customer_uuid in enterprise_customer_uuids: + logger.info('%s Unlinking started for licenses. Enterprise: [%s]', log_prefix, enterprise_customer_uuid) + self.unlink_expired_licenses(enterprise_customer_uuid, log_prefix, unlink) + logger.info('%s Unlinking completed for licenses. Enterprise: [%s]', log_prefix, enterprise_customer_uuid) + + logger.info('%s Command completed.', log_prefix) + + def unlink_expired_licenses(self, enterprise_customer_uuid, log_prefix, unlink): + """ + Unlink expired licenses. + """ + expired_licenses = self.expired_licenses(log_prefix, enterprise_customer_uuid) + + if not expired_licenses: + logger.info( + '%s No expired licenses were found for enterprise: [%s].', + log_prefix, enterprise_customer_uuid + ) + return + + paginator = Paginator(expired_licenses, 100) + for page_number in paginator.page_range: + licenses = paginator.page(page_number) + + license_uuids = [] + user_emails = [] + + for license in licenses: + # check if the user associated with the expired license + # has any other active licenses with the same customer + other_active_licenses = License.for_user_and_customer( + user_email=license.get('user_email'), + lms_user_id=license.get('lms_user_id'), + enterprise_customer_uuid=enterprise_customer_uuid, + active_plans_only=True, + current_plans_only=True, + ).exists() + if other_active_licenses: + continue + + license_uuids.append(license.get('uuid')) + user_emails.append(license.get('user_email')) + + if unlink and user_emails: + EnterpriseApiClient().bulk_unlink_enterprise_users( + enterprise_customer_uuid, + { + 'user_emails': user_emails, + 'is_relinkable': True + }, + ) + + # Create license events for unlinked licenses to avoid processing them again. + unlinked_license_events = [ + LicenseEvent(license_id=license_uuid, event_name=EXPIRED_LICENSE_UNLINKED) + for license_uuid in license_uuids + ] + LicenseEvent.objects.bulk_create(unlinked_license_events, batch_size=100) + + logger.info( + "%s learners unlinked for licenses. Enterprise: [%s], LicenseUUIDs: [%s].", + log_prefix, + enterprise_customer_uuid, + license_uuids + ) diff --git a/license_manager/settings/base.py b/license_manager/settings/base.py index 808d78c4..0fb297fd 100644 --- a/license_manager/settings/base.py +++ b/license_manager/settings/base.py @@ -477,3 +477,5 @@ ] CUSTOMERS_WITH_CUSTOM_LICENSE_EVENTS = ['00000000-1111-2222-3333-444444444444'] +CUSTOMERS_WITH_EXPIRED_LICENSES_UNLINKING_ENABLED = [] +BULK_UNLINK_REQUEST_TIMEOUT_SECONDS = os.environ.get('BULK_UNLINK_REQUEST_TIMEOUT_SECONDS', 120)