diff --git a/authentication/helpers.py b/authentication/helpers.py index 4a5c7af8..ae3b5980 100644 --- a/authentication/helpers.py +++ b/authentication/helpers.py @@ -1,20 +1,19 @@ -import requests +import base64 import json import time -import requests -import base64 + import ed25519 -from eth_account.messages import encode_defunct -from web3 import Web3 -from eth_account import Account +import requests +from django.apps import apps from django.core.exceptions import ValidationError from django.core.validators import RegexValidator -from django.apps import apps +from eth_account import Account +from eth_account.messages import encode_defunct -def verify_signature_eth_scheme(address, signature): +def verify_signature_eth_scheme(address, message, signature): try: - digest = encode_defunct(text=address) + digest = encode_defunct(text=message) signer = Account.recover_message(digest, signature=signature) if signer == address: return True @@ -30,10 +29,12 @@ def __init__(self, app) -> None: self.app = app def create_verification_link(self, contextId): - return f"https://app.brightid.org/link-verification/http:%2F%2Fnode.brightid.org/{self.app}/{str(contextId).lower()}" + return f"https://app.brightid.org/link-verification/http:%2F\ + %2Fnode.brightid.org/{self.app}/{str(contextId).lower()}" def create_qr_content(self, contextId): - return f"brightid://link-verification/http:%2f%2fnode.brightid.org/{self.app}/{str(contextId).lower()}" # TODO + return f"brightid://link-verification/http:%2f%2fnode.bright\ + id.org/{self.app}/{str(contextId).lower()}" # TODO def get_verification_status(self, context_id, verification): if verification == "BrightID" or verification == "Meet": @@ -44,7 +45,8 @@ def get_verification_status(self, context_id, verification): raise ValueError("Invalid verification type") # get list of context ids from brightId - endpoint = f"https://aura-node.brightid.org/brightid/v5/verifications/{self.app}/{context_id}?verification={verification_type}" + endpoint = f"https://aura-node.brightid.org/brightid/v5/veri\ + fications/{self.app}/{context_id}?verification={verification_type}" # print("endpoint: ", endpoint) bright_response = requests.get(endpoint) # decode response @@ -109,7 +111,8 @@ def is_username_valid_and_available(username): # Check if the string matches the required format validator = RegexValidator( regex=r"^(?=.*[a-zA-Z])([\w.@+-]{3,150})$", - message="Username must be more than 2 characters, contain at least one letter, and only contain letters, digits and @/./+/-/_.", + message="Username must be more than 2 characters, contain at \ + least one letter, and only contain letters, digits and @/./+/-/_.", ) try: @@ -117,7 +120,8 @@ def is_username_valid_and_available(username): except ValidationError: return ( False, - "Username must be more than 2 characters, contain at least one letter, and only contain letters, digits and @/./+/-/_.", + "Username must be more than 2 characters, contain at least one \ + letter, and only contain letters, digits and @/./+/-/_.", "validation_error", ) diff --git a/authentication/migrations/0018_alter_wallet_unique_together_wallet_primary.py b/authentication/migrations/0018_alter_wallet_unique_together_wallet_primary.py new file mode 100644 index 00000000..19b2de1e --- /dev/null +++ b/authentication/migrations/0018_alter_wallet_unique_together_wallet_primary.py @@ -0,0 +1,22 @@ +# Generated by Django 4.0.4 on 2023-09-27 11:03 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('authentication', '0017_alter_userprofile_username'), + ] + + operations = [ + migrations.AlterUniqueTogether( + name='wallet', + unique_together={('wallet_type', 'address')}, + ), + migrations.AddField( + model_name='wallet', + name='primary', + field=models.BooleanField(default=False), + ), + ] diff --git a/authentication/migrations/0019_alter_wallet_unique_together.py b/authentication/migrations/0019_alter_wallet_unique_together.py new file mode 100644 index 00000000..599d3da1 --- /dev/null +++ b/authentication/migrations/0019_alter_wallet_unique_together.py @@ -0,0 +1,17 @@ +# Generated by Django 4.0.4 on 2023-09-28 03:27 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('authentication', '0018_alter_wallet_unique_together_wallet_primary'), + ] + + operations = [ + migrations.AlterUniqueTogether( + name='wallet', + unique_together=set(), + ), + ] diff --git a/authentication/migrations/0021_merge_20231126_1858.py b/authentication/migrations/0021_merge_20231126_1858.py new file mode 100644 index 00000000..61fb510c --- /dev/null +++ b/authentication/migrations/0021_merge_20231126_1858.py @@ -0,0 +1,14 @@ +# Generated by Django 4.0.4 on 2023-11-26 18:58 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('authentication', '0019_alter_wallet_unique_together'), + ('authentication', '0020_remove_userprofile_is_new_by_wallet'), + ] + + operations = [ + ] diff --git a/authentication/migrations/0022_remove_wallet_primary.py b/authentication/migrations/0022_remove_wallet_primary.py new file mode 100644 index 00000000..b42c8c8a --- /dev/null +++ b/authentication/migrations/0022_remove_wallet_primary.py @@ -0,0 +1,17 @@ +# Generated by Django 4.0.4 on 2023-11-30 12:02 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('authentication', '0021_merge_20231126_1858'), + ] + + operations = [ + migrations.RemoveField( + model_name='wallet', + name='primary', + ), + ] diff --git a/authentication/migrations/0023_wallet_created_at.py b/authentication/migrations/0023_wallet_created_at.py new file mode 100644 index 00000000..59f5f9e8 --- /dev/null +++ b/authentication/migrations/0023_wallet_created_at.py @@ -0,0 +1,18 @@ +# Generated by Django 4.0.4 on 2023-11-30 15:49 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('authentication', '0022_remove_wallet_primary'), + ] + + operations = [ + migrations.AddField( + model_name='wallet', + name='created_at', + field=models.DateTimeField(auto_now_add=True, null=True), + ), + ] diff --git a/authentication/models.py b/authentication/models.py index 5bcba020..48958590 100644 --- a/authentication/models.py +++ b/authentication/models.py @@ -19,6 +19,14 @@ def get_or_create(self, first_context_id): return _profile +# class WalletManager(models.Manager): +# def get_primary_wallet(self): +# try: +# self.get(primary=True, wallet_type="EVM") +# except Wallet.DoesNotExist: +# return None + + class UserProfile(models.Model): user = models.OneToOneField(User, on_delete=models.PROTECT, related_name="profile") initial_context_id = models.CharField(max_length=512, unique=True) @@ -68,6 +76,9 @@ def is_aura_verified(self): return is_verified + def owns_wallet(self, wallet_address): + return self.wallets.filter(address=wallet_address).exists() + def save(self, *args, **kwargs): super().save(*args, **kwargs) @@ -94,12 +105,11 @@ class Wallet(models.Model): UserProfile, on_delete=models.PROTECT, related_name="wallets" ) address = models.CharField(max_length=512, unique=True) + # primary = models.BooleanField(default=False, null=False, blank=False) + created_at = models.DateTimeField(auto_now_add=True, null=True, blank=True) - class Meta: - unique_together = (("wallet_type", "user_profile"),) + # objects = WalletManager() def __str__(self): - return ( - f"{self.wallet_type} Wallet for profile with contextId " - f"{self.user_profile.initial_context_id}" - ) + return f"{self.wallet_type} Wallet for profile with contextId \ + {self.user_profile.initial_context_id}" diff --git a/authentication/permissions.py b/authentication/permissions.py index c871a94c..15862660 100644 --- a/authentication/permissions.py +++ b/authentication/permissions.py @@ -17,3 +17,12 @@ class IsAuraVerified(BasePermission): def has_permission(self, request, view): return bool(request.user.profile.is_aura_verified) + + +class IsOwner(BasePermission): + """ + Just owner has can access + """ + + def has_object_permission(self, request, view, obj): + return obj.user_profile == request.user.profile diff --git a/authentication/serializers.py b/authentication/serializers.py index eac73fc0..16c4ffaf 100644 --- a/authentication/serializers.py +++ b/authentication/serializers.py @@ -1,6 +1,7 @@ from rest_framework import serializers from rest_framework.authtoken.models import Token +from authentication.helpers import verify_signature_eth_scheme from authentication.models import UserProfile, Wallet @@ -22,36 +23,34 @@ def update(self, instance, validated_data): pass -# class SetUsernameSerializer(serializers.Serializer): -# username = UsernameRequestSerializer.username +class WalletSerializer(serializers.ModelSerializer): + signature = serializers.CharField(required=True, max_length=150, write_only=True) + message = serializers.CharField(required=True, max_length=150, write_only=True) -# def save(self, user_profile): -# username = self.validated_data.get("username") + class Meta: + model = Wallet + fields = ["pk", "wallet_type", "address", "signature", "message"] -# try: -# user_profile.username = username -# user_profile.save() -# return {"message": "Username Set"} + def is_valid(self, raise_exception=False): + super_is_validated = super().is_valid(raise_exception) -# except IntegrityError: -# raise ValidationError( -# {"message": "This username already exists. Try another one."} -# ) + address = self.validated_data.get("address") + message = self.validated_data.get("message") + signature = self.validated_data.get("signature") + signature_is_valid = verify_signature_eth_scheme(address, message, signature) -class WalletSerializer(serializers.ModelSerializer): - class Meta: - model = Wallet - fields = [ - "pk", - "wallet_type", - "address", - ] + if not signature_is_valid and raise_exception: + raise serializers.ValidationError("Signature is not valid") + + self.validated_data.pop("signature", None) + self.validated_data.pop("message", None) + + return super_is_validated and signature_is_valid class ProfileSerializer(serializers.ModelSerializer): wallets = WalletSerializer(many=True, read_only=True) - # total_round_claims_remaining = serializers.SerializerMethodField() token = serializers.SerializerMethodField() class Meta: @@ -63,7 +62,6 @@ class Meta: "initial_context_id", "is_meet_verified", "is_aura_verified", - # "total_round_claims_remaining", "wallets", ] @@ -71,11 +69,6 @@ def get_token(self, instance): token, bol = Token.objects.get_or_create(user=instance.user) return token.key - # def get_total_round_claims_remaining(self, instance): - # gs = GlobalSettings.objects.first() - # if gs is not None: - # return gs.gastap_round_claim_limit - LimitedChainClaimManager.get_total_round_claims(instance) - class SimpleProfilerSerializer(serializers.ModelSerializer): wallets = WalletSerializer(many=True, read_only=True) diff --git a/authentication/tests.py b/authentication/tests.py index c42a414a..080e74b1 100644 --- a/authentication/tests.py +++ b/authentication/tests.py @@ -1,17 +1,25 @@ from unittest.mock import patch + from django.urls import reverse from django.utils import timezone -from django.urls import reverse -from django.contrib.auth.models import User -from rest_framework.test import APITestCase +from eth_account.messages import encode_defunct from rest_framework.authtoken.models import Token -from rest_framework.status import HTTP_403_FORBIDDEN, HTTP_409_CONFLICT, HTTP_200_OK -from authentication.models import UserProfile +from rest_framework.status import ( + HTTP_200_OK, + HTTP_201_CREATED, + HTTP_400_BAD_REQUEST, + HTTP_403_FORBIDDEN, + HTTP_409_CONFLICT, +) +from rest_framework.test import APITestCase +from web3 import Account + +from authentication.models import UserProfile, Wallet from faucet.models import ClaimReceipt -### get address as username and signed address as password and verify signature +# get address as username and signed address as password and verify signature -### retrieve address from brightID +# retrieve address from brightID address = "0x90F8bf6A479f320ead074411a4B0e7944Ea8c9C1" fund_manager = "0x5802f1035AbB8B191bc12Ce4668E3815e8B7Efa0" @@ -50,6 +58,13 @@ def create_verified_user() -> UserProfile: return user +def create_new_wallet(user_profile, _address, wallet_type) -> Wallet: + wallet, is_create = Wallet.objects.get_or_create( + user_profile=user_profile, address=_address, wallet_type=wallet_type + ) + return wallet + + class CheckUsernameTestCase(APITestCase): def setUp(self) -> None: self.endpoint = "AUTHENTICATION:check-username" @@ -120,7 +135,6 @@ def setUp(self) -> None: self._address = "0x3E5e9111Ae8eB78Fe1CC3bb8915d5D461F3Ef9A9" self.endpoint = reverse("AUTHENTICATION:login-user") - @patch("faucet.views.ClaimMaxView.wallet_address_is_set", lambda a: (True, None)) @patch( "authentication.helpers.BrightIDSoulboundAPIInterface.get_verification_status", lambda a, b, c: (True, None), @@ -219,75 +233,143 @@ def test_become_sponsor(self): self.assertEqual(response.status_code, HTTP_200_OK) -class TestSetWalletAddress(APITestCase): +class TestListCreateWallet(APITestCase): def setUp(self) -> None: self.password = "test" self._address = "0x3E5e9111Ae8eB78Fe1CC3bb8915d5D461G3Ef9A9" - self.endpoint = reverse("AUTHENTICATION:set-wallet-user") + self.private_key_test1 = ( + "2534fa7456f3aaf0f72ece66434a7d380d08cee47d8a2db56c08a3048890b50f" + ) + self.public_key_test1 = "0xD8Be96705B9fb518eEb2758719831BAF6C6E5E05" + self.endpoint = reverse("AUTHENTICATION:wallets-user") self.user_profile = create_new_user() self.client.force_authenticate(user=self.user_profile.user) def test_invalid_arguments_provided_should_fail(self): response = self.client.post(self.endpoint) - self.assertEqual(response.status_code, HTTP_403_FORBIDDEN) + self.assertEqual(response.status_code, HTTP_400_BAD_REQUEST) response = self.client.post(self.endpoint, data={"address": False}) - self.assertEqual(response.status_code, HTTP_403_FORBIDDEN) + self.assertEqual(response.status_code, HTTP_400_BAD_REQUEST) response = self.client.post(self.endpoint, data={"wallet_type": False}) - self.assertEqual(response.status_code, HTTP_403_FORBIDDEN) + self.assertEqual(response.status_code, HTTP_400_BAD_REQUEST) + + def test_create_wallet_address(self): + message = "test-message" + + hashed_message = encode_defunct(text=message) + account = Account.from_key(self.private_key_test1) + signed_message = account.sign_message(hashed_message) + signature = signed_message.signature.hex() - def test_set_same_address_for_multiple_users_should_fail(self): response = self.client.post( - self.endpoint, data={"address": self._address, "wallet_type": "EVM"} + self.endpoint, + data={ + "address": self.public_key_test1, + "wallet_type": "EVM", + "message": message, + "signature": signature, + }, ) - self.assertEqual(response.status_code, HTTP_200_OK) + self.assertEqual(response.status_code, HTTP_201_CREATED) + + def test_create_wallet_address_wrong_signature(self): + message = "test-message" response = self.client.post( - self.endpoint, data={"address": self._address, "wallet_type": "Solana"} + self.endpoint, + data={ + "address": self.public_key_test1, + "wallet_type": "EVM", + "message": message, + "signature": message, + }, ) - self.assertEqual(response.status_code, HTTP_403_FORBIDDEN) - def test_not_existing_wallet_then_create_and_set_address_for_that_is_ok(self): + self.assertEqual(response.status_code, HTTP_400_BAD_REQUEST) + + def test_create_same_address_twice(self): + message = "test-message" + + hashed_message = encode_defunct(text=message) + account = Account.from_key(self.private_key_test1) + signed_message = account.sign_message(hashed_message) + signature = signed_message.signature.hex() + response = self.client.post( - self.endpoint, data={"address": self._address, "wallet_type": "EVM"} + self.endpoint, + data={ + "address": self.public_key_test1, + "wallet_type": "EVM", + "message": message, + "signature": signature, + }, ) - self.assertEqual(response.status_code, HTTP_200_OK) + self.assertEqual(response.status_code, HTTP_201_CREATED) + + response = self.client.post( + self.endpoint, + data={ + "address": self.public_key_test1, + "wallet_type": "EVM", + "message": message, + "signature": signature, + }, + ) + self.assertEqual(response.status_code, HTTP_400_BAD_REQUEST) + + def test_get_wallet_list(self): + message = "test-message" + + hashed_message = encode_defunct(text=message) + account = Account.from_key(self.private_key_test1) + signed_message = account.sign_message(hashed_message) + signature = signed_message.signature.hex() + + response = self.client.post( + self.endpoint, + data={ + "address": self.public_key_test1, + "wallet_type": "EVM", + "message": message, + "signature": signature, + }, + ) + self.assertEqual(response.status_code, HTTP_201_CREATED) + response2 = self.client.get(self.endpoint) + self.assertEqual(response2.status_code, HTTP_200_OK) + self.assertEqual(len(response2.data), 1) + self.assertEqual(response2.data[0].get("address"), self.public_key_test1) -# class TestGetWalletAddress(APITestCase): -# def setUp(self) -> None: -# self.password = "test" -# self._address = "0x3E5e9111Ae8eB78Fe1CC3bb8915d5D461F3Ef9A9" -# self.endpoint_set = reverse('AUTHENTICATION:set-wallet-user') -# self.endpoint_get = reverse('AUTHENTICATION:get-wallet-user') -# self.user_profile = create_new_user() -# self.client.force_authenticate(user=self.user_profile.user) -# -# def test_get_existing_wallet_is_ok(self): -# response = self.client.post(self.endpoint_set, data={'address': self._address, 'wallet_type': "EVM"}) -# self.assertEqual(response.status_code, HTTP_200_OK) -# -# response = self.client.post(self.endpoint_get, data={'wallet_type': "EVM"}) -# self.assertEqual(response.status_code, HTTP_200_OK) -# -# def test_not_existing_wallet_should_fail_getting_profile(self): -# response = self.client.post(self.endpoint_get, data={'wallet_type': "EVM"}) -# self.assertEqual(response.status_code, HTTP_403_FORBIDDEN) - - -class TestGetWalletsView(APITestCase): + +class TestWalletView(APITestCase): def setUp(self) -> None: self.password = "test" self._address = "0x3E5e9111Ae8eB78Fe1CC3bb8915d5D461F3Ef9A9" - self.endpoint = reverse("AUTHENTICATION:get-wallets-user") self.user_profile = create_new_user() + create_new_wallet(self.user_profile, self._address, "EVM") + self.endpoint = reverse("AUTHENTICATION:wallets-user") self.client.force_authenticate(user=self.user_profile.user) def test_request_to_this_api_is_ok(self): response = self.client.get(self.endpoint) self.assertEqual(response.status_code, HTTP_200_OK) + # def test_change_primary_ture(self): + # response: Response = self.client.patch(self.endpoint, data={'primary': True}) + # self.assertEqual(response.status_code, HTTP_200_OK) + # self.assertEqual(response.data.get('primary'), True) + + # def test_access_to_another_user_wallet(self): + # _address = '0x3E5e9111Ae8eB78Fe1CC3bb8915d5D461F3Ef9A2' + # other_user = create_new_user(_address) + # wallet = create_new_wallet(other_user, _address, 'EVM') + # _endpoint = reverse('AUTHENTICATION:wallet-user', kwargs={'pk': wallet.pk}) + # response = self.client.get(_endpoint) + # self.assertEqual(response.status_code, HTTP_404_NOT_FOUND) + class TestGetProfileView(APITestCase): def setUp(self) -> None: diff --git a/authentication/urls.py b/authentication/urls.py index dbfd02e4..2f9f80e4 100644 --- a/authentication/urls.py +++ b/authentication/urls.py @@ -1,5 +1,14 @@ from django.urls import path -from authentication.views import * + +from authentication.views import ( + CheckUsernameView, + GetProfileView, + LoginView, + SetUsernameView, + SponsorView, + UserProfileCountView, + WalletListCreateView, +) app_name = "AUTHENTICATION" @@ -17,26 +26,18 @@ name="check-username", ), path( - "user/set-wallet/", - SetWalletAddressView.as_view(), - name="set-wallet-user", - ), - path( - "user/get-wallet/", - GetWalletAddressView.as_view(), - name="get-wallet-user", - ), - path( - "user/delete-wallet/", - DeleteWalletAddressView.as_view(), - name="delete-wallet-user", - ), - path( - "user/get-wallets/", - GetWalletsView.as_view(), - name="get-wallets-user", + "user/wallets/", + WalletListCreateView.as_view(), + name="wallets-user", ), + # path( + # "user/wallets//", + # WalletView.as_view(), + # name="wallet-user", + # ), path("user/info/", GetProfileView.as_view(), name="get-profile-user"), path("user/sponsor/", SponsorView.as_view(), name="sponsor-user"), - path("user/history-count/", UserHistoryCountView.as_view(), name="user-history-count") + path( + "user/history-count/", UserProfileCountView.as_view(), name="user-history-count" + ), ] diff --git a/authentication/views.py b/authentication/views.py index 9ca512f5..58afe73a 100644 --- a/authentication/views.py +++ b/authentication/views.py @@ -1,25 +1,32 @@ -import time from django.db import IntegrityError -from rest_framework.permissions import IsAuthenticated -from rest_framework.generics import CreateAPIView, RetrieveAPIView, ListAPIView -from authentication.models import UserProfile, Wallet +from django_filters.rest_framework import DjangoFilterBackend +from drf_yasg import openapi +from drf_yasg.utils import swagger_auto_schema from rest_framework.authtoken.models import Token -from rest_framework.authtoken.views import ObtainAuthToken +from rest_framework.generics import ( + CreateAPIView, + ListAPIView, + ListCreateAPIView, + RetrieveAPIView, +) +from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response from rest_framework.views import APIView -from drf_yasg.utils import swagger_auto_schema + from authentication.helpers import ( BRIGHTID_SOULDBOUND_INTERFACE, - verify_signature_eth_scheme, is_username_valid_and_available, + verify_signature_eth_scheme, ) -from drf_yasg import openapi +from authentication.models import UserProfile, Wallet from authentication.serializers import ( - UsernameRequestSerializer, MessageResponseSerializer, ProfileSerializer, - WalletSerializer, UserHistoryCountSerializer, + UserHistoryCountSerializer, + UsernameRequestSerializer, + WalletSerializer, ) +from core.filters import IsOwnerFilterBackend class UserProfileCountView(ListAPIView): @@ -75,19 +82,21 @@ def post(self, request, *args, **kwargs): if BRIGHTID_SOULDBOUND_INTERFACE.sponsor(str(address)) is not True: return Response( { - "message": "We are in the process of sponsoring you. Please try again in five minutes." + "message": "We are in the process of sponsoring you. \ + Please try again in five minutes." }, status=403, ) else: return Response( { - "message": "We have requested to sponsor you on BrightID. Please try again in five minutes." + "message": "We have requested to sponsor you on BrightID\ + . Please try again in five minutes." }, status=409, ) - verified_signature = verify_signature_eth_scheme(address, signature) + verified_signature = verify_signature_eth_scheme(address, address, signature) if not verified_signature: return Response({"message": "Invalid signature"}, status=403) @@ -102,18 +111,21 @@ def post(self, request, *args, **kwargs): context_ids = [] - if is_meet_verified == False and is_aura_verified == False: + if is_meet_verified is False and is_aura_verified is False: if meet_context_ids == 3: # is not verified context_ids = address elif aura_context_ids == 4: # is not linked return Response( { - "message": "Something went wrong with the linking process. please link BrightID with Unitap.\nIf the problem persists, clear your browser cache and try again." + "message": "Something went wrong with the linking process. \ + please link BrightID with Unitap.\n" + "If the problem persists, clear your browser cache \ + and try again." }, status=403, ) - elif is_meet_verified == True or is_aura_verified == True: + elif is_meet_verified is True or is_aura_verified is True: if meet_context_ids is not None: context_ids = meet_context_ids elif aura_context_ids is not None: @@ -204,7 +216,8 @@ class CheckUsernameView(CreateAPIView): schema=MessageResponseSerializer(), ), 403: openapi.Response( - description="Username must be more than 2 characters, contain at least one letter, and only contain letters, digits and @/./+/-/_.", + description="Username must be more than 2 characters, contain at least" + " one letter, and only contain letters, digits and @/./+/-/_.", schema=MessageResponseSerializer(), ), }, @@ -245,101 +258,15 @@ def post(self, request, *args, **kwargs): return Response(request_serializer.errors, status=400) -class SetWalletAddressView(CreateAPIView): - permission_classes = [IsAuthenticated] - - def post(self, request, *args, **kwargs): - address = request.data.get("address", None) - wallet_type = request.data.get("wallet_type", None) - if not address or not wallet_type: - return Response({"message": "Invalid request"}, status=403) - - user_profile = request.user.profile - - try: - w = Wallet.objects.get(user_profile=user_profile, wallet_type=wallet_type) - w.address = address - w.save() - - return Response( - {"message": f"{wallet_type} wallet address updated"}, status=200 - ) - - except Wallet.DoesNotExist: - try: - Wallet.objects.create( - user_profile=user_profile, wallet_type=wallet_type, address=address - ) - return Response( - {"message": f"{wallet_type} wallet address set"}, status=200 - ) - # catch unique constraint error - except IntegrityError: - return Response( - { - "message": f"{wallet_type} wallet address is not unique. use another address" - }, - status=403, - ) - - -class GetWalletAddressView(RetrieveAPIView): - permission_classes = [IsAuthenticated] - - def get(self, request, *args, **kwargs): - wallet_type = request.data.get("wallet_type", None) - if not wallet_type: - return Response({"message": "Invalid request"}, status=403) - - # get user profile - user_profile = request.user.profile - - try: - # check if wallet already exists - wallet = Wallet.objects.get( - user_profile=user_profile, wallet_type=wallet_type - ) - return Response({"address": wallet.address}, status=200) - - except Wallet.DoesNotExist: - return Response( - {"message": f"{wallet_type} wallet address not set"}, status=403 - ) - - -class DeleteWalletAddressView(RetrieveAPIView): - permission_classes = [IsAuthenticated] - - def get(self, request, *args, **kwargs): - wallet_type = request.data.get("wallet_type", None) - if not wallet_type: - return Response({"message": "Invalid request"}, status=403) - - # get user profile - user_profile = request.user.profile - - try: - # check if wallet already exists - wallet = Wallet.objects.get( - user_profile=user_profile, wallet_type=wallet_type - ) - wallet.delete() - return Response( - {"message": f"{wallet_type} wallet address deleted"}, status=200 - ) - - except Wallet.DoesNotExist: - return Response( - {"message": f"{wallet_type} wallet address not set"}, status=403 - ) - - -class GetWalletsView(ListAPIView): +class WalletListCreateView(ListCreateAPIView): + queryset = Wallet.objects.all() permission_classes = [IsAuthenticated] serializer_class = WalletSerializer + filter_backends = [IsOwnerFilterBackend, DjangoFilterBackend] + filterset_fields = ["wallet_type"] - def get_queryset(self): - return Wallet.objects.filter(user_profile=self.request.user.profile) + def perform_create(self, serializer): + serializer.save(user_profile=self.request.user.profile) class UserHistoryCountView(RetrieveAPIView): @@ -348,11 +275,16 @@ class UserHistoryCountView(RetrieveAPIView): def get_object(self): from faucet.models import ClaimReceipt + user_profile = self.request.user.profile data = { - 'gas_claim': user_profile.claims.filter(_status=ClaimReceipt.VERIFIED).count(), - 'token_claim': user_profile.tokentap_claims.filter(status=ClaimReceipt.VERIFIED).count(), - 'raffle_win': user_profile.raffle_entries.count() + "gas_claim": user_profile.claims.filter( + _status=ClaimReceipt.VERIFIED + ).count(), + "token_claim": user_profile.tokentap_claims.filter( + status=ClaimReceipt.VERIFIED + ).count(), + "raffle_win": user_profile.raffle_entries.count(), } return data @@ -361,18 +293,5 @@ class GetProfileView(RetrieveAPIView): permission_classes = [IsAuthenticated] serializer_class = ProfileSerializer - # def get(self, request, *args, **kwargs): - # user = request.user - - # token, bol = Token.objects.get_or_create(user=user) - # print("token", token) - - # # return Response({"token": token.key}, status=200) - # # return token and profile using profile serializer for profile - # return Response( - # {"token": token.key, "profile": ProfileSerializer(user.profile).data}, - # status=200, - # ) - def get_object(self): return self.request.user.profile diff --git a/core/utils.py b/core/utils.py index 3be210b6..0e4be0c5 100644 --- a/core/utils.py +++ b/core/utils.py @@ -3,10 +3,9 @@ from contextlib import contextmanager import pytz +from django.core.cache import cache from eth_account.messages import encode_defunct from web3 import Account, Web3 -from django.core.cache import cache -from web3 import Web3 from web3.contract.contract import Contract, ContractFunction from web3.logs import DISCARD, IGNORE, STRICT, WARN from web3.middleware import geth_poa_middleware @@ -178,11 +177,11 @@ def to_checksum_address(address: str): return Web3.to_checksum_address(address.lower()) @staticmethod - def hash_message(user, token, amount, nonce): + def hash_message(address, token, amount, nonce): message_hash = Web3().solidity_keccak( ["address", "address", "uint256", "uint32"], [ - Web3.to_checksum_address(user), + Web3.to_checksum_address(address), Web3.to_checksum_address(token), amount, nonce, diff --git a/core/validators.py b/core/validators.py new file mode 100644 index 00000000..45fb2339 --- /dev/null +++ b/core/validators.py @@ -0,0 +1,27 @@ +from django.core.exceptions import BadRequest +from solders.pubkey import Pubkey + +from core.utils import Web3Utils + +from .models import Chain, NetworkTypes + + +def address_validator(address, chain: Chain): + is_address_valid = False + if chain.chain_type == NetworkTypes.LIGHTNING: + return + elif chain.chain_type == NetworkTypes.EVM: + try: + Web3Utils.to_checksum_address(address) + return + except ValueError: + is_address_valid = False + elif chain.chain_type == NetworkTypes.SOLANA: + try: + pub_key = Pubkey.from_string(address) + is_address_valid = pub_key.is_on_curve() + except ValueError: + is_address_valid = False + + if not is_address_valid: + raise BadRequest(f"Address: {address} is not valid") diff --git a/faucet/faucet_manager/claim_manager.py b/faucet/faucet_manager/claim_manager.py index f0ee45de..685ea47b 100644 --- a/faucet/faucet_manager/claim_manager.py +++ b/faucet/faucet_manager/claim_manager.py @@ -2,6 +2,7 @@ import logging from abc import ABC +import rest_framework.exceptions from django.db import transaction from django.utils import timezone @@ -19,7 +20,7 @@ class ClaimManager(ABC): @abc.abstractmethod - def claim(self, amount, passive_address=None) -> ClaimReceipt: + def claim(self, amount, to_address=None) -> ClaimReceipt: pass @abc.abstractmethod @@ -35,14 +36,14 @@ def __init__(self, credit_strategy: CreditStrategy): def fund_manager(self): return EVMFundManager(self.credit_strategy.chain) - def claim(self, amount, passive_address=None): + def claim(self, amount, to_address=None): with transaction.atomic(): user_profile = UserProfile.objects.select_for_update().get( pk=self.credit_strategy.user_profile.pk ) self.assert_pre_claim_conditions(amount, user_profile) return self.create_pending_claim_receipt( - amount, passive_address + amount, to_address ) # all pending claims will be processed periodically def assert_pre_claim_conditions(self, amount, user_profile): @@ -54,14 +55,17 @@ def assert_pre_claim_conditions(self, amount, user_profile): _status=ClaimReceipt.PENDING, ).exists() - def create_pending_claim_receipt(self, amount, passive_address): + def create_pending_claim_receipt(self, amount, to_address): + if to_address is None: + raise rest_framework.exceptions.ParseError("wallet address is required") + return ClaimReceipt.objects.create( chain=self.credit_strategy.chain, user_profile=self.credit_strategy.user_profile, datetime=timezone.now(), amount=amount, _status=ClaimReceipt.PENDING, - passive_address=passive_address, + to_address=to_address, ) def get_credit_strategy(self) -> CreditStrategy: @@ -97,19 +101,19 @@ def assert_pre_claim_conditions(self, amount, user_profile): class LightningClaimManger(LimitedChainClaimManager): - def claim(self, amount, passive_address): + def claim(self, amount, to_address): try: lnpay_client = LNPayClient( self.credit_strategy.chain.rpc_url_private, self.credit_strategy.chain.wallet.main_key, self.credit_strategy.chain.fund_manager_address, ) - decoded_invoice = lnpay_client.decode_invoice(passive_address) + decoded_invoice = lnpay_client.decode_invoice(to_address) except Exception as e: logging.error(e) raise AssertionError("Could not decode the invoice") assert int(decoded_invoice["num_satoshis"]) == amount, "Invalid amount" - return super().claim(amount, passive_address) + return super().claim(amount, to_address) class ClaimManagerFactory: diff --git a/faucet/faucet_manager/lnpay_manager/utility_helpers.py b/faucet/faucet_manager/lnpay_manager/utility_helpers.py index 15412191..6d64031a 100644 --- a/faucet/faucet_manager/lnpay_manager/utility_helpers.py +++ b/faucet/faucet_manager/lnpay_manager/utility_helpers.py @@ -1,7 +1,7 @@ -import pprint -import requests import json +import requests + def get_request(location): from .lnpay_main import __ENDPOINT_URL__, __PUBLIC_API_KEY__, __VERSION__ @@ -53,5 +53,4 @@ def post_request(location, params): data = json.dumps(params) r = requests.post(url=endpoint, data=data, headers=headers) - print("salam", r.text) return r.json() diff --git a/faucet/migrations/0065_rename_passive_address_claimreceipt_to_address.py b/faucet/migrations/0065_rename_passive_address_claimreceipt_to_address.py new file mode 100644 index 00000000..353be86c --- /dev/null +++ b/faucet/migrations/0065_rename_passive_address_claimreceipt_to_address.py @@ -0,0 +1,30 @@ +# Generated by Django 4.0.4 on 2023-11-30 14:21 +from django.db.models import Q +from django.db import migrations + + +def set_to_address(apps, schema_editor): + # set the wallet address in claim receipts + CR = apps.get_model("faucet", "ClaimReceipt") + + for c in CR.objects.filter(Q(to_address__isnull=True) | Q(to_address="")): + try: + c.to_address = c.user_profile.wallets.get(wallet_type="EVM").address + c.save() + except Exception as e: + print("set to_address error", e) + + +class Migration(migrations.Migration): + dependencies = [ + ("faucet", "0064_merge_20231108_1331"), + ] + + operations = [ + migrations.RenameField( + model_name="claimreceipt", + old_name="passive_address", + new_name="to_address", + ), + migrations.RunPython(set_to_address), + ] diff --git a/faucet/migrations/0067_merge_20231226_1046.py b/faucet/migrations/0067_merge_20231226_1046.py new file mode 100644 index 00000000..b183cd27 --- /dev/null +++ b/faucet/migrations/0067_merge_20231226_1046.py @@ -0,0 +1,14 @@ +# Generated by Django 4.0.4 on 2023-12-26 10:46 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('faucet', '0065_rename_passive_address_claimreceipt_to_address'), + ('faucet', '0066_rename_is_one_time_chain_is_one_time_claim'), + ] + + operations = [ + ] diff --git a/faucet/models.py b/faucet/models.py index 61365d73..f1048a8d 100644 --- a/faucet/models.py +++ b/faucet/models.py @@ -158,7 +158,7 @@ class ClaimReceipt(models.Model): _status = models.CharField(max_length=30, choices=states, default=PENDING) - passive_address = models.CharField(max_length=512, null=True, blank=True) + to_address = models.CharField(max_length=512, null=True, blank=True) amount = BigNumField() datetime = models.DateTimeField() diff --git a/faucet/serializers.py b/faucet/serializers.py index 89433b76..dd9b6e85 100644 --- a/faucet/serializers.py +++ b/faucet/serializers.py @@ -2,34 +2,6 @@ from faucet.models import Chain, ClaimReceipt, DonationReceipt, GlobalSettings -# class UserSerializer(serializers.ModelSerializer): -# total_weekly_claims_remaining = serializers.SerializerMethodField() - -# class Meta: -# model = BrightUser -# fields = [ -# "pk", -# "context_id", -# "address", -# "verification_url", -# "verification_status", -# "total_weekly_claims_remaining", -# ] -# read_only_fields = ["context_id"] - -# def get_total_weekly_claims_remaining(self, instance): -# gs = GlobalSettings.objects.first() -# if gs is not None: -# return ( -# gs.weekly_chain_claim_limit -# - LimitedChainClaimManager.get_total_weekly_claims(instance) -# ) - -# def create(self, validated_data): -# address = validated_data["address"] -# bright_user = BrightUser.objects.get_or_create(address) -# return bright_user - class GlobalSettingsSerializer(serializers.ModelSerializer): class Meta: @@ -102,9 +74,6 @@ class Meta: class ChainSerializer(serializers.ModelSerializer): - # claimed = serializers.SerializerMethodField() - # unclaimed = serializers.SerializerMethodField() - class Meta: model = Chain fields = [ @@ -121,8 +90,6 @@ class Meta: "modal_url", "gas_image_url", "max_claim_amount", - # "claimed", - # "unclaimed", "total_claims", "total_claims_this_round", "tokentap_contract_address", @@ -133,22 +100,6 @@ class Meta: "is_one_time_claim", ] - # def get_claimed(self, chain) -> int: - # user = self.context["request"].user - - # if not user.is_authenticated: - # return "N/A" - # user_profile = user.profile - # return CreditStrategyFactory(chain, user_profile).get_strategy().get_claimed() - - # def get_unclaimed(self, chain) -> int: - # user = self.context["request"].user - - # if not user.is_authenticated: - # return "N/A" - # user_profile = user.profile - # return CreditStrategyFactory(chain, user_profile).get_strategy().get_unclaimed() - class ReceiptSerializer(serializers.ModelSerializer): chain = SmallChainSerializer() @@ -158,6 +109,7 @@ class Meta: fields = [ "pk", "tx_hash", + "to_address", "chain", "datetime", "amount", @@ -180,19 +132,38 @@ def _validate_chain(self, pk: str): try: chain: Chain = Chain.objects.get(pk=pk, chain_type="EVM") except Chain.DoesNotExist: - raise serializers.ValidationError({"chain": "chain is not EVM or does not exist."}) + raise serializers.ValidationError( + {"chain": "chain is not EVM or does not exist."} + ) return chain class Meta: model = DonationReceipt depth = 1 - fields = ["tx_hash", "chain", "datetime", "total_price", "value", "chain_pk", "status", "user_profile"] - read_only_fields = ["value", "datetime", "total_price", "chain", "status", "user_profile"] + fields = [ + "tx_hash", + "chain", + "datetime", + "total_price", + "value", + "chain_pk", + "status", + "user_profile", + ] + read_only_fields = [ + "value", + "datetime", + "total_price", + "chain", + "status", + "user_profile", + ] class LeaderboardSerializer(serializers.Serializer): username = serializers.CharField(max_length=150, read_only=True) sum_total_price = serializers.CharField(max_length=150, read_only=True) - interacted_chains = serializers.ListField(child=serializers.IntegerField(), read_only=True) - wallet = serializers.CharField(max_length=512, read_only=True) + interacted_chains = serializers.ListField( + child=serializers.IntegerField(), read_only=True + ) rank = serializers.IntegerField(read_only=True, required=False) diff --git a/faucet/tests.py b/faucet/tests.py index eae0d4fa..fb022555 100644 --- a/faucet/tests.py +++ b/faucet/tests.py @@ -12,7 +12,7 @@ from authentication.models import UserProfile, Wallet from brightIDfaucet.settings import DEBUG from faucet.constants import MEMCACHE_LIGHTNING_LOCK_KEY -from faucet.constraints import OptimismClaimingGasConstraint, OptimismDonationConstraint +from faucet.constraints import OptimismDonationConstraint from faucet.faucet_manager.claim_manager import ClaimManagerFactory, SimpleClaimManager from faucet.faucet_manager.credit_strategy import ( RoundCreditStrategy, @@ -47,11 +47,10 @@ LIGHTNING_FUND_MANAGER = os.environ.get("LIGHTNING_FUND_MANAGER") LIGHTNING_RPC_URL = "https://api.lnpay.co/v1/" -LIGHTNING_INVOICE = ( - "lnbc100n1pjxtceppp5q65xc3w8tnnmzkhqgg9c7h4a8hzplm0dppr944upwsq4q62sjeesdqu2askcmr9wssx7e3q2dshgmmnd" - "p5scqzzsxqyz5vqsp5hj2vzha0x4qvuyzrym6ryvxwnccn4kjwa57037dgcshl5ls4tves9qyyssqj24t4j2dkp2r29ptgxqz2etsk0qp8ggwmt" - "20czfu48h5akgme43zevg6x040scjzx3qgtp8mkcg2gurv0hy8d8xm3hhf8k68uefl9sqqqscuvz" -) +LIGHTNING_INVOICE = "lnbc100n1pjxtceppp5q65xc3w8tnnmzkhqgg9c7h4a8hzplm0dppr9\ +44upwsq4q62sjeesdqu2askcmr9wssx7e3q2dshgmmndp5scqzzsxqyz5vqsp5hj2vzha0x4qvuyz\ +rym6ryvxwnccn4kjwa57037dgcshl5ls4tves9qyyssqj24t4j2dkp2r29ptgxqz2etsk0qp8ggwm\ +t20czfu48h5akgme43zevg6x040scjzx3qgtp8mkcg2gurv0hy8d8xm3hhf8k68uefl9sqqqscuvz" def create_new_user( @@ -119,7 +118,9 @@ def create_lightning_chain(wallet) -> Chain: class TestWalletAccount(APITestCase): def setUp(self) -> None: self.key = test_wallet_key - self.wallet = WalletAccount.objects.create(name="Test Wallet", private_key=test_wallet_key) + self.wallet = WalletAccount.objects.create( + name="Test Wallet", private_key=test_wallet_key + ) def test_create_wallet(self): self.assertEqual(WalletAccount.objects.count(), 1) @@ -183,7 +184,9 @@ def test_create_wallet(self): class TestChainInfo(APITestCase): def setUp(self) -> None: - self.wallet = WalletAccount.objects.create(name="Test Wallet", private_key=test_wallet_key) + self.wallet = WalletAccount.objects.create( + name="Test Wallet", private_key=test_wallet_key + ) self.new_user = create_new_user() self.xdai = create_xDai_chain(self.wallet) self.idChain = create_idChain_chain(self.wallet) @@ -231,7 +234,9 @@ def test_list_chains(self): class TestClaim(APITestCase): def setUp(self) -> None: - self.wallet = WalletAccount.objects.create(name="Test Wallet", private_key=test_wallet_key) + self.wallet = WalletAccount.objects.create( + name="Test Wallet", private_key=test_wallet_key + ) self.new_user = create_new_user() self.verified_user = create_new_user() self.x_dai = create_xDai_chain(self.wallet) @@ -264,10 +269,14 @@ def test_x_dai_claimed_be_zero_eth_be_100(self): self.assertEqual(credit_strategy_xdai.get_claimed(), 0) self.assertEqual(credit_strategy_id_chain.get_claimed(), claim_amount) self.assertEqual(credit_strategy_xdai.get_unclaimed(), x_dai_max_claim) - self.assertEqual(credit_strategy_id_chain.get_unclaimed(), eidi_max_claim - claim_amount) + self.assertEqual( + credit_strategy_id_chain.get_unclaimed(), eidi_max_claim - claim_amount + ) def test_claim_manager_fail_if_claim_amount_exceeds_unclaimed(self): - claim_manager_x_dai = SimpleClaimManager(RoundCreditStrategy(self.x_dai, self.new_user)) + claim_manager_x_dai = SimpleClaimManager( + RoundCreditStrategy(self.x_dai, self.new_user) + ) try: claim_manager_x_dai.claim(x_dai_max_claim + 10) @@ -281,7 +290,9 @@ def test_claim_manager_fail_if_claim_amount_exceeds_unclaimed(self): ) def test_claim_unverified_user_should_fail(self): claim_amount = 100 - claim_manager_x_dai = SimpleClaimManager(RoundCreditStrategy(self.x_dai, self.new_user)) + claim_manager_x_dai = SimpleClaimManager( + RoundCreditStrategy(self.x_dai, self.new_user) + ) try: claim_manager_x_dai.claim(claim_amount) @@ -295,14 +306,18 @@ def test_claim_unverified_user_should_fail(self): ) def test_claim_manager_should_claim(self): claim_amount = 100 - claim_manager_x_dai = ClaimManagerFactory(self.x_dai, self.verified_user).get_manager() + claim_manager_x_dai = ClaimManagerFactory( + self.x_dai, self.verified_user + ).get_manager() credit_strategy_x_dai = claim_manager_x_dai.get_credit_strategy() - r = claim_manager_x_dai.claim(claim_amount) + r = claim_manager_x_dai.claim(claim_amount, "0x12345") r._status = ClaimReceipt.VERIFIED r.save() self.assertEqual(credit_strategy_x_dai.get_claimed(), claim_amount) - self.assertEqual(credit_strategy_x_dai.get_unclaimed(), x_dai_max_claim - claim_amount) + self.assertEqual( + credit_strategy_x_dai.get_unclaimed(), x_dai_max_claim - claim_amount + ) @patch( "faucet.faucet_manager.claim_manager.SimpleClaimManager.user_is_meet_verified", @@ -311,8 +326,10 @@ def test_claim_manager_should_claim(self): def test_only_one_pending_claim(self): claim_amount_1 = 100 claim_amount_2 = 50 - claim_manager_x_dai = ClaimManagerFactory(self.x_dai, self.verified_user).get_manager() - claim_manager_x_dai.claim(claim_amount_1) + claim_manager_x_dai = ClaimManagerFactory( + self.x_dai, self.verified_user + ).get_manager() + claim_manager_x_dai.claim(claim_amount_1, "0x12345") try: claim_manager_x_dai.claim(claim_amount_2) @@ -326,12 +343,14 @@ def test_only_one_pending_claim(self): def test_second_claim_after_first_verifies(self): claim_amount_1 = 100 claim_amount_2 = 50 - claim_manager_x_dai = ClaimManagerFactory(self.x_dai, self.verified_user).get_manager() - claim_1 = claim_manager_x_dai.claim(claim_amount_1) + claim_manager_x_dai = ClaimManagerFactory( + self.x_dai, self.verified_user + ).get_manager() + claim_1 = claim_manager_x_dai.claim(claim_amount_1, "0x12345") claim_1._status = ClaimReceipt.VERIFIED claim_1.save() try: - claim_manager_x_dai.claim(claim_amount_2) + claim_manager_x_dai.claim(claim_amount_2, "0x12345") except AssertionError: self.assertEqual(False, True) @@ -342,12 +361,14 @@ def test_second_claim_after_first_verifies(self): def test_second_claim_after_first_fails(self): claim_amount_1 = 100 claim_amount_2 = 50 - claim_manager_x_dai = ClaimManagerFactory(self.x_dai, self.verified_user).get_manager() - claim_1 = claim_manager_x_dai.claim(claim_amount_1) + claim_manager_x_dai = ClaimManagerFactory( + self.x_dai, self.verified_user + ).get_manager() + claim_1 = claim_manager_x_dai.claim(claim_amount_1, "0x12345") claim_1._status = ClaimReceipt.REJECTED claim_1.save() try: - claim_manager_x_dai.claim(claim_amount_2) + claim_manager_x_dai.claim(claim_amount_2, "0x12345") except AssertionError: self.assertEqual(True, False) @@ -359,16 +380,18 @@ def test_claim_should_fail_if_limit_reached(self): claim_amount_1 = 10 claim_amount_2 = 5 claim_amount_3 = 1 - claim_manager_x_dai = ClaimManagerFactory(self.x_dai, self.verified_user).get_manager() - claim_1 = claim_manager_x_dai.claim(claim_amount_1) + claim_manager_x_dai = ClaimManagerFactory( + self.x_dai, self.verified_user + ).get_manager() + claim_1 = claim_manager_x_dai.claim(claim_amount_1, "0x12345") claim_1._status = ClaimReceipt.VERIFIED claim_1.save() - claim_2 = claim_manager_x_dai.claim(claim_amount_2) + claim_2 = claim_manager_x_dai.claim(claim_amount_2, "0x12345") claim_2._status = ClaimReceipt.VERIFIED claim_2.save() try: - claim_manager_x_dai.claim(claim_amount_3) + claim_manager_x_dai.claim(claim_amount_3, "0x12345") except AssertionError: self.assertEqual(True, True) @@ -378,14 +401,20 @@ def test_claim_should_fail_if_limit_reached(self): ) @skipIf(not DEBUG, "only on debug") def test_simple_claim_manager_transfer(self): - manager = SimpleClaimManager(SimpleCreditStrategy(self.test_chain, self.verified_user)) - manager.claim(100) + manager = SimpleClaimManager( + SimpleCreditStrategy(self.test_chain, self.verified_user) + ) + manager.claim(100, "0x12345") class TestClaimAPI(APITestCase): def setUp(self) -> None: - self.wallet = WalletAccount.objects.create(name="Test Wallet", private_key=test_wallet_key) - self.lightning_wallet = WalletAccount.objects.create(name="Test Lightning Wallet", private_key=LIGHTNING_WALLET) + self.wallet = WalletAccount.objects.create( + name="Test Wallet", private_key=test_wallet_key + ) + self.lightning_wallet = WalletAccount.objects.create( + name="Test Lightning Wallet", private_key=LIGHTNING_WALLET + ) self.verified_user = create_new_user() self.x_dai = create_xDai_chain(self.wallet) self.idChain = create_idChain_chain(self.wallet) @@ -397,61 +426,63 @@ def setUp(self) -> None: GlobalSettings.objects.create(gastap_round_claim_limit=2) LightningConfig.objects.create( - period=86800, period_max_cap=100, current_round=int(int(time.time()) / 86800) * 86800 + period=86800, + period_max_cap=100, + current_round=int(int(time.time()) / 86800) * 86800, ) self.client.force_authenticate(user=self.verified_user.user) self.user_profile = self.verified_user - @patch("faucet.views.ClaimMaxView.wallet_address_is_set", lambda a: (True, None)) @patch( "authentication.helpers.BrightIDSoulboundAPIInterface.get_verification_status", lambda a, b, c: (False, None), ) def test_claim_max_api_should_fail_if_not_verified(self): - endpoint = reverse( - "FAUCET:claim-max", - kwargs={"chain_pk": self.x_dai.pk}, - ) + endpoint = reverse("FAUCET:claim-max", kwargs={"chain_pk": self.x_dai.pk}) - response = self.client.post(endpoint) + response = self.client.post(endpoint, data={"address": "0x12345"}) self.assertEqual(response.status_code, 403) - @patch("faucet.views.ClaimMaxView.wallet_address_is_set", lambda a: (True, None)) @patch( "authentication.helpers.BrightIDSoulboundAPIInterface.get_verification_status", lambda a, b, c: (True, None), ) def test_claim_max_api_should_claim_all(self): - endpoint = reverse( - "FAUCET:claim-max", - kwargs={"chain_pk": self.x_dai.pk}, - ) + endpoint = reverse("FAUCET:claim-max", kwargs={"chain_pk": self.x_dai.pk}) - response = self.client.post(endpoint) + response = self.client.post( + endpoint, data={"address": "0x90F8bf6A479f320ead074411a4B0e7944Ea8c9C1"} + ) claim_receipt = json.loads(response.content) self.assertEqual(response.status_code, 200) self.assertEqual(claim_receipt["amount"], self.x_dai.max_claim_amount) - @patch("faucet.views.ClaimMaxView.wallet_address_is_set", lambda a: (True, None)) @patch( "authentication.helpers.BrightIDSoulboundAPIInterface.get_verification_status", lambda a, b, c: (True, None), ) def test_claim_max_twice_should_fail(self): - endpoint = reverse( - "FAUCET:claim-max", - kwargs={"chain_pk": self.x_dai.pk}, + endpoint = reverse("FAUCET:claim-max", kwargs={"chain_pk": self.x_dai.pk}) + response_1 = self.client.post( + endpoint, data={"address": "0x90F8bf6A479f320ead074411a4B0e7944Ea8c9C1"} ) - response_1 = self.client.post(endpoint) self.assertEqual(response_1.status_code, 200) try: self.client.post(endpoint) except CustomException: self.assertEqual(True, True) - @patch("faucet.views.ClaimMaxView.wallet_address_is_set", lambda a: (True, None)) + @patch( + "authentication.helpers.BrightIDSoulboundAPIInterface.get_verification_status", + lambda a, b, c: (True, None), + ) + def test_address_validator_evm(self): + endpoint = reverse("FAUCET:claim-max", kwargs={"chain_pk": self.x_dai.pk}) + response_1 = self.client.post(endpoint, data={"address": "0x132546"}) + self.assertEqual(response_1.status_code, 400) + @patch( "authentication.helpers.BrightIDSoulboundAPIInterface.get_verification_status", lambda a, b, c: (True, None), @@ -494,7 +525,6 @@ def test_get_last_claim_of_user(self): self.assertEqual(claim_data["txHash"], last_claim.tx_hash) self.assertEqual(claim_data["chain"]["pk"], last_claim.chain.pk) - @patch("faucet.views.ClaimMaxView.wallet_address_is_set", lambda a: (True, None)) @patch( "authentication.helpers.BrightIDSoulboundAPIInterface.get_verification_status", lambda a, b, c: (True, None), @@ -546,7 +576,9 @@ def test_lightning_claim_max_cap_exceeded(self): config = lightning_fund_manager.config config.claimed_amount = 100 config.save() - is_exceeded = lightning_fund_manager._LightningFundManager__check_max_cap_exceeds(10) + is_exceeded = ( + lightning_fund_manager._LightningFundManager__check_max_cap_exceeds(10) + ) self.assertEqual(is_exceeded, True) with self.assertRaises(AssertionError): @@ -557,7 +589,9 @@ def test_lightning_claim_max_cap_exceeded(self): class TestWeeklyCreditStrategy(APITestCase): def setUp(self) -> None: - self.wallet = WalletAccount.objects.create(name="Test Wallet", private_key=test_wallet_key) + self.wallet = WalletAccount.objects.create( + name="Test Wallet", private_key=test_wallet_key + ) # self.verified_user = create_verified_user() self.test_chain = create_test_chain(self.wallet) @@ -618,7 +652,9 @@ def test_unclaimed(self): class TestConstraints(APITestCase): def setUp(self) -> None: - self.wallet = WalletAccount.objects.create(name="Test Wallet", private_key=test_wallet_key) + self.wallet = WalletAccount.objects.create( + name="Test Wallet", private_key=test_wallet_key + ) self.test_chain = create_test_chain(self.wallet) @@ -636,7 +672,9 @@ def setUp(self) -> None: explorer_api_key="6PGF5HBTT7DG9CQCQZK3MWR9146JAWQKAC", ) - self.user_profile = create_new_user("0x5A73E32a77E04Fb3285608B0AdEaa000B8e248F2") + self.user_profile = create_new_user( + "0x5A73E32a77E04Fb3285608B0AdEaa000B8e248F2" + ) self.wallet = Wallet.objects.create( user_profile=self.user_profile, wallet_type=NetworkTypes.EVM, @@ -647,19 +685,24 @@ def setUp(self) -> None: def test_optimism_donation_contraint(self): constraint = OptimismDonationConstraint(self.user_profile) self.assertFalse(constraint.is_observed()) - DonationReceipt.objects.create(user_profile=self.user_profile, tx_hash="0x0", chain=self.test_chain) + DonationReceipt.objects.create( + user_profile=self.user_profile, tx_hash="0x0", chain=self.test_chain + ) self.assertFalse(constraint.is_observed()) DonationReceipt.objects.create( - user_profile=self.user_profile, tx_hash="0x0", chain=self.optimism, status=ClaimReceipt.VERIFIED + user_profile=self.user_profile, + tx_hash="0x0", + chain=self.optimism, + status=ClaimReceipt.VERIFIED, ) self.assertTrue(constraint.is_observed()) - def test_optimism_claiming_gas_contraint(self): - constraint = OptimismClaimingGasConstraint(self.user_profile) - self.assertTrue(constraint.is_observed()) - self.wallet.address = "0xE3eEBaB360E367b4e200759F0D955D1140F27430" - self.wallet.save() - self.assertTrue(constraint.is_observed()) - self.wallet.address = "0xB9e291b68E584be657477289389B3a6DEED3E34C" - self.wallet.save() - self.assertFalse(constraint.is_observed()) + # def test_optimism_claiming_gas_contraint(self): + # constraint = OptimismClaimingGasConstraint(self.user_profile) + # self.assertTrue(constraint.is_observed()) + # self.wallet.address = "0xE3eEBaB360E367b4e200759F0D955D1140F27430" + # self.wallet.save() + # self.assertTrue(constraint.is_observed()) + # self.wallet.address = "0xB9e291b68E584be657477289389B3a6DEED3E34C" + # self.wallet.save() + # self.assertFalse(constraint.is_observed()) diff --git a/faucet/views.py b/faucet/views.py index f9b066d7..53eac32a 100644 --- a/faucet/views.py +++ b/faucet/views.py @@ -21,9 +21,10 @@ from rest_framework.response import Response from rest_framework.views import APIView -from authentication.models import UserProfile, Wallet +from authentication.models import UserProfile from core.filters import ChainFilterBackend, IsOwnerFilterBackend from core.paginations import StandardResultsSetPagination +from core.validators import address_validator from faucet.faucet_manager.claim_manager import ( ClaimManagerFactory, LimitedChainClaimManager, @@ -184,23 +185,14 @@ def check_user_is_verified(self, type="Meet"): # _is_verified = True if not _is_verified: # return Response({"message": "You are not BrighID verified"}, status=403) - raise CustomException("You are not BrighID verified") - - def wallet_address_is_set(self): - passive_address = self.request.data.get("address", None) - if passive_address is not None: - return True, passive_address - - chain = self.get_chain() - - try: - Wallet.objects.get( - user_profile=self.get_user(), wallet_type=chain.chain_type + raise rest_framework.exceptions.PermissionDenied( + "You are not BrighID verified" ) - return True, None - except Exception as e: - logging.error("wallet address not set", e) - raise CustomException("wallet address not set") + + def to_address_is_provided(self): + to_address = self.request.data.get("address", None) + if not to_address: + raise rest_framework.exceptions.ParseError("wallet address not set") def get_chain(self) -> Chain: chain_pk = self.kwargs.get("chain_pk", None) @@ -212,26 +204,29 @@ def get_chain(self) -> Chain: def get_claim_manager(self): return ClaimManagerFactory(self.get_chain(), self.get_user()).get_manager() - def claim_max(self, passive_address) -> ClaimReceipt: + def check_to_address_is_validate(self): + chain = self.get_chain() + to_address = self.request.data.get("address", None) + address_validator(to_address, chain) + + def claim_max(self, to_address) -> ClaimReceipt: manager = self.get_claim_manager() max_credit = manager.get_credit_strategy().get_unclaimed() try: assert max_credit > 0 - return manager.claim(max_credit, passive_address=passive_address) + return manager.claim(max_credit, to_address=to_address) except AssertionError as e: logging.error("no credit left for user", e) - raise CustomException("no credit left") + raise rest_framework.exceptions.PermissionDenied("no credit left") except ValueError as e: raise rest_framework.exceptions.APIException(e) def post(self, request, *args, **kwargs): - try: - self.check_user_is_verified() - s, passive_address = self.wallet_address_is_set() - except CustomException as e: - return Response({"message": str(e)}, status=403) + self.check_user_is_verified() + self.to_address_is_provided() + self.check_to_address_is_validate() - receipt = self.claim_max(passive_address) + receipt = self.claim_max(to_address=request.data.get("address")) return Response(ReceiptSerializer(instance=receipt).data) @@ -287,7 +282,6 @@ def get_object(self): ) user_obj["rank"] = user_rank user_obj["username"] = self.get_user().username - user_obj["wallet"] = self.get_user().wallets.all()[0].address interacted_chains = list( DonationReceipt.objects.filter(user_profile=self.get_user()) .filter(status=ClaimReceipt.VERIFIED) @@ -326,12 +320,7 @@ def list(self, request, *args, **kwargs): subquery_username = UserProfile.objects.filter( pk=OuterRef("user_profile") ).values("username") - subquery_wallet = Wallet.objects.filter( - user_profile=OuterRef("user_profile") - ).values("address") - queryset = queryset.annotate( - username=Subquery(subquery_username), wallet=Subquery(subquery_wallet) - ) + queryset = queryset.annotate(username=Subquery(subquery_username)) page = self.paginate_queryset(queryset) if page is not None: serializer = self.get_serializer(page, many=True) diff --git a/prizetap/admin.py b/prizetap/admin.py index 3a1eddad..109ef6aa 100644 --- a/prizetap/admin.py +++ b/prizetap/admin.py @@ -1,30 +1,24 @@ from django.contrib import admin -from prizetap.models import * + from core.admin import UserConstraintBaseAdmin +from prizetap.models import Constraint, LineaRaffleEntries, Raffle, RaffleEntry class RaffleAdmin(admin.ModelAdmin): list_display = ["pk", "name", "creator_name"] + class RaffleŁEntryAdmin(admin.ModelAdmin): list_display = [ - "pk", - "raffle", - "get_wallet", + "pk", + "raffle", + "user_wallet_address", "age", ] - @admin.display(ordering='user_profile__wallets', description='Wallet') - def get_wallet(self, obj): - return obj.user_profile.wallets.get(wallet_type=NetworkTypes.EVM).address class LineaRaffleEntriesAdmin(admin.ModelAdmin): - list_display = [ - "pk", - "wallet_address", - "is_winner" - ] - + list_display = ["pk", "wallet_address", "is_winner"] admin.site.register(Raffle, RaffleAdmin) diff --git a/prizetap/constraints.py b/prizetap/constraints.py index 2968f3cc..e4d4c6da 100644 --- a/prizetap/constraints.py +++ b/prizetap/constraints.py @@ -12,9 +12,16 @@ def __init__(self, user_profile: UserProfile) -> None: def is_observed(self, *args, **kwargs): chain = Chain.objects.get(chain_id=1) self.unitappass_client = UnitapPassClient(chain) - user_address: str = self.user_profile.wallets.get(wallet_type=chain.chain_type).address - user_address = self.unitappass_client.to_checksum_address(user_address.lower()) - return self.unitappass_client.is_holder(user_address) + + user_addresses = [ + self.unitappass_client.to_checksum_address(wallet.address.lower()) + for wallet in self.user_profile.wallets.filter(wallet_type=chain.chain_type) + ] + + for user_address in user_addresses: + if self.unitappass_client.is_holder(user_address): + return True + return False class NotHaveUnitapPass(HaveUnitapPass): diff --git a/prizetap/migrations/0003_raffle_permissions.py b/prizetap/migrations/0003_raffle_permissions.py index f78f3a80..ebb5ead4 100644 --- a/prizetap/migrations/0003_raffle_permissions.py +++ b/prizetap/migrations/0003_raffle_permissions.py @@ -4,10 +4,9 @@ class Migration(migrations.Migration): - dependencies = [ # ('permissions', '0004_oncepermonthverification'), - ('prizetap', '0002_raffle_is_prize_nft'), + ("prizetap", "0002_raffle_is_prize_nft"), ] operations = [ diff --git a/prizetap/migrations/0045_raffleentry_user_wallet_alter_raffleentry_raffle_and_more.py b/prizetap/migrations/0045_raffleentry_user_wallet_alter_raffleentry_raffle_and_more.py new file mode 100644 index 00000000..243126d1 --- /dev/null +++ b/prizetap/migrations/0045_raffleentry_user_wallet_alter_raffleentry_raffle_and_more.py @@ -0,0 +1,40 @@ +# Generated by Django 4.0.4 on 2023-11-30 15:49 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + dependencies = [ + ("authentication", "0023_wallet_created_at"), + ("prizetap", "0044_raffle_reversed_constraints"), + ] + + operations = [ + migrations.AddField( + model_name="raffleentry", + name="user_wallet", + field=models.ForeignKey( + null=True, + on_delete=django.db.models.deletion.PROTECT, + related_name="raffle_entries", + to="authentication.wallet", + ), + ), + migrations.AlterField( + model_name="raffleentry", + name="raffle", + field=models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, related_name="entries", to="prizetap.raffle" + ), + ), + migrations.AlterField( + model_name="raffleentry", + name="user_profile", + field=models.ForeignKey( + on_delete=django.db.models.deletion.PROTECT, + related_name="raffle_entries", + to="authentication.userprofile", + ), + ), + ] diff --git a/prizetap/migrations/0046_remove_raffleentry_user_wallet_and_more.py b/prizetap/migrations/0046_remove_raffleentry_user_wallet_and_more.py new file mode 100644 index 00000000..decd8982 --- /dev/null +++ b/prizetap/migrations/0046_remove_raffleentry_user_wallet_and_more.py @@ -0,0 +1,33 @@ +# Generated by Django 4.0.4 on 2023-11-30 16:50 + +from django.db import migrations, models + + +def set_user_wallet(apps, schema_editor): + RaffleEntry = apps.get_model("prizetap", "RaffleEntry") + + for entry in RaffleEntry.objects.all(): + try: + entry.user_wallet_address = entry.user_profile.wallets.get(wallet_type="EVM").address + entry.save() + except Exception as e: + print("Error setting user wallet for raffle entries", e) + + +class Migration(migrations.Migration): + dependencies = [ + ("prizetap", "0045_raffleentry_user_wallet_alter_raffleentry_raffle_and_more"), + ] + + operations = [ + migrations.RemoveField( + model_name="raffleentry", + name="user_wallet", + ), + migrations.AddField( + model_name="raffleentry", + name="user_wallet_address", + field=models.CharField(blank=True, max_length=255, null=True), + ), + migrations.RunPython(set_user_wallet), + ] diff --git a/prizetap/migrations/0047_alter_raffleentry_user_wallet_address.py b/prizetap/migrations/0047_alter_raffleentry_user_wallet_address.py new file mode 100644 index 00000000..65f86e04 --- /dev/null +++ b/prizetap/migrations/0047_alter_raffleentry_user_wallet_address.py @@ -0,0 +1,18 @@ +# Generated by Django 4.0.4 on 2023-11-30 17:04 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('prizetap', '0046_remove_raffleentry_user_wallet_and_more'), + ] + + operations = [ + migrations.AlterField( + model_name='raffleentry', + name='user_wallet_address', + field=models.CharField(max_length=255), + ), + ] diff --git a/prizetap/migrations/0048_merge_20231226_1046.py b/prizetap/migrations/0048_merge_20231226_1046.py new file mode 100644 index 00000000..67112e28 --- /dev/null +++ b/prizetap/migrations/0048_merge_20231226_1046.py @@ -0,0 +1,14 @@ +# Generated by Django 4.0.4 on 2023-12-26 10:46 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('prizetap', '0045_alter_raffle_chain'), + ('prizetap', '0047_alter_raffleentry_user_wallet_address'), + ] + + operations = [ + ] diff --git a/prizetap/models.py b/prizetap/models.py index 28345c35..559e90b7 100644 --- a/prizetap/models.py +++ b/prizetap/models.py @@ -4,13 +4,11 @@ from django.utils.translation import gettext_lazy as _ from authentication.models import UserProfile -from core.models import BigNumField, Chain, NetworkTypes, UserConstraint +from core.models import BigNumField, Chain, UserConstraint from faucet.constraints import OptimismClaimingGasConstraint, OptimismDonationConstraint from .constraints import HaveUnitapPass, NotHaveUnitapPass -# Create your models here. - class Constraint(UserConstraint): constraints = UserConstraint.constraints + [ @@ -143,11 +141,13 @@ class Meta: unique_together = (("raffle", "user_profile"),) verbose_name_plural = "raffle entries" - raffle = models.ForeignKey(Raffle, on_delete=models.CASCADE, related_name="entries") + raffle = models.ForeignKey(Raffle, on_delete=models.PROTECT, related_name="entries") user_profile = models.ForeignKey( - UserProfile, on_delete=models.CASCADE, related_name="raffle_entries" + UserProfile, on_delete=models.PROTECT, related_name="raffle_entries" ) + user_wallet_address = models.CharField(max_length=255) + created_at = models.DateTimeField(auto_now_add=True, editable=True) multiplier = models.IntegerField(default=1) @@ -158,10 +158,6 @@ class Meta: def __str__(self): return f"{self.raffle} - {self.user_profile}" - @property - def user(self): - return self.user_profile.wallets.get(wallet_type=NetworkTypes.EVM).address - @property def age(self): return timezone.now() - self.created_at diff --git a/prizetap/serializers.py b/prizetap/serializers.py index 6fbb300d..366e5f09 100644 --- a/prizetap/serializers.py +++ b/prizetap/serializers.py @@ -42,7 +42,6 @@ class RaffleEntrySerializer(serializers.ModelSerializer): raffle = SimpleRaffleSerializer() user_profile = SimpleProfilerSerializer() chain = serializers.SerializerMethodField() - wallet = serializers.SerializerMethodField() class Meta: model = RaffleEntry @@ -51,7 +50,7 @@ class Meta: "chain", "raffle", "user_profile", - "wallet", + "user_wallet_address", "created_at", "multiplier", "tx_hash", @@ -62,7 +61,7 @@ class Meta: "chain", "raffle", "user_profile", - "wallet", + "user_wallet_address", "created_at", "multiplier", ] @@ -70,22 +69,16 @@ class Meta: def get_chain(self, entry: RaffleEntry): return entry.raffle.chain.chain_id - def get_wallet(self, entry: RaffleEntry): - return entry.user_profile.wallets.get( - wallet_type=entry.raffle.chain.chain_type - ).address - class WinnerEntrySerializer(serializers.ModelSerializer): user_profile = SimpleProfilerSerializer() - wallet = serializers.SerializerMethodField() class Meta: model = RaffleEntry fields = [ "pk", "user_profile", - "wallet", + "user_wallet_address", "created_at", "multiplier", "tx_hash", @@ -94,16 +87,11 @@ class Meta: read_only_fields = [ "pk", "user_profile", - "wallet", + "user_wallet_address", "created_at", "multiplier", ] - def get_wallet(self, entry: RaffleEntry): - return entry.user_profile.wallets.get( - wallet_type=entry.raffle.chain.chain_type - ).address - class CreateRaffleSerializer(serializers.ModelSerializer): class Meta: diff --git a/prizetap/tests.py b/prizetap/tests.py index ab5f0023..d61d8563 100644 --- a/prizetap/tests.py +++ b/prizetap/tests.py @@ -220,7 +220,8 @@ def test_raffle_enrollment_authentication(self): def test_raffle_enrollment_validation(self): self.client.force_authenticate(user=self.user_profile.user) response = self.client.post( - reverse("raflle-enrollment", kwargs={"pk": self.raffle.pk}) + reverse("raflle-enrollment", kwargs={"pk": self.raffle.pk}), + data={"user_wallet_address": "0xc1cbb2ab97260a8a7d4591045a9fb34ec14e87fb"}, ) self.assertEqual(response.status_code, 403) @@ -529,7 +530,8 @@ def setUp(self): def test_raffle_enrollment(self): self.client.force_authenticate(user=self.user_profile.user) response = self.client.post( - reverse("raflle-enrollment", kwargs={"pk": self.raffle.pk}) + reverse("raflle-enrollment", kwargs={"pk": self.raffle.pk}), + data={"user_wallet_address": "0xc1cbb2ab97260a8a7d4591045a9fb34ec14e87fb"}, ) self.assertEqual(response.status_code, 200) self.assertEqual(self.raffle.entries.count(), 1) @@ -547,7 +549,8 @@ def test_not_claimable_raffle_enrollment(self, is_claimable_mock: PropertyMock): is_claimable_mock.return_value = False self.client.force_authenticate(user=self.user_profile.user) response = self.client.post( - reverse("raflle-enrollment", kwargs={"pk": self.raffle.pk}) + reverse("raflle-enrollment", kwargs={"pk": self.raffle.pk}), + data={"user_wallet_address": "0xc1cbb2ab97260a8a7d4591045a9fb34ec14e87fb"}, ) self.assertEqual(response.status_code, 403) @@ -679,7 +682,8 @@ def test_duplicate_claiming_prize_tx_failure(self): def test_get_raffle_entry(self): self.client.force_authenticate(user=self.user_profile.user) response = self.client.post( - reverse("raflle-enrollment", kwargs={"pk": self.raffle.pk}) + reverse("raflle-enrollment", kwargs={"pk": self.raffle.pk}), + data={"user_wallet_address": "0xc1cbb2ab97260a8a7d4591045a9fb34ec14e87fb"}, ) first_entry = self.raffle.entries.first() response = self.client.get( diff --git a/prizetap/validators.py b/prizetap/validators.py index 2b86ec33..94a934c1 100644 --- a/prizetap/validators.py +++ b/prizetap/validators.py @@ -22,9 +22,15 @@ def check_user_constraints(self): param_values = json.loads(self.raffle.constraint_params) except Exception: param_values = {} - reversed_constraints = self.raffle.reversed_constraints.split(",") if self.raffle.reversed_constraints else [] + reversed_constraints = ( + self.raffle.reversed_constraints.split(",") + if self.raffle.reversed_constraints + else [] + ) for c in self.raffle.constraints.all(): - constraint: ConstraintVerification = get_constraint(c.name)(self.user_profile) + constraint: ConstraintVerification = get_constraint(c.name)( + self.user_profile + ) constraint.response = c.response try: constraint.param_values = param_values[c.name] @@ -37,16 +43,16 @@ def check_user_constraints(self): if not constraint.is_observed(): raise PermissionDenied(constraint.response) - def check_user_has_wallet(self): - if not self.user_profile.wallets.filter(wallet_type=self.raffle.chain.chain_type).exists(): - raise PermissionDenied(f"You have not connected an {self.raffle.chain.chain_type} wallet to your account") + def check_user_owns_wallet(self, user_wallet_address): + if not self.user_profile.owns_wallet(user_wallet_address): + raise PermissionDenied("This wallet is not registered for this user") def is_valid(self, data): self.can_enroll_in_raffle() self.check_user_constraints() - self.check_user_has_wallet() + self.check_user_owns_wallet(data.get("user_wallet_address")) class SetRaffleEntryTxValidator: @@ -56,7 +62,9 @@ def __init__(self, *args, **kwargs): def is_owner_of_raffle_entry(self): if not self.raffle_entry.user_profile == self.user_profile: - raise PermissionDenied("You don't have permission to update this raffle entry") + raise PermissionDenied( + "You don't have permission to update this raffle entry" + ) def is_tx_empty(self): if self.raffle_entry.tx_hash: diff --git a/prizetap/views.py b/prizetap/views.py index f39298b5..57aff193 100644 --- a/prizetap/views.py +++ b/prizetap/views.py @@ -1,5 +1,6 @@ import json +import rest_framework.exceptions from django.shortcuts import get_object_or_404 from django.utils import timezone from rest_framework.generics import CreateAPIView, ListAPIView @@ -56,6 +57,11 @@ class RaffleEnrollmentView(CreateAPIView): def post(self, request, pk): user_profile = request.user.profile raffle = get_object_or_404(Raffle, pk=pk) + user_wallet_address = request.data.get("user_wallet_address", None) + if not user_wallet_address: + raise rest_framework.exceptions.ParseError( + "user_wallet_address is required" + ) validator = RaffleEnrollmentValidator(user_profile=user_profile, raffle=raffle) @@ -66,6 +72,7 @@ def post(self, request, pk): except RaffleEntry.DoesNotExist: raffle_entry = RaffleEntry.objects.create( user_profile=user_profile, + user_wallet_address=user_wallet_address, raffle=raffle, ) raffle_entry.save() diff --git a/tokenTap/helpers.py b/tokenTap/helpers.py index 581cd186..326d0a1e 100644 --- a/tokenTap/helpers.py +++ b/tokenTap/helpers.py @@ -18,8 +18,8 @@ def create_uint32_random_nonce(): return nonce -def hash_message(user, token, amount, nonce): - hashed_message = Web3Utils.hash_message(user, token, amount, nonce) +def hash_message(address, token, amount, nonce): + hashed_message = Web3Utils.hash_message(address, token, amount, nonce) return hashed_message @@ -28,7 +28,7 @@ def sign_hashed_message(hashed_message): return Web3Utils.sign_hashed_message(private_key, hashed_message) -def has_weekly_credit_left(user_profile): +def has_credit_left(user_profile): return ( TokenDistributionClaim.objects.filter( user_profile=user_profile, diff --git a/tokenTap/migrations/0023_tokendistributionclaim_user_wallet_address.py b/tokenTap/migrations/0023_tokendistributionclaim_user_wallet_address.py new file mode 100644 index 00000000..be70a4a0 --- /dev/null +++ b/tokenTap/migrations/0023_tokendistributionclaim_user_wallet_address.py @@ -0,0 +1,35 @@ +# Generated by Django 4.0.4 on 2023-12-02 15:01 + +from django.db import migrations, models + + +def set_user_wallet(apps, schema_editor): + tdc = apps.get_model("tokenTap", "TokenDistributionClaim") + + for entry in tdc.objects.all(): + try: + if entry.token_distribution.chain.chain_type == "EVM": + entry.user_wallet_address = entry.user_profile.wallets.get( + wallet_type="EVM" + ).address + entry.save() + elif entry.token_distribution.chain.chain_type == "LIGHTNING": + entry.user_wallet_address = entry.signature + entry.save() + except Exception as e: + print("Error setting user wallet for tdc", e) + + +class Migration(migrations.Migration): + dependencies = [ + ("tokenTap", "0022_constraint_explanation"), + ] + + operations = [ + migrations.AddField( + model_name="tokendistributionclaim", + name="user_wallet_address", + field=models.CharField(blank=True, max_length=255, null=True), + ), + migrations.RunPython(set_user_wallet), + ] diff --git a/tokenTap/migrations/0025_merge_20231226_1046.py b/tokenTap/migrations/0025_merge_20231226_1046.py new file mode 100644 index 00000000..2b5efeed --- /dev/null +++ b/tokenTap/migrations/0025_merge_20231226_1046.py @@ -0,0 +1,14 @@ +# Generated by Django 4.0.4 on 2023-12-26 10:46 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('tokenTap', '0023_tokendistributionclaim_user_wallet_address'), + ('tokenTap', '0024_tokendistribution_contract'), + ] + + operations = [ + ] diff --git a/tokenTap/models.py b/tokenTap/models.py index ffca4e3b..8816dcc7 100644 --- a/tokenTap/models.py +++ b/tokenTap/models.py @@ -3,7 +3,7 @@ from django.utils import timezone from authentication.models import UserProfile -from core.models import Chain, NetworkTypes, UserConstraint +from core.models import Chain, UserConstraint from faucet.constraints import OptimismHasClaimedGasInThisRound from faucet.models import ClaimReceipt @@ -105,6 +105,8 @@ class TokenDistributionClaim(models.Model): ) created_at = models.DateTimeField(auto_now_add=True, editable=True) + user_wallet_address = models.CharField(max_length=255, null=True, blank=True) + notes = models.TextField(null=True, blank=True) signature = models.CharField(max_length=1024, blank=True, null=True) @@ -119,10 +121,6 @@ class TokenDistributionClaim(models.Model): def __str__(self): return f"{self.token_distribution} - {self.user_profile}" - @property - def user(self): - return self.user_profile.wallets.get(wallet_type=NetworkTypes.EVM).address - @property def token(self): return self.token_distribution.token_address diff --git a/tokenTap/serializers.py b/tokenTap/serializers.py index 41fadf11..7c892e63 100644 --- a/tokenTap/serializers.py +++ b/tokenTap/serializers.py @@ -94,7 +94,7 @@ class Meta: class PayloadSerializer(serializers.ModelSerializer): class Meta: model = TokenDistributionClaim - fields = ["user", "token", "amount", "nonce", "signature"] + fields = ["user_wallet_address", "token", "amount", "nonce", "signature"] class TokenDistributionClaimSerializer(serializers.ModelSerializer): @@ -107,6 +107,7 @@ class Meta: "id", "token_distribution", "user_profile", + "user_wallet_address", "created_at", "payload", "status", diff --git a/tokenTap/tests.py b/tokenTap/tests.py index 69d41f53..dec9c487 100644 --- a/tokenTap/tests.py +++ b/tokenTap/tests.py @@ -287,6 +287,7 @@ def test_token_distribution_not_claimable_max_reached(self): self.client.force_authenticate(user=self.user_profile.user) response = self.client.post( reverse("token-distribution-claim", kwargs={"pk": ltd.pk}), + data={"user_wallet_address": "0xc1cbb2ab97260a8a7d4591045a9fb34ec14e87fb"}, ) self.assertEqual(response.status_code, 403) @@ -313,6 +314,7 @@ def test_token_distribution_not_claimable_deadline_reached(self): self.client.force_authenticate(user=self.user_profile.user) response = self.client.post( reverse("token-distribution-claim", kwargs={"pk": ltd.pk}), + data={"user_wallet_address": "0xc1cbb2ab97260a8a7d4591045a9fb34ec14e87fb"}, ) self.assertEqual(response.status_code, 403) @@ -331,37 +333,10 @@ def test_token_distribution_not_claimable_already_claimed(self): self.client.force_authenticate(user=self.user_profile.user) response = self.client.post( reverse("token-distribution-claim", kwargs={"pk": self.td.pk}), + data={"user_wallet_address": "0xc1cbb2ab97260a8a7d4591045a9fb34ec14e87fb"}, ) self.assertEqual(response.status_code, 403) - # self.assertEqual( - # response.data["detail"], "You have already claimed this token this week" - # ) - - # @patch( - # "authentication.helpers.BrightIDSoulboundAPIInterface.get_verification_status", - # lambda a, b, c: (True, None), - # ) - # def test_token_distribution_not_claimable_already_claimed_month(self): - # tdc = TokenDistributionClaim.objects.create( - # user_profile=self.user_profile, - # token_distribution=self.td, - # # Claimed 2 weeks ago - # created_at=WeeklyCreditStrategy.get_first_day_of_the_month(), - # ) - # tdc.created_at = WeeklyCreditStrategy.get_first_day_of_the_month() - # tdc.save() - - # self.client.force_authenticate(user=self.user_profile.user) - # response = self.client.post( - # reverse("token-distribution-claim", kwargs={"pk": self.td.pk}), - # ) - - # self.assertEqual(response.status_code, 403) - # # self.assertEqual( - # # response.data["detail"], - # # "You have already claimed this token this month" - # # ) @patch( "authentication.helpers.BrightIDSoulboundAPIInterface.get_verification_status", @@ -370,7 +345,8 @@ def test_token_distribution_not_claimable_already_claimed(self): def test_token_distribution_not_claimable_false_permissions(self): self.client.force_authenticate(user=self.user_profile.user) response = self.client.post( - reverse("token-distribution-claim", kwargs={"pk": self.td.pk}) + reverse("token-distribution-claim", kwargs={"pk": self.td.pk}), + data={"user_wallet_address": "0xc1cbb2ab97260a8a7d4591045a9fb34ec14e87fb"}, ) self.assertEqual(response.status_code, 403) @@ -382,12 +358,10 @@ def test_token_distribution_not_claimable_weekly_credit_limit_reached(self): self.client.force_authenticate(user=self.user_profile.user) response = self.client.post( reverse("token-distribution-claim", kwargs={"pk": self.td.pk}), + data={"user_wallet_address": "0xc1cbb2ab97260a8a7d4591045a9fb34ec14e87fb"}, ) self.assertEqual(response.status_code, 403) - # self.assertEqual( - # response.data["detail"], "You have reached your weekly claim limit" - # ) @patch( "authentication.helpers.BrightIDSoulboundAPIInterface.get_verification_status", @@ -396,13 +370,14 @@ def test_token_distribution_not_claimable_weekly_credit_limit_reached(self): def test_token_distribution_not_claimable_no_wallet(self): self.client.force_authenticate(user=self.user_profile.user) response = self.client.post( - reverse("token-distribution-claim", kwargs={"pk": self.td.pk}) + reverse("token-distribution-claim", kwargs={"pk": self.td.pk}), + data={"user_wallet_address": "0xc1cbb2ab97260a8a7d4591045a9fb34ec14e87fb"}, ) self.assertEqual(response.status_code, 403) self.assertEqual( response.data["detail"], - "You have not connected an EVM wallet to your account", + "This wallet is not registered for this user", ) @patch( @@ -417,7 +392,8 @@ def test_token_distribution_claimable(self): ) self.client.force_authenticate(user=self.user_profile.user) response = self.client.post( - reverse("token-distribution-claim", kwargs={"pk": self.td.pk}) + reverse("token-distribution-claim", kwargs={"pk": self.td.pk}), + data={"user_wallet_address": "0xc1cbb2ab97260a8a7d4591045a9fb34ec14e87fb"}, ) self.assertEqual(response.status_code, 200) @@ -435,7 +411,7 @@ def test_btc_lightning_claimable(self): self.client.force_authenticate(user=self.user_profile.user) response = self.client.post( reverse("token-distribution-claim", kwargs={"pk": self.btc_td.pk}), - data={"lightning_invoice": "test"}, + data={"user_wallet_address": "test"}, ) self.assertEqual(response.status_code, 200) @@ -458,7 +434,7 @@ def test_btc_lightning_claimable_claim_updates_after_6seconds(self): self.client.force_authenticate(user=self.user_profile.user) response = self.client.post( reverse("token-distribution-claim", kwargs={"pk": self.btc_td.pk}), - data={"lightning_invoice": "test"}, + data={"user_wallet_address": "test"}, ) self.assertEqual(response.status_code, 200) @@ -509,7 +485,7 @@ def test_sign_message(self): ) hash = hash_message( - user="0xc1cbb2ab97260a8a7d4591045a9fb34ec14e87fb", + address="0xc1cbb2ab97260a8a7d4591045a9fb34ec14e87fb", token="0xc1cbb2ab97260a8a7d4591045a9fb34ec14e87fb", amount=100000000000000000, nonce=create_uint32_random_nonce(), @@ -622,8 +598,7 @@ def test_missing_tx_hash(self): self.assertEqual(response.status_code, 400) assert "tx_hash is a required field" in str(response.content) - # Tests that an error is raised when the token distribution claim - # does not belong to the user profile + # Tests that an error is raised when the td claim does not belong to the user def test_claim_not_belonging_to_user_profile(self): other_user_profile = UserProfile.objects.get_or_create("other") claim = TokenDistributionClaim.objects.create( @@ -639,8 +614,7 @@ def test_claim_not_belonging_to_user_profile(self): response = self.client.post(url, data=data) assert response.status_code == 403 - # Tests that an error is raised when the token distribution claim - # status is already verified + # Tests that an error is raised when the tdclaim status is already verified def test_already_verified_claim(self): claim = TokenDistributionClaim.objects.create( token_distribution=TokenDistribution.objects.create( diff --git a/tokenTap/views.py b/tokenTap/views.py index d21a8cb0..7641ecd2 100644 --- a/tokenTap/views.py +++ b/tokenTap/views.py @@ -29,7 +29,7 @@ from .constants import CONTRACT_ADDRESSES from .helpers import ( create_uint32_random_nonce, - has_weekly_credit_left, + has_credit_left, hash_message, sign_hashed_message, ) @@ -65,17 +65,19 @@ def check_user_permissions(self, token_distribution, user_profile): if not constraint.is_observed(token_distribution=token_distribution): raise PermissionDenied(constraint.response) - def check_user_weekly_credit(self, user_profile): - if not has_weekly_credit_left(user_profile): + def check_user_credit(self, user_profile): + if not has_credit_left(user_profile): raise rest_framework.exceptions.PermissionDenied( "You have reached your weekly claim limit" ) - def check_user_has_wallet(self, user_profile): - if not user_profile.wallets.filter(wallet_type=NetworkTypes.EVM).exists(): - raise rest_framework.exceptions.PermissionDenied( - "You have not connected an EVM wallet to your account" - ) + def wallet_is_vaild(self, user_profile, user_wallet_address, token_distribution): + if token_distribution.chain.chain_type == NetworkTypes.LIGHTNING: + return # TODO - check if user_wallet_address is a valid lightning invoice + + elif token_distribution.chain.chain_type == NetworkTypes.EVM: + if not user_profile.owns_wallet(user_wallet_address): + raise PermissionDenied("This wallet is not registered for this user") @swagger_auto_schema( responses={ @@ -102,11 +104,15 @@ def check_user_has_wallet(self, user_profile): def post(self, request, *args, **kwargs): user_profile = request.user.profile token_distribution = TokenDistribution.objects.get(pk=self.kwargs["pk"]) - lightning_invoice = request.data.get("lightning_invoice", None) + user_wallet_address = request.data.get("user_wallet_address", None) + if user_wallet_address is None: + raise rest_framework.exceptions.ParseError( + "user_wallet_address is a required field" + ) self.check_token_distribution_is_claimable(token_distribution) - self.check_user_has_wallet(user_profile) + self.wallet_is_vaild(user_profile, user_wallet_address, token_distribution) self.check_user_permissions(token_distribution, user_profile) @@ -127,12 +133,12 @@ def post(self, request, *args, **kwargs): except TokenDistributionClaim.DoesNotExist: pass - self.check_user_weekly_credit(user_profile) + self.check_user_credit(user_profile) nonce = create_uint32_random_nonce() if token_distribution.chain.chain_type == NetworkTypes.EVM: hashed_message = hash_message( - user=user_profile.wallets.get(wallet_type=NetworkTypes.EVM).address, + address=user_wallet_address, token=token_distribution.token_address, amount=token_distribution.amount, nonce=nonce, @@ -145,13 +151,15 @@ def post(self, request, *args, **kwargs): nonce=nonce, signature=signature, token_distribution=token_distribution, + user_wallet_address=user_wallet_address, ) elif token_distribution.chain.chain_type == NetworkTypes.LIGHTNING: tdc = TokenDistributionClaim.objects.create( user_profile=user_profile, nonce=nonce, - signature=lightning_invoice, + signature=user_wallet_address, + user_wallet_address=user_wallet_address, token_distribution=token_distribution, ) ClaimReceipt.objects.create( @@ -160,7 +168,7 @@ def post(self, request, *args, **kwargs): datetime=timezone.now(), amount=token_distribution.amount, _status=ClaimReceipt.PENDING, - passive_address=lightning_invoice, + to_address=user_wallet_address, ) return Response(