Skip to content

Commit

Permalink
feat: make subscription license provisioning asynchronous and batched
Browse files Browse the repository at this point in the history
ENT-8269
  • Loading branch information
pwnage101 committed Jan 25, 2024
1 parent 31a4ccf commit 971eec7
Show file tree
Hide file tree
Showing 8 changed files with 351 additions and 4 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ private.py
# Editor Temp Files
*.swp

# pyenv
.python-version

*.trace

docs/_build/
Expand Down
18 changes: 15 additions & 3 deletions license_manager/apps/subscriptions/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
SubscriptionPlan,
SubscriptionPlanRenewal,
)
from license_manager.apps.subscriptions.tasks import (
PROVISION_LICENSES_BATCH_SIZE,
provision_licenses_task,
)


def get_related_object_link(admin_viewname, object_pk, object_str):
Expand Down Expand Up @@ -334,10 +338,18 @@ def save_model(self, request, obj, form, change):
customer_agreement_catalog = obj.customer_agreement.default_enterprise_catalog_uuid
obj.enterprise_catalog_uuid = (obj.enterprise_catalog_uuid or customer_agreement_catalog)

# Create licenses to be associated with the subscription plan after creating the subscription plan
num_new_licenses = form.cleaned_data.get('num_licenses', 0) - obj.num_licenses
# Set desired_num_licenses which will lead to the eventual creation of those licenses.
obj.desired_num_licenses = form.cleaned_data.get('num_licenses', 0)

super().save_model(request, obj, form, change)
SubscriptionPlan.increase_num_licenses(obj, num_new_licenses)

num_new_licenses = obj.desired_num_licenses - obj.num_licenses
if num_new_licenses <= PROVISION_LICENSES_BATCH_SIZE:
# We can handle just one batch synchronously.
SubscriptionPlan.increase_num_licenses(obj, num_new_licenses)
else:
# Multiple batches of licenses will need to be created, so provision them asynchronously.
provision_licenses_task.delay(subscription_plan_uuid=obj.uuid)


@admin.register(CustomerAgreement)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Generated by Django 4.2.9 on 2024-01-24 03:39

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('subscriptions', '0063_transfer_all_licenses'),
]

operations = [
migrations.AddField(
model_name='historicalsubscriptionplan',
name='desired_num_licenses',
field=models.PositiveIntegerField(blank=True, help_text='Total number of licenses that should exist for this SubscriptionPlan. The total license count (provisioned asynchronously) will reach the desired amount eventually. Empty (NULL) means no attempts will be made to asynchronously provision licenses.', null=True, verbose_name='Desired Number of Licenses'),
),
migrations.AddField(
model_name='subscriptionplan',
name='desired_num_licenses',
field=models.PositiveIntegerField(blank=True, help_text='Total number of licenses that should exist for this SubscriptionPlan. The total license count (provisioned asynchronously) will reach the desired amount eventually. Empty (NULL) means no attempts will be made to asynchronously provision licenses.', null=True, verbose_name='Desired Number of Licenses'),
),
]
11 changes: 11 additions & 0 deletions license_manager/apps/subscriptions/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,17 @@ class SubscriptionPlan(TimeStampedModel):
)
)

desired_num_licenses = models.PositiveIntegerField(
blank=True,
null=True,
verbose_name="Desired Number of Licenses",
help_text=(
"Total number of licenses that should exist for this SubscriptionPlan. "
"The total license count (provisioned asynchronously) will reach the desired amount eventually. "
"Empty (NULL) means no attempts will be made to asynchronously provision licenses."
),
)

@property
def days_until_expiration(self):
"""
Expand Down
110 changes: 110 additions & 0 deletions license_manager/apps/subscriptions/tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""
Celery tasks for the subscriptions app.
"""
import functools
import logging

from celery import shared_task
from celery_utils.logged_task import LoggedTask
from django.db import IntegrityError
from django.db.utils import OperationalError

from license_manager.apps.api.utils import (
acquire_subscription_plan_lock,
release_subscription_plan_lock,
)
from license_manager.apps.subscriptions.models import SubscriptionPlan
from license_manager.apps.subscriptions.utils import batch_counts


logger = logging.getLogger(__name__)

TASK_RETRY_SECONDS = 60
PROVISION_LICENSES_BATCH_SIZE = 300


class RequiredTaskUnreadyError(Exception):
"""
An exception representing a state where one type of task that is required
to be complete before another task is run is not in a ready state.
"""


def subscription_plan_semaphore():
"""
Celery Task decorator that wraps a bound (bind=True) task. If another wrapped task with the same given
"subscription_plan_uuid" kwarg value is still running, defer running this task (by retrying until all other tasks
are completed).
"""
def decorator(task):
@functools.wraps(task)
def wrapped_task(self, *args, **kwargs):
subscription_plan = SubscriptionPlan.objects.get(uuid=kwargs['subscription_plan_uuid'])
if not acquire_subscription_plan_lock(subscription_plan):
logger.info(
f'Deferring task {self.name} with id {self.request.id} '
f'and args: {self.request.args}, kwargs: {self.request.kwargs}, '
'since another task run with the same subscription plan has not yet completed.'
)
raise self.retry(exc=RequiredTaskUnreadyError())
# Try to run the task. If it fails, release the lock and bubble up the exception.
try:
task_results = task(self, *args, **kwargs)
finally:
release_subscription_plan_lock(subscription_plan)
return task_results
return wrapped_task
return decorator


class LoggedTaskWithRetry(LoggedTask): # pylint: disable=abstract-method
"""
Shared base task that allows tasks that raise some common exceptions to retry automatically.
See https://docs.celeryproject.org/en/stable/userguide/tasks.html#automatic-retry-for-known-exceptions for
more documentation.
"""
autoretry_for = (
IntegrityError,
OperationalError,
)
retry_kwargs = {'max_retries': 5}
# Use exponential backoff for retrying tasks
retry_backoff = True
# Add randomness to backoff delays to prevent all tasks in queue from executing simultaneously
retry_jitter = True


@shared_task(base=LoggedTaskWithRetry, bind=True, default_retry_delay=TASK_RETRY_SECONDS)
@subscription_plan_semaphore()
def provision_licenses_task(self, subscription_plan_uuid=None): # pylint: disable=unused-argument
"""
For a given subscription plan, try to make its count of licenses match the number defined by the
`desired_num_licenses` field of that subscription plan. Never decrease the count of licenses; if there are already
more licenses than `desired_num_licenses`, do nothing.
Args:
subscription_plan_uuid (str): UUID of the SubscriptionPlan object to provision licenses for.
"""
subscription_plan = SubscriptionPlan.objects.get(uuid=subscription_plan_uuid)
if not subscription_plan.desired_num_licenses:
logger.info(
f'Skipping task {self.name} with id {self.request.id} '
f'and args: {self.request.args}, kwargs: {self.request.kwargs}, '
f'because desired_num_licenses is not set on this subscription plan.'
)
return
license_count_gap = subscription_plan.desired_num_licenses - subscription_plan.num_licenses
if license_count_gap <= 0:
logger.info(
f'Skipping task {self.name} with id {self.request.id} '
f'and args: {self.request.args}, kwargs: {self.request.kwargs}, '
f'because the actual license count ({subscription_plan.num_licenses}) '
f'already meets or exceeds the desired license count ({subscription_plan.desired_num_licenses}).'
)
return

# There's work to do, creating licenses! It should be safe to not re-check the license count between batches
# because we lock this subscription plan anyway (via @subscription_plan_semaphore decorator).
for batch_count in batch_counts(license_count_gap, batch_size=PROVISION_LICENSES_BATCH_SIZE):
subscription_plan.increase_num_licenses(batch_count)
121 changes: 121 additions & 0 deletions license_manager/apps/subscriptions/tests/test_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""
Tests for subscriptions app celery tasks
"""
from datetime import datetime, timedelta
from unittest import mock
from uuid import uuid4

import ddt
import freezegun
import pytest
from braze.exceptions import BrazeClientError
from django.conf import settings
from django.test import TestCase
from django.test.utils import override_settings
from freezegun import freeze_time
from requests import models

from license_manager.apps.api.utils import (
acquire_subscription_plan_lock,
release_subscription_plan_lock,
)
from license_manager.apps.subscriptions import tasks
from license_manager.apps.subscriptions.models import License, SubscriptionPlan
from license_manager.apps.subscriptions.tests.factories import (
LicenseFactory,
SubscriptionPlanFactory,
)


# pylint: disable=unused-argument
@ddt.ddt
class ProvisionLicensesTaskTests(TestCase):
"""
Tests for provision_licenses_task.
"""
def setUp(self):
super().setUp()
self.subscription_plan = SubscriptionPlanFactory()

def tearDown(self):
super().tearDown()
release_subscription_plan_lock(self.subscription_plan)

# For all cases below, assume batch size of 5.
@ddt.data(
# Don't add licenses if none are desired.
{
'num_initial_licenses': 0,
'desired_num_licenses': 0,
'expected_num_licenses': 0,
},
# Create fewer licenses than one batch.
{
'num_initial_licenses': 0,
'desired_num_licenses': 1,
'expected_num_licenses': 1,
},
# Create licenses that span multiple batches.
{
'num_initial_licenses': 0,
'desired_num_licenses': 6,
'expected_num_licenses': 6,
},
# Don't add more licenses if the goal has already been reached.
{
'num_initial_licenses': 10,
'desired_num_licenses': 10,
'expected_num_licenses': 10,
},
# Create fewer licenses than one batch (starting with 10 initially).
{
'num_initial_licenses': 10,
'desired_num_licenses': 11,
'expected_num_licenses': 11,
},
# Create licenses that span multiple batches (starting with 10 initially).
{
'num_initial_licenses': 10,
'desired_num_licenses': 16,
'expected_num_licenses': 16,
},
# Don't remove licenses if the desired number of licenses is smaller than count of existing licenses.
{
'num_initial_licenses': 20,
'desired_num_licenses': 10,
'expected_num_licenses': 20,
},
# Don't add or remove licenses if the desired number of licenses is None.
{
'num_initial_licenses': 20,
'desired_num_licenses': None,
'expected_num_licenses': 20,
},
)
@ddt.unpack
@mock.patch('license_manager.apps.subscriptions.tasks.PROVISION_LICENSES_BATCH_SIZE', 5)
def test_provision_licenses_task(self, num_initial_licenses, desired_num_licenses, expected_num_licenses):
"""
Test provision_licenses_task.
"""
self.subscription_plan.desired_num_licenses = desired_num_licenses
self.subscription_plan.save()
self.subscription_plan.increase_num_licenses(num_initial_licenses)

tasks.provision_licenses_task(subscription_plan_uuid=self.subscription_plan.uuid)

assert self.subscription_plan.num_licenses == expected_num_licenses

def test_provision_licenses_task_locked(self):
"""
Test provision_licenses_task throws an exception if the subscription is locked.
"""
self.subscription_plan.desired_num_licenses = 5
self.subscription_plan.save()

acquire_subscription_plan_lock(self.subscription_plan)

with self.assertRaises(tasks.RequiredTaskUnreadyError):
tasks.provision_licenses_task(subscription_plan_uuid=self.subscription_plan.uuid)

assert self.subscription_plan.num_licenses == 0
52 changes: 51 additions & 1 deletion license_manager/apps/subscriptions/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import hashlib
import hmac
import uuid
from unittest import mock
from unittest import TestCase, mock

import ddt

from license_manager.apps.subscriptions import utils

Expand All @@ -28,3 +30,51 @@ def test_get_subsidy_checksum():
expected_checksum,
utils.get_subsidy_checksum(lms_user_id, course_key, license_uuid),
)


@ddt.ddt
class TestBatchCounts(TestCase):
"""
Tests for batch_counts().
"""

@ddt.data(
{
'total_count': 0,
'batch_size': 5,
'expected_batch_counts': [],
},
{
'total_count': 4,
'batch_size': 5,
'expected_batch_counts': [4],
},
{
'total_count': 5,
'batch_size': 5,
'expected_batch_counts': [5],
},
{
'total_count': 6,
'batch_size': 5,
'expected_batch_counts': [5, 1],
},
{
'total_count': 23,
'batch_size': 5,
'expected_batch_counts': [5, 5, 5, 5, 3],
},
# Just make sure something weird doesn't happen when the batch size is 1.
{
'total_count': 5,
'batch_size': 1,
'expected_batch_counts': [1, 1, 1, 1, 1],
},
)
@ddt.unpack
def test_batch_counts(self, total_count, batch_size, expected_batch_counts):
"""
Test batch_counts().
"""
actual_batch_counts = list(utils.batch_counts(total_count, batch_size=batch_size))
assert actual_batch_counts == expected_batch_counts
Loading

0 comments on commit 971eec7

Please sign in to comment.