Skip to content

Commit

Permalink
feat: make subscription license provisioning eventually consistent
Browse files Browse the repository at this point in the history
ENT-8269
  • Loading branch information
pwnage101 committed Jan 24, 2024
1 parent 52f02b7 commit f378e2e
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 3 deletions.
10 changes: 7 additions & 3 deletions license_manager/apps/subscriptions/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
SubscriptionPlan,
SubscriptionPlanRenewal,
)
from license_manager.apps.subscriptions.tasks import provision_licenses_task


def get_related_object_link(admin_viewname, object_pk, object_str):
Expand Down Expand Up @@ -334,10 +335,13 @@ def save_model(self, request, obj, form, change):
customer_agreement_catalog = obj.customer_agreement.default_enterprise_catalog_uuid
obj.enterprise_catalog_uuid = (obj.enterprise_catalog_uuid or customer_agreement_catalog)

# Create licenses to be associated with the subscription plan after creating the subscription plan
num_new_licenses = form.cleaned_data.get('num_licenses', 0) - obj.num_licenses
# Set desired_num_licenses which will lead to the eventual creation of those licenses.
obj.desired_num_licenses = form.cleaned_data.get('num_licenses', 0)

super().save_model(request, obj, form, change)
SubscriptionPlan.increase_num_licenses(obj, num_new_licenses)

# Kick off the asynchronous process to provision licenses.
provision_licenses_task.delay(subscription_plan_uuid=obj.uuid)


@admin.register(CustomerAgreement)
Expand Down
11 changes: 11 additions & 0 deletions license_manager/apps/subscriptions/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,17 @@ class SubscriptionPlan(TimeStampedModel):
)
)

desired_num_licenses = models.PositiveIntegerField(
blank=True,
null=True,
verbose_name="Desired Number of Licenses",
help_text=(
"Total number of licenses that should exist for this SubscriptionPlan. "
"The total license count (provisioned asynchronously) will reach the desired amount eventually. "
"Empty (NULL) means no attempts will be made to asynchronously provision licenses."
),
)

@property
def days_until_expiration(self):
"""
Expand Down
131 changes: 131 additions & 0 deletions license_manager/apps/subscriptions/tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""
Celery tasks for the subscriptions app.
"""
from datetime import timedelta

from celery import shared_task, states
from celery_utils.logged_task import LoggedTask
from django.db import IntegrityError
from django.db.utils import OperationalError
from django_celery_results.models import TaskResult
from license_manager.apps.subscriptions.utils import localized_utcnow


ONE_HOUR = timedelta(hours=1)
TASK_RETRY_SECONDS = 60
PROVISION_LICENSES_BATCH_SIZE = 300


class RequiredTaskUnreadyError(Exception):
"""
An exception representing a state where one type of task that is required
to be complete before another task is run is not in a ready state.
"""


def unready_tasks(celery_task, time_delta):
"""
Returns any unready tasks with the name of the given celery task
that were created within the given (now - time_delta, now) range.
The unready celery states are
{'RECEIVED', 'REJECTED', 'STARTED', 'PENDING', 'RETRY'}.
https://docs.celeryproject.org/en/v5.0.5/reference/celery.states.html#unready-states
Args:
celery_task: A celery task definition or "type" (not an applied task "instance"),
for example, ``update_catalog_metadata_task``.
time_delta: A datetime.timedelta indicating how for back to look for unready tasks of this type.
"""
return TaskResult.objects.filter(
task_name=celery_task.name,
date_created__gte=localized_utcnow() - time_delta,
status__in=states.UNREADY_STATES,
)

def already_running_semaphore(time_delta=None):
"""
Celery Task decorator that wraps a bound (bind=True) task. If another task with the same (name, args, kwargs) as
the given task was executed in the time between `time_delta` and now, and the task still has not completed, defer
running this task (by retrying until all other tasks are completed).
`time_delta` defaults to one hour.
Args:
time_delta (datetime.timedelta): An optional timedelta that specifies how far back
to look for the same task.
"""
def decorator(task):
@functools.wraps(task)
def wrapped_task(self, *args, **kwargs):
delta = time_delta or ONE_HOUR
if unready_tasks(update_catalog_metadata_task, delta).exists():
logger.info(
f'Deferring task {self.name} with id {self.request.id} '
f'and args: {self.request.args}, kwargs: {self.request.kwargs}, '
'since another task run with the same arguments has not yet completed.'
)
raise self.retry(exc=RequiredTaskUnreadyError())
return task(self, *args, **kwargs)
return wrapped_task
return decorator


class LoggedTaskWithRetry(LoggedTask): # pylint: disable=abstract-method
"""
Shared base task that allows tasks that raise some common exceptions to retry automatically.
See https://docs.celeryproject.org/en/stable/userguide/tasks.html#automatic-retry-for-known-exceptions for
more documentation.
"""
autoretry_for = (
IntegrityError,
OperationalError,
)
retry_kwargs = {'max_retries': 5}
# Use exponential backoff for retrying tasks
retry_backoff = True
# Add randomness to backoff delays to prevent all tasks in queue from executing simultaneously
retry_jitter = True


def batch_counts(total_count, batch_size=1):
"""
Break up a total count into equal-sized batch counts.
Arguments:
total_count (int): The total count to batch.
batch_size (int): The size of each batch. Defaults to 1.
Returns:
generator: returns the count for each batch.
"""
num_full_batches, last_batch_count = divmod(total_count, batch_size)
for _ in range(num_full_batches):
yield batch_size
yield last_batch_count


@shared_task(base=LoggedTaskWithRetry, bind=True, default_retry_delay=TASK_RETRY_SECONDS)
@already_running_semaphore()
def provision_licenses_task(self, subscription_plan_uuid=None): # pylint: disable=unused-argument
"""
For a given subscription plan, try to make its count of licenses match the number defined by the
`desired_num_licenses` field of that subscription plan. Never decrease the count of licenses; if there are already
more licenses than `desired_num_licenses`, do nothing.
Args:
subscription_plan_uuid (str): UUID of the SubscriptionPlan object to provision licenses for.
"""
subscription_plan = SubscriptionPlan.objects.get(uuid=subscription_plan_uuid)
license_count_gap = subscription_plan.desired_num_licenses - subscription_plan.num_licenses
if license_count_gap <= 0:
logger.info(
f'Skipping task {self.name} with id {self.request.id} '
f'and args: {self.request.args}, kwargs: {self.request.kwargs}, '
f'because the actual license count ({subscription_plan.num_licenses}) '
f'already meets or exceeds the desired license count ({subscription_plan.desired_num_licenses}).'
)
else:
# There's work to do, creating licenses! It should be safe to not re-check the license count between batches
# because we lock this subscription plan anyway (via @already_running_semaphore decorator).
for batch_count in batch_counts(license_count_gap, batch_size=PROVISION_LICENSES_BATCH_SIZE):
subscription_plan.increase_num_licenses(batch_count)

0 comments on commit f378e2e

Please sign in to comment.