Skip to content

Commit

Permalink
Merge branch 'refactor/002-refactore-backend' of github.com:abbastoof…
Browse files Browse the repository at this point in the history
…/transcendence into develop
  • Loading branch information
mtoof committed Jun 29, 2024
2 parents e1d581c + 53672ef commit 0f78b09
Show file tree
Hide file tree
Showing 24 changed files with 1,682 additions and 105 deletions.
2 changes: 1 addition & 1 deletion Backend/auth_service/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ RUN apk update && apk add --no-cache \
py3-pip=24.0-r2 \
postgresql16=16.3-r0 \
postgresql16-client=16.3-r0 \
bash=5.2.26-r0 curl=8.8.0-r0 openssl=3.3.1-r1 curl=8.8.0-r0 && \
bash=5.2.26-r0 curl=8.8.0-r0 openssl=3.3.1-r1 && \
mkdir /run/postgresql && \
chown -R postgres:postgres /run/postgresql
WORKDIR /app/
Expand Down
3 changes: 2 additions & 1 deletion Backend/auth_service/auth_service/auth_service/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# Build paths inside the project like this: BASE_DIR / 'subdir'.
BASE_DIR = Path(__file__).resolve().parent.parent


APPEND_SLASH = False
# Quick-start development settings - unsuitable for production
# See https://docs.djangoproject.com/en/5.0/howto/deployment/checklist/

Expand Down Expand Up @@ -63,6 +63,7 @@
"BLACKLIST_AFTER_ROTATION": True, # If True, the refresh token will be blacklisted after it is used to obtain a new access token. This means that if a refresh token is stolen, it can only be used once to obtain a new access token. This is useful if rotating refresh tokens is enabled, but can cause problems if a refresh token is shared between multiple clients.
"AUTH_HEADER_TYPES": ("Bearer",),
"AUTH_TOKEN_CLASSES": ("rest_framework_simplejwt.tokens.AccessToken",),
"TOKEN_OBTAIN_SERIALIZER": "user_auth.serializers.CustomTokenObtainPairSerializer"
}

REST_FRAMEWORK = {
Expand Down
11 changes: 1 addition & 10 deletions Backend/auth_service/auth_service/user_auth/models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1 @@
from django.db import models


# Create your models here.
class UserTokens(models.Model):
username = models.CharField(unique=True)
token_data = models.JSONField(null=True)

def __str__(self):
return self.username
# from django.db import models
22 changes: 22 additions & 0 deletions Backend/auth_service/auth_service/user_auth/serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from rest_framework_simplejwt.serializers import TokenObtainPairSerializer
import asyncio

class CustomTokenObtainPairSerializer(TokenObtainPairSerializer):
@classmethod
def get_token(cls, user) -> dict:

"""
Get token method to generate tokens for the user.
This method overrides the get_token method of TokenObtainPairSerializer to generate tokens for the user.
It generates the tokens for the user and returns the tokens.
Args:
user: The user object.
Returns:
dict: The dictionary containing the tokens.
"""
token = super().get_token(user)
token["custom_claims"] = {"username": user.username, "password": user.password}
return token
4 changes: 2 additions & 2 deletions Backend/auth_service/auth_service/user_auth/urls.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# from django.contrib import admin
from django.urls import path
from rest_framework_simplejwt.views import TokenRefreshView
from user_auth.views import CustomTokenObtainPairView
from user_auth.views import CustomTokenObtainPairView, CustomTokenRefreshView

urlpatterns = [
path("api/token/", CustomTokenObtainPairView.as_view(), name="token_obtain_pair"),
path("api/token/refresh/", TokenRefreshView.as_view(), name="token_refresh"),
path("api/token/refresh/", CustomTokenRefreshView.as_view(), name="token_refresh"),
]
161 changes: 92 additions & 69 deletions Backend/auth_service/auth_service/user_auth/views.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,44 @@
import json
from datetime import datetime, timezone

from rest_framework import status
from auth_service import settings
from rest_framework.response import Response
from rest_framework_simplejwt.serializers import TokenObtainPairSerializer
from rest_framework_simplejwt.tokens import RefreshToken
from rest_framework_simplejwt.views import TokenObtainPairView

from .models import UserTokens
from rest_framework_simplejwt.views import TokenObtainPairView, TokenRefreshView
from .rabbitmq_utils import consume_message, publish_message
from .serializers import CustomTokenObtainPairSerializer
import jwt

class CustomTokenObtainPairView(TokenObtainPairView):
"""
Custom token obtain pair view to generate tokens for the user.
This class inherits from TokenObtainPairView. It overrides the post method to generate tokens for the user.
Attributes:
serializer_class: The serializer class to use for the view.
Methods:
post: Post method to generate tokens for the user.
"""

class CustomTokenObtainPairSerializer(TokenObtainPairSerializer):
@classmethod
def get_token(cls, user):
token = super().get_token(user)
token["custom_claims"] = {"username": user.username, "password": user.password}
return token
serializer_class = CustomTokenObtainPairSerializer

def post(self, request, *args, **kwargs) -> Response:
"""
Post method to generate tokens for the user.
class CustomTokenObtainPairView(TokenObtainPairView):
serializer_class = CustomTokenObtainPairSerializer
This method overrides the post method of TokenObtainPairView to generate tokens for the user.
It sends a request to the user service to validate the user credentials
and get the user data. It then generates the tokens for the user and returns the tokens in the response.
Args:
request: The request object.
Returns:
Response: The response object containing the tokens.
"""

def post(self, request, *args, **kwargs):
username = request.data.get("username")
password = request.data.get("password")

Expand Down Expand Up @@ -51,70 +67,77 @@ def handle_response(ch, method, properties, body):
)
try:
user = type("User", (object,), user_data) # Mock user object with user_data
refresh = None
access = None

# Retrieve or create the UserTokens entry
user_token_entry, created = UserTokens.objects.get_or_create(
username=username
)

if not created:
# If the user entry already exists, check if tokens are still valid
token_data = user_token_entry.token_data
refresh_expire_time = token_data["refresh"]["exp"]
access_expire_time = token_data["access"]["exp"]
refresh = RefreshToken(token_data["refresh"]["token"])
access = token_data["access"]["token"]
current_time = datetime.now(timezone.utc)
if (
refresh_expire_time < current_time.timestamp()
or access_expire_time < current_time.timestamp()
):
if refresh_expire_time < current_time.timestamp():
refresh = RefreshToken.for_user(user)
access = refresh.access_token
elif access_expire_time < current_time.timestamp():
refresh = RefreshToken(token_data["refresh"]["token"])
access = refresh.access_token
user_token_entry.token_data = {
"refresh": {
"token": str(refresh),
"exp": int(refresh["exp"]), # Store expiration as integer
},
"access": {
"token": str(access),
"exp": int(access["exp"]), # Store expiration as integer
},
}
else:
# If the user entry was just created, store new tokens
refresh = RefreshToken.for_user(user)
access = refresh.access_token
user_token_entry.token_data = {
"refresh": {
"token": str(refresh),
"exp": int(refresh["exp"]), # Store expiration as integer
},
"access": {
"token": str(access),
"exp": int(access["exp"]), # Store expiration as integer
},
}

refresh = RefreshToken.for_user(user)
access = refresh.access_token
# Save the updated or newly created UserTokens entry
user_token_entry.save()

return Response(
{
"refresh": str(refresh),
"access": str(access),
},
status=status.HTTP_200_OK,
)

except Exception as e:
return Response(
{"error": "Could not generate tokens", "details": str(e)},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)

class CustomTokenRefreshView(TokenRefreshView):
def post(self, request, *args, **kwargs) -> Response:
"""
Post method to generate new access token using refresh token.
This method overrides the post method of TokenRefreshView to generate a new access token using the refresh token.
It validates the refresh token and generates a new access token.
Args:
request: The request object.
Returns:
Response: The response object containing the new access token.
"""
bearer = request.headers.get("Authorization")
if not bearer or not bearer.startswith('Bearer '):
return Response(
{"error": "Refresh token is required"}, status=status.HTTP_400_BAD_REQUEST
)
try:
refresh_token = bearer.split(' ')[1]
is_valid_token = ValidateToken.validate_token(refresh_token)
if not is_valid_token:
return Response({"error": "Session has expired"}, status=status.HTTP_401_UNAUTHORIZED)
# If token is valid, generate a new access token
refresh = RefreshToken(refresh_token)
access_token = str(refresh.access_token)

return Response({"access": access_token}, status=status.HTTP_200_OK)

except Exception as err:
return Response({"error": "Could not generate tokens", "details": str(err)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)


class ValidateToken():
@staticmethod
def validate_token(refresh_token) -> bool:
"""
Validate the refresh token.
This method validates the refresh token by decoding the token and checking if it is expired.
If the token is expired or invalid, it returns False, otherwise it returns True.
Args:
refresh_token: The refresh token to validate.
Returns:
bool: True if the token is valid, False otherwise.
"""
try:
decoded_token = jwt.decode(refresh_token, settings.SECRET_KEY, algorithms=["HS256"], options={"verify_signature": True})
return True
except jwt.ExpiredSignatureError:
print("expired token")
return False
except jwt.InvalidTokenError:
print("invalid token")
return False
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

# Build paths inside the project like this: BASE_DIR / 'subdir'.
BASE_DIR = Path(__file__).resolve().parent.parent

APPEND_SLASH = False

# Quick-start development settings - unsuitable for production
# See https://docs.djangoproject.com/en/5.0/howto/deployment/checklist/
Expand Down
10 changes: 10 additions & 0 deletions Backend/user_management/user_management/users/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,14 @@


class User(AbstractUser):
"""
User class to define the user model.
This class inherits from AbstractUser. It defines the fields of the user model.
Attributes:
REQUIRED_FIELDS: The list of required fields for the user model.
Email: The email field is required for the user model.
"""
REQUIRED_FIELDS = ["email"]
46 changes: 44 additions & 2 deletions Backend/user_management/user_management/users/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,19 @@


class UserSerializer(serializers.ModelSerializer):
"""
UserSerializer class to define the user serializer.
This class defines the user serializer to serialize the user data.
Attributes:
email: The email field.
Meta: The meta class to define the model and fields for the serializer.
Methods:
create: Method to create a new user.
update: Method to update a user.
"""
email = serializers.EmailField(
validators=[UniqueValidator(queryset=User.objects.all())]
)
Expand All @@ -19,7 +32,19 @@ class Meta:

### Password should be strong password, minimum 8 characters, at least one uppercase letter, one lowercase letter, one number and one special character

def create(self, validate_data):
def create(self, validate_data) -> User:
"""
Method to create a new user.
This method creates a new user with the given data.
The password is validated using CustomPasswordValidator.
The password is hashed before saving the user object.
Args:
validate_data: The data to validate.
Returns:
User: The user object.
"""
try:
validate_password(validate_data["password"])
except ValidationError as err:
Expand All @@ -31,7 +56,24 @@ def create(self, validate_data):
instance.save()
return instance

def update(self, instance, validate_data):
def update(self, instance, validate_data) -> User:
"""
Method to update a user.
This method updates a user with the given data.
The password is hashed before saving the user object.
Args:
instance: The user object.
validate_data: The data to validate.
Returns:
User: The updated user object.
Raises:
serializers.ValidationError: If the password is the same as the current password.
"""
for attr, value in validate_data.items():
if attr == "password" and value is not None:
if instance.check_password(value):
Expand Down
Loading

0 comments on commit 0f78b09

Please sign in to comment.