diff --git a/backend/authentication/serializers.py b/backend/authentication/serializers.py index 60771277..7c4f1d92 100644 --- a/backend/authentication/serializers.py +++ b/backend/authentication/serializers.py @@ -1,20 +1,20 @@ from typing import Tuple -from django.contrib.auth.models import update_last_login + +from authentication.cas.client import client +from authentication.models import User +from authentication.signals import user_created, user_login from django.contrib.auth import login +from django.contrib.auth.models import update_last_login from rest_framework.serializers import ( CharField, EmailField, + HyperlinkedRelatedField, ModelSerializer, - ValidationError, Serializer, - HyperlinkedIdentityField, - HyperlinkedRelatedField, + ValidationError, ) -from rest_framework_simplejwt.tokens import RefreshToken, AccessToken from rest_framework_simplejwt.settings import api_settings -from authentication.signals import user_created, user_login -from authentication.models import User, Faculty -from authentication.cas.client import client +from rest_framework_simplejwt.tokens import AccessToken, RefreshToken class CASTokenObtainSerializer(Serializer): @@ -22,6 +22,7 @@ class CASTokenObtainSerializer(Serializer): This serializer takes the CAS ticket and tries to validate it. Upon successful validation, create a new user if it doesn't exist. """ + ticket = CharField(required=True, min_length=49, max_length=49) def validate(self, data): @@ -40,23 +41,15 @@ def validate(self, data): if "request" in self.context: login(self.context["request"], user) - user_login.send( - sender=self, user=user - ) + user_login.send(sender=self, user=user) if created: - user_created.send( - sender=self, attributes=attributes, user=user - ) + user_created.send(sender=self, attributes=attributes, user=user) # Return access tokens for the now logged-in user. return { - "access": str( - AccessToken.for_user(user) - ), - "refresh": str( - RefreshToken.for_user(user) - ), + "access": str(AccessToken.for_user(user)), + "refresh": str(RefreshToken.for_user(user)), } def _validate_ticket(self, ticket: str) -> dict: @@ -102,7 +95,7 @@ class UserSerializer(ModelSerializer): many=True, read_only=True, view_name="faculty-detail" ) - notifications = HyperlinkedIdentityField( + notifications = HyperlinkedRelatedField( view_name="notification-detail", read_only=True, ) diff --git a/backend/notifications/admin.py b/backend/notifications/admin.py deleted file mode 100644 index 4185d360..00000000 --- a/backend/notifications/admin.py +++ /dev/null @@ -1,3 +0,0 @@ -# from django.contrib import admin - -# Register your models here. diff --git a/backend/notifications/apps.py b/backend/notifications/apps.py index e81be476..3a084766 100644 --- a/backend/notifications/apps.py +++ b/backend/notifications/apps.py @@ -4,9 +4,3 @@ class NotificationsConfig(AppConfig): default_auto_field = "django.db.models.BigAutoField" name = "notifications" - - -# TODO: Allow is_sent to be adjusted -# TODO: Signals to send notifications -# TODO: Send emails -# TODO: Think about the required api endpoints diff --git a/backend/notifications/models.py b/backend/notifications/models.py index 235723cf..d2827892 100644 --- a/backend/notifications/models.py +++ b/backend/notifications/models.py @@ -4,8 +4,10 @@ class NotificationTemplate(models.Model): id = models.AutoField(auto_created=True, primary_key=True) - title_key = models.CharField(max_length=255) - description_key = models.CharField(max_length=511) + title_key = models.CharField(max_length=255) # Key used to get translated title + description_key = models.CharField( + max_length=511 + ) # Key used to get translated description class Notification(models.Model): @@ -13,14 +15,10 @@ class Notification(models.Model): user = models.ForeignKey(User, on_delete=models.CASCADE) template_id = models.ForeignKey(NotificationTemplate, on_delete=models.CASCADE) created_at = models.DateTimeField(auto_now_add=True) - arguments = models.JSONField(default=dict) - is_read = models.BooleanField(default=False) - is_sent = models.BooleanField(default=False) - - def read(self): - self.is_read = True - self.save() - - def send(self): - self.is_sent = True - self.save() + arguments = models.JSONField(default=dict) # Arguments to be used in the template + is_read = models.BooleanField( + default=False + ) # Whether the notification has been read + is_sent = models.BooleanField( + default=False + ) # Whether the notification has been sent (email) diff --git a/backend/notifications/serializers.py b/backend/notifications/serializers.py index d5a1697f..4a24cd59 100644 --- a/backend/notifications/serializers.py +++ b/backend/notifications/serializers.py @@ -1,5 +1,4 @@ import re -from os import read from typing import Dict, List from authentication.models import User @@ -14,44 +13,26 @@ class Meta: fields = "__all__" -class UserHyperLinkedRelatedField(serializers.HyperlinkedRelatedField): - view_name = "user-detail" - queryset = User.objects.all() - - def to_internal_value(self, data): - try: - return self.queryset.get(pk=data) - except User.DoesNotExist: - self.fail("no_match") - - class NotificationSerializer(serializers.ModelSerializer): - user = UserHyperLinkedRelatedField() + # Hyper linked user field + user = serializers.HyperlinkedRelatedField( + view_name="user-detail", queryset=User.objects.all() + ) + # Translate template and arguments into a message message = serializers.SerializerMethodField() - class Meta: - model = Notification - fields = [ - "id", - "user", - "template_id", - "arguments", - "message", - "created_at", - "is_read", - "is_sent", - ] - - def _get_missing_keys(self, s: str, d: Dict[str, str]) -> List[str]: - required_keys = re.findall(r"%\((\w+)\)", s) - missing_keys = [key for key in required_keys if key not in d] + # Check if the required arguments are present + def _get_missing_keys(self, string: str, arguments: Dict[str, str]) -> List[str]: + required_keys: List[str] = re.findall(r"%\((\w+)\)", string) + missing_keys = [key for key in required_keys if key not in arguments] return missing_keys - def validate(self, data): - data = super().validate(data) + def validate(self, data: Dict[str, str]) -> Dict[str, str]: + data: Dict[str, str] = super().validate(data) + # Validate the arguments if "arguments" not in data: data["arguments"] = {} @@ -74,8 +55,22 @@ def validate(self, data): return data - def get_message(self, obj): + # Get the message from the template and arguments + def get_message(self, obj: Notification) -> Dict[str, str]: return { "title": _(obj.template_id.title_key), "description": _(obj.template_id.description_key) % obj.arguments, } + + class Meta: + model = Notification + fields = [ + "id", + "user", + "template_id", + "arguments", + "message", + "created_at", + "is_read", + "is_sent", + ] diff --git a/backend/notifications/signals.py b/backend/notifications/signals.py new file mode 100644 index 00000000..2ded382f --- /dev/null +++ b/backend/notifications/signals.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from enum import Enum +from typing import Dict + +from authentication.models import User +from django.dispatch import Signal, receiver +from django.urls import reverse +from notifications.serializers import NotificationSerializer + +notification_create = Signal() + + +@receiver(notification_create) +def notification_creation( + type: NotificationType, user: User, arguments: Dict[str, str], **kwargs +) -> bool: + serializer = NotificationSerializer( + data={ + "template_id": type.value, + "user": reverse("user-detail", kwargs={"pk": user.id}), + "arguments": arguments, + } + ) + + if not serializer.is_valid(): + return False + + serializer.save() + + return True + + +class NotificationType(Enum): + SCORE_ADDED = 1 # Arguments: {"score": int} + SCORE_UPDATED = 2 # Arguments: {"score": int} diff --git a/backend/notifications/urls.py b/backend/notifications/urls.py index cdd8a247..e80acd66 100644 --- a/backend/notifications/urls.py +++ b/backend/notifications/urls.py @@ -1,8 +1,6 @@ -from notifications.views import NotificationViewSet -from rest_framework.routers import DefaultRouter +from django.urls import path +from notifications.views import NotificationView -router = DefaultRouter() - -router.register(r"", NotificationViewSet, basename="notification") - -urlpatterns = router.urls +urlpatterns = [ + path("/", NotificationView.as_view(), name="notification-detail"), +] diff --git a/backend/notifications/views.py b/backend/notifications/views.py index 88bbbb54..5f4dd772 100644 --- a/backend/notifications/views.py +++ b/backend/notifications/views.py @@ -1,8 +1,38 @@ +from __future__ import annotations + +from typing import List + from notifications.models import Notification from notifications.serializers import NotificationSerializer -from rest_framework.viewsets import ModelViewSet +from rest_framework.permissions import BasePermission, IsAuthenticated +from rest_framework.request import Request +from rest_framework.response import Response +from rest_framework.status import HTTP_200_OK +from rest_framework.views import APIView + + +# TODO: Give admin access to everything +class NotificationPermission(BasePermission): + # The user can only access their own notifications + # An admin can access all notifications + def has_permission(self, request: Request, view: NotificationView) -> bool: + return view.kwargs.get("user_id") == request.user.id or request.user.is_staff + + +class NotificationView(APIView): + permission_classes: List[BasePermission] = [IsAuthenticated, NotificationPermission] + + def get(self, request: Request, user_id: str) -> Response: + notifications = Notification.objects.filter(user=user_id) + serializer = NotificationSerializer( + notifications, many=True, context={"request": request} + ) + + return Response(serializer.data) + # Mark all notifications as read for the user + def post(self, request: Request, user_id: str) -> Response: + notifications = Notification.objects.filter(user=user_id) + notifications.update(is_read=True) -class NotificationViewSet(ModelViewSet): - queryset = Notification.objects.all() - serializer_class = NotificationSerializer + return Response(status=HTTP_200_OK) diff --git a/backend/ypovoli/urls.py b/backend/ypovoli/urls.py index 25e30a72..f3093cdc 100644 --- a/backend/ypovoli/urls.py +++ b/backend/ypovoli/urls.py @@ -14,10 +14,11 @@ 1. Import the include() function: from django.urls import include, path 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) """ -from django.urls import path, include -from rest_framework import permissions -from drf_yasg.views import get_schema_view + +from django.urls import include, path from drf_yasg import openapi +from drf_yasg.views import get_schema_view +from rest_framework import permissions schema_view = get_schema_view( openapi.Info( @@ -25,7 +26,9 @@ default_version="v1", ), public=True, - permission_classes=[permissions.AllowAny,], + permission_classes=[ + permissions.AllowAny, + ], ) @@ -34,8 +37,14 @@ path("", include("api.urls")), # Authentication endpoints. path("auth/", include("authentication.urls")), - path("notifications/", include("notifications.urls")), + path("notifications/", include("notifications.urls"), name="notifications"), # Swagger documentation. - path("swagger/", schema_view.with_ui("swagger", cache_timeout=0), name="schema-swagger-ui"), - path("swagger/", schema_view.without_ui(cache_timeout=0), name="schema-json"), + path( + "swagger/", + schema_view.with_ui("swagger", cache_timeout=0), + name="schema-swagger-ui", + ), + path( + "swagger/", schema_view.without_ui(cache_timeout=0), name="schema-json" + ), ]