Skip to content

Commit

Permalink
feat: refactor and add unittests
Browse files Browse the repository at this point in the history
  • Loading branch information
muhammad-ammar committed Nov 5, 2024
1 parent a9e3883 commit e923f3a
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 33 deletions.
2 changes: 1 addition & 1 deletion license_manager/apps/subscriptions/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,4 +152,4 @@ class SegmentEvents:

ENTERPRISE_BRAZE_ALIAS_LABEL = 'Enterprise' # Do Not change this, this is consistent with other uses across edX repos.

EXPIRED_LICENSE_PROCESSED = 'edx.server.license-manager.expired.license.processed'
EXPIRED_LICENSE_UNLINKED = 'edx.server.license-manager.expired.license.unlinked'
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from datetime import timedelta
from unittest import mock

import pytest
from django.core.management import call_command
from django.test import TestCase
from django.test.utils import override_settings

from license_manager.apps.subscriptions.constants import (
ACTIVATED,
ASSIGNED,
EXPIRED_LICENSE_UNLINKED,
REVOKED,
UNASSIGNED,
)
from license_manager.apps.subscriptions.models import LicenseEvent
from license_manager.apps.subscriptions.tests.factories import (
CustomerAgreementFactory,
LicenseFactory,
SubscriptionPlanFactory,
)
from license_manager.apps.subscriptions.utils import localized_utcnow


@pytest.mark.django_db
class UnlinkExpiredLicensesCommandTests(TestCase):
command_name = 'unlink_expired_licenses'
today = localized_utcnow()
customer_uuid = '76b933cb-bf2a-4c1e-bf44-4e8a58cc37ae'

def _create_expired_plan_with_licenses(
self,
unassigned_licenses_count=1,
assigned_licenses_count=2,
activated_licenses_count=3,
revoked_licenses_count=4,
start_date=today - timedelta(days=7),
expiration_date=today,
expiration_processed=False
):
"""
Creates a plan with licenses. The plan is expired by default.
"""
customer_agreement = CustomerAgreementFactory(enterprise_customer_uuid=self.customer_uuid)
expired_plan = SubscriptionPlanFactory.create(
customer_agreement=customer_agreement,
start_date=start_date,
expiration_date=expiration_date,
expiration_processed=expiration_processed
)

LicenseFactory.create_batch(unassigned_licenses_count, status=UNASSIGNED, subscription_plan=expired_plan)
LicenseFactory.create_batch(assigned_licenses_count, status=ASSIGNED, subscription_plan=expired_plan)
LicenseFactory.create_batch(activated_licenses_count, status=ACTIVATED, subscription_plan=expired_plan)
LicenseFactory.create_batch(revoked_licenses_count, status=REVOKED, subscription_plan=expired_plan)

return expired_plan

def _get_allocated_license_uuids(self, subscription_plan):
return [str(license.uuid) for license in subscription_plan.licenses.filter(status__in=[ASSIGNED, ACTIVATED])]

@override_settings(
CUSTOMERS_WITH_EXPIRED_LICENSES_UNLINKING_ENABLED=['76b933cb-bf2a-4c1e-bf44-4e8a58cc37ae']
)
@mock.patch(
'license_manager.apps.subscriptions.management.commands.unlink_expired_licenses.EnterpriseApiClient',
return_value=mock.MagicMock()
)
def test_expired_licenses_unlinking(self, mock_enterprise_client):
"""
Verify that expired licenses unlinking working as expected.
"""
today = localized_utcnow()

# create a plan that is expired but difference between expiration_date and today is less than 90
self._create_expired_plan_with_licenses()
# create a plan that is expired 90 days ago
plan_expired_90_days_ago = self._create_expired_plan_with_licenses(
start_date=today - timedelta(days=150),
expiration_date=today - timedelta(days=90)
)

call_command(self.command_name)

# verify that correct licenses from desired subscription plan were recorded in database
for license_event in LicenseEvent.objects.all():
assert license_event.license.subscription_plan.uuid == plan_expired_90_days_ago.uuid
assert license_event.event_name == EXPIRED_LICENSE_UNLINKED

# verify that call to unlink_users endpoint has correct user emails
mock_client_call_args = mock_enterprise_client().bulk_unlink_enterprise_users.call_args_list[0]
assert mock_client_call_args.args[0] == self.customer_uuid
assert sorted(mock_client_call_args.args[1]['user_emails']) == sorted([
license.user_email for license in plan_expired_90_days_ago.licenses.filter(
status__in=[ASSIGNED, ACTIVATED]
)
])

@override_settings(
CUSTOMERS_WITH_EXPIRED_LICENSES_UNLINKING_ENABLED=['76b933cb-bf2a-4c1e-bf44-4e8a58cc37ae']
)
@mock.patch(
'license_manager.apps.subscriptions.management.commands.unlink_expired_licenses.EnterpriseApiClient',
return_value=mock.MagicMock()
)
def test_expired_licenses_other_active_licenses(self, mock_enterprise_client):
"""
Verify that no unlinking happens when all expired licenses has other active licenses.
"""
assert LicenseEvent.objects.count() == 0
today = localized_utcnow()

# create a plan that is expired 90 days ago
plan_expired_90_days_ago = self._create_expired_plan_with_licenses(
start_date=today - timedelta(days=150),
expiration_date=today - timedelta(days=90)
)
# just another plan
another_plan = self._create_expired_plan_with_licenses(
start_date=today - timedelta(days=150),
expiration_date=today + timedelta(days=10)
)

# fetch user emails from the expired plan
user_emails = list(plan_expired_90_days_ago.licenses.filter(
status__in=[ASSIGNED, ACTIVATED]
).values_list('user_email', flat=True))

# assigned the above emails to licenses to create the test scenario where a learner has other active licenses
for license in another_plan.licenses.filter(status__in=[ASSIGNED, ACTIVATED]):
license.user_email = user_emails.pop()
license.save()

call_command(self.command_name)

# verify that no records were created in database for LicenseEvent
assert LicenseEvent.objects.count() == 0

# verify that no calls have been made to the unlink_users endpoint.
assert mock_enterprise_client().bulk_unlink_enterprise_users.call_count == 0
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from license_manager.apps.subscriptions.constants import (
ACTIVATED,
ASSIGNED,
EXPIRED_LICENSE_PROCESSED,
EXPIRED_LICENSE_UNLINKED,
)
from license_manager.apps.subscriptions.models import (
CustomerAgreement,
Expand Down Expand Up @@ -42,7 +42,7 @@ def add_arguments(self, parser):
help='Dry Run, print log messages without firing the segment event.',
)

def expired_licenses(self, enterprise_customer_uuid):
def expired_licenses(self, log_prefix, enterprise_customer_uuid):
"""
Get expired licenses.
"""
Expand All @@ -52,27 +52,14 @@ def expired_licenses(self, enterprise_customer_uuid):
customer_agreement = CustomerAgreement.objects.get(enterprise_customer_uuid=enterprise_customer_uuid)

# fetch expired subscription plans where the expiration date is older than 90 days.
expired_subscription_plans = list(
SubscriptionPlan.objects.filter(
customer_agreement=customer_agreement,
expiration_date__lt=now - timedelta(days=90),
).select_related(
'customer_agreement'
).prefetch_related(
'licenses'
)
)

for expired_subscription_plan in expired_subscription_plans:
# exclude subscription plan if there is a renewal
if expired_subscription_plan.get_renewal():
continue
expired_subscription_plan_uuids = list(SubscriptionPlan.objects.filter(
customer_agreement=customer_agreement,
expiration_date__lt=now - timedelta(days=90),
).prefetch_related(
'licenses'
).values_list('uuid', flat=True))

expired_subscription_plan_uuids.append(expired_subscription_plan.uuid)

# include prior renewals
for prior_renewal in expired_subscription_plan.prior_renewals:
expired_subscription_plan_uuids.append(prior_renewal.prior_subscription_plan.uuid)
logger.info('%s Expired plans. UUIDs: %s', log_prefix, expired_subscription_plan_uuids)

queryset = License.objects.filter(
status__in=[ASSIGNED, ACTIVATED],
Expand All @@ -82,10 +69,10 @@ def expired_licenses(self, enterprise_customer_uuid):
'subscription_plan',
).values('uuid', 'lms_user_id', 'user_email')

# subquery to check for the existence of `EXPIRED_LICENSE_PROCESSED`
# subquery to check for the existence of `EXPIRED_LICENSE_UNLINKED`
event_exists_subquery = LicenseEvent.objects.filter(
license=OuterRef('pk'),
event_name=EXPIRED_LICENSE_PROCESSED
event_name=EXPIRED_LICENSE_UNLINKED
).values('pk')

# exclude previously processed licenses.
Expand Down Expand Up @@ -117,7 +104,7 @@ def unlink_expired_licenses(self, enterprise_customer_uuid, log_prefix, unlink):
"""
Unlink expired licenses.
"""
expired_licenses = self.expired_licenses(enterprise_customer_uuid)
expired_licenses = self.expired_licenses(log_prefix, enterprise_customer_uuid)

if not expired_licenses:
logger.info(
Expand All @@ -137,8 +124,8 @@ def unlink_expired_licenses(self, enterprise_customer_uuid, log_prefix, unlink):
# check if the user associated with the expired license
# has any other active licenses with the same customer
other_active_licenses = License.for_user_and_customer(
user_email=license.user_email,
lms_user_id=license.lms_user_id,
user_email=license.get('user_email'),
lms_user_id=license.get('lms_user_id'),
enterprise_customer_uuid=enterprise_customer_uuid,
active_plans_only=True,
current_plans_only=True,
Expand All @@ -149,19 +136,18 @@ def unlink_expired_licenses(self, enterprise_customer_uuid, log_prefix, unlink):
license_uuids.append(license.get('uuid'))
user_emails.append(license.get('user_email'))

if unlink:
if unlink and user_emails:
EnterpriseApiClient().bulk_unlink_enterprise_users(
enterprise_customer_uuid,
{
'user_emails': user_emails,
'is_relinkable': False
'is_relinkable': True
},

)

# Create license events for unlinked licenses to avoid processing them again.
unlinked_license_events = [
LicenseEvent(license_id=license_uuid, event_name=EXPIRED_LICENSE_PROCESSED)
LicenseEvent(license_id=license_uuid, event_name=EXPIRED_LICENSE_UNLINKED)
for license_uuid in license_uuids
]
LicenseEvent.objects.bulk_create(unlinked_license_events, batch_size=100)
Expand Down
2 changes: 1 addition & 1 deletion license_manager/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,5 +477,5 @@
]

CUSTOMERS_WITH_CUSTOM_LICENSE_EVENTS = ['00000000-1111-2222-3333-444444444444']
CUSTOMERS_WITH_EXPIRED_LICENSES_UNLINKING_ENABLED = ['76b933cb-bf2a-4c1e-bf44-4e8a58cc37ae']
CUSTOMERS_WITH_EXPIRED_LICENSES_UNLINKING_ENABLED = []
BULK_UNLINK_REQUEST_TIMEOUT_SECONDS = os.environ.get('BULK_UNLINK_REQUEST_TIMEOUT_SECONDS', 120)

0 comments on commit e923f3a

Please sign in to comment.