diff --git a/.gitignore b/.gitignore index e2ac8bd9..0f395a90 100644 --- a/.gitignore +++ b/.gitignore @@ -85,6 +85,9 @@ private.py # Editor Temp Files *.swp +# pyenv +.python-version + *.trace docs/_build/ diff --git a/license_manager/apps/subscriptions/admin.py b/license_manager/apps/subscriptions/admin.py index ec501d0b..4eb55abb 100644 --- a/license_manager/apps/subscriptions/admin.py +++ b/license_manager/apps/subscriptions/admin.py @@ -34,6 +34,10 @@ SubscriptionPlan, SubscriptionPlanRenewal, ) +from license_manager.apps.subscriptions.tasks import ( + PROVISION_LICENSES_BATCH_SIZE, + provision_licenses_task, +) def get_related_object_link(admin_viewname, object_pk, object_str): @@ -334,10 +338,18 @@ def save_model(self, request, obj, form, change): customer_agreement_catalog = obj.customer_agreement.default_enterprise_catalog_uuid obj.enterprise_catalog_uuid = (obj.enterprise_catalog_uuid or customer_agreement_catalog) - # Create licenses to be associated with the subscription plan after creating the subscription plan - num_new_licenses = form.cleaned_data.get('num_licenses', 0) - obj.num_licenses + # Set desired_num_licenses which will lead to the eventual creation of those licenses. + obj.desired_num_licenses = form.cleaned_data.get('num_licenses', 0) + super().save_model(request, obj, form, change) - SubscriptionPlan.increase_num_licenses(obj, num_new_licenses) + + num_new_licenses = obj.desired_num_licenses - obj.num_licenses + if num_new_licenses <= PROVISION_LICENSES_BATCH_SIZE: + # We can handle just one batch synchronously. + SubscriptionPlan.increase_num_licenses(obj, num_new_licenses) + else: + # Multiple batches of licenses will need to be created, so provision them asynchronously. + provision_licenses_task.delay(subscription_plan_uuid=obj.uuid) @admin.register(CustomerAgreement) diff --git a/license_manager/apps/subscriptions/migrations/0064_subscriptionplan_desired_num_licenses.py b/license_manager/apps/subscriptions/migrations/0064_subscriptionplan_desired_num_licenses.py new file mode 100644 index 00000000..dba12492 --- /dev/null +++ b/license_manager/apps/subscriptions/migrations/0064_subscriptionplan_desired_num_licenses.py @@ -0,0 +1,23 @@ +# Generated by Django 4.2.9 on 2024-01-24 03:39 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('subscriptions', '0063_transfer_all_licenses'), + ] + + operations = [ + migrations.AddField( + model_name='historicalsubscriptionplan', + name='desired_num_licenses', + field=models.PositiveIntegerField(blank=True, help_text='Total number of licenses that should exist for this SubscriptionPlan. The total license count (provisioned asynchronously) will reach the desired amount eventually. Empty (NULL) means no attempts will be made to asynchronously provision licenses.', null=True, verbose_name='Desired Number of Licenses'), + ), + migrations.AddField( + model_name='subscriptionplan', + name='desired_num_licenses', + field=models.PositiveIntegerField(blank=True, help_text='Total number of licenses that should exist for this SubscriptionPlan. The total license count (provisioned asynchronously) will reach the desired amount eventually. Empty (NULL) means no attempts will be made to asynchronously provision licenses.', null=True, verbose_name='Desired Number of Licenses'), + ), + ] diff --git a/license_manager/apps/subscriptions/models.py b/license_manager/apps/subscriptions/models.py index 321f9e47..46685fc9 100644 --- a/license_manager/apps/subscriptions/models.py +++ b/license_manager/apps/subscriptions/models.py @@ -434,6 +434,17 @@ class SubscriptionPlan(TimeStampedModel): ) ) + desired_num_licenses = models.PositiveIntegerField( + blank=True, + null=True, + verbose_name="Desired Number of Licenses", + help_text=( + "Total number of licenses that should exist for this SubscriptionPlan. " + "The total license count (provisioned asynchronously) will reach the desired amount eventually. " + "Empty (NULL) means no attempts will be made to asynchronously provision licenses." + ), + ) + @property def days_until_expiration(self): """ diff --git a/license_manager/apps/subscriptions/tasks.py b/license_manager/apps/subscriptions/tasks.py new file mode 100644 index 00000000..82f53b07 --- /dev/null +++ b/license_manager/apps/subscriptions/tasks.py @@ -0,0 +1,110 @@ +""" +Celery tasks for the subscriptions app. +""" +import functools +import logging + +from celery import shared_task +from celery_utils.logged_task import LoggedTask +from django.db import IntegrityError +from django.db.utils import OperationalError + +from license_manager.apps.api.utils import ( + acquire_subscription_plan_lock, + release_subscription_plan_lock, +) +from license_manager.apps.subscriptions.models import SubscriptionPlan +from license_manager.apps.subscriptions.utils import batch_counts + + +logger = logging.getLogger(__name__) + +TASK_RETRY_SECONDS = 60 +PROVISION_LICENSES_BATCH_SIZE = 300 + + +class RequiredTaskUnreadyError(Exception): + """ + An exception representing a state where one type of task that is required + to be complete before another task is run is not in a ready state. + """ + + +def subscription_plan_semaphore(): + """ + Celery Task decorator that wraps a bound (bind=True) task. If another wrapped task with the same given + "subscription_plan_uuid" kwarg value is still running, defer running this task (by retrying until all other tasks + are completed). + """ + def decorator(task): + @functools.wraps(task) + def wrapped_task(self, *args, **kwargs): + subscription_plan = SubscriptionPlan.objects.get(uuid=kwargs['subscription_plan_uuid']) + if not acquire_subscription_plan_lock(subscription_plan): + logger.info( + f'Deferring task {self.name} with id {self.request.id} ' + f'and args: {self.request.args}, kwargs: {self.request.kwargs}, ' + 'since another task run with the same subscription plan has not yet completed.' + ) + raise self.retry(exc=RequiredTaskUnreadyError()) + # Try to run the task. If it fails, release the lock and bubble up the exception. + try: + task_results = task(self, *args, **kwargs) + finally: + release_subscription_plan_lock(subscription_plan) + return task_results + return wrapped_task + return decorator + + +class LoggedTaskWithRetry(LoggedTask): # pylint: disable=abstract-method + """ + Shared base task that allows tasks that raise some common exceptions to retry automatically. + + See https://docs.celeryproject.org/en/stable/userguide/tasks.html#automatic-retry-for-known-exceptions for + more documentation. + """ + autoretry_for = ( + IntegrityError, + OperationalError, + ) + retry_kwargs = {'max_retries': 5} + # Use exponential backoff for retrying tasks + retry_backoff = True + # Add randomness to backoff delays to prevent all tasks in queue from executing simultaneously + retry_jitter = True + + +@shared_task(base=LoggedTaskWithRetry, bind=True, default_retry_delay=TASK_RETRY_SECONDS) +@subscription_plan_semaphore() +def provision_licenses_task(self, subscription_plan_uuid=None): # pylint: disable=unused-argument + """ + For a given subscription plan, try to make its count of licenses match the number defined by the + `desired_num_licenses` field of that subscription plan. Never decrease the count of licenses; if there are already + more licenses than `desired_num_licenses`, do nothing. + + Args: + subscription_plan_uuid (str): UUID of the SubscriptionPlan object to provision licenses for. + """ + subscription_plan = SubscriptionPlan.objects.get(uuid=subscription_plan_uuid) + if not subscription_plan.desired_num_licenses: + logger.info( + f'Skipping task {self.name} with id {self.request.id} ' + f'and args: {self.request.args}, kwargs: {self.request.kwargs}, ' + f'because desired_num_licenses is not set on this subscription plan.' + ) + return + license_count_gap = subscription_plan.desired_num_licenses - subscription_plan.num_licenses + if license_count_gap <= 0: + logger.info( + f'Skipping task {self.name} with id {self.request.id} ' + f'and args: {self.request.args}, kwargs: {self.request.kwargs}, ' + f'because the actual license count ({subscription_plan.num_licenses}) ' + f'already meets or exceeds the desired license count ({subscription_plan.desired_num_licenses}).' + ) + return + + # There's work to do, creating licenses! It should be safe to not re-check the license count between batches + # because we lock this subscription plan anyway (via @subscription_plan_semaphore decorator). + for batch_count in batch_counts(license_count_gap, batch_size=PROVISION_LICENSES_BATCH_SIZE): + subscription_plan.increase_num_licenses(batch_count) diff --git a/license_manager/apps/subscriptions/tests/test_tasks.py b/license_manager/apps/subscriptions/tests/test_tasks.py new file mode 100644 index 00000000..61a99e74 --- /dev/null +++ b/license_manager/apps/subscriptions/tests/test_tasks.py @@ -0,0 +1,121 @@ +""" +Tests for subscriptions app celery tasks +""" +from datetime import datetime, timedelta +from unittest import mock +from uuid import uuid4 + +import ddt +import freezegun +import pytest +from braze.exceptions import BrazeClientError +from django.conf import settings +from django.test import TestCase +from django.test.utils import override_settings +from freezegun import freeze_time +from requests import models + +from license_manager.apps.api.utils import ( + acquire_subscription_plan_lock, + release_subscription_plan_lock, +) +from license_manager.apps.subscriptions import tasks +from license_manager.apps.subscriptions.models import License, SubscriptionPlan +from license_manager.apps.subscriptions.tests.factories import ( + LicenseFactory, + SubscriptionPlanFactory, +) + + +# pylint: disable=unused-argument +@ddt.ddt +class ProvisionLicensesTaskTests(TestCase): + """ + Tests for provision_licenses_task. + """ + def setUp(self): + super().setUp() + self.subscription_plan = SubscriptionPlanFactory() + + def tearDown(self): + super().tearDown() + release_subscription_plan_lock(self.subscription_plan) + + # For all cases below, assume batch size of 5. + @ddt.data( + # Don't add licenses if none are desired. + { + 'num_initial_licenses': 0, + 'desired_num_licenses': 0, + 'expected_num_licenses': 0, + }, + # Create fewer licenses than one batch. + { + 'num_initial_licenses': 0, + 'desired_num_licenses': 1, + 'expected_num_licenses': 1, + }, + # Create licenses that span multiple batches. + { + 'num_initial_licenses': 0, + 'desired_num_licenses': 6, + 'expected_num_licenses': 6, + }, + # Don't add more licenses if the goal has already been reached. + { + 'num_initial_licenses': 10, + 'desired_num_licenses': 10, + 'expected_num_licenses': 10, + }, + # Create fewer licenses than one batch (starting with 10 initially). + { + 'num_initial_licenses': 10, + 'desired_num_licenses': 11, + 'expected_num_licenses': 11, + }, + # Create licenses that span multiple batches (starting with 10 initially). + { + 'num_initial_licenses': 10, + 'desired_num_licenses': 16, + 'expected_num_licenses': 16, + }, + # Don't remove licenses if the desired number of licenses is smaller than count of existing licenses. + { + 'num_initial_licenses': 20, + 'desired_num_licenses': 10, + 'expected_num_licenses': 20, + }, + # Don't add or remove licenses if the desired number of licenses is None. + { + 'num_initial_licenses': 20, + 'desired_num_licenses': None, + 'expected_num_licenses': 20, + }, + ) + @ddt.unpack + @mock.patch('license_manager.apps.subscriptions.tasks.PROVISION_LICENSES_BATCH_SIZE', 5) + def test_provision_licenses_task(self, num_initial_licenses, desired_num_licenses, expected_num_licenses): + """ + Test provision_licenses_task. + """ + self.subscription_plan.desired_num_licenses = desired_num_licenses + self.subscription_plan.save() + self.subscription_plan.increase_num_licenses(num_initial_licenses) + + tasks.provision_licenses_task(subscription_plan_uuid=self.subscription_plan.uuid) + + assert self.subscription_plan.num_licenses == expected_num_licenses + + def test_provision_licenses_task_locked(self): + """ + Test provision_licenses_task throws an exception if the subscription is locked. + """ + self.subscription_plan.desired_num_licenses = 5 + self.subscription_plan.save() + + acquire_subscription_plan_lock(self.subscription_plan) + + with self.assertRaises(tasks.RequiredTaskUnreadyError): + tasks.provision_licenses_task(subscription_plan_uuid=self.subscription_plan.uuid) + + assert self.subscription_plan.num_licenses == 0 diff --git a/license_manager/apps/subscriptions/tests/test_utils.py b/license_manager/apps/subscriptions/tests/test_utils.py index 40de9aa8..13dd0883 100644 --- a/license_manager/apps/subscriptions/tests/test_utils.py +++ b/license_manager/apps/subscriptions/tests/test_utils.py @@ -5,7 +5,9 @@ import hashlib import hmac import uuid -from unittest import mock +from unittest import TestCase, mock + +import ddt from license_manager.apps.subscriptions import utils @@ -28,3 +30,51 @@ def test_get_subsidy_checksum(): expected_checksum, utils.get_subsidy_checksum(lms_user_id, course_key, license_uuid), ) + + +@ddt.ddt +class TestBatchCounts(TestCase): + """ + Tests for batch_counts(). + """ + + @ddt.data( + { + 'total_count': 0, + 'batch_size': 5, + 'expected_batch_counts': [], + }, + { + 'total_count': 4, + 'batch_size': 5, + 'expected_batch_counts': [4], + }, + { + 'total_count': 5, + 'batch_size': 5, + 'expected_batch_counts': [5], + }, + { + 'total_count': 6, + 'batch_size': 5, + 'expected_batch_counts': [5, 1], + }, + { + 'total_count': 23, + 'batch_size': 5, + 'expected_batch_counts': [5, 5, 5, 5, 3], + }, + # Just make sure something weird doesn't happen when the batch size is 1. + { + 'total_count': 5, + 'batch_size': 1, + 'expected_batch_counts': [1, 1, 1, 1, 1], + }, + ) + @ddt.unpack + def test_batch_counts(self, total_count, batch_size, expected_batch_counts): + """ + Test batch_counts(). + """ + actual_batch_counts = list(utils.batch_counts(total_count, batch_size=batch_size)) + assert actual_batch_counts == expected_batch_counts diff --git a/license_manager/apps/subscriptions/utils.py b/license_manager/apps/subscriptions/utils.py index bbef85dd..94ca2b85 100644 --- a/license_manager/apps/subscriptions/utils.py +++ b/license_manager/apps/subscriptions/utils.py @@ -66,6 +66,23 @@ def chunks(a_list, chunk_size): yield a_list[i:i + chunk_size] +def batch_counts(total_count, batch_size=1): + """ + Break up a total count into equal-sized batch counts. + + Arguments: + total_count (int): The total count to batch. + batch_size (int): The size of each batch. Defaults to 1. + Returns: + generator: returns the count for each batch. + """ + num_full_batches, last_batch_count = divmod(total_count, batch_size) + for _ in range(num_full_batches): + yield batch_size + if last_batch_count > 0: + yield last_batch_count + + def get_learner_portal_url(enterprise_slug): """ Returns the link to the learner portal, given an enterprise slug.