Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(terminate_all_session): Terminate all user sessions #35

Merged
merged 3 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ passlib==1.7.4
pydantic==2.4.2
email-validator==2.1.0
pydantic-settings==2.0.3
aioredis==2.0.1
redis==5.0.1
fastapi==0.103.2
uvicorn==0.23.2
gunicorn==21.2.0
Expand Down
6 changes: 5 additions & 1 deletion src/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Settings(BaseSettings):
listen_port: int = 8000
allowed_hosts: list = const.DEFAULT_ALLOWED_HOSTS

redis_host: str = "redis"
redis_host: str = "localhost"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Тут лучше оставить имя контейнера по умолчанию

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

имхо, когда будем собирать докер - это может усложнить работу, лучше 'redis'
то же самое с postgres, но я могу это поправить в ветке с докером

redis_port: int = 6379
redis_db: int = 0

Expand Down Expand Up @@ -42,6 +42,10 @@ def pg_dsn(self):
f"{self.postgres_port}/{self.postgres_db}"
)

@cached_property
def redis_dsn(self):
return f"redis://{self.redis_host}:{self.redis_port}"
AlexanderZharyuk marked this conversation as resolved.
Show resolved Hide resolved

class Config:
case_sensitive = False
env_file = ".env"
Expand Down
23 changes: 19 additions & 4 deletions src/db/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from fastapi import Depends
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.exc import SQLAlchemyError
from jose import jwt
from jose import jwt, JWTError
from pydantic import UUID4
from sqlalchemy import delete

from src.core.config import settings
from src.v1.exceptions import ServiceError
from src.v1.auth.exceptions import InvalidTokenError
from src.db.storages import Database, BaseStorage
from src.v1.auth.helpers import decode_jwt
from src.v1.users.models import UserRefreshTokens
Expand Down Expand Up @@ -40,11 +41,15 @@ class PostgresRefreshTokenStorage(BaseStorage):
@staticmethod
async def create(db_session: AsyncSession, refresh_token: str, user_id: UUID4) -> UUID4:
token_headers = jwt.get_unverified_header(refresh_token)
token_data = decode_jwt(refresh_token)
try:
token_payload = decode_jwt(refresh_token)
await __class__._verify_that_token_is_refresh(token_payload)
except JWTError:
raise InvalidTokenError()
refresh_token = UserRefreshTokens(
token=token_headers.get("jti"),
user_id=user_id,
expire_at=datetime.fromtimestamp(token_data.get("exp")),
expire_at=datetime.fromtimestamp(token_payload.get("exp")),
)
db_session.add(refresh_token)
try:
Expand All @@ -62,7 +67,12 @@ async def get(db_session: AsyncSession, token: str) -> UserRefreshTokens:

@staticmethod
async def delete(db_session: AsyncSession, token: str):
decode_jwt(token)
try:
token_payload = decode_jwt(token)
await __class__._verify_that_token_is_refresh(token_payload)
except JWTError:
raise InvalidTokenError()

token_headers = jwt.get_unverified_header(token)
token_id = token_headers.get("jti")

Expand All @@ -83,6 +93,11 @@ async def delete_all(db_session: AsyncSession, user_id: UUID4):
except SQLAlchemyError:
await db_session.rollback()

@staticmethod
async def _verify_that_token_is_refresh(token_payload: dict):
if len(token_payload.values()) > 1:
raise JWTError()


db_session = PostgresDatabase()
refresh_tokens_storage = PostgresRefreshTokenStorage()
Expand Down
51 changes: 51 additions & 0 deletions src/db/redis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Annotated

from redis.asyncio import Redis, from_url
from pydantic import UUID4
from fastapi import Depends

from src.db.storages import BaseStorage
from src.core.config import settings


class RedisBlacklistUserSignatureStorage(BaseStorage):

def __init__(self) -> None:
self.protocol: Redis = from_url(
settings.redis_dsn,
decode_responses=True,
db=settings.redis_db
)
self.namespace: str = "auth_service"

async def create(self, user_id: UUID4, signature: str):
signature_key = f"{self.namespace}:{user_id}"
async with self.protocol.client() as conn:
await conn.set(signature_key, signature)

async def get(self, user_id: UUID4) -> str:
signature_key = f"{self.namespace}:{user_id}"
async with self.protocol.client() as conn:
return await conn.get(signature_key)

async def delete(self, user_id: UUID4):
signature_key = f"{self.namespace}:{user_id}"
async with self.protocol.client() as conn:
return await conn.delete(signature_key)

async def delete_all(self, user_id: UUID4, count_size: int = 10) -> int:
pattern = f"{self.namespace}:{user_id}:*"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

возможно я не понимаю как клиент редиса работает, но если мы создаем только пары ключ-значение вида f"{self.namespace}:{user_id}": signature то зачем нам такой паттерн поиска
f"{self.namespace}:{user_id}:*"
?

  1. кажется что будет всего один такой ключ - у нас у каждого пользователя одна сигнатура
  2. ключ при сохранении на включает замыкающее ":" + какие-либо еще символы

cursor = b"0"
deleted_count = 0

async with self.protocol.client() as conn:
while cursor:
cursor, keys = await conn.scan(cursor, match=pattern, count=count_size)
deleted_count += await conn.unlink(*keys)
return deleted_count


redis_blackist_storage = RedisBlacklistUserSignatureStorage()
BlacklistSignatureStorage = Annotated[
RedisBlacklistUserSignatureStorage, Depends(redis_blackist_storage)
]
13 changes: 13 additions & 0 deletions src/v1/auth/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class AuthExceptionCodes:
USER_UNAUTHORIZED: int = 3002
PROVIDED_PASSWORD_INCORRECT: int = 3003
INVALID_PROVIDED_TOKEN: int = 3004
TOKEN_NOT_FOUND: int = 3005


class UserAlreadyExistsError(HTTPException):
Expand Down Expand Up @@ -58,3 +59,15 @@ def __init__(
) -> None:
detail = {"code": AuthExceptionCodes.INVALID_PROVIDED_TOKEN, "message": message}
super().__init__(status_code=status_code, detail=detail)


class TokenNotFoundError(HTTPException):
"""Error raised whe token doesnt exists in DB."""

def __init__(
self,
status_code: int = status.HTTP_404_NOT_FOUND,
message: str = "Invalid token.",
) -> None:
detail = {"code": AuthExceptionCodes.TOKEN_NOT_FOUND, "message": message}
super().__init__(status_code=status_code, detail=detail)
20 changes: 16 additions & 4 deletions src/v1/auth/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from passlib.hash import pbkdf2_sha256
from pydantic import UUID4

from src.db.redis import BlacklistSignatureStorage
from src.core.config import settings
from src.v1.auth.schemas import JWTTokens
from src.v1.auth.exceptions import InvalidTokenError
from src.v1.auth.exceptions import UnauthorizedError, InvalidTokenError


# TODO(alexander.zharyuk): Improve generation. Maybe add some salt?
Expand Down Expand Up @@ -56,12 +57,23 @@ def generate_jwt(payload: dict, access_jti: str, refresh_jti: UUID4) -> JWTToken

def decode_jwt(token: str) -> dict:
"""Decode access / refresh tokens payload"""
try:
payload = jwt.decode(
return jwt.decode(
token,
key=settings.jwt_secret_key,
algorithms=[settings.jwt_algorithm]
)


async def validate_jwt(blacklist_tokens_storage: BlacklistSignatureStorage, token: str):
"""Validate that token is not in blacklists"""
try:
token_payload = decode_jwt(token)
except JWTError:
raise InvalidTokenError()
return payload
token_headers = jwt.get_unverified_header(token)

user_id = token_payload.get("user_id")
token_signature = token_headers.get("jti")

if token_signature == await blacklist_tokens_storage.get(user_id):
raise UnauthorizedError()
26 changes: 24 additions & 2 deletions src/v1/auth/routers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from fastapi import APIRouter, Depends, Request, Response, status
from fastapi.security import APIKeyCookie

from src.db.postgres import DatabaseSession
from src.db.postgres import RefreshTokensStorage
from src.db.postgres import DatabaseSession, RefreshTokensStorage
from src.db.redis import BlacklistSignatureStorage
from src.v1.auth.schemas import (TokensResponse, UserCreate, UserLogin,
UserResponse, LogoutResponse, UserLogout)
from src.v1.auth.service import AuthService
Expand Down Expand Up @@ -54,3 +54,25 @@ async def logout(
)
return LogoutResponse()


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

заметил что в logout методе есть access_token из кук, но он никуда дальше не передается
по идее надо по нему верификацию делать как в terminate_all_sessions

@router.post(
"/logout_all",
summary="Выход из всех сессий пользователя",
response_model=LogoutResponse
)
async def terminate_all_sessions(
db_session: DatabaseSession,
blacklist_signatures_storage: BlacklistSignatureStorage,
refresh_token_storage: RefreshTokensStorage,
response: Response,
access_token: str | None = Depends(cookie_scheme)
) -> TokensResponse:
"""Выход из всех сессий пользователя"""
await AuthService.terminate_all_sessions(
db_session,
blacklist_signatures_storage,
refresh_token_storage,
response,
access_token
)
return LogoutResponse()
AlexanderZharyuk marked this conversation as resolved.
Show resolved Hide resolved
12 changes: 12 additions & 0 deletions src/v1/auth/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ class UserLogout(BaseModel):
refresh_token: str


class RefreshTokens(UserLogout):
...


class User(UserBase):
id: UUID4

Expand All @@ -51,3 +55,11 @@ class TokensResponse(BaseResponseBody):

class LogoutResponse(BaseResponseBody):
data: dict = {"sucess": True}


class VerifyTokenResponse(BaseResponseBody):
data: dict = {"access": True}


class JWTPayload(BaseModel):
user_id: UUID4
Loading