diff --git a/requirements.txt b/requirements.txt index a0f4feb..cca0d85 100755 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ passlib==1.7.4 pydantic==2.4.2 email-validator==2.1.0 pydantic-settings==2.0.3 -aioredis==2.0.1 +redis==5.0.1 fastapi==0.103.2 uvicorn==0.23.2 gunicorn==21.2.0 diff --git a/src/core/config.py b/src/core/config.py index 281784b..e99c4ac 100755 --- a/src/core/config.py +++ b/src/core/config.py @@ -14,7 +14,7 @@ class Settings(BaseSettings): listen_port: int = 8000 allowed_hosts: list = const.DEFAULT_ALLOWED_HOSTS - redis_host: str = "redis" + redis_host: str = "localhost" redis_port: int = 6379 redis_db: int = 0 @@ -42,6 +42,10 @@ def pg_dsn(self): f"{self.postgres_port}/{self.postgres_db}" ) + @cached_property + def redis_dsn(self): + return f"redis://{self.redis_host}:{self.redis_port}" + class Config: case_sensitive = False env_file = ".env" diff --git a/src/db/postgres.py b/src/db/postgres.py index c60e4b7..433f865 100644 --- a/src/db/postgres.py +++ b/src/db/postgres.py @@ -5,12 +5,13 @@ from fastapi import Depends from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.exc import SQLAlchemyError -from jose import jwt +from jose import jwt, JWTError from pydantic import UUID4 from sqlalchemy import delete from src.core.config import settings from src.v1.exceptions import ServiceError +from src.v1.auth.exceptions import InvalidTokenError from src.db.storages import Database, BaseStorage from src.v1.auth.helpers import decode_jwt from src.v1.users.models import UserRefreshTokens @@ -40,11 +41,15 @@ class PostgresRefreshTokenStorage(BaseStorage): @staticmethod async def create(db_session: AsyncSession, refresh_token: str, user_id: UUID4) -> UUID4: token_headers = jwt.get_unverified_header(refresh_token) - token_data = decode_jwt(refresh_token) + try: + token_payload = decode_jwt(refresh_token) + await __class__._verify_that_token_is_refresh(token_payload) + except JWTError: + raise InvalidTokenError() refresh_token = UserRefreshTokens( token=token_headers.get("jti"), user_id=user_id, - expire_at=datetime.fromtimestamp(token_data.get("exp")), + expire_at=datetime.fromtimestamp(token_payload.get("exp")), ) db_session.add(refresh_token) try: @@ -62,7 +67,12 @@ async def get(db_session: AsyncSession, token: str) -> UserRefreshTokens: @staticmethod async def delete(db_session: AsyncSession, token: str): - decode_jwt(token) + try: + token_payload = decode_jwt(token) + await __class__._verify_that_token_is_refresh(token_payload) + except JWTError: + raise InvalidTokenError() + token_headers = jwt.get_unverified_header(token) token_id = token_headers.get("jti") @@ -83,6 +93,11 @@ async def delete_all(db_session: AsyncSession, user_id: UUID4): except SQLAlchemyError: await db_session.rollback() + @staticmethod + async def _verify_that_token_is_refresh(token_payload: dict): + if len(token_payload.values()) > 1: + raise JWTError() + db_session = PostgresDatabase() refresh_tokens_storage = PostgresRefreshTokenStorage() diff --git a/src/db/redis.py b/src/db/redis.py index e69de29..64c872f 100644 --- a/src/db/redis.py +++ b/src/db/redis.py @@ -0,0 +1,51 @@ +from typing import Annotated + +from redis.asyncio import Redis, from_url +from pydantic import UUID4 +from fastapi import Depends + +from src.db.storages import BaseStorage +from src.core.config import settings + + +class RedisBlacklistUserSignatureStorage(BaseStorage): + + def __init__(self) -> None: + self.protocol: Redis = from_url( + settings.redis_dsn, + decode_responses=True, + db=settings.redis_db + ) + self.namespace: str = "auth_service" + + async def create(self, user_id: UUID4, signature: str): + signature_key = f"{self.namespace}:{user_id}" + async with self.protocol.client() as conn: + await conn.set(signature_key, signature) + + async def get(self, user_id: UUID4) -> str: + signature_key = f"{self.namespace}:{user_id}" + async with self.protocol.client() as conn: + return await conn.get(signature_key) + + async def delete(self, user_id: UUID4): + signature_key = f"{self.namespace}:{user_id}" + async with self.protocol.client() as conn: + return await conn.delete(signature_key) + + async def delete_all(self, user_id: UUID4, count_size: int = 10) -> int: + pattern = f"{self.namespace}:{user_id}:*" + cursor = b"0" + deleted_count = 0 + + async with self.protocol.client() as conn: + while cursor: + cursor, keys = await conn.scan(cursor, match=pattern, count=count_size) + deleted_count += await conn.unlink(*keys) + return deleted_count + + +redis_blackist_storage = RedisBlacklistUserSignatureStorage() +BlacklistSignatureStorage = Annotated[ + RedisBlacklistUserSignatureStorage, Depends(redis_blackist_storage) +] diff --git a/src/v1/auth/exceptions.py b/src/v1/auth/exceptions.py index d71c88c..234829e 100644 --- a/src/v1/auth/exceptions.py +++ b/src/v1/auth/exceptions.py @@ -10,6 +10,7 @@ class AuthExceptionCodes: USER_UNAUTHORIZED: int = 3002 PROVIDED_PASSWORD_INCORRECT: int = 3003 INVALID_PROVIDED_TOKEN: int = 3004 + TOKEN_NOT_FOUND: int = 3005 class UserAlreadyExistsError(HTTPException): @@ -58,3 +59,15 @@ def __init__( ) -> None: detail = {"code": AuthExceptionCodes.INVALID_PROVIDED_TOKEN, "message": message} super().__init__(status_code=status_code, detail=detail) + + +class TokenNotFoundError(HTTPException): + """Error raised whe token doesnt exists in DB.""" + + def __init__( + self, + status_code: int = status.HTTP_404_NOT_FOUND, + message: str = "Invalid token.", + ) -> None: + detail = {"code": AuthExceptionCodes.TOKEN_NOT_FOUND, "message": message} + super().__init__(status_code=status_code, detail=detail) diff --git a/src/v1/auth/helpers.py b/src/v1/auth/helpers.py index d329e75..04a0f87 100644 --- a/src/v1/auth/helpers.py +++ b/src/v1/auth/helpers.py @@ -5,9 +5,10 @@ from passlib.hash import pbkdf2_sha256 from pydantic import UUID4 +from src.db.redis import BlacklistSignatureStorage from src.core.config import settings from src.v1.auth.schemas import JWTTokens -from src.v1.auth.exceptions import InvalidTokenError +from src.v1.auth.exceptions import UnauthorizedError, InvalidTokenError # TODO(alexander.zharyuk): Improve generation. Maybe add some salt? @@ -56,12 +57,23 @@ def generate_jwt(payload: dict, access_jti: str, refresh_jti: UUID4) -> JWTToken def decode_jwt(token: str) -> dict: """Decode access / refresh tokens payload""" - try: - payload = jwt.decode( + return jwt.decode( token, key=settings.jwt_secret_key, algorithms=[settings.jwt_algorithm] ) + + +async def validate_jwt(blacklist_tokens_storage: BlacklistSignatureStorage, token: str): + """Validate that token is not in blacklists""" + try: + token_payload = decode_jwt(token) except JWTError: raise InvalidTokenError() - return payload + token_headers = jwt.get_unverified_header(token) + + user_id = token_payload.get("user_id") + token_signature = token_headers.get("jti") + + if token_signature == await blacklist_tokens_storage.get(user_id): + raise UnauthorizedError() diff --git a/src/v1/auth/routers.py b/src/v1/auth/routers.py index bea5f20..9a18bae 100755 --- a/src/v1/auth/routers.py +++ b/src/v1/auth/routers.py @@ -1,10 +1,11 @@ from fastapi import APIRouter, Depends, Request, Response, status from fastapi.security import APIKeyCookie -from src.db.postgres import DatabaseSession -from src.db.postgres import RefreshTokensStorage +from src.db.postgres import DatabaseSession, RefreshTokensStorage +from src.db.redis import BlacklistSignatureStorage from src.v1.auth.schemas import (TokensResponse, UserCreate, UserLogin, - UserResponse, LogoutResponse, UserLogout) + UserResponse, LogoutResponse, UserLogout, VerifyTokenResponse, + RefreshTokens) from src.v1.auth.service import AuthService from src.core.config import settings @@ -37,6 +38,40 @@ async def signin( return TokensResponse(data=tokens) +@router.post( + "/verify", + summary="Верификация переданного access_token", + response_model=VerifyTokenResponse +) +async def verify_token( + blacklist_signatures_storage: BlacklistSignatureStorage, + access_token: str | None = Depends(cookie_scheme) +) -> VerifyTokenResponse: + """Выход из всех сессий пользователя""" + return await AuthService.verify(access_token, blacklist_signatures_storage) + + +@router.post( + "/refresh", + summary="Выдача новой пары JWT-токенов", + response_model=TokensResponse +) +async def refresh_tokens( + db_session: DatabaseSession, + refresh_token_storage: RefreshTokensStorage, + response: Response, + data: RefreshTokens, +) -> TokensResponse: + """Выход из всех сессий пользователя""" + tokens = await AuthService.refresh_tokens( + db_session, + refresh_token_storage, + response, + data.refresh_token + ) + return TokensResponse(data=tokens) + + @router.post("/logout", summary="Выход из текущей сессии", response_model=LogoutResponse) async def logout( db_session: DatabaseSession, @@ -54,3 +89,25 @@ async def logout( ) return LogoutResponse() + +@router.post( + "/logout_all", + summary="Выход из всех сессий пользователя", + response_model=LogoutResponse +) +async def terminate_all_sessions( + db_session: DatabaseSession, + blacklist_signatures_storage: BlacklistSignatureStorage, + refresh_token_storage: RefreshTokensStorage, + response: Response, + access_token: str | None = Depends(cookie_scheme) +) -> LogoutResponse: + """Выход из всех сессий пользователя""" + await AuthService.terminate_all_sessions( + db_session, + blacklist_signatures_storage, + refresh_token_storage, + response, + access_token + ) + return LogoutResponse() \ No newline at end of file diff --git a/src/v1/auth/schemas.py b/src/v1/auth/schemas.py index ac02ff6..db59935 100644 --- a/src/v1/auth/schemas.py +++ b/src/v1/auth/schemas.py @@ -29,6 +29,10 @@ class UserLogout(BaseModel): refresh_token: str +class RefreshTokens(UserLogout): + ... + + class User(UserBase): id: UUID4 @@ -51,3 +55,11 @@ class TokensResponse(BaseResponseBody): class LogoutResponse(BaseResponseBody): data: dict = {"sucess": True} + + +class VerifyTokenResponse(BaseResponseBody): + data: dict = {"access": True} + + +class JWTPayload(BaseModel): + user_id: UUID4 diff --git a/src/v1/auth/service.py b/src/v1/auth/service.py index bc018ea..117712c 100755 --- a/src/v1/auth/service.py +++ b/src/v1/auth/service.py @@ -4,21 +4,27 @@ from fastapi import Request, Response from pydantic import BaseModel, UUID4 -from sqlalchemy import and_, or_, select, delete +from sqlalchemy import and_, or_, select, update from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession +from jose import jwt, JWTError from src.core.config import settings from src.db.postgres import RefreshTokensStorage -from src.v1.auth.exceptions import UserAlreadyExistsError +from src.v1.auth.exceptions import ( + UserAlreadyExistsError, UnauthorizedError, TokenNotFoundError, InvalidTokenError +) from src.v1.auth.helpers import ( generate_jwt, generate_user_signature, hash_password, verify_password, - decode_jwt + decode_jwt, + validate_jwt ) -from src.v1.auth.schemas import JWTTokens, User, UserCreate, UserLogin +from src.v1.auth.schemas import (JWTTokens, User, UserCreate, UserLogin, + VerifyTokenResponse, JWTPayload) +from src.db.redis import BlacklistSignatureStorage from src.v1.exceptions import ServiceError from src.v1.users.service import UserService from src.v1.users import models as users_models @@ -90,8 +96,9 @@ async def signin( await __class__._save_login_session_if_not_exists(db_session, exists_user, request) # TODO: Add role to JWT + jwt_payload = JWTPayload(user_id=exists_user.id).model_dump(mode="json") tokens = generate_jwt( - payload={"user_id": str(exists_user.id)}, + payload=jwt_payload, access_jti=exists_user.signature.signature, refresh_jti=uuid.uuid4(), ) @@ -112,12 +119,83 @@ async def logout( response.delete_cookie(settings.sessions_cookie_name) await refresh_token_storage.delete(db_session, refresh_token) + @staticmethod + async def verify(access_token: str, blacklist_signatures_storage: BlacklistSignatureStorage): + """Verfiy that provided access token is not blacklist""" + try: + await validate_jwt(blacklist_signatures_storage, access_token) + data={"access": True} + except UnauthorizedError: + data = {"access": False} + return VerifyTokenResponse(data=data) - async def verify(current_user=None): - ... + @staticmethod + async def terminate_all_sessions( + db_session: AsyncSession, + blacklist_signatures_storage: BlacklistSignatureStorage, + refresh_token_storage: RefreshTokensStorage, + response: Response, + access_token: str + ): + """Terminate all sessions""" + + await validate_jwt(blacklist_signatures_storage, access_token) + response.delete_cookie(settings.sessions_cookie_name) + try: + token_payload = decode_jwt(access_token) + except JWTError: + raise InvalidTokenError() + + token_headers = jwt.get_unverified_header(access_token) + + user_id = token_payload.get("user_id") + old_user_signature = token_headers.get("jti") + + await blacklist_signatures_storage.create(user_id, old_user_signature) + await __class__._update_user_signature(db_session, user_id, old_user_signature) + + await refresh_token_storage.delete_all(db_session, user_id=user_id) + + @staticmethod + async def refresh_tokens( + db_session: AsyncSession, + refresh_token_storage: RefreshTokensStorage, + response: Response, + refresh_token: str + ): + """Regenerate JWT pair of tokens""" + + try: + decode_jwt(refresh_token) + except JWTError: + raise InvalidTokenError() + + refresh_token_headers = jwt.get_unverified_header(refresh_token) + refresh_jti = refresh_token_headers.get("jti") + statement = select(users_models.UserRefreshTokens).where( + users_models.UserRefreshTokens.token == refresh_jti + ) + result = await db_session.execute(statement) + if (token := result.scalar()) is None: + raise TokenNotFoundError() + + user = await UserService.get_by_id(db_session, token.user_id) + await refresh_token_storage.delete(db_session, refresh_token) + + jwt_payload = JWTPayload(user_id=user.id).model_dump(mode="json") + tokens = generate_jwt( + payload=jwt_payload, + access_jti=user.signature.signature, + refresh_jti=str(uuid.uuid4()), + ) + await refresh_token_storage.create(db_session, tokens.refresh_token, user.id) + await __class__._set_user_cookie( + settings.sessions_cookie_name, + tokens.access_token, + response + ) + return tokens - async def terminate_all_sessions(current_user=None): - ... @staticmethod async def _save_login_session_if_not_exists( @@ -166,5 +244,25 @@ async def _set_user_cookie(cookie_key: str, cookie_value: str, response: Respons expires=settings.jwt_access_expire_time_in_seconds ) + @staticmethod + async def _update_user_signature( + db_session: AsyncSession, + user_id: UUID4, + old_user_signature: str + ): + """Update user signature when user terminate all sessions.""" + user = await UserService.get_by_id(db_session, user_id) + new_user_signature = generate_user_signature(user.username) + + statement = update(users_models.UserSignature)\ + .where(users_models.UserSignature.signature == old_user_signature)\ + .values({users_models.UserSignature.signature: new_user_signature}) + await db_session.execute(statement) + try: + await db_session.commit() + except SQLAlchemyError: + await db_session.rollback() + raise ServiceError() + AuthService = JWTAuthService() diff --git a/src/v1/users/service.py b/src/v1/users/service.py index 821518c..f7174a2 100755 --- a/src/v1/users/service.py +++ b/src/v1/users/service.py @@ -22,7 +22,7 @@ class UserService: async def get_by_email(db_session: AsyncSession, email: EmailStr) -> Type[User]: statement = select(User).where(User.email == email) result = await db_session.execute(statement) - if (exists_user := result.scalar_one()) is None: + if (exists_user := result.scalar()) is None: raise UserNotFoundError() return exists_user