diff --git a/license_manager/apps/subscriptions/admin.py b/license_manager/apps/subscriptions/admin.py index ec501d0b..76e468f6 100644 --- a/license_manager/apps/subscriptions/admin.py +++ b/license_manager/apps/subscriptions/admin.py @@ -34,6 +34,7 @@ SubscriptionPlan, SubscriptionPlanRenewal, ) +from license_manager.apps.subscriptions.tasks import provision_licenses_task def get_related_object_link(admin_viewname, object_pk, object_str): @@ -334,10 +335,13 @@ def save_model(self, request, obj, form, change): customer_agreement_catalog = obj.customer_agreement.default_enterprise_catalog_uuid obj.enterprise_catalog_uuid = (obj.enterprise_catalog_uuid or customer_agreement_catalog) - # Create licenses to be associated with the subscription plan after creating the subscription plan - num_new_licenses = form.cleaned_data.get('num_licenses', 0) - obj.num_licenses + # Set desired_num_licenses which will lead to the eventual creation of those licenses. + obj.desired_num_licenses = form.cleaned_data.get('num_licenses', 0) + super().save_model(request, obj, form, change) - SubscriptionPlan.increase_num_licenses(obj, num_new_licenses) + + # Kick off the asynchronous process to provision licenses. + provision_licenses_task.delay(subscription_plan_uuid=obj.uuid) @admin.register(CustomerAgreement) diff --git a/license_manager/apps/subscriptions/migrations/0064_subscriptionplan_desired_num_licenses.py b/license_manager/apps/subscriptions/migrations/0064_subscriptionplan_desired_num_licenses.py new file mode 100644 index 00000000..dba12492 --- /dev/null +++ b/license_manager/apps/subscriptions/migrations/0064_subscriptionplan_desired_num_licenses.py @@ -0,0 +1,23 @@ +# Generated by Django 4.2.9 on 2024-01-24 03:39 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('subscriptions', '0063_transfer_all_licenses'), + ] + + operations = [ + migrations.AddField( + model_name='historicalsubscriptionplan', + name='desired_num_licenses', + field=models.PositiveIntegerField(blank=True, help_text='Total number of licenses that should exist for this SubscriptionPlan. The total license count (provisioned asynchronously) will reach the desired amount eventually. Empty (NULL) means no attempts will be made to asynchronously provision licenses.', null=True, verbose_name='Desired Number of Licenses'), + ), + migrations.AddField( + model_name='subscriptionplan', + name='desired_num_licenses', + field=models.PositiveIntegerField(blank=True, help_text='Total number of licenses that should exist for this SubscriptionPlan. The total license count (provisioned asynchronously) will reach the desired amount eventually. Empty (NULL) means no attempts will be made to asynchronously provision licenses.', null=True, verbose_name='Desired Number of Licenses'), + ), + ] diff --git a/license_manager/apps/subscriptions/models.py b/license_manager/apps/subscriptions/models.py index 321f9e47..46685fc9 100644 --- a/license_manager/apps/subscriptions/models.py +++ b/license_manager/apps/subscriptions/models.py @@ -434,6 +434,17 @@ class SubscriptionPlan(TimeStampedModel): ) ) + desired_num_licenses = models.PositiveIntegerField( + blank=True, + null=True, + verbose_name="Desired Number of Licenses", + help_text=( + "Total number of licenses that should exist for this SubscriptionPlan. " + "The total license count (provisioned asynchronously) will reach the desired amount eventually. " + "Empty (NULL) means no attempts will be made to asynchronously provision licenses." + ), + ) + @property def days_until_expiration(self): """ diff --git a/license_manager/apps/subscriptions/tasks.py b/license_manager/apps/subscriptions/tasks.py new file mode 100644 index 00000000..8adfee26 --- /dev/null +++ b/license_manager/apps/subscriptions/tasks.py @@ -0,0 +1,116 @@ +""" +Celery tasks for the subscriptions app. +""" +from datetime import timedelta + +from celery import shared_task, states +from celery_utils.logged_task import LoggedTask +from django.db import IntegrityError +from django.db.utils import OperationalError +from django_celery_results.models import TaskResult +import functools +from license_manager.apps.subscriptions.utils import batch_counts, localized_utcnow + + +ONE_HOUR = timedelta(hours=1) +TASK_RETRY_SECONDS = 60 +PROVISION_LICENSES_BATCH_SIZE = 300 + + +class RequiredTaskUnreadyError(Exception): + """ + An exception representing a state where one type of task that is required + to be complete before another task is run is not in a ready state. + """ + + +def unready_tasks(celery_task, time_delta): + """ + Returns any unready tasks with the name of the given celery task + that were created within the given (now - time_delta, now) range. + The unready celery states are + {'RECEIVED', 'REJECTED', 'STARTED', 'PENDING', 'RETRY'}. + https://docs.celeryproject.org/en/v5.0.5/reference/celery.states.html#unready-states + + Args: + celery_task: A celery task definition or "type" (not an applied task "instance"), + for example, ``update_catalog_metadata_task``. + time_delta: A datetime.timedelta indicating how for back to look for unready tasks of this type. + """ + return TaskResult.objects.filter( + task_name=celery_task.name, + date_created__gte=localized_utcnow() - time_delta, + status__in=states.UNREADY_STATES, + ) + +def already_running_semaphore(time_delta=None): + """ + Celery Task decorator that wraps a bound (bind=True) task. If another task with the same (name, args, kwargs) as + the given task was executed in the time between `time_delta` and now, and the task still has not completed, defer + running this task (by retrying until all other tasks are completed). + + `time_delta` defaults to one hour. + + Args: + time_delta (datetime.timedelta): An optional timedelta that specifies how far back + to look for the same task. + """ + def decorator(task): + @functools.wraps(task) + def wrapped_task(self, *args, **kwargs): + delta = time_delta or ONE_HOUR + if unready_tasks(update_catalog_metadata_task, delta).exists(): + logger.info( + f'Deferring task {self.name} with id {self.request.id} ' + f'and args: {self.request.args}, kwargs: {self.request.kwargs}, ' + 'since another task run with the same arguments has not yet completed.' + ) + raise self.retry(exc=RequiredTaskUnreadyError()) + return task(self, *args, **kwargs) + return wrapped_task + return decorator + + +class LoggedTaskWithRetry(LoggedTask): # pylint: disable=abstract-method + """ + Shared base task that allows tasks that raise some common exceptions to retry automatically. + + See https://docs.celeryproject.org/en/stable/userguide/tasks.html#automatic-retry-for-known-exceptions for + more documentation. + """ + autoretry_for = ( + IntegrityError, + OperationalError, + ) + retry_kwargs = {'max_retries': 5} + # Use exponential backoff for retrying tasks + retry_backoff = True + # Add randomness to backoff delays to prevent all tasks in queue from executing simultaneously + retry_jitter = True + + +@shared_task(base=LoggedTaskWithRetry, bind=True, default_retry_delay=TASK_RETRY_SECONDS) +@already_running_semaphore() +def provision_licenses_task(self, subscription_plan_uuid=None): # pylint: disable=unused-argument + """ + For a given subscription plan, try to make its count of licenses match the number defined by the + `desired_num_licenses` field of that subscription plan. Never decrease the count of licenses; if there are already + more licenses than `desired_num_licenses`, do nothing. + + Args: + subscription_plan_uuid (str): UUID of the SubscriptionPlan object to provision licenses for. + """ + subscription_plan = SubscriptionPlan.objects.get(uuid=subscription_plan_uuid) + license_count_gap = subscription_plan.desired_num_licenses - subscription_plan.num_licenses + if license_count_gap <= 0: + logger.info( + f'Skipping task {self.name} with id {self.request.id} ' + f'and args: {self.request.args}, kwargs: {self.request.kwargs}, ' + f'because the actual license count ({subscription_plan.num_licenses}) ' + f'already meets or exceeds the desired license count ({subscription_plan.desired_num_licenses}).' + ) + else: + # There's work to do, creating licenses! It should be safe to not re-check the license count between batches + # because we lock this subscription plan anyway (via @already_running_semaphore decorator). + for batch_count in batch_counts(license_count_gap, batch_size=PROVISION_LICENSES_BATCH_SIZE): + subscription_plan.increase_num_licenses(batch_count) diff --git a/license_manager/apps/subscriptions/utils.py b/license_manager/apps/subscriptions/utils.py index bbef85dd..236375a5 100644 --- a/license_manager/apps/subscriptions/utils.py +++ b/license_manager/apps/subscriptions/utils.py @@ -66,6 +66,22 @@ def chunks(a_list, chunk_size): yield a_list[i:i + chunk_size] +def batch_counts(total_count, batch_size=1): + """ + Break up a total count into equal-sized batch counts. + + Arguments: + total_count (int): The total count to batch. + batch_size (int): The size of each batch. Defaults to 1. + Returns: + generator: returns the count for each batch. + """ + num_full_batches, last_batch_count = divmod(total_count, batch_size) + for _ in range(num_full_batches): + yield batch_size + yield last_batch_count + + def get_learner_portal_url(enterprise_slug): """ Returns the link to the learner portal, given an enterprise slug.