Skip to content

Commit

Permalink
Merge branch 'master' of github.com:openedx/license-manager into ENT-…
Browse files Browse the repository at this point in the history
…8306/make-var-field-optional
  • Loading branch information
hamzawaleed01 committed Feb 14, 2024
2 parents 8a27968 + 2cc9228 commit c7128ed
Show file tree
Hide file tree
Showing 12 changed files with 942 additions and 11 deletions.
15 changes: 13 additions & 2 deletions license_manager/apps/api/mixins.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from functools import cached_property

from rest_framework.exceptions import ParseError

from license_manager.apps.api import utils
from license_manager.apps.api_client.lms import LMSApiClient


class UserDetailsFromJwtMixin:
Expand All @@ -18,9 +21,17 @@ def decoded_jwt(self):

return utils.get_decoded_jwt(self.request)

@property
@cached_property
def lms_user_id(self):
return utils.get_key_from_jwt(self.decoded_jwt, 'user_id')
"""
Retrieve the LMS user ID.
"""
try:
return utils.get_key_from_jwt(self.decoded_jwt, 'user_id')
except ParseError:
lms_client = LMSApiClient()
user_id = lms_client.fetch_lms_user_id(self.request.user.email)
return user_id

@property
def user_email(self):
Expand Down
2 changes: 2 additions & 0 deletions license_manager/apps/api/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from django.db import IntegrityError, transaction
from django.db.utils import OperationalError
from requests.exceptions import ConnectionError as RequestsConnectionError
from requests.exceptions import HTTPError
from requests.exceptions import JSONDecodeError as RequestsJSONDecodeError
from requests.exceptions import Timeout as RequestsTimeoutError

Expand Down Expand Up @@ -70,6 +71,7 @@ class LoggedTaskWithRetry(LoggedTask): # pylint: disable=abstract-method
IntegrityError,
OperationalError,
BrazeClientError,
HTTPError,
)
retry_kwargs = {'max_retries': 3}
# Use exponential backoff for retrying tasks
Expand Down
27 changes: 22 additions & 5 deletions license_manager/apps/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import urllib
import uuid
from collections import defaultdict

import boto3
from django.http import Http404
Expand Down Expand Up @@ -151,12 +152,22 @@ def check_missing_licenses(customer_agreement, user_emails, course_run_keys, sub

logger.info('[check_missing_licenses] Starting to iterate over all `user_emails`...')

# Map licenses by email across all user_emails in a single DB query.
# Also, join the plans into the queryset, so that we don't do
# one query per license down in the loops below.
license_queryset = License.objects.filter(
subscription_plan__in=subscription_plan_filter,
user_email__in=user_emails,
).select_related(
'subscription_plan',
)
licenses_by_email = defaultdict(list)
for license_record in license_queryset:
licenses_by_email[license_record.user_email].append(license_record)

for email in set(user_emails):
logger.info(f'[check_missing_licenses] handling user email {email}')
filtered_licenses = License.objects.filter(
subscription_plan__in=subscription_plan_filter,
user_email=email,
)
filtered_licenses = licenses_by_email.get(email, [])

logger.info('[check_missing_licenses] user licenses for email %s: %s', email, filtered_licenses)

Expand All @@ -173,11 +184,17 @@ def check_missing_licenses(customer_agreement, user_emails, course_run_keys, sub
logger.info('[check_missing_licenses] handling user license %s', str(user_license.uuid))
subscription_plan = user_license.subscription_plan
plan_key = f'{subscription_plan.uuid}_{course_key}'

# TODO AED 2024-02-09: I think this chunk of code is defective.
# It's only mapping plan ids to booleans, but what we really want
# to know is, for each plan *and course*, if the plan's associated catalog
# contains the course.
if plan_key in subscription_plan_course_map:
plan_contains_content = subscription_plan_course_map.get(plan_key)
else:
plan_contains_content = subscription_plan.contains_content([course_key])
subscription_plan_course_map[plan_key] = plan_contains_content

logger.info(
'[check_missing_licenses] does plan (%s) contain content?: %s',
str(subscription_plan.uuid),
Expand All @@ -189,7 +206,7 @@ def check_missing_licenses(customer_agreement, user_emails, course_run_keys, sub
'course_run_key': course_key,
'license_uuid': str(user_license.uuid)
}
# assigned, not yet activated, incliude activation URL
# assigned, not yet activated, include activation URL
if user_license.status == constants.ASSIGNED:
this_enrollment['activation_link'] = get_license_activation_link(
enterprise_slug,
Expand Down
14 changes: 10 additions & 4 deletions license_manager/apps/api/v1/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3227,17 +3227,23 @@ def test_get_subsidy_missing_course_key(self):
assert response.status_code == status.HTTP_400_BAD_REQUEST

@mock.patch('license_manager.apps.api.v1.views.utils.get_decoded_jwt')
def test_get_subsidy_no_jwt(self, mock_get_decoded_jwt):
@mock.patch('license_manager.apps.api.mixins.LMSApiClient')
def test_get_subsidy_no_jwt(self, MockLMSApiClient, mock_get_decoded_jwt):
"""
Verify the view returns a 400 if the user_id could not be found in the JWT.
Verify the view makes an API call to fetch lmsUserId if user_id could not be found in the JWT.
"""
self._assign_learner_roles()
mock_get_decoded_jwt.return_value = {}
url = self._get_url_with_params()

# Mock the behavior of LMSApiClient to return a sample user ID
mock_lms_client = MockLMSApiClient.return_value
mock_lms_client.fetch_lms_user_id.return_value = 443
response = self.api_client.get(url)
assert status.HTTP_404_NOT_FOUND == response.status_code

assert status.HTTP_400_BAD_REQUEST == response.status_code
assert '`user_id` is required and could not be found in your jwt' in str(response.content)
# Assert that the LMSApiClient.fetch_lms_user_id method was called once with the correct argument
mock_lms_client.fetch_lms_user_id.assert_called_once_with(self.user.email)

def test_get_subsidy_no_subscription_for_enterprise_customer(self):
"""
Expand Down
43 changes: 43 additions & 0 deletions license_manager/apps/api_client/lms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import logging

import requests
from django.conf import settings

from license_manager.apps.api_client.base_oauth import BaseOAuthClient


logger = logging.getLogger(__name__)


class LMSApiClient(BaseOAuthClient):
"""
API client for calls to the LMS.
"""
api_base_url = settings.LMS_URL
user_details_endpoint = api_base_url + '/api/user/v1/accounts'

def fetch_lms_user_id(self, email):
"""
Fetch user details for the specified user email.
Arguments:
email (str): Email of the user for which we want to fetch details for.
Returns:
str: lms_user_id of the user.
"""
# {base_api_url}/api/user/v1/[email protected]
try:
query_params = {'email': email}
response = self.client.get(self.user_details_endpoint, params=query_params)
response.raise_for_status()
response_json = response.json()
return response_json[0].get('id')
except requests.exceptions.HTTPError as exc:
logger.error(
'Failed to fetch user details for user {email} because {reason}'.format(
email=email,
reason=str(exc),
)
)
raise exc
14 changes: 14 additions & 0 deletions license_manager/apps/subscriptions/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from uuid import uuid4

from django.conf import settings
from django.core.cache import cache
from django.core.serializers.json import DjangoJSONEncoder
from django.core.validators import MinLengthValidator
from django.db import models, transaction
Expand Down Expand Up @@ -62,6 +63,10 @@

logger = getLogger(__name__)

CONTAINS_CONTENT_CACHE_TIMEOUT = 60 * 60

_CACHE_MISS = object()


class CustomerAgreement(TimeStampedModel):
"""
Expand Down Expand Up @@ -736,13 +741,22 @@ def contains_content(self, content_ids):
Returns:
bool: Whether the given content_ids are part of the subscription.
"""
cache_key = self.get_contains_content_cache_key(content_ids)
cached_value = cache.get(cache_key, _CACHE_MISS)
if cached_value is not _CACHE_MISS:
return cached_value

enterprise_catalog_client = EnterpriseCatalogApiClient()
content_in_catalog = enterprise_catalog_client.contains_content_items(
self.enterprise_catalog_uuid,
content_ids,
)
cache.set(cache_key, content_in_catalog, timeout=CONTAINS_CONTENT_CACHE_TIMEOUT)
return content_in_catalog

def get_contains_content_cache_key(self, content_ids):
return f'plan_contains_content:{self.uuid}:{content_ids}'

history = HistoricalRecords()

class Meta:
Expand Down
9 changes: 9 additions & 0 deletions license_manager/apps/subscriptions/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import ddt
import freezegun
import pytest
from django.core.cache import cache
from django.forms import ValidationError
from django.test import TestCase
from requests.exceptions import HTTPError
Expand Down Expand Up @@ -55,7 +56,15 @@ def test_contains_content(self, contains_content, mock_enterprise_catalog_client
# Mock the value from the enterprise catalog client
mock_enterprise_catalog_client().contains_content_items.return_value = contains_content
content_ids = ['test-key', 'another-key']

cache.delete(self.subscription_plan.get_contains_content_cache_key(content_ids))

assert self.subscription_plan.contains_content(content_ids) == contains_content

# call it again to utilize the cache
assert self.subscription_plan.contains_content(content_ids) == contains_content

# ...but assert we only used the catalog client once
mock_enterprise_catalog_client().contains_content_items.assert_called_with(
self.subscription_plan.enterprise_catalog_uuid,
content_ids,
Expand Down
72 changes: 72 additions & 0 deletions scripts/assignment_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
Script to help validate input file before
consumption by ``local_assignment_multi.py``
To use:
```
pip install -r scripts/local_assignment_requirements.txt
python assignment_validation.py print_duplicates --input-file=your-input-file.csv
# or
python assignment_validation.py print_plan_counts --input-file=your-input-file.csv
"""
import csv
from collections import defaultdict, Counter

import click

INPUT_FIELDNAMES = ['email', 'university_name']


def _iterate_csv(input_file):
with open(input_file, 'r') as f_in:
reader = csv.DictReader(f_in, fieldnames=INPUT_FIELDNAMES, delimiter=',')
# read and skip the header
next(reader, None)
breakpoint()
for row in reader:
yield row


@click.command()
@click.option(
'--input-file',
help='Path of local file containing email addresses to assign.',
)
def print_duplicates(input_file):
unis_by_email = defaultdict(list)
for row in _iterate_csv(input_file):
unis_by_email[row['email']].append(row['university_name'])

for email, uni_list in unis_by_email.items():
if len(uni_list) > 1:
print(email, uni_list)


@click.command()
@click.option(
'--input-file',
help='Path of local file containing email addresses to assign.',
)
def print_plan_counts(input_file):
counts_by_plan = Counter()
for row in _iterate_csv(input_file):
counts_by_plan[row['university_name']] += 1

for plan, count in counts_by_plan.items():
print(plan, count)


@click.group()
def run():
pass


run.add_command(print_duplicates)
run.add_command(print_plan_counts)


if __name__ == '__main__':
run()
76 changes: 76 additions & 0 deletions scripts/generate_csvs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""
Helper to generate assignment input
CSVs of fake email data.
"""
import csv
import math

import click

DEFAULT_EMAIL_TEMPLATE = 'testuser+{}@example.com'


def generate_multi_plan_input(
subscription_plan_identifiers, number_in_plan,
email_template, filename, subscription_plan_fieldname='university_name',
):
total = len(subscription_plan_identifiers) * sum(number_in_plan)
order_mag = math.ceil(math.log(total, 10))

with open(filename, 'w') as file_out:
fieldnames = ['email', subscription_plan_fieldname]
writer = csv.DictWriter(file_out, fieldnames)
writer.writeheader()

# This offset helps us generate emails
# that are unique across all sub plan identifiers that we iterate.
offset = 0
for plan_id, num_emails_for_plan in zip(subscription_plan_identifiers, number_in_plan):
for index in range(offset, offset + num_emails_for_plan):
email = email_template.format(
str(index).zfill(order_mag)
)
writer.writerow({'email': email, subscription_plan_fieldname: plan_id})
offset = index + 1


@click.command
@click.option(
'--subscription-plan-identifier', '-s',
multiple=True,
help='One or more subscription plan identifier, comma-separated. Could be a uuid or an external name.',
)
@click.option(
'--subscription-plan-fieldname', '-n',
help='Name of output field corresponding to subscription plans',
default='university_name',
show_default=True,
)
@click.option(
'--number-in-plan', '-n',
multiple=True,
help='One or more: Number of emails to generate in each plan.',
show_default=True,
)
@click.option(
'--email-template',
default=DEFAULT_EMAIL_TEMPLATE,
help='Optional python string template to use for email address generation, must take exactly one argument',
)
@click.option(
'--filename',
help='Where to write the generated file.',
)
def run(
subscription_plan_identifier, subscription_plan_fieldname, number_in_plan,
email_template, filename,
):
number_in_plan = [int(s) for s in number_in_plan]
generate_multi_plan_input(
subscription_plan_identifier, number_in_plan, email_template,
filename, subscription_plan_fieldname,
)


if __name__ == '__main__':
run()
Loading

0 comments on commit c7128ed

Please sign in to comment.