diff --git a/conftest.py b/conftest.py index 8e42ffd..13f674c 100644 --- a/conftest.py +++ b/conftest.py @@ -1,7 +1,11 @@ +import base64 + import pytest +from oauth2_provider.models import Application +from rest_framework.test import APIClient from users.factories import UserFactory, FCMDeviceFactory - +from messaging.factories import ServerFactory @pytest.fixture def user(db): @@ -11,3 +15,42 @@ def user(db): @pytest.fixture def fcm_device(user): return FCMDeviceFactory(user=user) + + +@pytest.fixture +def api_client(): + return APIClient() + + +@pytest.fixture +def auth_device(user, api_client): + """ + Create the Basic Authentication credentials for the test user. + """ + credentials = f"{user.username}:testpass".encode("utf-8") + base64_credentials = base64.b64encode(credentials).decode("utf-8") + cred = f"Basic {base64_credentials}" + api_client.credentials(HTTP_AUTHORIZATION=cred) + return api_client + + +@pytest.fixture +def oauth_app(user): + application = Application( + name="Test Application", + redirect_uris="http://localhost", + user=user, + client_type=Application.CLIENT_CONFIDENTIAL, + authorization_grant_type=Application.GRANT_CLIENT_CREDENTIALS, + ) + application.raw_client_secret = application.client_secret + application.save() + return application + + +@pytest.fixture +def authed_client(api_client, oauth_app): + auth = f'{oauth_app.client_id}:{oauth_app.raw_client_secret}'.encode('utf-8') + credentials = base64.b64encode(auth).decode('utf-8') + api_client.defaults['HTTP_AUTHORIZATION'] = 'Basic ' + credentials + return api_client diff --git a/connectid/__init__.py b/connectid/__init__.py index e69de29..10f5014 100644 --- a/connectid/__init__.py +++ b/connectid/__init__.py @@ -0,0 +1,5 @@ +# This will make sure the app is always imported when +# Django starts so that shared_task will use this app. +from .celery_app import app as celery_app + +__all__ = ("celery_app",) diff --git a/connectid/celery_app.py b/connectid/celery_app.py new file mode 100644 index 0000000..7404f70 --- /dev/null +++ b/connectid/celery_app.py @@ -0,0 +1,17 @@ +import os + +from celery import Celery + +# set the default Django settings module for the 'celery' program. +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "connectid.settings") + +app = Celery("connectid") + +# Using a string here means the worker doesn't have to serialize +# the configuration object to child processes. +# - namespace='CELERY' means all celery-related configuration keys +# should have a `CELERY_` prefix. +app.config_from_object("django.conf:settings", namespace="CELERY") + +# Load task modules from all registered Django app configs. +app.autodiscover_tasks() diff --git a/connectid/settings.py b/connectid/settings.py index 5eef7c0..34393c8 100644 --- a/connectid/settings.py +++ b/connectid/settings.py @@ -16,6 +16,7 @@ # Build paths inside the project like this: BASE_DIR / 'subdir'. BASE_DIR = Path(__file__).resolve().parent.parent +env = os.environ # Quick-start development settings - unsuitable for production # See https://docs.djangoproject.com/en/4.1/howto/deployment/checklist/ @@ -34,6 +35,7 @@ 'users.apps.UsersConfig', 'messaging', 'oauth2_provider', + 'payments', 'rest_framework', 'axes', 'fcm_django', @@ -63,7 +65,7 @@ TEMPLATES = [ { 'BACKEND': 'django.template.backends.django.DjangoTemplates', - 'DIRS': [], + "DIRS": [BASE_DIR / "templates"], 'APP_DIRS': True, 'OPTIONS': { 'context_processors': [ @@ -79,7 +81,11 @@ WSGI_APPLICATION = 'connectid.wsgi.application' - +TRUSTED_COMMCAREHQ_HOSTS = [ + "www.commcarehq.org", + "commcarehq.org", + "staging.commcarehq.org", +] # Password validation # https://docs.djangoproject.com/en/4.1/ref/settings/#auth-password-validators @@ -93,7 +99,6 @@ }, ] - # Internationalization # https://docs.djangoproject.com/en/4.1/topics/i18n/ @@ -105,7 +110,6 @@ USE_TZ = True - # Static files (CSS, JavaScript, Images) # https://docs.djangoproject.com/en/4.1/howto/static-files/ @@ -214,8 +218,12 @@ "DELETE_INACTIVE_DEVICES": False, } +OAUTH2_PROVIDER_APPLICATION_MODEL = 'oauth2_provider.Application' + SITE_ID = 1 +APP_HASH = "apphash" + from .localsettings import * # Firebase @@ -223,3 +231,10 @@ from firebase_admin import credentials, initialize_app creds = credentials.Certificate(FCM_CREDENTIALS) default_app = initialize_app(credential=creds) + +CELERY_TASK_ALWAYS_EAGER = True +CELERY_TASK_EAGER_PROPAGATES = True + +CELERY_BROKER_URL = env.get("CELERY_BROKER_URL", default="redis://localhost:6379/0") + + diff --git a/connectid/urls.py b/connectid/urls.py index 9bf9497..afab018 100644 --- a/connectid/urls.py +++ b/connectid/urls.py @@ -15,11 +15,14 @@ """ from django.contrib import admin from django.urls import include, path +from django.views.generic import TemplateView +from . import views urlpatterns = [ path('users/', include('users.urls')), path('messaging/', include('messaging.urls')), path('admin/', admin.site.urls), path('o/', include('oauth2_provider.urls', namespace='oauth2_provider')), + path('hq_invite/', TemplateView.as_view(template_name="connectid/deeplink.html"), name='deeplink'), + path('.well-known/assetlinks.json', views.assetlinks_json, name='assetlinks_json'), ] - diff --git a/connectid/views.py b/connectid/views.py new file mode 100644 index 0000000..b5cb30e --- /dev/null +++ b/connectid/views.py @@ -0,0 +1,30 @@ +from django.http import HttpResponse, JsonResponse + + +def assetlinks_json(request): + assetfile = [ + { + "relation": ["delegate_permission/common.handle_all_urls"], + "target": { + "namespace": "android_app", + "package_name": "org.commcare.dalvik", + "sha256_cert_fingerprints": + [ + "88:57:18:F8:E8:7D:74:04:97:AE:83:65:74:ED:EF:10:40:D9:4C:E2:54:F0:E0:40:64:77:96:7F:D1:39:F9:81", + "89:55:DF:D8:0E:66:63:06:D2:6D:88:A4:A3:88:A4:D9:16:5A:C4:1A:7E:E1:C6:78:87:00:37:55:93:03:7B:03" + ] + } + }, + { + "relation": ["delegate_permission/common.handle_all_urls"], + "target": { + "namespace": "android_app", + "package_name": "org.commcare.dalvik.debug", + "sha256_cert_fingerprints": + [ + "88:57:18:F8:E8:7D:74:04:97:AE:83:65:74:ED:EF:10:40:D9:4C:E2:54:F0:E0:40:64:77:96:7F:D1:39:F9:81" + ] + } + }, + ] + return JsonResponse(assetfile, safe=False) diff --git a/messaging/admin.py b/messaging/admin.py new file mode 100644 index 0000000..9fdf7ff --- /dev/null +++ b/messaging/admin.py @@ -0,0 +1,10 @@ +from django.contrib import admin + +from .models import MessageServer + + +@admin.register(MessageServer) +class MessageServerAdmin(admin.ModelAdmin): + list_display = ('name', 'key_url', 'callback_url', 'delivery_url', 'consent_url', 'server_id', 'secret_key') + search_fields = ('name',) + diff --git a/messaging/factories.py b/messaging/factories.py new file mode 100644 index 0000000..b2aa590 --- /dev/null +++ b/messaging/factories.py @@ -0,0 +1,68 @@ +import base64 +import os +from uuid import uuid4 + +import factory +from django.utils import timezone +from factory import LazyFunction +from factory.django import DjangoModelFactory +from oauth2_provider.models import Application + +from messaging.models import Channel, Message, MessageServer +from users.factories import UserFactory + + +class ApplicationFactory(DjangoModelFactory): + class Meta: + model = Application + + client_id = factory.Faker("uuid4") + client_secret = factory.Faker("uuid4") + client_type = "confidential" + authorization_grant_type = factory.Faker("random_element", elements=["authorization-code", "implicit", "password", + "client-credentials"]) + name = factory.Faker("company") + + +class ServerFactory(DjangoModelFactory): + class Meta: + model = MessageServer + + delivery_url = factory.Faker("url") + consent_url = factory.Faker("url") + callback_url = factory.Faker("url") + key_url = factory.Faker("url") + oauth_application = factory.SubFactory(ApplicationFactory) + + +class ChannelFactory(DjangoModelFactory): + class Meta: + model = Channel + + channel_id = factory.LazyFunction(uuid4) + user_consent = True + connect_user = factory.SubFactory(UserFactory) + server = factory.SubFactory(ServerFactory) + + +def generate_random_content(): + nonce = base64.b64encode(os.urandom(12)).decode('utf-8') + tag = base64.b64encode(os.urandom(16)).decode('utf-8') + ciphertext = base64.b64encode(os.urandom(32)).decode('utf-8') + + return { + "nonce": nonce, + "tag": tag, + "ciphertext": ciphertext + } + + +class MessageFactory(DjangoModelFactory): + class Meta: + model = Message + + message_id = factory.LazyFunction(uuid4) + channel = factory.SubFactory(ChannelFactory) + content = LazyFunction(generate_random_content) + timestamp = factory.LazyFunction(timezone.now) + received = None diff --git a/messaging/migrations/0001_initial.py b/messaging/migrations/0001_initial.py new file mode 100644 index 0000000..d1f8618 --- /dev/null +++ b/messaging/migrations/0001_initial.py @@ -0,0 +1,114 @@ +# Generated by Django 4.1.7 on 2024-10-24 09:00 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion +import django.utils.timezone +import uuid + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + migrations.swappable_dependency(settings.OAUTH2_PROVIDER_APPLICATION_MODEL), + ] + + operations = [ + migrations.CreateModel( + name="Channel", + fields=[ + ( + "channel_id", + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + ("user_consent", models.BooleanField(default=False)), + ("channel_source", models.TextField()), + ( + "connect_user", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to=settings.AUTH_USER_MODEL, + ), + ), + ], + ), + migrations.CreateModel( + name="MessageServer", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("name", models.CharField(max_length=255)), + ("key_url", models.URLField()), + ("callback_url", models.URLField()), + ("delivery_url", models.URLField()), + ("consent_url", models.URLField()), + ( + "oauth_application", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to=settings.OAUTH2_PROVIDER_APPLICATION_MODEL, + ), + ), + ], + ), + migrations.CreateModel( + name="Message", + fields=[ + ( + "message_id", + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + ), + ), + ("content", models.JSONField()), + ("timestamp", models.DateTimeField(default=django.utils.timezone.now)), + ("received", models.DateTimeField(blank=True, null=True)), + ( + "status", + models.CharField( + choices=[ + ("PENDING", "Pending"), + ("SENT_TO_SERVICE", "Sent To Service"), + ("DELIVERED", "Delivered"), + ("CONFIRMED_RECEIVED", "Confirmed Received"), + ], + default="PENDING", + max_length=50, + ), + ), + ( + "channel", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to="messaging.channel", + ), + ), + ], + ), + migrations.AddField( + model_name="channel", + name="server", + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to="messaging.messageserver", + ), + ), + ] diff --git a/messaging/migrations/0002_remove_messageserver_oauth_application_and_more.py b/messaging/migrations/0002_remove_messageserver_oauth_application_and_more.py new file mode 100644 index 0000000..3701526 --- /dev/null +++ b/messaging/migrations/0002_remove_messageserver_oauth_application_and_more.py @@ -0,0 +1,35 @@ +# Generated by Django 4.1.7 on 2024-11-06 01:13 + +from django.db import migrations, models +import oauth2_provider.generators + + +class Migration(migrations.Migration): + dependencies = [ + ("messaging", "0001_initial"), + ] + + operations = [ + migrations.RemoveField( + model_name="messageserver", + name="oauth_application", + ), + migrations.AddField( + model_name="messageserver", + name="secret_key", + field=models.CharField( + default=oauth2_provider.generators.generate_client_secret, + max_length=255, + ), + ), + migrations.AddField( + model_name="messageserver", + name="server_id", + field=models.CharField( + db_index=True, + default=oauth2_provider.generators.generate_client_id, + max_length=100, + unique=True, + ), + ), + ] diff --git a/messaging/migrations/__init__.py b/messaging/migrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/messaging/models.py b/messaging/models.py new file mode 100644 index 0000000..985aa44 --- /dev/null +++ b/messaging/models.py @@ -0,0 +1,46 @@ +import uuid + +from django.db import models +from django.utils import timezone +from oauth2_provider.models import Application +from oauth2_provider.generators import generate_client_id, generate_client_secret + +from users.models import ConnectUser + + +class MessageServer(models.Model): + name = models.CharField(max_length=255) + key_url = models.URLField(max_length=200) + callback_url = models.URLField(max_length=200) + delivery_url = models.URLField(max_length=200) + consent_url = models.URLField(max_length=200) + server_id = models.CharField(max_length=100, unique=True, default=generate_client_id, db_index=True) + secret_key = models.CharField( + max_length=255, + default=generate_client_secret, + ) + + +class Channel(models.Model): + channel_id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + user_consent = models.BooleanField(default=False) + channel_source = models.TextField() + connect_user = models.ForeignKey(ConnectUser, on_delete=models.CASCADE) + server = models.ForeignKey(MessageServer, on_delete=models.CASCADE) + + +class MessageStatus(models.TextChoices): + PENDING = "PENDING", # initially when message is received by connectid from mobile. + SENT_TO_SERVICE = "SENT_TO_SERVICE" # when message is sent to service + DELIVERED = "DELIVERED", # when mobile get the message and mark received on connectid + CONFIRMED_RECEIVED = "CONFIRMED_RECEIVED" # when message is mark received on service + + +class Message(models.Model): + message_id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + channel = models.ForeignKey(Channel, on_delete=models.CASCADE) + content = models.JSONField() + timestamp = models.DateTimeField(default=timezone.now) + received = models.DateTimeField(null=True, blank=True) + status = models.CharField( + max_length=50, choices=MessageStatus.choices, default=MessageStatus.PENDING) diff --git a/messaging/serializers.py b/messaging/serializers.py index 1de7483..78997c2 100644 --- a/messaging/serializers.py +++ b/messaging/serializers.py @@ -3,9 +3,11 @@ from rest_framework import serializers +from messaging.models import Message + @dataclasses.dataclass -class Message: +class MessageData: usernames: List[str] = None title: str = None body: str = None @@ -23,12 +25,30 @@ def create(self, validated_data): username = validated_data.pop('username', None) if username: validated_data["usernames"] = [username] - return Message(**validated_data) + return MessageData(**validated_data) class BulkMessageSerializer(serializers.Serializer): messages = serializers.ListField(child=SingleMessageSerializer()) def create(self, validated_data): - return [Message(**message) for message in validated_data["messages"]] + return [MessageData(**message) for message in validated_data["messages"]] + + +class MessageSerializer(serializers.ModelSerializer): + ciphertext = serializers.SerializerMethodField() + tag = serializers.SerializerMethodField() + nonce = serializers.SerializerMethodField() + + class Meta: + model = Message + fields = ["message_id", "channel", "ciphertext", "tag", "nonce", "timestamp", "received", "status"] + + def get_ciphertext(self, obj): + return obj.content["ciphertext"] + + def get_tag(self, obj): + return obj.content["tag"] + def get_nonce(self, obj): + return obj.content["nonce"] diff --git a/messaging/task.py b/messaging/task.py new file mode 100644 index 0000000..ea4814a --- /dev/null +++ b/messaging/task.py @@ -0,0 +1,60 @@ +import base64 +import hashlib +import hmac +import json + +import requests +from rest_framework import status +from rest_framework.generics import get_object_or_404 + +from messaging.models import Message, MessageStatus, Channel + + +class CommCareHQAPIException(Exception): + pass + + +def make_request(url, json_data, secret): + try: + data = json.dumps(json_data).encode('utf-8') + digest = hmac.new(secret.encode('utf-8'), data, hashlib.sha256).digest() + mac_digest = base64.b64encode(digest).decode('utf-8') + headers = { + "Content-Type": "application/json", + "X-MAC-DIGEST": mac_digest, + } + response = requests.post(url, json=json_data, headers=headers) + response.raise_for_status() + return response + except requests.exceptions.RequestException as e: + return CommCareHQAPIException({"status": "error", "message": str(e)}) + + +def send_messages_to_service_and_mark_status(channel_messages, + status_to_be_updated: MessageStatus): + sent_message_ids = [] + + for channel_id, data in channel_messages.items(): + url = data["url"] + messages = data["messages"] + + try: + channel = get_object_or_404(Channel, channel_id=channel_id) + + response = make_request( + url=url, + json_data={ + "channel_id": str(channel_id), + "messages": messages, + }, + secret=channel.server.secret_key + ) + if response == status.HTTP_200_OK: + sent_message_ids.extend(msg["message_id"] for msg in messages) + + except CommCareHQAPIException as e: + # To-Do: All the messages which gets failed should be sent again with some task. + pass + + if sent_message_ids: + Message.objects.filter(message_id__in=sent_message_ids).update(status=status_to_be_updated) diff --git a/messaging/tests.py b/messaging/tests.py index 5075552..47de89c 100644 --- a/messaging/tests.py +++ b/messaging/tests.py @@ -1,36 +1,28 @@ -import base64 import json +from collections import defaultdict from unittest import mock +from unittest.mock import Mock, patch +from uuid import uuid4 import pytest from django.urls import reverse from firebase_admin import messaging -from oauth2_provider.models import Application +from rest_framework import status +from rest_framework.test import APITestCase -from users.factories import FCMDeviceFactory +from messaging.factories import ChannelFactory, MessageFactory, ServerFactory +from messaging.models import Channel, Message, MessageStatus +from messaging.serializers import MessageData +from payments.models import PaymentProfile +from users.factories import FCMDeviceFactory, UserFactory -@pytest.fixture -def oauth_app(user): - application = Application( - name="Test Application", - redirect_uris="http://localhost", - user=user, - client_type=Application.CLIENT_CONFIDENTIAL, - authorization_grant_type=Application.GRANT_CLIENT_CREDENTIALS, - ) - application.raw_client_secret = application.client_secret - application.save() - return application +APPLICATION_JSON = "application/json" @pytest.fixture -def authed_client(client, oauth_app): - auth = f'{oauth_app.client_id}:{oauth_app.raw_client_secret}'.encode('utf-8') - credentials = base64.b64encode(auth).decode('utf-8') - client.defaults['HTTP_AUTHORIZATION'] = 'Basic ' + credentials - return client - +def server(oauth_app): + return ServerFactory(oauth_application=oauth_app) def test_send_message(authed_client, fcm_device): url = reverse('messaging:send_message') @@ -40,7 +32,7 @@ def test_send_message(authed_client, fcm_device): "username": fcm_device.user.username, "body": "test message", "data": {"test": "data"}, - }, content_type="application/json") + }, content_type=APPLICATION_JSON) assert response.status_code == 200, response.content assert response.json() == { 'all_success': True, @@ -76,7 +68,7 @@ def test_send_message_bulk(authed_client, fcm_device): "data": {"test": "data2"}, } ] - }, content_type="application/json") + }, content_type=APPLICATION_JSON) assert response.status_code == 200, response.content assert mock_send_message.call_count == 2 @@ -110,7 +102,334 @@ def test_send_message_bulk(authed_client, fcm_device): def _fake_send(messages, **kwargs): - return messaging.BatchResponse([ - messaging.SendResponse({'name': f'message_id_{i}'}, None) - for i, message in enumerate(messages) - ]) + return messaging.BatchResponse( + [ + messaging.SendResponse({"name": f"message_id_{i}"}, None) + for i, message in enumerate(messages) + ] + ) + + +@pytest.fixture +def channel(user, server, consent=True): + return ChannelFactory(connect_user=user, user_consent=consent, server=server) + + +def rest_channel_data(user=None, consent=False): + return { + "user_consent": consent, + "connectid": str(user.id) if user else None, + "channel_source": "hq project space", + } + + +def rest_message(channel_id=None): + content = { + "nonce": "test_nonce_value", + "tag": "test_tag_value", + "ciphertext": "test_ciphertext_value" + } + return { + "channel": str(channel_id) if channel_id else None, + "content": content + } + + +@pytest.mark.django_db +class TestCreateChannelView: + @staticmethod + def post_channel_request(client, data, expected_status, expected_error_field=None): + url = reverse("messaging:create_channel") + response = client.post(url, data=data, content_type=APPLICATION_JSON) + + assert response.status_code == expected_status + + if expected_status == status.HTTP_400_BAD_REQUEST and expected_error_field: + json_data = response.json() + assert expected_error_field in json_data + + return response + + def test_create_channel_success(self, authed_client, fcm_device, oauth_app): + server = ServerFactory.create(oauth_application=oauth_app) + data = rest_channel_data(fcm_device.user) + + with mock.patch( + "fcm_django.models.messaging.send_all", wraps=_fake_send + ) as mock_send_message: + response = self.post_channel_request( + authed_client, data, status.HTTP_201_CREATED + ) + + json_data = response.json() + assert "channel_id" in json_data + + mock_send_message.assert_called_once() + messages = mock_send_message.call_args.args[0] + + assert len(messages) == 1 + message = messages[0] + assert message.token == fcm_device.registration_id + assert message.notification.title == "Channel created" + assert ( + message.notification.body + == "Please provide your consent to send/receive message." + ) + assert message.data == {"keyUrl": server.key_url} + + +@pytest.mark.django_db +def test_send_fcm_notification_view(authed_client, channel): + url = reverse("messaging:send_fcm") + data = rest_message(channel.channel_id) + + with mock.patch( + "messaging.views.send_bulk_message" + ) as mock_send_bulk_message: + response = authed_client.post(url, data=data, content_type=APPLICATION_JSON) + json_data = response.json() + assert response.status_code == status.HTTP_200_OK + assert "message_id" in json_data + + message_id = json_data["message_id"] + db_msg = Message.objects.get(message_id=message_id) + assert db_msg + + message_to_send = MessageData( + usernames=[channel.connect_user.username], + data={ + "message_id": db_msg.message_id, + "channel_id": str(channel.channel_id), + "content": db_msg.content, + }, + ) + + mock_send_bulk_message.assert_called_once_with(message_to_send) + + +@pytest.mark.django_db +class TestSendMessageView: + url = reverse("messaging:post_message") + + def test_send_message_from_mobile(self, auth_device, channel, server): + data = rest_message(channel.channel_id) + + with patch( + "messaging.views.send_messages_to_service_and_mark_status" + ) as mock_make_request: + response = auth_device.post(self.url, json.dumps(data), content_type=APPLICATION_JSON) + json_data = response.json() + assert response.status_code == status.HTTP_201_CREATED + assert "message_id" in json_data + + message_id = json_data["message_id"][0] + assert Message.objects.filter(message_id=message_id).exists() + + msg = Message.objects.filter(message_id=message_id).first() + + # Prepare the expected message data in a defaultdict format + expected_message_data = defaultdict(lambda: {"messages": [], "url": None}) + expected_message_data[str(channel.channel_id)] = { + "url": server.delivery_url, + "messages": [msg] + } + + mock_make_request.assert_called_once_with( + expected_message_data, + MessageStatus.SENT_TO_SERVICE + ) + + def test_multiple_messages(self, auth_device, channel, server): + data = [rest_message(channel.channel_id), rest_message(channel.channel_id)] + + with mock.patch( + "messaging.views.send_messages_to_service_and_mark_status" + ) as mock_send_bulk_message: + response = auth_device.post( + self.url, + data=json.dumps(data), + content_type=APPLICATION_JSON, + ) + json_data = response.json() + + assert response.status_code == status.HTTP_201_CREATED + assert "message_id" in json_data + + message_ids = json_data["message_id"] + assert len(message_ids) == 2 + + assert mock_send_bulk_message.call_count == 1 + + expected_message_data = defaultdict(lambda: {"messages": [], "url": None}) + expected_messages = [Message.objects.get(message_id=msg_id) for msg_id in message_ids] + expected_message_data[str(channel.channel_id)] = { + "url": server.delivery_url, + "messages": expected_messages + } + + mock_send_bulk_message.assert_called_once_with( + expected_message_data, + MessageStatus.SENT_TO_SERVICE + ) + + +@pytest.mark.django_db +class TestRetrieveMessagesView: + url = reverse("messaging:retrieve_messages") + + def test_retrieve_messages_success(self, auth_device, fcm_device): + ch = ChannelFactory.create(connect_user=fcm_device.user, server=ServerFactory.create()) + MessageFactory.create_batch(10, channel=ch) + + response = auth_device.get(self.url) + json_data = response.json() + + assert response.status_code == status.HTTP_200_OK + assert all(key in json_data for key in ['channels', 'messages']) + assert len(json_data['messages']) == 10 + + channel = json_data['channels'][0] + message = json_data['messages'][0] + + assert isinstance(message["content"], dict) + + assert all(key in channel for key in ['channel_id', 'channel_source', 'key_url']) + assert all(key in message for key in ['message_id', 'channel', 'timestamp', 'content']) + + def test_retrieve_messages_no_data(self, auth_device): + Channel.objects.all().delete() + Message.objects.all().delete() + + response = auth_device.get(self.url) + + response_data = response.json() + + assert response.status_code == status.HTTP_200_OK + assert all(key in response_data for key in ['channels', 'messages']) + assert all(not response_data[key] for key in ['channels', 'messages']) + + def test_retrieve_messages_multiple_channels(self, auth_device, fcm_device): + channels = ChannelFactory.create_batch(5, connect_user=fcm_device.user, server=ServerFactory.create()) + for channel in channels: + MessageFactory.create_batch(5, channel=channel) + + response = auth_device.get(self.url) + data = response.json() + assert response.status_code == status.HTTP_200_OK + assert all(len(data[key]) == expected for key, expected in [('channels', 5), ('messages', 25)]) + + +@pytest.mark.django_db +class TestUpdateConsentView: + url = reverse("messaging:update_consent") + + def test_consent(self, auth_device, channel, server, consent=False, ): + with patch( + "messaging.views.make_request" + ) as mock_make_request: + mock_make_request.return_value = Mock(status_code=status.HTTP_200_OK) + data = { + "channel": str(channel.channel_id), + "consent": consent, + } + json_data = json.dumps(data) + response = auth_device.post( + self.url, json_data, content_type=APPLICATION_JSON + ) + + assert response.status_code == status.HTTP_200_OK + channel.refresh_from_db() + + assert channel.user_consent == consent + + mock_make_request.assert_called_once_with( + url=server.consent_url, + json_data={ + "channel_id": str(channel.channel_id), + "consent": str(consent), + }, + secret=server.oauth_application.client_secret + ) + + def test_restrict_consent(self, auth_device, channel, server): + channel.user_consent = False + channel.save() + channel.refresh_from_db() + self.test_consent(auth_device, channel, server, True) + + def test_invalid_channel_id(self, auth_device): + url = reverse("messaging:update_consent") + data = {"channel": str(uuid4()), "consent": False} + data = json.dumps(data) + response = auth_device.post(url, data, content_type=APPLICATION_JSON) + assert response.status_code == status.HTTP_404_NOT_FOUND + + +@pytest.mark.django_db +class TestUpdateReceivedView: + url = reverse("messaging:update_received") + + def test_update_received(self, auth_device, channel): + messages = MessageFactory.create_batch(5, channel=channel) + message_ids = [str(message.message_id) for message in messages] + + data = {"messages": message_ids} + data = json.dumps(data) + response = auth_device.post(self.url, data, content_type=APPLICATION_JSON) + + assert response.status_code == status.HTTP_200_OK + + for message in messages: + message.refresh_from_db() + assert message.received is not None + + def test_empty_message_list(self, auth_device): + data = {"messages": []} + data = json.dumps(data) + response = auth_device.post(self.url, data, content_type=APPLICATION_JSON) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert Message.objects.filter(received__isnull=False).count() == 0 + + def test_invalid_message_ids(self, auth_device): + invalid_message_ids = [str(uuid4()), str(uuid4())] + data = {"messages": invalid_message_ids} + data = json.dumps(data) + response = auth_device.post(self.url, data, content_type=APPLICATION_JSON) + + assert response.status_code == status.HTTP_404_NOT_FOUND + assert Message.objects.filter(received__isnull=False).count() == 0 + + @patch("messaging.views.send_messages_to_service_and_mark_status") + def test_grouped_channel_messages(self, mock_send_messages, auth_device): + channel1 = ChannelFactory.create(server=ServerFactory.create()) + channel2 = ChannelFactory.create(server=ServerFactory.create()) + messages1 = MessageFactory.create_batch(3, channel=channel1) + messages2 = MessageFactory.create_batch(2, channel=channel2) + + message_ids = [str(message.message_id) for message in messages1 + messages2] + + data = {"messages": message_ids} + data = json.dumps(data) + response = auth_device.post(self.url, data, content_type=APPLICATION_JSON) + + assert response.status_code == status.HTTP_200_OK + + for message in messages1: + message.refresh_from_db() + assert message.received is not None + assert message.status == MessageStatus.DELIVERED + + for message in messages2: + message.refresh_from_db() + assert message.received is not None + assert message.status == MessageStatus.DELIVERED + + # Validate the mock call + mock_send_messages.assert_called_once() + args, kwargs = mock_send_messages.call_args + data, msg_status = args + assert isinstance(data, defaultdict) and len(data) == 2 + assert all(str(ch.channel_id) in data for ch in [channel1, channel2]) + assert all(all(msg["received"] for msg in data[str(ch.channel_id)]["messages"]) for ch in [channel1, channel2]) + assert msg_status == MessageStatus.CONFIRMED_RECEIVED diff --git a/messaging/urls.py b/messaging/urls.py index 3eb395c..3194828 100644 --- a/messaging/urls.py +++ b/messaging/urls.py @@ -2,9 +2,21 @@ from messaging import views -app_name = 'messaging' +app_name = "messaging" urlpatterns = [ - path('send/', views.SendMessage.as_view(), name='send_message'), - path('send_bulk/', views.SendMessageBulk.as_view(), name='send_message_bulk'), + path("send/", views.SendMessage.as_view(), name="send_message"), + path("send_bulk/", views.SendMessageBulk.as_view(), name="send_message_bulk"), + path("create_channel/", views.CreateChannelView.as_view(), name="create_channel"), + path("send_message/", views.SendMobileConnectMessage.as_view(), name="post_message"), + path("send_fcm/", views.SendServerConnectMessage.as_view(), name="send_fcm"), + path("update_consent/", views.UpdateConsentView.as_view(), name="update_consent"), + path( + "retrieve_messages/", + views.RetrieveMessageView.as_view(), + name="retrieve_messages", + ), + path( + "update_received/", views.UpdateReceivedView.as_view(), name="update_received" + ), ] diff --git a/messaging/views.py b/messaging/views.py index e157d0f..b6dd1ed 100644 --- a/messaging/views.py +++ b/messaging/views.py @@ -1,10 +1,32 @@ +import base64 +from collections import defaultdict + +from django.db import transaction +from django.db.models import Prefetch from django.http import JsonResponse +from django.shortcuts import get_object_or_404 +from django.utils import timezone from fcm_django.models import FCMDevice from firebase_admin import messaging +from rest_framework import status +from rest_framework.exceptions import ValidationError from rest_framework.views import APIView -from messaging.serializers import SingleMessageSerializer, BulkMessageSerializer -from utils.rest_framework import ClientProtectedResourceAuth +from messaging.models import Channel, Message, MessageStatus, MessageServer +from messaging.serializers import SingleMessageSerializer, BulkMessageSerializer, MessageSerializer, \ + MessageData +from messaging.task import make_request, send_messages_to_service_and_mark_status +from users.models import ConnectUser +from utils.rest_framework import ClientProtectedResourceAuth, MessagingServerAuth + + +def get_current_message_server(request): + auth_header = request.META.get('HTTP_AUTHORIZATION') + encoded_credentials = auth_header.split(' ')[1] + decoded_credentials = base64.b64decode(encoded_credentials).decode('utf-8') + client_id, client_secret = decoded_credentials.split(':') + server = get_object_or_404(MessageServer, server_id=client_id) + return server class SendMessage(APIView): @@ -25,6 +47,7 @@ class SendMessage(APIView): ] } """ + authentication_classes = [ClientProtectedResourceAuth] def post(self, request, *args, **kwargs): @@ -127,10 +150,7 @@ def send_bulk_message(message): def _build_message(message): notification = _build_notification(message) - return messaging.Message( - data=message.data, - notification=notification - ) + return messaging.Message(data=message.data, notification=notification) def _build_notification(data): @@ -139,3 +159,222 @@ def _build_notification(data): title=data.title, body=data.body, ) + + +class CreateChannelView(APIView): + authentication_classes = [MessagingServerAuth] + + def post(self, request, *args, **kwargs): + data = request.data + connect_id = data["connectid"] + channel_source = data["channel_source"] + server = get_current_message_server(request) + user = get_object_or_404(ConnectUser, username=connect_id) + channel, created = Channel.objects.get_or_create(server=server, connect_user=user, channel_source=channel_source) + if created: + message = MessageData( + usernames=[channel.connect_user.username], + title="Channel created", + body="Please provide your consent to send/receive message.", + data={"keyUrl": server.key_url}, + ) + # send fcm notification. + send_bulk_message(message) + return JsonResponse( + {"channel_id": str(channel.channel_id)}, status=status.HTTP_201_CREATED + ) + else: + return JsonResponse( + {"channel_id": str(channel.channel_id)}, status=status.HTTP_200_OK + ) + + +class SendServerConnectMessage(APIView): + authentication_classes = [MessagingServerAuth] + + def post(self, request, *args, **kwargs): + data = request.data + content = data["content"] + for field in ("nonce", "tag", "ciphertext"): + if not content[field]: + return JsonResponse({"errors": "invalid message content"}, status=status.HTTP_400_BAD_REQUEST) + message_data = { + "channel_id": data["channel"], + "content": data["content"], + "message_id": data["message_id"] + } + message = Message(**message_data) + message.save() + channel = message.channel + message_to_send = MessageData( + usernames=[channel.connect_user.username], + data=MessageSerializer(message).data + ) + send_bulk_message(message_to_send) + return JsonResponse( + {"message_id": str(message.message_id)}, + status=status.HTTP_200_OK, + ) + + +class SendMobileConnectMessage(APIView): + + def post(self, request, *args, **kwargs): + data = request.data + if not isinstance(data, list): + data = [data] + messages = [] + errors = set() + for message in data: + if not message.get("message_id"): + errors.add("missing message_id") + + if not message.get("channel"): + errors.add("missing channel_id") + + for field in ("nonce", "tag", "ciphertext"): + if not message.get("content", {}).get(field): + errors.add("invalid message content") + + if errors: + break + + message_data = { + "message_id": message["message_id"], + "content": message["content"], + "channel_id": message["channel"] + } + messages.append(Message(**message_data)) + + if errors: + return JsonResponse({"errors": list(errors)}, status=status.HTTP_400_BAD_REQUEST) + + message_objs = Message.objects.bulk_create(messages) + messages_ready_to_be_sent = defaultdict(lambda: {"messages": [], "url": None}) + messages_ready_to_be_sent_ids = [] + + for msg in message_objs: + channel = msg.channel + server = channel.server + + channel_id = str(channel.channel_id) + messages_ready_to_be_sent[channel_id]["messages"].append(MessageSerializer(msg).data) + + if messages_ready_to_be_sent[channel_id]["url"] is None: + messages_ready_to_be_sent[channel_id][ + "url" + ] = server.delivery_url + + messages_ready_to_be_sent_ids.append(str(msg.message_id)) + + send_messages_to_service_and_mark_status(messages_ready_to_be_sent, MessageStatus.SENT_TO_SERVICE) + + return JsonResponse( + {"message_id": messages_ready_to_be_sent_ids}, + status=status.HTTP_201_CREATED, + ) + + +class RetrieveMessageView(APIView): + def get(self, request, *args, **kwargs): + user = request.user + channels = ( + Channel.objects.filter(connect_user=user) + .only("channel_id", "channel_source") + .prefetch_related( + Prefetch( + "message_set", + queryset=Message.objects.only("message_id", "channel", "timestamp", "content"), + ), + Prefetch( + "server", + queryset=MessageServer.objects.only("key_url") + ) + ) + ) + + channels_data = [] + messages = [] + for channel in channels: + channels_data.append({"channel_source": channel.channel_source, "channel_id": str(channel.channel_id), + "key_url": channel.server.key_url, "consent": channel.user_consent}) + channel_messages = channel.message_set.all() + messages.extend(channel_messages) + + messages_data = MessageSerializer(messages, many=True).data + + return JsonResponse({"channels": channels_data, "messages": messages_data}) + + +class UpdateConsentView(APIView): + def post(self, request, *args, **kwargs): + data = request.data + channel_id = data.get("channel") + consent = data.get("consent") + + if channel_id is None or consent is None: + raise ValidationError("Both 'channel' and 'consent' fields are required.") + + channel = get_object_or_404(Channel, channel_id=channel_id) + + channel.user_consent = consent + channel.save() + + json_data = { + "channel_id": str(channel.channel_id), + "consent": channel.user_consent, + } + + response = make_request(url=channel.server.consent_url, json_data=json_data, + secret=channel.server.secret_key) + + if response.status_code != status.HTTP_200_OK: + return JsonResponse( + {"error": "Failed to update consent service"}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + return JsonResponse({}, status=status.HTTP_200_OK) + + +class UpdateReceivedView(APIView): + def post(self, request, *args, **kwargs): + message_ids = request.data.get("messages", []) + + if not message_ids: + return JsonResponse({}, status=status.HTTP_400_BAD_REQUEST) + + with transaction.atomic(): + messages = ( + Message.objects.select_for_update() + .filter(message_id__in=message_ids) + .select_related("channel") + ) + + if not messages.exists(): + return JsonResponse({}, status=status.HTTP_404_NOT_FOUND) + + current_time = timezone.now() + messages.update(received=current_time, status=MessageStatus.DELIVERED) + + # Group messages by their channel + channel_messages = defaultdict(lambda: {"messages": [], "url": None}) + for message in messages: + channel_id = str(message.channel.channel_id) + + channel_messages[channel_id]["messages"].append( + { + "message_id": str(message.message_id), + "received_on": str(current_time), + } + ) + + if channel_messages[channel_id]["url"] is None: + channel_messages[channel_id][ + "url" + ] = message.channel.server.callback_url + + # To-Do should be async. + send_messages_to_service_and_mark_status(channel_messages, MessageStatus.CONFIRMED_RECEIVED) + + return JsonResponse({}, status=status.HTTP_200_OK) diff --git a/payments/__init__.py b/payments/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/payments/apps.py b/payments/apps.py new file mode 100644 index 0000000..4886655 --- /dev/null +++ b/payments/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class PaymentsConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'payments' diff --git a/payments/migrations/0001_initial.py b/payments/migrations/0001_initial.py new file mode 100644 index 0000000..f4a8cbe --- /dev/null +++ b/payments/migrations/0001_initial.py @@ -0,0 +1,32 @@ +# Generated by Django 4.1.7 on 2024-11-09 10:16 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion +import phonenumber_field.modelfields + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name='PaymentProfile', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('phone_number', phonenumber_field.modelfields.PhoneNumberField(max_length=128, region=None)), + ('owner_name', models.TextField(max_length=150, blank=True)), + ('telecom_provider', models.CharField(blank=True, max_length=50, null=True)), + ('is_verified', models.BooleanField(default=False)), + ('status', models.CharField(choices=[('pending', 'Pending'), ('approved', 'Approved'), ('rejected', 'Rejected')], default='pending', max_length=10)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(auto_now=True)), + ('user', models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, related_name='payment_profile', to=settings.AUTH_USER_MODEL)), + ], + ), + ] diff --git a/payments/migrations/__init__.py b/payments/migrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/payments/models.py b/payments/models.py new file mode 100644 index 0000000..c4cebd5 --- /dev/null +++ b/payments/models.py @@ -0,0 +1,34 @@ +from django.db import models + +from phonenumber_field.modelfields import PhoneNumberField +from users.models import ConnectUser + + +class PaymentProfile(models.Model): + PENDING = 'pending' + APPROVED = 'approved' + REJECTED = 'rejected' + + STATUS_CHOICES = [ + (PENDING, 'Pending'), + (APPROVED, 'Approved'), + (REJECTED, 'Rejected'), + ] + + user = models.OneToOneField( + ConnectUser, + on_delete=models.CASCADE, + related_name='payment_profile' + ) + phone_number = PhoneNumberField() + owner_name = models.TextField(max_length=150, blank=True) + telecom_provider = models.CharField(max_length=50, blank=True, null=True) + # whether the number is verified using OTP + is_verified = models.BooleanField(default=False) + status = models.CharField( + max_length=10, + choices=STATUS_CHOICES, + default=PENDING, + ) + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) diff --git a/payments/tests.py b/payments/tests.py new file mode 100644 index 0000000..62466c6 --- /dev/null +++ b/payments/tests.py @@ -0,0 +1,95 @@ +import base64 +import pytest + +from django.urls import reverse +from rest_framework import status + +from messaging.tests import APPLICATION_JSON +from payments.models import PaymentProfile +from users.factories import UserFactory + + +@pytest.mark.parametrize( + "data, expected_status, expected_user1_status, expected_user2_status, result", + [ + # Scenario 1: Update both statuses successfully + ( + [ + {"username": "user1", "status": "approved"}, + {"username": "user2", "status": "rejected"}, + ], + status.HTTP_200_OK, + "approved", + "rejected", + {"approved": 1, "rejected": 1, "pending": 0} + ), + # Scenario 2: No change in status + ( + [ + {"username": "user2", "status": "approved"}, + ], + status.HTTP_200_OK, + "pending", # Should remain unchanged + "approved", # Should remain unchanged + {"approved": 0, "rejected": 0, "pending": 0} + ), + # Scenario 3: Invalid user (user doesn't exist) + ( + [ + {"username": "nonexistent_user", "status": "rejected"}, + ], + status.HTTP_404_NOT_FOUND, + "pending", # No change + "approved", # No change + {} + ), + # Scenario 4: Multiple users, one invalid + ( + [ + {"username": "user1", "status": "approved"}, + {"username": "nonexistent_user", "status": "rejected"}, + ], + status.HTTP_404_NOT_FOUND, + "pending", # No change + "approved", # No change + {} + ), + ] +) +def test_validate_phone_numbers(authed_client, data, expected_status, expected_user1_status, expected_user2_status, result): + user1 = UserFactory(username="user1") + user2 = UserFactory(username="user2") + PaymentProfile.objects.create(user=user1, phone_number="12345", status="pending") + PaymentProfile.objects.create(user=user2, phone_number="67890", status="approved") + + url = reverse("validate_payment_phone_numbers") + + response = authed_client.post(url, {"updates": data}, content_type=APPLICATION_JSON) + + assert response.status_code == expected_status + + profile1 = PaymentProfile.objects.get(user=user1) + profile2 = PaymentProfile.objects.get(user=user2) + + assert profile1.status == expected_user1_status + assert profile2.status == expected_user2_status + if response.status_code == 200: + assert response.json()["result"] == result + + +def test_fetch_phone_numbers(authed_client): + user1 = UserFactory(username="user1") + user2 = UserFactory(username="user2") + PaymentProfile.objects.create(user=user1, phone_number="12345", status="pending") + PaymentProfile.objects.create(user=user2, phone_number="67890", status="approved") + + url = reverse("fetch_payment_phone_numbers") + + response = authed_client.get(url, {"usernames": ["user1", "user2"]}) + assert len(response.json()['found_payment_numbers']) == 2 + + response = authed_client.get(url, {"usernames": ["user1", "user2"], "status": "pending"}) + assert len(response.json()['found_payment_numbers']) == 1 + + response = authed_client.get(url, {"usernames": ["user1"], "status": "approved"}) + assert len(response.json()['found_payment_numbers']) == 0 diff --git a/payments/views.py b/payments/views.py new file mode 100644 index 0000000..ed96bca --- /dev/null +++ b/payments/views.py @@ -0,0 +1,138 @@ +from django.db import transaction +from django.db.models import Q +from django.http import JsonResponse, HttpResponse +from django.views.decorators.http import require_POST +from messaging.views import send_bulk_message +from messaging.serializers import MessageData +from oauth2_provider.decorators import protected_resource +from utils.rest_framework import ClientProtectedResourceAuth +from rest_framework import status as drf_status +from rest_framework.decorators import api_view +from rest_framework.response import Response +from rest_framework.views import APIView + +from users.models import ConnectUser, PhoneDevice +from utils.twilio import lookup_telecom_provider +from .models import PaymentProfile + + +@api_view(['POST']) +def update_payment_profile_phone(request): + user = request.user + phone_number = request.data.get('phone_number') + owner_name = request.data.get('owner_name') + telecom_provider = lookup_telecom_provider(phone_number) + payment_profile, created = PaymentProfile.objects.update_or_create( + user=user, + defaults={ + 'phone_number': phone_number, + 'owner_name': owner_name, + 'telecom_provider': telecom_provider, + 'is_verified': False, + 'status': PaymentProfile.PENDING + } + ) + return PhoneDevice.send_otp_httpresponse(phone_number=payment_profile.phone_number, user=payment_profile.user) + + +@api_view(['POST']) +def confirm_payment_profile_otp(request): + payment_profile = request.user.payment_profile + device = PhoneDevice.objects.get(phone_number=payment_profile.phone_number, user=payment_profile.user) + if not device.verify_token(request.data.get('token')): + return JsonResponse({"error": "OTP token is incorrect"}, status=401) + + payment_profile.is_verified = True + payment_profile.save() + return JsonResponse({"success": True}) + + +class FetchPhoneNumbers(APIView): + authentication_classes = [ClientProtectedResourceAuth] + + def get(self, request, *args, **kwargs): + usernames = request.GET.getlist('usernames') + status = request.GET.get("status") + results = {} + profiles = PaymentProfile.objects.filter( + user__username__in=usernames) + if status: + profiles = profiles.filter(status=status) + profiles = profiles.select_related("user") + results["found_payment_numbers"] = [ + { + "username": p.user.username, + "phone_number": str(p.phone_number), + "status": p.status, + } + for p in profiles + ] + return JsonResponse(results) + + +class ValidatePhoneNumbers(APIView): + authentication_classes = [ClientProtectedResourceAuth] + + def post(self, request, *args, **kwargs): + # List of dictionaries: [{"username": ..., "phone_number": ..., "status": ...}, ...] + users_data = request.data["updates"] + + usernames = [data["username"] for data in users_data] + status_map = {data["username"]: data["status"] for data in users_data} + + profiles = PaymentProfile.objects.filter(user__username__in=usernames).select_related("user") + if len(profiles) != len(users_data): + return Response(status=drf_status.HTTP_404_NOT_FOUND) + + profiles_to_update = [] + usernames_by_states = { + "pending": [], + "approved": [], + "rejected": [], + } + + for profile in profiles: + username = profile.user.username + requested_status = status_map.get(username) + + if profile.status != requested_status: + profile.status = requested_status + profiles_to_update.append(profile) + + usernames_by_states[requested_status].append(username) + + if profiles_to_update: + PaymentProfile.objects.bulk_update(profiles_to_update, ["status"]) + + if usernames_by_states["approved"]: + send_bulk_message( + MessageData( + usernames=usernames_by_states["approved"], + title="Your Payment Phone Number is approved", + body="Your payment phone number is approved and future payments will be made to this number.", + data={"action": "ccc_payment_info_confirmation", "confirmation_status": "approved"} + ) + ) + if usernames_by_states["rejected"]: + send_bulk_message( + MessageData( + usernames=usernames_by_states["rejected"], + title="Your Payment Phone Number did not work", + body="Your payment number did not work. Please try to change to a different payment phone number", + data={"action": "ccc_payment_info_confirmation", "confirmation_status": "approved"} + ) + ) + if usernames_by_states["pending"]: + send_bulk_message( + MessageData( + usernames=usernames_by_states["pending"], + title="Your Payment Phone Number is pending review", + body="Your payment phone number is pending review. Please wait for further updates.", + data={"action": "ccc_payment_info_confirmation", "confirmation_status": "pending"} + ) + ) + result = { + state: len(usernames_by_states[state]) + for state in ["approved", "rejected", "pending"] + } + return JsonResponse({"success": True, "result": result}, status=200) diff --git a/requirements/requirements.in b/requirements/requirements.in index 2821f56..4d05a71 100644 --- a/requirements/requirements.in +++ b/requirements/requirements.in @@ -9,4 +9,6 @@ phonenumberslite psycopg2 twilio zxcvbn -fcm-django \ No newline at end of file +fcm-django>=2.2 +redis +celery \ No newline at end of file diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 4fa669c..cc8b564 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,17 +1,23 @@ # -# This file is autogenerated by pip-compile with Python 3.11 +# This file is autogenerated by pip-compile with Python 3.10 # by the following command: # # pip-compile --allow-unsafe --output-file=requirements.txt requirements.in # +amqp==5.3.1 + # via kombu asgiref==3.6.0 # via django -build==0.10.0 - # via pip-tools +async-timeout==5.0.1 + # via redis +billiard==4.2.1 + # via celery cachecontrol==0.13.1 # via firebase-admin cachetools==5.3.1 # via google-auth +celery==5.4.0 + # via -r requirements.in certifi==2022.12.7 # via requests cffi==1.15.1 @@ -19,7 +25,17 @@ cffi==1.15.1 charset-normalizer==3.1.0 # via requests click==8.1.3 - # via pip-tools + # via + # celery + # click-didyoumean + # click-plugins + # click-repl +click-didyoumean==0.3.1 + # via celery +click-plugins==1.1.1 + # via celery +click-repl==0.3.0 + # via celery cryptography==41.0.2 # via # jwcrypto @@ -30,18 +46,15 @@ django==4.1.7 # via # -r requirements.in # django-axes + # django-oauth-toolkit # django-otp # django-phonenumber-field # djangorestframework + # fcm-django django-axes[ipware]==6.0.3 # via -r requirements.in django-ipware==5.0.0 # via django-axes - # django-oauth-toolkit - # django-otp - # django-phonenumber-field - # djangorestframework - # fcm-django django-oauth-toolkit==2.3.0 # via -r requirements.in django-otp==1.1.6 @@ -50,7 +63,7 @@ django-phonenumber-field==7.0.2 # via -r requirements.in djangorestframework==3.14.0 # via -r requirements.in -fcm-django==2.0.0 +fcm-django==2.2.1 # via -r requirements.in firebase-admin==6.2.0 # via fcm-django @@ -104,16 +117,16 @@ idna==3.4 # via requests jwcrypto==1.5.0 # via django-oauth-toolkit +kombu==5.4.2 + # via celery msgpack==1.0.7 # via cachecontrol oauthlib==3.2.2 # via django-oauth-toolkit -packaging==23.1 - # via build phonenumberslite==8.13.11 # via -r requirements.in -pip-tools==6.13.0 - # via -r requirements.in +prompt-toolkit==3.0.48 + # via click-repl proto-plus==1.22.3 # via google-cloud-firestore protobuf==4.24.4 @@ -139,12 +152,14 @@ pyjwt[crypto]==2.6.0 # twilio pyparsing==3.1.1 # via httplib2 -pyproject-hooks==1.0.0 - # via build +python-dateutil==2.9.0.post0 + # via celery pytz==2022.7.1 # via # djangorestframework # twilio +redis==5.1.1 + # via -r requirements.in requests==2.28.2 # via # cachecontrol @@ -154,26 +169,36 @@ requests==2.28.2 # twilio rsa==4.9 # via google-auth +six==1.16.0 + # via python-dateutil sqlparse==0.4.3 # via django +swapper==1.4.0 + # via fcm-django twilio==7.16.5 # via -r requirements.in +tzdata==2024.2 + # via + # celery + # kombu uritemplate==4.1.1 # via google-api-python-client urllib3==1.26.15 # via requests -wheel==0.40.0 - # via pip-tools +vine==5.1.0 + # via + # amqp + # celery + # kombu +wcwidth==0.2.13 + # via prompt-toolkit wrapt==1.15.0 # via deprecated zxcvbn==4.4.28 # via -r requirements.in # The following packages are considered to be unsafe in a requirements file: -pip==23.1.2 - # via pip-tools setuptools==67.8.0 # via # django-axes # gunicorn - # pip-tools diff --git a/templates/connectid/deeplink.html b/templates/connectid/deeplink.html new file mode 100644 index 0000000..c09cb81 --- /dev/null +++ b/templates/connectid/deeplink.html @@ -0,0 +1,36 @@ + + + + + + Redirecting to CommCare + + + +
+ You are being redirected to install CommCare app, please install CommCare and re-open this link +
+ + Click here if you are not redirected automatically + + + diff --git a/users/admin.py b/users/admin.py index f45c614..7d4104f 100644 --- a/users/admin.py +++ b/users/admin.py @@ -20,6 +20,7 @@ class ConnectUserAdmin(UserAdmin): }, ), (_("Important dates"), {"fields": ("last_login", "date_joined")}), + (_("Extras"), {"fields": ("deactivation_token", "recovery_phone", "recovery_phone_validated")}), ) list_display = ("username", "phone_number", "name", "is_staff") search_fields = ("username", "name", "phone_number") diff --git a/users/factories.py b/users/factories.py index 4d78b7a..97175b0 100644 --- a/users/factories.py +++ b/users/factories.py @@ -2,7 +2,7 @@ from factory.django import DjangoModelFactory from fcm_django.models import FCMDevice -from users.models import ConnectUser +from users.models import ConnectUser, Credential class UserFactory(DjangoModelFactory): @@ -22,3 +22,12 @@ class Meta: registration_id = factory.Faker('uuid4') type = 'android' active = True + + +class CredentialFactory(DjangoModelFactory): + class Meta: + model = Credential + + name = factory.Faker('name') + slug = factory.Faker('slug') + organization_slug = factory.Faker('slug') diff --git a/users/migrations/0009_connectuser_deactivation_token_and_more.py b/users/migrations/0009_connectuser_deactivation_token_and_more.py new file mode 100644 index 0000000..ecaee44 --- /dev/null +++ b/users/migrations/0009_connectuser_deactivation_token_and_more.py @@ -0,0 +1,23 @@ +# Generated by Django 4.1.7 on 2024-07-24 13:27 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("users", "0008_credential_usercredential"), + ] + + operations = [ + migrations.AddField( + model_name="connectuser", + name="deactivation_token", + field=models.CharField(blank=True, max_length=25, null=True), + ), + migrations.AddField( + model_name="connectuser", + name="deactivation_token_valid_until", + field=models.DateTimeField(blank=True, null=True), + ), + ] diff --git a/users/migrations/0010_alter_connectuser_options_and_more.py b/users/migrations/0010_alter_connectuser_options_and_more.py new file mode 100644 index 0000000..721b422 --- /dev/null +++ b/users/migrations/0010_alter_connectuser_options_and_more.py @@ -0,0 +1,25 @@ +# Generated by Django 4.1.7 on 2024-08-08 09:05 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("users", "0009_connectuser_deactivation_token_and_more"), + ] + + operations = [ + migrations.AlterModelOptions( + name="connectuser", + options={}, + ), + migrations.AddConstraint( + model_name="connectuser", + constraint=models.UniqueConstraint( + condition=models.Q(("is_active", True)), + fields=("phone_number",), + name="phone_number_active_user", + ), + ), + ] diff --git a/users/migrations/0011_alter_connectuser_phone_number.py b/users/migrations/0011_alter_connectuser_phone_number.py new file mode 100644 index 0000000..8bd6805 --- /dev/null +++ b/users/migrations/0011_alter_connectuser_phone_number.py @@ -0,0 +1,21 @@ +# Generated by Django 4.1.7 on 2024-10-10 10:53 + +from django.db import migrations +import phonenumber_field.modelfields + + +class Migration(migrations.Migration): + + dependencies = [ + ("users", "0010_alter_connectuser_options_and_more"), + ] + + operations = [ + migrations.AlterField( + model_name="connectuser", + name="phone_number", + field=phonenumber_field.modelfields.PhoneNumberField( + max_length=128, region=None + ), + ), + ] diff --git a/users/models.py b/users/models.py index f7bc673..37cf19f 100644 --- a/users/models.py +++ b/users/models.py @@ -1,28 +1,31 @@ from datetime import timedelta import base64 +from datetime import timedelta import os from uuid import uuid4 +from django.conf import settings from django.contrib.auth.hashers import check_password, make_password from django.contrib.auth.models import AbstractUser from django.contrib.sites.models import Site from django.db import models +from django.http import HttpResponse from django.utils.timezone import now from django.urls import reverse +from django.utils.timezone import now from django_otp.models import SideChannelDevice +from django_otp.util import random_hex from phonenumber_field.modelfields import PhoneNumberField from utils import get_sms_sender, send_sms from .const import TEST_NUMBER_PREFIX -# Create your models here. - class ConnectUser(AbstractUser): - phone_number = PhoneNumberField(unique=True) + phone_number = PhoneNumberField() phone_validated = models.BooleanField(default=False) recovery_phone = PhoneNumberField(blank=True) recovery_phone_validated = models.BooleanField(default=False) @@ -33,6 +36,8 @@ class ConnectUser(AbstractUser): # store a hashed value rather than setting it directly recovery_pin = models.CharField(null=True, blank=True, max_length=128) recovery_phone_validation_deadline = models.DateField(blank=True, null=True) + deactivation_token = models.CharField(max_length=25, blank=True, null=True) + deactivation_token_valid_until = models.DateTimeField(blank=True, null=True) # removed from base class first_name = None @@ -47,6 +52,28 @@ def set_recovery_pin(self, pin): def check_recovery_pin(self, pin): return check_password(pin, self.recovery_pin) + def initiate_deactivation(self): + self.deactivation_token = random_hex(7) + self.deactivation_token_valid_until = now() + timedelta(seconds=600) + self.save() + message = ( + f"Your account deactivation request is pending. Please enter this token {self.deactivation_token} to confirm account deactivation." + f"Warning: This action is irreversible. If you didn't request deactivation, please ignore this message. \n\n {settings.APP_HASH}" + ) + if not self.phone_number.raw_input.startswith(TEST_NUMBER_PREFIX): + sender = get_sms_sender(self.phone_number.country_code) + send_sms(self.phone_number.as_e164, message, sender) + return message + + class Meta: + constraints = [ + models.UniqueConstraint( + fields=["phone_number"], + condition=models.Q(is_active=True), + name="phone_number_active_user", + ) + ] + class UserKey(models.Model): user = models.ForeignKey(ConnectUser, on_delete=models.CASCADE) @@ -76,7 +103,7 @@ def generate_challenge(self): if self.valid_until - now() <= timedelta(minutes=5): self.otp_last_sent = None self.generate_token(valid_secs=600) - message = f"Your verification token from commcare connect is {self.token}" + message = f"Your verification token from commcare connect is {self.token} \n\n {settings.APP_HASH}" # send the OTP if last sent message is not within the last 2 minutes if self.otp_last_sent is None or ( self.otp_last_sent and now() - self.otp_last_sent >= timedelta(minutes=2) @@ -86,8 +113,18 @@ def generate_challenge(self): send_sms(self.phone_number.as_e164, message, sender) self.otp_last_sent = now() self.save() + return message + @classmethod + def send_otp_httpresponse(cls, phone_number, user): + # create otp device for user + # send otp code via twilio + otp_device, _ = cls.objects.get_or_create(phone_number=phone_number, user=user) + otp_device.save() + otp_device.generate_challenge() + return HttpResponse() + class Meta: constraints = [ models.UniqueConstraint( diff --git a/users/oauth.py b/users/oauth.py index abbe68c..28ce80c 100644 --- a/users/oauth.py +++ b/users/oauth.py @@ -1,10 +1,16 @@ from oauth2_provider.oauth2_validators import OAuth2Validator + class ConnectOAuth2Validator(OAuth2Validator): + oidc_claim_scope = OAuth2Validator.oidc_claim_scope + oidc_claim_scope.update( + {"is_active": "openid", "phone": "openid", "name": "openid"} + ) def get_additional_claims(self, request): claims = {} claims["sub"] = request.user.username claims["name"] = request.user.name - claims["phone"] = request.user.phone_number + claims["phone"] = request.user.phone_number.as_e164 + claims["is_active"] = request.user.is_active return claims diff --git a/users/tests.py b/users/tests.py index e1088dc..1565d38 100644 --- a/users/tests.py +++ b/users/tests.py @@ -2,9 +2,11 @@ from re import A from django.utils.timezone import now import pytest +from django.urls import reverse from fcm_django.models import FCMDevice from unittest import mock +from users.factories import CredentialFactory from users.fcm_utils import create_update_device from users.models import ConnectUser, PhoneDevice @@ -87,6 +89,7 @@ def test_otp_generation(user): assert phone_device.token is not None assert phone_device.otp_last_sent is not None + def test_otp_generation_after_two_minutes(user): with mock.patch("users.models.send_sms") as send_sms: phone_device, _ = PhoneDevice.objects.get_or_create(phone_number=user.phone_number, user=user) @@ -129,3 +132,28 @@ def test_otp_generation_after_five_minutes(user): assert phone_device.token is not None assert phone_device.token != token assert send_sms.call_count == 3 + + +@pytest.mark.django_db +class TestFetchCredentials: + + def setup_method(self): + self.url = "/users/fetch_credentials" + CredentialFactory.create_batch(3, organization_slug="test_slug") + CredentialFactory.create_batch(10) + + def assert_statements(self, response, expected_count): + assert response.status_code == 200 + response_data = response.json() + assert "credentials" in response_data + assert len(response_data["credentials"]) == expected_count + for credential in response_data["credentials"]: + assert set(credential.keys()) == {"name", "slug"} + + def test_fetch_credential_with_org_slug(self, authed_client): + response = authed_client.get(self.url + "?org_slug=test_slug") + self.assert_statements(response, expected_count=3) + + def test_fetch_credential_without_org_slug(self, authed_client): + response = authed_client.get(self.url) + self.assert_statements(response, expected_count=13) \ No newline at end of file diff --git a/users/urls.py b/users/urls.py index c6a3eee..1f3316f 100644 --- a/users/urls.py +++ b/users/urls.py @@ -1,6 +1,7 @@ from django.urls import path from . import views +from payments import views as payment_views urlpatterns = [ path('', views.test, name='test'), @@ -30,4 +31,12 @@ path('accept_credential/', views.accept_credential, name='accept_credential'), path('fetch_credentials', views.FetchCredentials.as_view(), name='fetch_credentials'), path('fetch_db_key', views.fetch_db_key, name='fetch_db_key'), + path('recover/initiate_deactivation', views.initiate_deactivation, name='initiate_deactivation'), + path('recover/confirm_deactivation', views.confirm_deactivation, name='confirm_deactivation'), + path('profile/payment_phone_number', payment_views.update_payment_profile_phone, name='update_payment_profile_phone'), + path('profile/confirm_payment_otp', payment_views.confirm_payment_profile_otp, name='confirm_payment_profile_otp'), + path('fetch_payment_phone_numbers', payment_views.FetchPhoneNumbers.as_view(), name='fetch_payment_phone_numbers'), + path('validate_payment_phone_numbers', payment_views.ValidatePhoneNumbers.as_view(), name='validate_payment_phone_numbers'), + path('forward_hq_invite', views.ForwardHQInvite.as_view(), name='forward_hq_invite'), + path('confirm_hq_invite', views.ConfirmHQInviteCallback.as_view(), name='confirm_hq_invite'), ] diff --git a/users/views.py b/users/views.py index 193585a..6ab9e7c 100644 --- a/users/views.py +++ b/users/views.py @@ -1,21 +1,32 @@ +import requests from datetime import timedelta from secrets import token_hex +from urllib.parse import urlparse, urlencode +from django.conf import settings from django.contrib.auth.hashers import check_password from django.contrib.auth.password_validation import validate_password -from django.core.exceptions import ValidationError +from django.core.exceptions import ValidationError, ObjectDoesNotExist from django.http import HttpResponse, JsonResponse from django.utils.timezone import now from django.views import View +from oauth2_provider.models import AccessToken, RefreshToken from oauth2_provider.views.mixins import ClientProtectedResourceMixin from rest_framework.decorators import api_view, permission_classes from rest_framework.views import APIView -from utils import get_ip +from utils import get_ip, get_sms_sender, send_sms from utils.rest_framework import ClientProtectedResourceAuth from .const import TEST_NUMBER_PREFIX from .fcm_utils import create_update_device -from .models import ConnectUser, Credential, PhoneDevice, RecoveryStatus, UserCredential, UserKey +from .models import ( + ConnectUser, + Credential, + PhoneDevice, + RecoveryStatus, + UserCredential, + UserKey, +) # Create your views here. @@ -56,13 +67,8 @@ def test(request): @api_view(['POST']) def validate_phone(request): - # create otp device for user - # send otp code via twilio user = request.user - otp_device, _ = PhoneDevice.objects.get_or_create(phone_number=user.phone_number, user=user) - otp_device.save() - otp_device.generate_challenge() - return HttpResponse() + return PhoneDevice.send_otp_httpresponse(phone_number=user.phone_number, user=user) @api_view(['POST']) @@ -82,13 +88,8 @@ def confirm_otp(request): @api_view(['POST']) def validate_secondary_phone(request): - # create otp device for user - # send otp code via twilio user = request.user - otp_device, _ = PhoneDevice.objects.get_or_create(phone_number=user.recovery_phone, user=user) - otp_device.save() - otp_device.generate_challenge() - return HttpResponse() + return PhoneDevice.send_otp_httpresponse(phone_number=user.recovery_phone, user=user) @api_view(['POST']) @@ -96,7 +97,7 @@ def confirm_secondary_otp(request): # check otp code for user # mark phone as confirmed on user model user = request.user - device = PhoneDevice.objects.get(phone_number=user.recovery_phone, user=user) + device, _ = PhoneDevice.objects.get_or_create(phone_number=user.recovery_phone, user=user) data = request.data verified = device.verify_token(data.get('token')) if not verified: @@ -111,7 +112,7 @@ def confirm_secondary_otp(request): @permission_classes([]) def recover_account(request): data = request.data - user = ConnectUser.objects.get(phone_number=data['phone']) + user = ConnectUser.objects.get(phone_number=data["phone"], is_active=True) device = PhoneDevice.objects.get(phone_number=user.phone_number, user=user) device.generate_challenge() secret = token_hex() @@ -128,7 +129,7 @@ def confirm_recovery_otp(request): data = request.data phone_number = data["phone"] secret_key = data["secret_key"] - user = ConnectUser.objects.get(phone_number=phone_number) + user = ConnectUser.objects.get(phone_number=phone_number, is_active=True) status = RecoveryStatus.objects.get(user=user) if status.secret_key != secret_key: return HttpResponse(status=401) @@ -149,7 +150,7 @@ def recover_secondary_phone(request): data = request.data phone_number = data["phone"] secret_key = data["secret_key"] - user = ConnectUser.objects.get(phone_number=phone_number) + user = ConnectUser.objects.get(phone_number=phone_number, is_active=True) status = RecoveryStatus.objects.get(user=user) if status.secret_key != secret_key: return HttpResponse(status=401) @@ -169,7 +170,7 @@ def confirm_secondary_recovery_otp(request): data = request.data phone_number = data["phone"] secret_key = data["secret_key"] - user = ConnectUser.objects.get(phone_number=phone_number) + user = ConnectUser.objects.get(phone_number=phone_number, is_active=True) status = RecoveryStatus.objects.get(user=user) if status.secret_key != secret_key: return HttpResponse(status=401) @@ -182,7 +183,9 @@ def confirm_secondary_recovery_otp(request): status.step = RecoveryStatus.RecoverySteps.RESET_PASSWORD status.save() db_key = UserKey.get_or_create_key_for_user(user) - return JsonResponse({"name": user.name, "username": user.username, "db_key": db_key.key}) + user_data = {"name": user.name, "username": user.username, "db_key": db_key.key} + user_data.update(user_payment_profile(user)) + return JsonResponse(user_data) @api_view(['POST']) @@ -191,7 +194,7 @@ def confirm_password(request): data = request.data phone_number = data["phone"] secret_key = data["secret_key"] - user = ConnectUser.objects.get(phone_number=phone_number) + user = ConnectUser.objects.get(phone_number=phone_number, is_active=True) status = RecoveryStatus.objects.get(user=user) if status.secret_key != secret_key: return HttpResponse(status=401) @@ -201,8 +204,7 @@ def confirm_password(request): if not check_password(password, user.password): return HttpResponse(status=401) status.delete() - db_key = UserKey.get_or_create_key_for_user(user) - return JsonResponse({"name": user.name, "username": user.username, "secondary_phone_validate_by": user.recovery_phone_validation_deadline, "db_key": db_key.key}) + return JsonResponse(user_data(user)) @api_view(['POST']) @@ -211,7 +213,7 @@ def reset_password(request): data = request.data phone_number = data["phone"] secret_key = data["secret_key"] - user = ConnectUser.objects.get(phone_number=phone_number) + user = ConnectUser.objects.get(phone_number=phone_number, is_active=True) status = RecoveryStatus.objects.get(user=user) if status.secret_key != secret_key: return HttpResponse(status=401) @@ -235,7 +237,7 @@ def phone_available(request): if not phone_number: return HttpResponse(status=400) try: - ConnectUser.objects.get(phone_number=phone_number) + ConnectUser.objects.get(phone_number=phone_number, is_active=True) except ConnectUser.DoesNotExist: return HttpResponse() else: @@ -307,13 +309,34 @@ def set_recovery_pin(request): return HttpResponse() +def user_data(user): + db_key = UserKey.get_or_create_key_for_user(user) + user_data = {"name": user.name, "username": user.username, "secondary_phone_validate_by": user.recovery_phone_validation_deadline, "db_key": db_key.key} + user_data.update(user_payment_profile(user)) + return user_data + + +def user_payment_profile(user): + try: + profile = user.payment_profile + return {"payment_profile": { + "phone_number": profile.phone_number.as_e164, + "owner_name": profile.owner_name, + "telecom_provider": profile.telecom_provider, + "is_verified": profile.is_verified, + "status": profile.status, + }} + except ObjectDoesNotExist: + return {"payment_profile": {}} + + @api_view(['POST']) @permission_classes([]) def confirm_recovery_pin(request): data = request.data phone_number = data["phone"] secret_key = data["secret_key"] - user = ConnectUser.objects.get(phone_number=phone_number) + user = ConnectUser.objects.get(phone_number=phone_number, is_active=True) status = RecoveryStatus.objects.get(user=user) if status.secret_key != secret_key: return HttpResponse(status=401) @@ -324,8 +347,7 @@ def confirm_recovery_pin(request): return JsonResponse({"error": "Recovery PIN is incorrect"}, status=401) status.step = RecoveryStatus.RecoverySteps.RESET_PASSWORD status.save() - db_key = UserKey.get_or_create_key_for_user(user) - return JsonResponse({"name": user.name, "username": user.username, "secondary_phone_validate_by": user.recovery_phone_validation_deadline, "db_key": db_key.key}) + return JsonResponse(user_data(user)) @api_view(['GET']) @@ -351,7 +373,11 @@ class FetchUsers(ClientProtectedResourceMixin, View): def get(self, request, *args, **kwargs): numbers = request.GET.getlist('phone_numbers') results = {} - found_users = list(ConnectUser.objects.filter(phone_number__in=numbers).values('username', 'phone_number', 'name')) + found_users = list( + ConnectUser.objects.filter(phone_number__in=numbers, is_active=True).values( + "username", "phone_number", "name" + ) + ) results["found_users"] = found_users return JsonResponse(results) @@ -377,7 +403,9 @@ def get(self, request, *args, **kwargs): if credential is not None: query = query.filter(credential__slug=credential) if country is not None: - query = query.filter(user__phone_number__startswith=country) + query = query.filter( + user__phone_number__startswith=country, user__is_active=True + ) users = query.select_related("user") user_list = [{"username": u.user.username, "phone_number": u.user.phone_number.as_e164, "name": u.user.name} for u in users] result = {"found_users": user_list} @@ -395,12 +423,81 @@ def post(self, request, *args, **kwargs): credential_name = request.data["credential"] slug = f"{credential_name.lower().replace(' ', '_')}_{org_slug}" credential, _ = Credential.objects.get_or_create(name=credential_name, organization_slug=org_slug, defaults={"slug": slug}) - users = ConnectUser.objects.filter(phone_number__in=phone_numbers) + users = ConnectUser.objects.filter( + phone_number__in=phone_numbers, is_active=True + ) for user in users: UserCredential.add_credential(user, credential, request) return HttpResponse() +class ForwardHQInvite(APIView): + """ + This view gets called by CommCareHQ to invite + a ConnectID User. It takes invite metadata + and fowards it as a deeplink SMS to mobile + """ + authentication_classes = [ClientProtectedResourceAuth] + + def post(self, request, *args, **kwargs): + phone_number = request.data["phone_number"] + callback_url = request.data["callback_url"] + if not is_trusted_hqinvite_url(callback_url): + return JsonResponse({"error": "Unauthorized callback URL"}, status=400) + try: + user = ConnectUser.objects.get(phone_number=phone_number, is_active=True) + except ConnectUser.DoesNotExist: + # We don't want to make this a user lookup service + # So fake a success message + return JsonResponse({"success": True}) + + query_string = urlencode({ + "hq_username": request.data["username"], + "hq_domain": request.data["user_domain"], + "connect_username": user.username, + "invite_code": request.data["invite_code"], + "callback_url": callback_url, + }) + deeplink = f"https://connectid.dimagi.com/hq_invite/?{query_string}" + + message = f""" + You are invited to join a CommCare project ({request.data["user_domain"]}) + Please click on {deeplink} to join using your ConnectID + account. + Once you confirm, you will be able to login using your + ConnectID account. Your username is ({request.data["username"]}) + Thanks. + -The ConnectID Team. + """ + sender = get_sms_sender(user.phone_number.country_code) + send_sms(user.phone_number.as_e164, message, sender) + return JsonResponse({"success": True}) + + +def is_trusted_hqinvite_url(url): + parsed_url = urlparse(url) + return parsed_url.netloc in settings.TRUSTED_COMMCAREHQ_HOSTS + + +class ConfirmHQInviteCallback(APIView): + + def post(self, request, *args, **kwargs): + invite_code = request.data["invite_code"] + user_token = request.data["user_token"] + callback_url = request.data["callback_url"] + + # Validate callback_url + if not is_trusted_hqinvite_url(callback_url): + return JsonResponse({"error": "Unauthorized callback URL"}, status=400) + + try: + response = requests.post(callback_url, data={"invite_code": invite_code, "token": user_token}) + response.raise_for_status() + except requests.RequestException as e: + return JsonResponse({"error": "Failed to reach callback URL"}, status=500) + return JsonResponse({"success": True}) + + @api_view(['GET']) @permission_classes([]) def accept_credential(request, invite_id): @@ -419,7 +516,55 @@ def accept_credential(request, invite_id): class FetchCredentials(ClientProtectedResourceMixin, View): required_scopes = ['user_fetch'] - def get(self, request, *args, **kwargs): - credentials = Credential.objects.all().values('name', 'slug') - results = {"credentials": list(credentials)} + def get(self, request): + org_slug = request.GET.get('org_slug', None) + queryset = Credential.objects.all() + if org_slug: + queryset = queryset.filter(organization_slug=org_slug) + + credentials = queryset.values('name', 'slug') + results = {"credentials": list(credentials)} return JsonResponse(results) + + +@api_view(["POST"]) +@permission_classes([]) +def initiate_deactivation(request): + data = request.data + phone_number = data["phone_number"] + secret_key = data["secret_key"] + try: + user = ConnectUser.objects.get(phone_number=phone_number, is_active=True) + except ConnectUser.DoesNotExist: + return JsonResponse({"success": False}) + status = RecoveryStatus.objects.get(user=user) + if status.secret_key != secret_key: + return HttpResponse(status=401) + user.initiate_deactivation() + return JsonResponse({"success": True}) + + +@api_view(["POST"]) +@permission_classes([]) +def confirm_deactivation(request): + data = request.data + phone_number = data["phone_number"] + secret_key = data["secret_key"] + deactivation_token = data["token"] + try: + user = ConnectUser.objects.get(phone_number=phone_number, is_active=True) + except ConnectUser.DoesNotExist: + return JsonResponse({"success": False}) + status = RecoveryStatus.objects.get(user=user) + if status.secret_key != secret_key: + return HttpResponse(status=401) + if user.deactivation_token == deactivation_token: + user.is_active = False + user.save() + tokens = list(AccessToken.objects.filter(user=user)) + list( + RefreshToken.objects.filter(user=user) + ) + for token in tokens: + token.revoke() + return JsonResponse({"success": True}) + return JsonResponse({"success": False}) diff --git a/utils/__init__.py b/utils/__init__.py index 1c47d04..965a098 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -25,7 +25,8 @@ def get_ip(request): def get_sms_sender(country_code): SMS_SENDERS = { - "265": "ConnectID" + "265": "ConnectID", + "258": "ConnectID" } return SMS_SENDERS.get(str(country_code)) diff --git a/utils/rest_framework.py b/utils/rest_framework.py index b688055..d49746c 100644 --- a/utils/rest_framework.py +++ b/utils/rest_framework.py @@ -1,6 +1,8 @@ from django.contrib.auth.models import AnonymousUser from oauth2_provider.views.mixins import OAuthLibMixin -from rest_framework.authentication import BaseAuthentication +from rest_framework.authentication import BaseAuthentication, BasicAuthentication + +from messaging.models import MessageServer class ClientProtectedResourceAuth(OAuthLibMixin, BaseAuthentication): @@ -21,3 +23,26 @@ def is_authenticated(self): def __str__(self): return "OauthClientUser" + + +class MessagingServerAuth(BasicAuthentication): + """Authenticate request using Client credentials (as in the OAuth2 spec). + """ + + def authenticate_credentials(self, userid, password, request=None): + try: + server = MessageServer.objects.get(server_id=userid) + except MessageServer.DoesNotExist: + return None + valid = (password == server.secret_key) + if valid: + return MessagingServerUser(), None + + +class MessagingServerUser(AnonymousUser): + """Fake user used for requests authenticated via Client credentials""" + def is_authenticated(self): + return True + + def __str__(self): + return "MessagingServerUser" diff --git a/utils/twilio.py b/utils/twilio.py new file mode 100644 index 0000000..3591143 --- /dev/null +++ b/utils/twilio.py @@ -0,0 +1,17 @@ +import logging + +from twilio.rest import Client +from django.conf import settings + + +logger = logging.getLogger(__name__) + + +def lookup_telecom_provider(phone_number): + client = Client(settings.TWILIO_ACCOUNT_SID, settings.TWILIO_AUTH_TOKEN) + try: + phone_info = client.lookups.v1.phone_numbers(phone_number).fetch(type="carrier") + return phone_info.carrier.get("name") + except Exception as e: + logger.exception("Error occurred during Twilio call for phone number %s: %s", phone_number, str(e)) + return None