-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
88 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,91 +1,126 @@ | ||
from collections import defaultdict | ||
from contextlib import suppress | ||
from datetime import datetime, timedelta | ||
from typing import Any, Optional, List | ||
from typing import Any, Optional, Final, TypeAlias, TypedDict | ||
|
||
from telebot import TeleBot | ||
from telebot.handler_backends import BaseMiddleware, CancelUpdate | ||
from telebot.types import Message | ||
from telebot.types import Message, User | ||
|
||
from pytmbot.logs import bot_logger | ||
|
||
# Type aliases for better readability | ||
Timestamp: TypeAlias = datetime | ||
UserID: TypeAlias = int | ||
|
||
|
||
class RateLimitConfig(TypedDict): | ||
"""Type definition for rate limit configuration.""" | ||
limit: int | ||
period: timedelta | ||
|
||
|
||
class RateLimit(BaseMiddleware): | ||
""" | ||
Middleware for rate limiting user requests to prevent DDoS attacks. | ||
This middleware keeps track of user requests and limits the number of requests | ||
each user can make within a specified time period. | ||
Uses a sliding window approach to track and limit user requests within | ||
a specified time period. | ||
""" | ||
|
||
def __init__(self, bot: TeleBot, limit: int, period: timedelta) -> None: | ||
SUPPORTED_UPDATES: Final[list[str]] = ["message"] | ||
WARNING_MESSAGE: Final[str] = "⚠️ You're sending messages too quickly. 🕒 Please slow down." | ||
|
||
def __init__(self, bot: TeleBot, *, limit: int, period: timedelta) -> None: | ||
""" | ||
Initializes the rate limit middleware. | ||
Initialize rate limit middleware. | ||
Args: | ||
bot (TeleBot): The bot object. | ||
limit (int): Maximum number of requests allowed per user within the period. | ||
period (timedelta): The time period during which requests are counted. | ||
bot: The bot instance for sending messages | ||
limit: Maximum number of requests allowed per user within the period | ||
period: The time period during which requests are counted | ||
Raises: | ||
ValueError: If limit or period are invalid | ||
""" | ||
if limit <= 0: | ||
raise ValueError("Request limit must be positive") | ||
if period <= timedelta(): | ||
raise ValueError("Time period must be positive") | ||
|
||
super().__init__() | ||
self.bot = bot | ||
self.update_types: List[str] = ["message"] | ||
self.limit = limit | ||
self.period = period | ||
self.user_requests = defaultdict(list) | ||
self.update_types = self.SUPPORTED_UPDATES | ||
self._user_requests: defaultdict[UserID, list[Timestamp]] = defaultdict(list) | ||
|
||
def _clean_old_requests(self, user_id: UserID, current_time: datetime) -> None: | ||
"""Remove expired request timestamps for a user.""" | ||
requests = self._user_requests[user_id] | ||
cutoff_time = current_time - self.period | ||
|
||
while requests and requests[0] < cutoff_time: | ||
requests.pop(0) | ||
|
||
# Clean up empty user entries | ||
if not requests: | ||
with suppress(KeyError): | ||
del self._user_requests[user_id] | ||
|
||
def _is_rate_limited(self, user_id: UserID, current_time: datetime) -> bool: | ||
"""Check if user has exceeded their rate limit.""" | ||
self._clean_old_requests(user_id, current_time) | ||
return len(self._user_requests[user_id]) >= self.limit | ||
|
||
def _handle_rate_limit(self, message: Message, user: User) -> CancelUpdate: | ||
"""Handle rate limit exceeded scenario.""" | ||
bot_logger.warning( | ||
"Rate limit exceeded", | ||
extra={ | ||
"user_id": user.id, | ||
"username": user.username or "unknown", | ||
"limit": self.limit, | ||
"period": str(self.period) | ||
} | ||
) | ||
|
||
with suppress(Exception): | ||
self.bot.send_message( | ||
chat_id=message.chat.id, | ||
text=self.WARNING_MESSAGE | ||
) | ||
|
||
return CancelUpdate() | ||
|
||
@bot_logger.catch() | ||
def pre_process(self, message: Message, data: Any) -> Optional[CancelUpdate]: | ||
""" | ||
Processes the incoming message and checks for rate limiting. | ||
Process incoming message and enforce rate limiting. | ||
Args: | ||
message (Message): The incoming message from the user. | ||
data (Any): Additional data from Telebot. | ||
message: The incoming message to process | ||
data: Additional processing data | ||
Returns: | ||
Optional[CancelUpdate]: An instance of CancelUpdate if the user exceeds the rate limit, | ||
or None otherwise. | ||
CancelUpdate if rate limit exceeded, None otherwise | ||
Raises: | ||
CancelUpdate: If user information is missing | ||
""" | ||
user = message.from_user | ||
if not user: | ||
bot_logger.error("User information missing in the message.") | ||
if not (user := message.from_user): | ||
bot_logger.error("Missing user information in message") | ||
return CancelUpdate() | ||
|
||
user_id = user.id | ||
now = datetime.now() | ||
user_requests = self.user_requests[user_id] | ||
current_time = datetime.now() | ||
|
||
# Remove timestamps older than the defined period | ||
while user_requests and user_requests[0] < now - self.period: | ||
user_requests.pop(0) | ||
if self._is_rate_limited(user.id, current_time): | ||
return self._handle_rate_limit(message, user) | ||
|
||
# Check if the limit is exceeded | ||
if len(user_requests) >= self.limit: | ||
bot_logger.warning( | ||
f"User {user.username or 'unknown'} (ID: {user_id}) exceeded rate limit." | ||
) | ||
self.bot.send_message( | ||
chat_id=message.chat.id, | ||
text="⚠️ You're sending messages too quickly. 🕒 Please slow down.", | ||
) | ||
return CancelUpdate() | ||
|
||
user_requests.append(now) | ||
self._user_requests[user.id].append(current_time) | ||
return None | ||
|
||
def post_process( | ||
self, message: Message, data: Any, exception: Optional[Exception] | ||
) -> None: | ||
""" | ||
Post-processes the incoming message. | ||
This method can be implemented if necessary to handle any post-processing after | ||
the main logic of the middleware has executed. | ||
Args: | ||
message (Message): The message object being processed. | ||
data (Any): Additional data from Telebot. | ||
exception (Optional[Exception]): Exception raised during processing, if any. | ||
""" | ||
# Implement if necessary or remove if not used. | ||
pass | ||
def post_process(self, message: Message, data: Any, | ||
exception: Optional[Exception]) -> None: | ||
"""Post-process message after main middleware execution.""" | ||
pass # Currently unused but kept for interface compliance |