From 4959503fe762d0f42804b2df95eba8086b0c8c2e Mon Sep 17 00:00:00 2001 From: Alex Dusenbery Date: Wed, 31 Jan 2024 11:10:53 -0500 Subject: [PATCH 1/2] feat: lazy license pagination for large plans, sometimes --- license_manager/apps/api/pagination.py | 72 +++++++++++++++++++ .../apps/api/v1/tests/test_views.py | 29 ++++++++ license_manager/apps/api/v1/views.py | 59 +++++++++------ 3 files changed, 139 insertions(+), 21 deletions(-) create mode 100644 license_manager/apps/api/pagination.py diff --git a/license_manager/apps/api/pagination.py b/license_manager/apps/api/pagination.py new file mode 100644 index 00000000..2b5dd79d --- /dev/null +++ b/license_manager/apps/api/pagination.py @@ -0,0 +1,72 @@ +""" +Defines custom paginators used by subscription viewsets. +""" +from django.core.paginator import Paginator as DjangoPaginator +from django.utils.functional import cached_property +from rest_framework.pagination import PageNumberPagination + + +class PageNumberPaginationWithCount(PageNumberPagination): + """ + A PageNumber paginator that adds the total number of pages to the paginated response. + """ + + def get_paginated_response(self, data): + """ Adds a ``num_pages`` field into the paginated response. """ + response = super().get_paginated_response(data) + response.data['num_pages'] = self.page.paginator.num_pages + return response + + +class LicensePagination(PageNumberPaginationWithCount): + """ + A PageNumber paginator that allows the client to specify the page size, up to some maximum. + """ + page_size_query_param = 'page_size' + max_page_size = 500 + + +class EstimatedCountDjangoPaginator(DjangoPaginator): + """ + A lazy paginator that determines it's count from + the upstream `estimated_count` + """ + def __init__(self, *args, estimated_count=None, **kwargs): + self.estimated_count = estimated_count + super().__init__(*args, **kwargs) + + @cached_property + def count(self): + if self.estimated_count is None: + return super().count + return self.estimated_count + + +class EstimatedCountLicensePagination(LicensePagination): + """ + Allows the caller (probably the `paginator()` property + of an upstream Viewset) to provided an `estimated_count`, + which means the downstream django paginator does *not* + perform an additional query to get the count of the queryset. + """ + def __init__(self, *args, estimated_count=None, **kwargs): + """ + Optionally stores an `estimated_count` to pass along + to `EstimatedCountDjangoPaginator`. + """ + self.estimated_count = estimated_count + super().__init__(*args, **kwargs) + + def django_paginator_class(self, queryset, page_size): + """ + This only works because the implementation of `paginate_queryset` + treats `self.django_paginator_class` as if it is simply a callable, + and not necessarily a class, that returns a Django Paginator instance. + + It also (safely) relies on `self` having an instance variable called `estimated_count`. + """ + if self.estimated_count is not None: + return EstimatedCountDjangoPaginator( + queryset, page_size, estimated_count=self.estimated_count, + ) + return DjangoPaginator(queryset, page_size) diff --git a/license_manager/apps/api/v1/tests/test_views.py b/license_manager/apps/api/v1/tests/test_views.py index d1150d0d..9c7efc8a 100644 --- a/license_manager/apps/api/v1/tests/test_views.py +++ b/license_manager/apps/api/v1/tests/test_views.py @@ -36,6 +36,9 @@ LEARNER_ROLES, SUBSCRIPTION_RENEWAL_DAYS_OFFSET, ) +from license_manager.apps.api.v1.views import ( + ESTIMATED_COUNT_PAGINATOR_THRESHOLD, +) from license_manager.apps.core.models import User from license_manager.apps.subscriptions import constants from license_manager.apps.subscriptions.exceptions import LicenseRevocationError @@ -788,6 +791,32 @@ def test_license_list_staff_user_200_custom_page_size(api_client, staff_user): assert response.data['next'] is not None +@pytest.mark.django_db +def test_license_list_staff_user_200_estimated_license_count(api_client, staff_user): + subscription, _, _, _, _ = _subscription_and_licenses() + _assign_role_via_jwt_or_db( + api_client, + staff_user, + subscription.enterprise_customer_uuid, + True, + ) + subscription.desired_num_licenses = ESTIMATED_COUNT_PAGINATOR_THRESHOLD + 1 + subscription.save() + + response = _licenses_list_request( + api_client, subscription.uuid, page_size=1, + status=','.join([constants.UNASSIGNED, constants.ASSIGNED, constants.ACTIVATED]), + ) + + assert status.HTTP_200_OK == response.status_code + results_by_uuid = {item['uuid']: item for item in response.data['results']} + # We test for content in the test above, + # we're only worried about the response count matching `desired_num_licenses` here. + assert len(results_by_uuid) == 1 + assert response.data['count'] == ESTIMATED_COUNT_PAGINATOR_THRESHOLD + 1 + assert response.data['next'] is not None + + @pytest.mark.django_db def test_license_list_ignore_null_emails_query_param(api_client, staff_user, boolean_toggle): """ diff --git a/license_manager/apps/api/v1/views.py b/license_manager/apps/api/v1/views.py index 8c0b8940..f10ec41b 100644 --- a/license_manager/apps/api/v1/views.py +++ b/license_manager/apps/api/v1/views.py @@ -22,7 +22,6 @@ from rest_framework.decorators import action from rest_framework.exceptions import ParseError from rest_framework.mixins import ListModelMixin -from rest_framework.pagination import PageNumberPagination from rest_framework.response import Response from rest_framework.views import APIView from rest_framework_csv.renderers import CSVRenderer @@ -68,11 +67,15 @@ localized_utcnow, ) +from ..pagination import EstimatedCountLicensePagination, LicensePagination + logger = logging.getLogger(__name__) ASSIGNMENT_LOCK_TIMEOUT_SECONDS = 300 +ESTIMATED_COUNT_PAGINATOR_THRESHOLD = 10000 + class CustomerAgreementViewSet( PermissionRequiredForListingMixin, @@ -485,26 +488,6 @@ def base_queryset(self): return licenses -class PageNumberPaginationWithCount(PageNumberPagination): - """ - A PageNumber paginator that adds the total number of pages to the paginated response. - """ - - def get_paginated_response(self, data): - """ Adds a ``num_pages`` field into the paginated response. """ - response = super().get_paginated_response(data) - response.data['num_pages'] = self.page.paginator.num_pages - return response - - -class LicensePagination(PageNumberPaginationWithCount): - """ - A PageNumber paginator that allows the client to specify the page size, up to some maximum. - """ - page_size_query_param = 'page_size' - max_page_size = 500 - - class BaseLicenseViewSet(PermissionRequiredForListingMixin, viewsets.ReadOnlyModelViewSet): """ Base Viewset for read operations on individual licenses in a given subscription plan. @@ -609,6 +592,40 @@ class LicenseAdminViewSet(BaseLicenseViewSet): pagination_class = LicensePagination + @property + def paginator(self): + # pylint: disable=line-too-long + """ + If the caller has requested all usable licenses, we want to fall + back to grabbing the paginator's count from ``SubscriptionPlan.desired_num_licenses`` + for large plans, as determining the count dynamically is an expensive query. + + This is the only way to dynamically select a pagination class in DRF. + https://github.com/encode/django-rest-framework/issues/6397 + + Underlying implementation of the paginator() property: + https://github.com/encode/django-rest-framework/blob/7749e4e3bed56e0f3e7775b0b1158d300964f6c0/rest_framework/generics.py#L156 + """ + if hasattr(self, '_paginator'): + return self._paginator + + # If we don't have a subscription plan, or the requested + # status values aren't for all usable licenses, fall back to + # the normal LicensePagination class + self._paginator = super().paginator # pylint: disable=attribute-defined-outside-init + + usable_license_states = {constants.UNASSIGNED, constants.ASSIGNED, constants.ACTIVATED} + if value := self.request.query_params.get('status'): + status_values = value.strip().split(',') + if set(status_values) == usable_license_states: + if subscription_plan := self._get_subscription_plan(): + estimated_count = subscription_plan.desired_num_licenses + if estimated_count is not None and estimated_count > ESTIMATED_COUNT_PAGINATOR_THRESHOLD: + # pylint: disable=attribute-defined-outside-init + self._paginator = EstimatedCountLicensePagination(estimated_count=estimated_count) + + return self._paginator + @property def active_only(self): return int(self.request.query_params.get('active_only', 0)) From 79e3bfd5d0f744de8fe9134083aee469cf0e48dd Mon Sep 17 00:00:00 2001 From: Alex Dusenbery Date: Fri, 19 Jan 2024 10:29:05 -0500 Subject: [PATCH 2/2] feat: script for local, large-scale license assignment ENT-8270 | Adds a script to read an input file of email addresses and make synchronous calls to the license-assignment endpoint to assign licenses for those email addresses within a particular plan. fetch jwt from client id/secret before assignment request --- .gitignore | 2 - scripts/local_assignment.py | 247 ++++++++++++++++++++++ scripts/local_assignment_requirements.txt | 2 + 3 files changed, 249 insertions(+), 2 deletions(-) create mode 100644 scripts/local_assignment.py create mode 100644 scripts/local_assignment_requirements.txt diff --git a/.gitignore b/.gitignore index 82a71bb8..fd19169e 100644 --- a/.gitignore +++ b/.gitignore @@ -92,8 +92,6 @@ private.py docs/_build/ -scripts/ - .vscode/ .dev/ diff --git a/scripts/local_assignment.py b/scripts/local_assignment.py new file mode 100644 index 00000000..6f1f3495 --- /dev/null +++ b/scripts/local_assignment.py @@ -0,0 +1,247 @@ +""" +Script designed for local execution. +Reads a CSV file of email addresses and target subscription plan uuid +as input, then chunks those up in calls to the ``assign`` view. + +To use: +``` +# os environ names are meaningful and should correspond to the requested environment +# this allows us to fetch a JWT before each request, so you don't have to +# worry about your JWT expiring in the middle of the script execution. +export CLIENT_SECRET_LOCAL=[your-client-secret] +export CLIENT_ID_LOCAL=[your-client-id] + +pip install -r scripts/local_assignment_requirements.txt + +python local_assignment.py \ + --input-file=your-input-file.csv \ + --subscription-plan-uuid=[the-plan-uuid] \ + --output-file=local-assignment-output.csv \ + --chunk-size=10 \ + --environment=local \ + --sleep-interval=5 \ + --fetch-jwt +``` + +Options: +* ``input-file`` is your input file - it should be a single-column csv +(or just a list delimited by newlines, really) of valid email addresses. This +script does not attempt to do any validation. Required. + +* ``subscription-plan-uuid`` is the uuid of the plan to assign license to. Required. + +* ``output-file`` is where results of the call to the assignment view are stored. +It'll be a CSV with three columns: the chunk id, email address, and assigned license uuid. + +* ``chunk-size`` is how many emails will be contained in each chunk. Default is 100. + +* ``environment`` Which environment to execute against. Choices are 'local', 'stage', or 'prod'. + +* ``sleep-interval`` is useful for not overwhelming the license-manager celery broker. +The assignment endpoints causes several different asychronous tasks to be submitted +downstream of successful assignment. +""" +import csv +import json +import os +import time + +import click +import requests + + +DEFAULT_CHUNK_SIZE = 100 + +DEFAULT_SLEEP_INTERVAL = 0.5 + +ENVIRONMENTS = { + 'local': 'http://localhost:18170/api/v1/subscriptions/{subscription_plan_uuid}/licenses/assign/', + 'stage': 'https://license-manager.stage.edx.org/api/v1/subscriptions/{subscription_plan_uuid}/licenses/assign/', + 'prod': 'https://license-manager.edx.org/api/v1/subscriptions/{subscription_plan_uuid}/licenses/assign/', +} + +ACCESS_TOKEN_URL_BY_ENVIRONMENT = { + 'local': 'http://localhost:18000/oauth2/access_token/', + 'stage': 'https://courses.stage.edx.org/oauth2/access_token/', + 'prod': 'https://courses.edx.org/oauth2/access_token/', +} + +def _get_jwt(fetch_jwt=False, environment='local'): + if fetch_jwt: + client_id = os.environ.get(f'CLIENT_ID_{environment}'.upper()) + client_secret = os.environ.get(f'CLIENT_SECRET_{environment}'.upper()) + assert client_id and client_secret, 'client_id and client_secret must be set if fetch_jwt is true' + request_payload = { + 'client_id': client_id, + 'client_secret': client_secret, + 'grant_type': 'client_credentials', + 'token_type': 'jwt', + } + # we want to sent with a Content-Type of 'application/x-www-form-urlencoded' + # so send in the `data` param instead of `json`. + response = requests.post( + ACCESS_TOKEN_URL_BY_ENVIRONMENT.get(environment), + data=request_payload, + headers={'Content-Type': 'application/x-www-form-urlencoded'}, + ) + response.raise_for_status() + return response.json().get('access_token') + else: + return os.environ.get('LICENSE_MANAGER_JWT') + + +def get_already_processed_emails(results_file): + """ + Reads a headerless CSV with rows representing `chunk_id,email,assigned_license_uuid` + and returns a dictionary mapping already processed emails to their chunk_id. + """ + already_processed_emails = {} + with open(results_file, 'r') as f_in: + reader = csv.reader(f_in, delimiter=',') + for (chunk_id, email, license_uuid) in reader: + already_processed_emails[email] = chunk_id + return already_processed_emails + + +def get_email_chunks(input_file_path, chunk_size=DEFAULT_CHUNK_SIZE): + """ + Yield chunks of email addresses from the given input file. Given the same input file and chunk_size, + this will always yield rows with the same chunk id for each provided email. + """ + current_chunk = [] + chunk_id = 0 + with open(input_file_path, 'r') as f_in: + reader = csv.reader(f_in, delimiter=',') + for row in reader: + email = row[0] + current_chunk.append(email) + if len(current_chunk) == chunk_size: + yield chunk_id, current_chunk + current_chunk = [] + chunk_id += 1 + + if current_chunk: + yield chunk_id, current_chunk + + +def request_assignments(subscription_plan_uuid, chunk_id, emails_for_chunk, environment='local', fetch_jwt=False): + """ + Makes the request to the ``assign`` endpoint for the given subscription plan + to assign liceses for `emails_for_chunk`. + """ + print('\nSending assignment request for chunk id', chunk_id, 'with num emails', len(emails_for_chunk)) + + url_pattern = ENVIRONMENTS[environment] + url = url_pattern.format(subscription_plan_uuid=subscription_plan_uuid) + + payload = { + 'user_emails': emails_for_chunk, + 'notify_users': False, + } + headers = { + "Authorization": "JWT {}".format(_get_jwt(fetch_jwt, environment=environment)), + } + + response = requests.post(url, json=payload, headers=headers) + + response.raise_for_status() + response_data = response.json() + + results_for_chunk = [] + for assignment in response_data['license_assignments']: + results_for_chunk.append([str(chunk_id), assignment['user_email'], str(assignment['license'])]) + + print('Num assigned by assignment API:', response_data['num_successful_assignments']) + print('Num already associated from assignment API:', response_data['num_already_associated']) + print('Successfully sent assignment request for chunk id', chunk_id, 'with num emails', len(results_for_chunk)) + + return results_for_chunk + + +def do_assignment_for_chunk( + subscription_plan_uuid, chunk_id, email_chunk, + results_file, environment='local', fetch_jwt=False, sleep_interval=DEFAULT_SLEEP_INTERVAL +): + """ + Given a "chunk" list emails for which assignments should be requested, checks if the given + email has already been processed for the given email. If not, adds it to a list for this + chunk to be requested, then requests license assignment in the given subscription plan. + On successful request, appends results including chunk id, email, and license uuid + to results_file. + """ + already_processed = {} + if results_file: + already_processed = get_already_processed_emails(results_file) + + payload_for_chunk = [] + for email in email_chunk: + if email in already_processed: + continue + payload_for_chunk.append(email) + + results_for_chunk = [] + if payload_for_chunk: + results_for_chunk = request_assignments( + subscription_plan_uuid, chunk_id, payload_for_chunk, environment, fetch_jwt, + ) + with open(results_file, 'a') as f_out: + writer = csv.writer(f_out, delimiter=',') + writer.writerows(results_for_chunk) + if sleep_interval: + print(f'Sleeping for {sleep_interval} seconds.') + time.sleep(sleep_interval) + else: + print('No assignments need to be made for chunk_id', chunk_id, 'with size', len(email_chunk)) + + +@click.command() +@click.option( + '--input-file', + help='Path of local file containing email addresses to assign.', +) +@click.option( + '--subscription-plan-uuid', + help='Subscription plan to which licenses should be assigned.', +) +@click.option( + '--output-file', + default=None, + help='CSV file of emails that we have processed.', +) +@click.option( + '--chunk-size', + help='Size of email chunks to operate on.', + default=DEFAULT_CHUNK_SIZE, + show_default=True, +) +@click.option( + '--environment', + help='Which environment to operate in.', + default='local', + type=click.Choice(['local', 'stage', 'prod'], case_sensitive=False), + show_default=True, +) +@click.option( + '--sleep-interval', + help='How long, in seconds, to sleep between each chunk.', + default=DEFAULT_SLEEP_INTERVAL, + show_default=True, +) +@click.option( + '--fetch-jwt', + help='Whether to fetch JWT based on stored client id and secret.', + is_flag=True, +) + +def run(input_file, subscription_plan_uuid, output_file, chunk_size, environment, sleep_interval, fetch_jwt): + """ + Entry-point for this script. + """ + for chunk_id, email_chunk in get_email_chunks(input_file, chunk_size): + do_assignment_for_chunk( + subscription_plan_uuid, chunk_id, email_chunk, + output_file, environment, fetch_jwt, sleep_interval, + ) + +if __name__ == '__main__': + run() diff --git a/scripts/local_assignment_requirements.txt b/scripts/local_assignment_requirements.txt new file mode 100644 index 00000000..0d8c96eb --- /dev/null +++ b/scripts/local_assignment_requirements.txt @@ -0,0 +1,2 @@ +click +requests