Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(terminate_all_session): Terminate all user sessions #35

Merged
merged 3 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Settings(BaseSettings):
listen_port: int = 8000
allowed_hosts: list = const.DEFAULT_ALLOWED_HOSTS

redis_host: str = "redis"
redis_host: str = "localhost"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Тут лучше оставить имя контейнера по умолчанию

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

имхо, когда будем собирать докер - это может усложнить работу, лучше 'redis'
то же самое с postgres, но я могу это поправить в ветке с докером

redis_port: int = 6379
redis_db: int = 0

Expand Down Expand Up @@ -42,6 +42,10 @@ def pg_dsn(self):
f"{self.postgres_port}/{self.postgres_db}"
)

@cached_property
def redis_dsn(self):
return f"redis://{self.redis_host}:{self.redis_port}"
AlexanderZharyuk marked this conversation as resolved.
Show resolved Hide resolved

class Config:
case_sensitive = False
env_file = ".env"
Expand Down
47 changes: 47 additions & 0 deletions src/db/redis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Annotated

from aioredis import Redis, from_url
AlexanderZharyuk marked this conversation as resolved.
Show resolved Hide resolved
from pydantic import UUID4
from fastapi import Depends

from src.db.storages import BaseStorage
from src.core.config import settings


class RedisBlacklistUserSignatureStorage(BaseStorage):

def __init__(self) -> None:
self.protocol: Redis = from_url(settings.redis_dsn, decode_responses=True)
self.namespace: str = "auth_service"

async def create(self, user_id: UUID4, signature: str):
signature_key = f"{self.namespace}:{user_id}"
async with self.protocol.client() as conn:
await conn.set(signature_key, signature)

async def get(self, user_id: UUID4) -> str:
signature_key = f"{self.namespace}:{user_id}"
async with self.protocol.client() as conn:
return await conn.get(signature_key)

async def delete(self, user_id: UUID4):
signature_key = f"{self.namespace}:{user_id}"
async with self.protocol.client() as conn:
return await conn.delete(signature_key)

async def delete_all(self, user_id: UUID4, count_size: int = 10) -> int:
pattern = f"{self.namespace}:{user_id}:*"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

возможно я не понимаю как клиент редиса работает, но если мы создаем только пары ключ-значение вида f"{self.namespace}:{user_id}": signature то зачем нам такой паттерн поиска
f"{self.namespace}:{user_id}:*"
?

  1. кажется что будет всего один такой ключ - у нас у каждого пользователя одна сигнатура
  2. ключ при сохранении на включает замыкающее ":" + какие-либо еще символы

cursor = b"0"
deleted_count = 0

async with self.protocol.client() as conn:
while cursor:
cursor, keys = await conn.scan(cursor, match=pattern, count=count_size)
deleted_count += await conn.unlink(*keys)
return deleted_count


redis_blackist_storage = RedisBlacklistUserSignatureStorage()
BlacklistSignatureStorage = Annotated[
RedisBlacklistUserSignatureStorage, Depends(redis_blackist_storage)
]
15 changes: 14 additions & 1 deletion src/v1/auth/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from passlib.hash import pbkdf2_sha256
from pydantic import UUID4

from src.db.redis import BlacklistSignatureStorage
from src.core.config import settings
from src.v1.auth.schemas import JWTTokens
from src.v1.auth.exceptions import InvalidTokenError
from src.v1.auth.exceptions import InvalidTokenError, UnauthorizedError


# TODO(alexander.zharyuk): Improve generation. Maybe add some salt?
Expand Down Expand Up @@ -65,3 +66,15 @@ def decode_jwt(token: str) -> dict:
except JWTError:
raise InvalidTokenError()
return payload


async def validate_jwt(blacklist_tokens_storage: BlacklistSignatureStorage, token: str):
"""Validate that token is not in blacklists"""
token_payload = decode_jwt(token)
token_headers = jwt.get_unverified_header(token)

user_id = token_payload.get("user_id")
token_signature = token_headers.get("jti")

if token_signature == await blacklist_tokens_storage.get(user_id):
raise UnauthorizedError()
26 changes: 24 additions & 2 deletions src/v1/auth/routers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from fastapi import APIRouter, Depends, Request, Response, status
from fastapi.security import APIKeyCookie

from src.db.postgres import DatabaseSession
from src.db.postgres import RefreshTokensStorage
from src.db.postgres import DatabaseSession, RefreshTokensStorage
from src.db.redis import BlacklistSignatureStorage
from src.v1.auth.schemas import (TokensResponse, UserCreate, UserLogin,
UserResponse, LogoutResponse, UserLogout)
from src.v1.auth.service import AuthService
Expand Down Expand Up @@ -54,3 +54,25 @@ async def logout(
)
return LogoutResponse()


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

заметил что в logout методе есть access_token из кук, но он никуда дальше не передается
по идее надо по нему верификацию делать как в terminate_all_sessions

@router.post(
"/logout_all",
summary="Выход из всех сессий пользователя",
response_model=LogoutResponse
)
async def terminate_all_sessions(
db_session: DatabaseSession,
blacklist_signatures_storage: BlacklistSignatureStorage,
refresh_token_storage: RefreshTokensStorage,
response: Response,
access_token: str | None = Depends(cookie_scheme)
) -> TokensResponse:
"""Выход из всех сессий пользователя"""
await AuthService.terminate_all_sessions(
db_session,
blacklist_signatures_storage,
refresh_token_storage,
response,
access_token
)
return LogoutResponse()
AlexanderZharyuk marked this conversation as resolved.
Show resolved Hide resolved
53 changes: 47 additions & 6 deletions src/v1/auth/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from datetime import datetime

from fastapi import Request, Response
from pydantic import BaseModel, UUID4
from sqlalchemy import and_, or_, select, delete
from pydantic import BaseModel
from sqlalchemy import and_, or_, select, update
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from jose import jwt

from src.core.config import settings
from src.db.postgres import RefreshTokensStorage
Expand All @@ -16,9 +17,11 @@
generate_user_signature,
hash_password,
verify_password,
decode_jwt
decode_jwt,
validate_jwt
)
from src.v1.auth.schemas import JWTTokens, User, UserCreate, UserLogin
from src.db.redis import BlacklistSignatureStorage
from src.v1.exceptions import ServiceError
from src.v1.users.service import UserService
from src.v1.users import models as users_models
Expand Down Expand Up @@ -91,7 +94,7 @@ async def signin(

# TODO: Add role to JWT
tokens = generate_jwt(
payload={"user_id": str(exists_user.id)},
payload={"user_id": str(exists_user.id), "username": exists_user.username},
AlexanderZharyuk marked this conversation as resolved.
Show resolved Hide resolved
access_jti=exists_user.signature.signature,
refresh_jti=uuid.uuid4(),
)
Expand All @@ -116,8 +119,30 @@ async def logout(
async def verify(current_user=None):
...

async def terminate_all_sessions(current_user=None):
...
@staticmethod
async def terminate_all_sessions(
db_session: AsyncSession,
blacklist_signatures_storage: BlacklistSignatureStorage,
refresh_token_storage: RefreshTokensStorage,
response: Response,
access_token: str
):
"""Terminate all sessions"""

await validate_jwt(blacklist_signatures_storage, access_token)
response.delete_cookie(settings.sessions_cookie_name)
token_payload = decode_jwt(access_token)
token_headers = jwt.get_unverified_header(access_token)

username = token_payload.get("username")
user_id = token_payload.get("user_id")
old_user_signature = token_headers.get("jti")
new_user_signature = generate_user_signature(username)

await blacklist_signatures_storage.create(user_id, old_user_signature)
await __class__._update_user_signature(db_session, old_user_signature, new_user_signature)

await refresh_token_storage.delete_all(db_session, user_id=user_id)

@staticmethod
async def _save_login_session_if_not_exists(
Expand Down Expand Up @@ -166,5 +191,21 @@ async def _set_user_cookie(cookie_key: str, cookie_value: str, response: Respons
expires=settings.jwt_access_expire_time_in_seconds
)

@staticmethod
async def _update_user_signature(
db_session: AsyncSession,
old_user_signature: str,
new_user_signature: str
):
"""Update user signature when user terminate all sessions."""
statement = update(users_models.UserSignature).where(
users_models.UserSignature.signature == old_user_signature
).values({users_models.UserSignature.signature: new_user_signature})
await db_session.execute(statement)
try:
await db_session.commit()
except SQLAlchemyError:
await db_session.rollback()
raise ServiceError()

AuthService = JWTAuthService()