diff --git a/api/custom_auth/oauth/serializers.py b/api/custom_auth/oauth/serializers.py index 6a1e80ab90af..026db5fc93c5 100644 --- a/api/custom_auth/oauth/serializers.py +++ b/api/custom_auth/oauth/serializers.py @@ -13,6 +13,7 @@ from users.models import SignUpType from ..constants import USER_REGISTRATION_WITHOUT_INVITE_ERROR_MESSAGE +from ..serializers import InviteLinkValidationMixin from .github import GithubUser from .google import get_user_info @@ -20,7 +21,7 @@ UserModel = get_user_model() -class OAuthLoginSerializer(serializers.Serializer): +class OAuthLoginSerializer(InviteLinkValidationMixin, serializers.Serializer): access_token = serializers.CharField( required=True, help_text="Code or access token returned from the FE interaction with the third party login provider.", diff --git a/api/custom_auth/serializers.py b/api/custom_auth/serializers.py index 55bb43e595ae..9b56877e63a2 100644 --- a/api/custom_auth/serializers.py +++ b/api/custom_auth/serializers.py @@ -1,3 +1,5 @@ +from typing import Any + from django.conf import settings from djoser.serializers import UserCreateSerializer from rest_framework import serializers @@ -5,7 +7,7 @@ from rest_framework.exceptions import PermissionDenied from rest_framework.validators import UniqueValidator -from organisations.invites.models import Invite +from organisations.invites.models import Invite, InviteLink from users.auth_type import AuthType from users.constants import DEFAULT_DELETE_ORPHAN_ORGANISATIONS_VALUE from users.models import FFAdminUser, SignUpType @@ -23,7 +25,35 @@ class Meta: fields = ("key",) -class CustomUserCreateSerializer(UserCreateSerializer): +class InviteLinkValidationMixin: + invite_hash = serializers.CharField(required=False, write_only=True) + + def validate(self, attrs: dict[str, Any]) -> dict[str, Any]: + attrs = super().validate(attrs) + + if not settings.ALLOW_REGISTRATION_WITHOUT_INVITE: + self._validate_registration_invite(attrs) + + return attrs + + def _validate_registration_invite(self, attrs: dict[str, Any]) -> None: + valid = False + + match attrs.get("sign_up_type"): + case SignUpType.INVITE_LINK.value: + valid = InviteLink.objects.filter( + hash=self.initial_data.get("invite_hash") + ).exists() + case SignUpType.INVITE_EMAIL.value: + valid = Invite.objects.filter( + email__iexact=attrs.get("email", "").lower() + ).exists() + + if not valid: + raise PermissionDenied(USER_REGISTRATION_WITHOUT_INVITE_ERROR_MESSAGE) + + +class CustomUserCreateSerializer(UserCreateSerializer, InviteLinkValidationMixin): key = serializers.SerializerMethodField() class Meta(UserCreateSerializer.Meta): @@ -66,16 +96,6 @@ def get_key(instance): token, _ = Token.objects.get_or_create(user=instance) return token.key - def save(self, **kwargs): - if not ( - settings.ALLOW_REGISTRATION_WITHOUT_INVITE - or self.validated_data.get("sign_up_type") == SignUpType.INVITE_LINK.value - or Invite.objects.filter(email=self.validated_data.get("email")) - ): - raise PermissionDenied(USER_REGISTRATION_WITHOUT_INVITE_ERROR_MESSAGE) - - return super(CustomUserCreateSerializer, self).save(**kwargs) - class CustomUserDelete(serializers.Serializer): current_password = serializers.CharField( diff --git a/api/tests/unit/custom_auth/conftest.py b/api/tests/unit/custom_auth/conftest.py new file mode 100644 index 000000000000..17d5f760c4c1 --- /dev/null +++ b/api/tests/unit/custom_auth/conftest.py @@ -0,0 +1,9 @@ +import pytest + +from organisations.invites.models import InviteLink +from organisations.models import Organisation + + +@pytest.fixture() +def invite_link(organisation: Organisation) -> InviteLink: + return InviteLink.objects.create(organisation=organisation) diff --git a/api/tests/unit/custom_auth/test_unit_custom_auth_serializer.py b/api/tests/unit/custom_auth/test_unit_custom_auth_serializer.py index 00f099e1ace6..63731e0f7df3 100644 --- a/api/tests/unit/custom_auth/test_unit_custom_auth_serializer.py +++ b/api/tests/unit/custom_auth/test_unit_custom_auth_serializer.py @@ -1,7 +1,17 @@ +import pytest from django.test import RequestFactory from pytest_django.fixtures import SettingsWrapper - -from custom_auth.serializers import CustomUserCreateSerializer +from rest_framework.exceptions import PermissionDenied +from rest_framework.serializers import ModelSerializer + +from custom_auth.constants import ( + USER_REGISTRATION_WITHOUT_INVITE_ERROR_MESSAGE, +) +from custom_auth.serializers import ( + CustomUserCreateSerializer, + InviteLinkValidationMixin, +) +from organisations.invites.models import InviteLink from users.models import FFAdminUser, SignUpType user_dict = { @@ -70,6 +80,7 @@ def test_CustomUserCreateSerializer_calls_is_authentication_method_valid_correct def test_CustomUserCreateSerializer_allows_registration_if_sign_up_type_is_invite_link( + invite_link: InviteLink, db: None, settings: SettingsWrapper, rf: RequestFactory, @@ -80,6 +91,7 @@ def test_CustomUserCreateSerializer_allows_registration_if_sign_up_type_is_invit data = { **user_dict, "sign_up_type": SignUpType.INVITE_LINK.value, + "invite_hash": invite_link.hash, } serializer = CustomUserCreateSerializer( @@ -92,3 +104,52 @@ def test_CustomUserCreateSerializer_allows_registration_if_sign_up_type_is_invit # Then assert user + + +def test_invite_link_validation_mixin_validate_fails_if_invite_link_hash_not_provided( + settings: SettingsWrapper, + db: None, +) -> None: + # Given + settings.ALLOW_REGISTRATION_WITHOUT_INVITE = False + + class TestSerializer(InviteLinkValidationMixin, ModelSerializer): + class Meta: + model = FFAdminUser + fields = ("sign_up_type",) + + serializer = TestSerializer(data={"sign_up_type": SignUpType.INVITE_LINK.value}) + + # When + with pytest.raises(PermissionDenied) as exc_info: + serializer.is_valid(raise_exception=True) + + # Then + assert exc_info.value.detail == USER_REGISTRATION_WITHOUT_INVITE_ERROR_MESSAGE + + +def test_invite_link_validation_mixin_validate_fails_if_invite_link_hash_not_valid( + invite_link: InviteLink, + settings: SettingsWrapper, +) -> None: + # Given + settings.ALLOW_REGISTRATION_WITHOUT_INVITE = False + + class TestSerializer(InviteLinkValidationMixin, ModelSerializer): + class Meta: + model = FFAdminUser + fields = ("sign_up_type",) + + serializer = TestSerializer( + data={ + "sign_up_type": SignUpType.INVITE_LINK.value, + "invite_hash": "invalid-hash", + } + ) + + # When + with pytest.raises(PermissionDenied) as exc_info: + serializer.is_valid(raise_exception=True) + + # Then + assert exc_info.value.detail == USER_REGISTRATION_WITHOUT_INVITE_ERROR_MESSAGE