Skip to content

Commit

Permalink
chore: code refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
orenlab committed Dec 29, 2024
1 parent 934bdf8 commit a4233bd
Showing 1 changed file with 88 additions and 53 deletions.
141 changes: 88 additions & 53 deletions pytmbot/middleware/rate_limit.py
Original file line number Diff line number Diff line change
@@ -1,91 +1,126 @@
from collections import defaultdict
from contextlib import suppress
from datetime import datetime, timedelta
from typing import Any, Optional, List
from typing import Any, Optional, Final, TypeAlias, TypedDict

from telebot import TeleBot
from telebot.handler_backends import BaseMiddleware, CancelUpdate
from telebot.types import Message
from telebot.types import Message, User

from pytmbot.logs import bot_logger

# Type aliases for better readability
Timestamp: TypeAlias = datetime
UserID: TypeAlias = int


class RateLimitConfig(TypedDict):
"""Type definition for rate limit configuration."""
limit: int
period: timedelta


class RateLimit(BaseMiddleware):
"""
Middleware for rate limiting user requests to prevent DDoS attacks.
This middleware keeps track of user requests and limits the number of requests
each user can make within a specified time period.
Uses a sliding window approach to track and limit user requests within
a specified time period.
"""

def __init__(self, bot: TeleBot, limit: int, period: timedelta) -> None:
SUPPORTED_UPDATES: Final[list[str]] = ["message"]
WARNING_MESSAGE: Final[str] = "⚠️ You're sending messages too quickly. 🕒 Please slow down."

def __init__(self, bot: TeleBot, *, limit: int, period: timedelta) -> None:
"""
Initializes the rate limit middleware.
Initialize rate limit middleware.
Args:
bot (TeleBot): The bot object.
limit (int): Maximum number of requests allowed per user within the period.
period (timedelta): The time period during which requests are counted.
bot: The bot instance for sending messages
limit: Maximum number of requests allowed per user within the period
period: The time period during which requests are counted
Raises:
ValueError: If limit or period are invalid
"""
if limit <= 0:
raise ValueError("Request limit must be positive")
if period <= timedelta():
raise ValueError("Time period must be positive")

super().__init__()
self.bot = bot
self.update_types: List[str] = ["message"]
self.limit = limit
self.period = period
self.user_requests = defaultdict(list)
self.update_types = self.SUPPORTED_UPDATES
self._user_requests: defaultdict[UserID, list[Timestamp]] = defaultdict(list)

def _clean_old_requests(self, user_id: UserID, current_time: datetime) -> None:
"""Remove expired request timestamps for a user."""
requests = self._user_requests[user_id]
cutoff_time = current_time - self.period

while requests and requests[0] < cutoff_time:
requests.pop(0)

# Clean up empty user entries
if not requests:
with suppress(KeyError):
del self._user_requests[user_id]

def _is_rate_limited(self, user_id: UserID, current_time: datetime) -> bool:
"""Check if user has exceeded their rate limit."""
self._clean_old_requests(user_id, current_time)
return len(self._user_requests[user_id]) >= self.limit

def _handle_rate_limit(self, message: Message, user: User) -> CancelUpdate:
"""Handle rate limit exceeded scenario."""
bot_logger.warning(
"Rate limit exceeded",
extra={
"user_id": user.id,
"username": user.username or "unknown",
"limit": self.limit,
"period": str(self.period)
}
)

with suppress(Exception):
self.bot.send_message(
chat_id=message.chat.id,
text=self.WARNING_MESSAGE
)

return CancelUpdate()

@bot_logger.catch()
def pre_process(self, message: Message, data: Any) -> Optional[CancelUpdate]:
"""
Processes the incoming message and checks for rate limiting.
Process incoming message and enforce rate limiting.
Args:
message (Message): The incoming message from the user.
data (Any): Additional data from Telebot.
message: The incoming message to process
data: Additional processing data
Returns:
Optional[CancelUpdate]: An instance of CancelUpdate if the user exceeds the rate limit,
or None otherwise.
CancelUpdate if rate limit exceeded, None otherwise
Raises:
CancelUpdate: If user information is missing
"""
user = message.from_user
if not user:
bot_logger.error("User information missing in the message.")
if not (user := message.from_user):
bot_logger.error("Missing user information in message")
return CancelUpdate()

user_id = user.id
now = datetime.now()
user_requests = self.user_requests[user_id]
current_time = datetime.now()

# Remove timestamps older than the defined period
while user_requests and user_requests[0] < now - self.period:
user_requests.pop(0)
if self._is_rate_limited(user.id, current_time):
return self._handle_rate_limit(message, user)

# Check if the limit is exceeded
if len(user_requests) >= self.limit:
bot_logger.warning(
f"User {user.username or 'unknown'} (ID: {user_id}) exceeded rate limit."
)
self.bot.send_message(
chat_id=message.chat.id,
text="⚠️ You're sending messages too quickly. 🕒 Please slow down.",
)
return CancelUpdate()

user_requests.append(now)
self._user_requests[user.id].append(current_time)
return None

def post_process(
self, message: Message, data: Any, exception: Optional[Exception]
) -> None:
"""
Post-processes the incoming message.
This method can be implemented if necessary to handle any post-processing after
the main logic of the middleware has executed.
Args:
message (Message): The message object being processed.
data (Any): Additional data from Telebot.
exception (Optional[Exception]): Exception raised during processing, if any.
"""
# Implement if necessary or remove if not used.
pass
def post_process(self, message: Message, data: Any,
exception: Optional[Exception]) -> None:
"""Post-process message after main middleware execution."""
pass # Currently unused but kept for interface compliance

0 comments on commit a4233bd

Please sign in to comment.