From fce1c3d1710dad4e4d2a7f683ac1188116b16372 Mon Sep 17 00:00:00 2001 From: _run Date: Sun, 23 Jun 2024 15:34:05 +0500 Subject: [PATCH 01/24] Added statesv2 --- telebot/__init__.py | 113 +++++++++++++++++--- telebot/custom_filters.py | 16 ++- telebot/storage/base_storage.py | 40 ++++++- telebot/storage/memory_storage.py | 168 ++++++++++++++++++++---------- 4 files changed, 263 insertions(+), 74 deletions(-) diff --git a/telebot/__init__.py b/telebot/__init__.py index d2bcdc307..8fe414252 100644 --- a/telebot/__init__.py +++ b/telebot/__init__.py @@ -7,7 +7,7 @@ import threading import time import traceback -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, List, Optional, Union, Dict # these imports are used to avoid circular import error import telebot.util @@ -168,7 +168,8 @@ def __init__( disable_notification: Optional[bool]=None, protect_content: Optional[bool]=None, allow_sending_without_reply: Optional[bool]=None, - colorful_logs: Optional[bool]=False + colorful_logs: Optional[bool]=False, + token_check: Optional[bool]=True ): # update-related @@ -186,6 +187,11 @@ def __init__( self.webhook_listener = None self._user = None + # token check + if token_check: + self._user = self.get_me() + self.bot_id = self._user.id + # logs-related if colorful_logs: try: @@ -280,6 +286,8 @@ def __init__( self.threaded = threaded if self.threaded: self.worker_pool = util.ThreadPool(self, num_threads=num_threads) + + @property def user(self) -> types.User: @@ -6572,7 +6580,9 @@ def setup_middleware(self, middleware: BaseMiddleware): self.middlewares.append(middleware) - def set_state(self, user_id: int, state: Union[int, str, State], chat_id: Optional[int]=None) -> None: + def set_state(self, user_id: int, state: Union[int, str, State], chat_id: Optional[int]=None, + business_connection_id: Optional[str]=None, message_thread_id: Optional[int]=None, + bot_id: Optional[int]=None) -> None: """ Sets a new state of a user. @@ -6591,14 +6601,29 @@ def set_state(self, user_id: int, state: Union[int, str, State], chat_id: Option :param chat_id: Chat's identifier :type chat_id: :obj:`int` + :param bot_id: Bot's identifier + :type bot_id: :obj:`int` + + :param business_connection_id: Business identifier + :type business_connection_id: :obj:`str` + + :param message_thread_id: Identifier of the message thread + :type message_thread_id: :obj:`int` + :return: None """ if chat_id is None: chat_id = user_id - self.current_states.set_state(chat_id, user_id, state) + if bot_id is None: + bot_id = self.bot_id + self.current_states.set_state( + chat_id=chat_id, user_id=user_id, state=state, bot_id=bot_id, + business_connection_id=business_connection_id, message_thread_id=message_thread_id) - def reset_data(self, user_id: int, chat_id: Optional[int]=None): + def reset_data(self, user_id: int, chat_id: Optional[int]=None, + business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> None: """ Reset data for a user in chat. @@ -6608,14 +6633,27 @@ def reset_data(self, user_id: int, chat_id: Optional[int]=None): :param chat_id: Chat's identifier :type chat_id: :obj:`int` + :param bot_id: Bot's identifier + :type bot_id: :obj:`int` + + :param business_connection_id: Business identifier + :type business_connection_id: :obj:`str` + + :param message_thread_id: Identifier of the message thread + :type message_thread_id: :obj:`int` + :return: None """ if chat_id is None: chat_id = user_id - self.current_states.reset_data(chat_id, user_id) + if bot_id is None: + bot_id = self.bot_id + self.current_states.reset_data(chat_id=chat_id, user_id=user_id, bot_id=bot_id, + business_connection_id=business_connection_id, message_thread_id=message_thread_id) - def delete_state(self, user_id: int, chat_id: Optional[int]=None) -> None: + def delete_state(self, user_id: int, chat_id: Optional[int]=None, business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> None: """ Delete the current state of a user. @@ -6629,10 +6667,14 @@ def delete_state(self, user_id: int, chat_id: Optional[int]=None) -> None: """ if chat_id is None: chat_id = user_id - self.current_states.delete_state(chat_id, user_id) + if bot_id is None: + bot_id = self.bot_id + self.current_states.delete_state(chat_id=chat_id, user_id=user_id, bot_id=bot_id, + business_connection_id=business_connection_id, message_thread_id=message_thread_id) - def retrieve_data(self, user_id: int, chat_id: Optional[int]=None) -> Optional[Any]: + def retrieve_data(self, user_id: int, chat_id: Optional[int]=None, business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> Optional[Dict[str, Any]]: """ Returns context manager with data for a user in chat. @@ -6642,15 +6684,30 @@ def retrieve_data(self, user_id: int, chat_id: Optional[int]=None) -> Optional[A :param chat_id: Chat's unique identifier, defaults to user_id :type chat_id: int, optional + :param bot_id: Bot's identifier + :type bot_id: int, optional + + :param business_connection_id: Business identifier + :type business_connection_id: str, optional + + :param message_thread_id: Identifier of the message thread + :type message_thread_id: int, optional + :return: Context manager with data for a user in chat :rtype: Optional[Any] """ if chat_id is None: chat_id = user_id - return self.current_states.get_interactive_data(chat_id, user_id) + if bot_id is None: + bot_id = self.bot_id + return self.current_states.get_interactive_data(chat_id=chat_id, user_id=user_id, bot_id=bot_id, + business_connection_id=business_connection_id, + message_thread_id=message_thread_id) - def get_state(self, user_id: int, chat_id: Optional[int]=None) -> Optional[Union[int, str, State]]: + def get_state(self, user_id: int, chat_id: Optional[int]=None, + business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> Union[int, str]: """ Gets current state of a user. Not recommended to use this method. But it is ok for debugging. @@ -6661,15 +6718,31 @@ def get_state(self, user_id: int, chat_id: Optional[int]=None) -> Optional[Union :param chat_id: Chat's identifier :type chat_id: :obj:`int` + :param bot_id: Bot's identifier + :type bot_id: :obj:`int` + + :param business_connection_id: Business identifier + :type business_connection_id: :obj:`str` + + :param message_thread_id: Identifier of the message thread + :type message_thread_id: :obj:`int` + :return: state of a user :rtype: :obj:`int` or :obj:`str` or :class:`telebot.types.State` """ if chat_id is None: chat_id = user_id - return self.current_states.get_state(chat_id, user_id) + if bot_id is None: + bot_id = self.bot_id + return self.current_states.get_state(chat_id=chat_id, user_id=user_id, bot_id=bot_id, + business_connection_id=business_connection_id, message_thread_id=message_thread_id) - def add_data(self, user_id: int, chat_id: Optional[int]=None, **kwargs): + def add_data(self, user_id: int, chat_id: Optional[int]=None, + business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, + bot_id: Optional[int]=None, + **kwargs) -> None: """ Add data to states. @@ -6679,13 +6752,25 @@ def add_data(self, user_id: int, chat_id: Optional[int]=None, **kwargs): :param chat_id: Chat's identifier :type chat_id: :obj:`int` + :param bot_id: Bot's identifier + :type bot_id: :obj:`int` + + :param business_connection_id: Business identifier + :type business_connection_id: :obj:`str` + + :param message_thread_id: Identifier of the message thread + :type message_thread_id: :obj:`int` + :param kwargs: Data to add :return: None """ if chat_id is None: chat_id = user_id + if bot_id is None: + bot_id = self.bot_id for key, value in kwargs.items(): - self.current_states.set_data(chat_id, user_id, key, value) + self.current_states.set_data(chat_id=chat_id, user_id=user_id, key=key, value=value, bot_id=bot_id, + business_connection_id=business_connection_id, message_thread_id=message_thread_id) def register_next_step_handler_by_chat_id( diff --git a/telebot/custom_filters.py b/telebot/custom_filters.py index d4d4ffeca..41a7c780c 100644 --- a/telebot/custom_filters.py +++ b/telebot/custom_filters.py @@ -7,6 +7,7 @@ + class SimpleCustomFilter(ABC): """ Simple Custom Filter base class. @@ -417,8 +418,6 @@ def check(self, message, text): user_id = message.from_user.id message = message.message - - if isinstance(text, list): new_text = [] @@ -430,7 +429,11 @@ def check(self, message, text): text = text.name if message.chat.type in ['group', 'supergroup']: - group_state = self.bot.current_states.get_state(chat_id, user_id) + group_state = self.bot.current_states.get_state(chat_id=chat_id, user_id=user_id, business_connection_id=message.business_connection_id, bot_id=self.bot._user.id, + message_thread_id=message.message_thread_id) + if group_state is None and not message.is_topic_message: # needed for general topic and group messages + group_state = self.bot.current_states.get_state(chat_id=chat_id, user_id=user_id, business_connection_id=message.business_connection_id, bot_id=self.bot._user.id) + if group_state == text: return True elif type(text) is list and group_state in text: @@ -438,7 +441,12 @@ def check(self, message, text): else: - user_state = self.bot.current_states.get_state(chat_id, user_id) + user_state = self.bot.current_states.get_state( + chat_id=chat_id, + user_id=user_id, + business_connection_id=message.business_connection_id, + bot_id=self.bot._user.id + ) if user_state == text: return True elif type(text) is list and user_state in text: diff --git a/telebot/storage/base_storage.py b/telebot/storage/base_storage.py index 92b31ba85..36545a77c 100644 --- a/telebot/storage/base_storage.py +++ b/telebot/storage/base_storage.py @@ -47,22 +47,56 @@ def get_interactive_data(self, chat_id, user_id): def save(self, chat_id, user_id, data): raise NotImplementedError + + def convert_params_to_key( + self, + chat_id: int, + user_id: int, + prefix: str, + separator: str, + business_connection_id: str=None, + message_thread_id: int=None, + bot_id: int=None + ) -> str: + """ + Convert parameters to a key. + """ + params = [prefix] + if bot_id: + params.append(str(bot_id)) + if business_connection_id: + params.append(business_connection_id) + if message_thread_id: + params.append(str(message_thread_id)) + params.append(str(chat_id)) + params.append(str(user_id)) + return separator.join(params) + + + + class StateContext: """ Class for data. """ - def __init__(self , obj, chat_id, user_id) -> None: + def __init__(self , obj, chat_id, user_id, business_connection_id=None, message_thread_id=None, bot_id=None, ): self.obj = obj - self.data = copy.deepcopy(obj.get_data(chat_id, user_id)) + res = obj.get_data(chat_id=chat_id, user_id=user_id, business_connection_id=business_connection_id, + message_thread_id=message_thread_id, bot_id=bot_id) + self.data = copy.deepcopy(res) self.chat_id = chat_id self.user_id = user_id + self.bot_id = bot_id + self.business_connection_id = business_connection_id + self.message_thread_id = message_thread_id + def __enter__(self): return self.data def __exit__(self, exc_type, exc_val, exc_tb): - return self.obj.save(self.chat_id, self.user_id, self.data) \ No newline at end of file + return self.obj.save(self.chat_id, self.user_id, self.data, self.business_connection_id, self.message_thread_id, self.bot_id) \ No newline at end of file diff --git a/telebot/storage/memory_storage.py b/telebot/storage/memory_storage.py index 7d71c7ccd..fbf9eebda 100644 --- a/telebot/storage/memory_storage.py +++ b/telebot/storage/memory_storage.py @@ -1,69 +1,131 @@ from telebot.storage.base_storage import StateStorageBase, StateContext - +from typing import Optional, Union class StateMemoryStorage(StateStorageBase): - def __init__(self) -> None: - super().__init__() - self.data = {} - # - # {chat_id: {user_id: {'state': None, 'data': {}}, ...}, ...} - - - def set_state(self, chat_id, user_id, state): - if hasattr(state, 'name'): + def __init__(self, + separator: Optional[str]=":", + prefix: Optional[str]="telebot" + ) -> None: + self.separator = separator + self.prefix = prefix + if not self.prefix: + raise ValueError("Prefix cannot be empty") + + self.data = {} # key: telebot:bot_id:business_connection_id:message_thread_id:chat_id:user_id + + def set_state( + self, chat_id: int, user_id: int, state: str, business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None + + ) -> bool: + if hasattr(state, "name"): state = state.name - if chat_id in self.data: - if user_id in self.data[chat_id]: - self.data[chat_id][user_id]['state'] = state - return True - else: - self.data[chat_id][user_id] = {'state': state, 'data': {}} - return True - self.data[chat_id] = {user_id: {'state': state, 'data': {}}} + + _key = self.convert_params_to_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + + if self.data.get(_key) is None: + self.data[_key] = {"state": state, "data": {}} + else: + self.data[_key]["state"] = state + return True - def delete_state(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - del self.data[chat_id][user_id] - if chat_id == user_id: - del self.data[chat_id] - - return True + def get_state( + self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None + ) -> Union[str, None]: - return False + _key = self.convert_params_to_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + if self.data.get(_key) is None: + return None + + return self.data[_key]["state"] - def get_state(self, chat_id, user_id): + def delete_state( + self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None + ) -> bool: + _key = self.convert_params_to_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + + if self.data.get(_key) is None: + return False + + del self.data[_key] + return True + + + def set_data( + self, chat_id: int, user_id: int, key: str, value: Union[str, int, float, dict], + business_connection_id: Optional[str]=None, message_thread_id: Optional[int]=None, + bot_id: Optional[int]=None) -> bool: + + _key = self.convert_params_to_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - return self.data[chat_id][user_id]['state'] + if self.data.get(_key) is None: + return False + self.data[_key]["data"][key] = value + return True - return None - def get_data(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - return self.data[chat_id][user_id]['data'] + + def get_data( + self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None + ) -> dict: - return None + _key = self.convert_params_to_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) - def reset_data(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - self.data[chat_id][user_id]['data'] = {} - return True - return False + return self.data.get(_key, {}).get("data", None) + + def reset_data( + self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None + ) -> bool: + + _key = self.convert_params_to_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) - def set_data(self, chat_id, user_id, key, value): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - self.data[chat_id][user_id]['data'][key] = value - return True - raise RuntimeError('chat_id {} and user_id {} does not exist'.format(chat_id, user_id)) + if self.data.get(_key) is None: + return False + self.data[_key]["data"] = {} + return True + + def get_interactive_data( + self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None + ) -> Optional[dict]: + return StateContext( + self, chat_id=chat_id, user_id=user_id, business_connection_id=business_connection_id, + message_thread_id=message_thread_id, bot_id=bot_id + ) + + def save( + self, chat_id: int, user_id: int, data: dict, business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None + ) -> bool: + _key = self.convert_params_to_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, + message_thread_id, bot_id + ) - def get_interactive_data(self, chat_id, user_id): - return StateContext(self, chat_id, user_id) + if self.data.get(_key) is None: + return False + self.data[_key]["data"] = data + return True + + def __str__(self) -> str: + return f"" + + - def save(self, chat_id, user_id, data): - self.data[chat_id][user_id]['data'] = data \ No newline at end of file From 7e5a044b104527489c84e5579396d24e7c2a8115 Mon Sep 17 00:00:00 2001 From: _run Date: Mon, 8 Jul 2024 17:11:37 +0500 Subject: [PATCH 02/24] Sync states v2 early version --- telebot/handler_backends.py | 43 +-------- telebot/states/__init__.py | 43 +++++++++ telebot/states/aio/__init__.py | 0 telebot/states/sync/__init__.py | 7 ++ telebot/states/sync/context.py | 153 ++++++++++++++++++++++++++++++ telebot/states/sync/middleware.py | 17 ++++ 6 files changed, 223 insertions(+), 40 deletions(-) create mode 100644 telebot/states/__init__.py create mode 100644 telebot/states/aio/__init__.py create mode 100644 telebot/states/sync/__init__.py create mode 100644 telebot/states/sync/context.py create mode 100644 telebot/states/sync/middleware.py diff --git a/telebot/handler_backends.py b/telebot/handler_backends.py index b95861e0b..12f8c89bf 100644 --- a/telebot/handler_backends.py +++ b/telebot/handler_backends.py @@ -9,6 +9,8 @@ except: redis_installed = False +# backward compatibility +from telebot.states import State, StatesGroup class HandlerBackend(object): """ @@ -160,45 +162,6 @@ def get_handlers(self, handler_group_id): return handlers -class State: - """ - Class representing a state. - - .. code-block:: python3 - - class MyStates(StatesGroup): - my_state = State() # returns my_state:State string. - """ - def __init__(self) -> None: - self.name = None - def __str__(self) -> str: - return self.name - - -class StatesGroup: - """ - Class representing common states. - - .. code-block:: python3 - - class MyStates(StatesGroup): - my_state = State() # returns my_state:State string. - """ - def __init_subclass__(cls) -> None: - state_list = [] - for name, value in cls.__dict__.items(): - if not name.startswith('__') and not callable(value) and isinstance(value, State): - # change value of that variable - value.name = ':'.join((cls.__name__, name)) - value.group = cls - state_list.append(value) - cls._state_list = state_list - - @classmethod - def state_list(self): - return self._state_list - - class BaseMiddleware: """ Base class for middleware. @@ -292,4 +255,4 @@ def start2(message): """ def __init__(self) -> None: - pass + pass \ No newline at end of file diff --git a/telebot/states/__init__.py b/telebot/states/__init__.py new file mode 100644 index 000000000..0a45f17e1 --- /dev/null +++ b/telebot/states/__init__.py @@ -0,0 +1,43 @@ +""" +Contains classes for states and state groups. +""" + + +class State: + """ + Class representing a state. + + .. code-block:: python3 + + class MyStates(StatesGroup): + my_state = State() # returns my_state:State string. + """ + def __init__(self) -> None: + self.name: str = None + self.group: StatesGroup = None + def __str__(self) -> str: + return f"<{self.group.__name__}:{self.name}>" + + +class StatesGroup: + """ + Class representing common states. + + .. code-block:: python3 + + class MyStates(StatesGroup): + my_state = State() # returns my_state:State string. + """ + def __init_subclass__(cls) -> None: + state_list = [] + for name, value in cls.__dict__.items(): + if not name.startswith('__') and not callable(value) and isinstance(value, State): + # change value of that variable + value.name = ':'.join((cls.__name__, name)) + value.group = cls + state_list.append(value) + cls._state_list = state_list + + @classmethod + def state_list(self): + return self._state_list diff --git a/telebot/states/aio/__init__.py b/telebot/states/aio/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/telebot/states/sync/__init__.py b/telebot/states/sync/__init__.py new file mode 100644 index 000000000..60f9ea2b7 --- /dev/null +++ b/telebot/states/sync/__init__.py @@ -0,0 +1,7 @@ +from .context import StateContext +from .middleware import StateMiddleware + +__all__ = [ + 'StateContext', + 'StateMiddleware', +] \ No newline at end of file diff --git a/telebot/states/sync/context.py b/telebot/states/sync/context.py new file mode 100644 index 000000000..f0a395d01 --- /dev/null +++ b/telebot/states/sync/context.py @@ -0,0 +1,153 @@ +from telebot.states import State, StatesGroup +from telebot.types import CallbackQuery, Message +from telebot import TeleBot + +from typing import Union + + +class StateContext(): + """ + Class representing a state context. + + Passed through a middleware to provide easy way to set states. + + .. code-block:: python3 + + @bot.message_handler(commands=['start']) + def start_ex(message: types.Message, state_context: StateContext): + state_context.set(MyStates.name) + bot.send_message(message.chat.id, 'Hi, write me a name', reply_to_message_id=message.message_id) + # also, state_context.data(), .add_data(), .reset_data(), .delete() methods available. + """ + + def __init__(self, message: Union[Message, CallbackQuery], bot: str) -> None: + self.message: Union[Message, CallbackQuery] = message + self.bot: TeleBot = bot + self.bot_id = self.bot.bot_id + + def _resolve_context(self) -> Union[Message, CallbackQuery]: + chat_id = None + user_id = None + business_connection_id = self.message.business_connection_id + bot_id = self.bot_id + message_thread_id = None + + if isinstance(self.message, Message): + chat_id = self.message.chat.id + user_id = self.message.from_user.id + message_thread_id = self.message.message_thread_id if self.message.is_topic_message else None + elif isinstance(self.message, CallbackQuery): + chat_id = self.message.message.chat.id + user_id = self.message.from_user.id + message_thread_id = self.message.message.message_thread_id if self.message.message.is_topic_message else None + + return chat_id, user_id, business_connection_id, bot_id, message_thread_id + + def set(self, state: Union[State, str]) -> None: + """ + Set state for current user. + + :param state: State object or state name. + :type state: Union[State, str] + + .. code-block:: python3 + + @bot.message_handler(commands=['start']) + def start_ex(message: types.Message, state_context: StateContext): + state_context.set(MyStates.name) + bot.send_message(message.chat.id, 'Hi, write me a name', reply_to_message_id=message.message_id) + """ + + chat_id, user_id, business_connection_id, bot_id, message_thread_id = self._resolve_context() + if isinstance(state, State): + state = state.name + return self.bot.set_state( + chat_id=chat_id, + user_id=user_id, + state=state, + business_connection_id=business_connection_id, + bot_id=bot_id, + message_thread_id=message_thread_id + ) + + def get(self) -> str: + """ + Get current state for current user. + + :return: Current state name. + :rtype: str + """ + + chat_id, user_id, business_connection_id, bot_id, message_thread_id = self._resolve_context() + return self.bot.get_state( + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + bot_id=bot_id, + message_thread_id=message_thread_id + ) + + def reset_data(self) -> None: + """ + Reset data for current user. + State will not be changed. + """ + + chat_id, user_id, business_connection_id, bot_id, message_thread_id = self._resolve_context() + return self.bot.reset_data( + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + bot_id=bot_id, + message_thread_id=message_thread_id + ) + + def delete(self) -> None: + """ + Deletes state and data for current user. + """ + chat_id, user_id, business_connection_id, bot_id, message_thread_id = self._resolve_context() + return self.bot.delete_state( + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + bot_id=bot_id, + message_thread_id=message_thread_id + ) + + def data(self) -> dict: + """ + Get data for current user. + + .. code-block:: python3 + + with state_context.data() as data: + print(data) + """ + + chat_id, user_id, business_connection_id, bot_id, message_thread_id = self._resolve_context() + return self.bot.retrieve_data( + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + bot_id=bot_id, + message_thread_id=message_thread_id + ) + + def add_data(self, **kwargs) -> None: + """ + Add data for current user. + + :param kwargs: Data to add. + :type kwargs: dict + """ + + chat_id, user_id, business_connection_id, bot_id, message_thread_id = self._resolve_context() + return self.bot.add_data( + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + bot_id=bot_id, + message_thread_id=message_thread_id, + **kwargs + ) \ No newline at end of file diff --git a/telebot/states/sync/middleware.py b/telebot/states/sync/middleware.py new file mode 100644 index 000000000..e1c74cc23 --- /dev/null +++ b/telebot/states/sync/middleware.py @@ -0,0 +1,17 @@ +from telebot.handler_backends import BaseMiddleware +from telebot import TeleBot +from telebot.states.sync.context import StateContext + + +class StateMiddleware(BaseMiddleware): + + def __init__(self, bot: TeleBot) -> None: + self.update_sensitive = False + self.update_types = ['message', 'edited_message', 'callback_query'] #TODO: support other types + self.bot: TeleBot = bot + + def pre_process(self, message, data): + data['state_context'] = StateContext(message, self.bot) + + def post_process(self, message, data, exception): + pass \ No newline at end of file From 4a7bc5d5006b664848feff84b5200c361b3e170b Mon Sep 17 00:00:00 2001 From: _run Date: Mon, 8 Jul 2024 21:47:38 +0500 Subject: [PATCH 03/24] all update types are supported for states(theoretically) --- telebot/custom_filters.py | 48 +++++++++---------------------- telebot/states/__init__.py | 41 +++++++++++++++++++++++++- telebot/states/sync/context.py | 37 +++++++----------------- telebot/states/sync/middleware.py | 6 ++-- tests/test_handler_backends.py | 2 +- 5 files changed, 69 insertions(+), 65 deletions(-) diff --git a/telebot/custom_filters.py b/telebot/custom_filters.py index 41a7c780c..1c5bd0b40 100644 --- a/telebot/custom_filters.py +++ b/telebot/custom_filters.py @@ -4,7 +4,7 @@ from telebot import types - +from telebot.states import resolve_context @@ -407,17 +407,7 @@ def check(self, message, text): """ if text == '*': return True - # needs to work with callbackquery - if isinstance(message, types.Message): - chat_id = message.chat.id - user_id = message.from_user.id - - if isinstance(message, types.CallbackQuery): - - chat_id = message.message.chat.id - user_id = message.from_user.id - message = message.message - + chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(message, self.bot._user.id) if isinstance(text, list): new_text = [] @@ -428,29 +418,17 @@ def check(self, message, text): elif isinstance(text, State): text = text.name - if message.chat.type in ['group', 'supergroup']: - group_state = self.bot.current_states.get_state(chat_id=chat_id, user_id=user_id, business_connection_id=message.business_connection_id, bot_id=self.bot._user.id, - message_thread_id=message.message_thread_id) - if group_state is None and not message.is_topic_message: # needed for general topic and group messages - group_state = self.bot.current_states.get_state(chat_id=chat_id, user_id=user_id, business_connection_id=message.business_connection_id, bot_id=self.bot._user.id) - - if group_state == text: - return True - elif type(text) is list and group_state in text: - return True - - - else: - user_state = self.bot.current_states.get_state( - chat_id=chat_id, - user_id=user_id, - business_connection_id=message.business_connection_id, - bot_id=self.bot._user.id - ) - if user_state == text: - return True - elif type(text) is list and user_state in text: - return True + user_state = self.bot.current_states.get_state( + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + bot_id=bot_id, + message_thread_id=message_thread_id + ) + if user_state == text: + return True + elif type(text) is list and user_state in text: + return True class IsDigitFilter(SimpleCustomFilter): diff --git a/telebot/states/__init__.py b/telebot/states/__init__.py index 0a45f17e1..b8efe3bbe 100644 --- a/telebot/states/__init__.py +++ b/telebot/states/__init__.py @@ -1,7 +1,7 @@ """ Contains classes for states and state groups. """ - +from telebot import types class State: """ @@ -41,3 +41,42 @@ def __init_subclass__(cls) -> None: @classmethod def state_list(self): return self._state_list + +def resolve_context(message, bot_id: int) -> tuple: + # chat_id, user_id, business_connection_id, bot_id, message_thread_id + + # message, edited_message, channel_post, edited_channel_post, business_message, edited_business_message + if isinstance(message, types.Message): + return (message.chat.id, message.from_user.id, message.business_connection_id, bot_id, + message.message_thread_id if message.is_topic_message else None) + elif isinstance(message, types.CallbackQuery): # callback_query + return (message.message.chat.id, message.from_user.id, message.message.business_connection_id, bot_id, + message.message.message_thread_id if message.message.is_topic_message else None) + elif isinstance(message, types.BusinessConnection): # business_connection + return (message.user_chat_id, message.user.id, message.id, bot_id, None) + elif isinstance(message, types.BusinessMessagesDeleted): # deleted_business_messages + return (message.chat.id, message.chat.id, message.business_connection_id, bot_id, None) + elif isinstance(message, types.MessageReactionUpdated): # message_reaction + return (message.chat.id, message.user.id, None, bot_id, None) + elif isinstance(message, types.MessageReactionCountUpdated): # message_reaction_count + return (message.chat.id, None, None, bot_id, None) + elif isinstance(message, types.InlineQuery): # inline_query + return (None, message.from_user.id, None, bot_id, None) + elif isinstance(message, types.ChosenInlineResult): # chosen_inline_result + return (None, message.from_user.id, None, bot_id, None) + elif isinstance(message, types.ShippingQuery): # shipping_query + return (None, message.from_user.id, None, bot_id, None) + elif isinstance(message, types.PreCheckoutQuery): # pre_checkout_query + return (None, message.from_user.id, None, bot_id, None) + elif isinstance(message, types.PollAnswer): # poll_answer + return (None, message.user.id, None, bot_id, None) + elif isinstance(message, types.ChatMemberUpdated): # chat_member # my_chat_member + return (message.chat.id, message.from_user.id, None, bot_id, None) + elif isinstance(message, types.ChatJoinRequest): # chat_join_request + return (message.chat.id, message.from_user.id, None, bot_id, None) + elif isinstance(message, types.ChatBoostRemoved): # removed_chat_boost + return (message.chat.id, message.source.user.id if message.source else None, None, bot_id, None) + elif isinstance(message, types.ChatBoostUpdated): # chat_boost + return (message.chat.id, message.boost.source.user.id if message.boost.source else None, None, bot_id, None) + else: + pass # not yet supported :( \ No newline at end of file diff --git a/telebot/states/sync/context.py b/telebot/states/sync/context.py index f0a395d01..c8009007a 100644 --- a/telebot/states/sync/context.py +++ b/telebot/states/sync/context.py @@ -1,10 +1,12 @@ from telebot.states import State, StatesGroup from telebot.types import CallbackQuery, Message -from telebot import TeleBot +from telebot import TeleBot, types +from telebot.states import resolve_context from typing import Union + class StateContext(): """ Class representing a state context. @@ -25,24 +27,6 @@ def __init__(self, message: Union[Message, CallbackQuery], bot: str) -> None: self.bot: TeleBot = bot self.bot_id = self.bot.bot_id - def _resolve_context(self) -> Union[Message, CallbackQuery]: - chat_id = None - user_id = None - business_connection_id = self.message.business_connection_id - bot_id = self.bot_id - message_thread_id = None - - if isinstance(self.message, Message): - chat_id = self.message.chat.id - user_id = self.message.from_user.id - message_thread_id = self.message.message_thread_id if self.message.is_topic_message else None - elif isinstance(self.message, CallbackQuery): - chat_id = self.message.message.chat.id - user_id = self.message.from_user.id - message_thread_id = self.message.message.message_thread_id if self.message.message.is_topic_message else None - - return chat_id, user_id, business_connection_id, bot_id, message_thread_id - def set(self, state: Union[State, str]) -> None: """ Set state for current user. @@ -58,7 +42,7 @@ def start_ex(message: types.Message, state_context: StateContext): bot.send_message(message.chat.id, 'Hi, write me a name', reply_to_message_id=message.message_id) """ - chat_id, user_id, business_connection_id, bot_id, message_thread_id = self._resolve_context() + chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) if isinstance(state, State): state = state.name return self.bot.set_state( @@ -78,7 +62,7 @@ def get(self) -> str: :rtype: str """ - chat_id, user_id, business_connection_id, bot_id, message_thread_id = self._resolve_context() + chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) return self.bot.get_state( chat_id=chat_id, user_id=user_id, @@ -93,7 +77,7 @@ def reset_data(self) -> None: State will not be changed. """ - chat_id, user_id, business_connection_id, bot_id, message_thread_id = self._resolve_context() + chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) return self.bot.reset_data( chat_id=chat_id, user_id=user_id, @@ -106,7 +90,7 @@ def delete(self) -> None: """ Deletes state and data for current user. """ - chat_id, user_id, business_connection_id, bot_id, message_thread_id = self._resolve_context() + chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) return self.bot.delete_state( chat_id=chat_id, user_id=user_id, @@ -125,7 +109,7 @@ def data(self) -> dict: print(data) """ - chat_id, user_id, business_connection_id, bot_id, message_thread_id = self._resolve_context() + chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) return self.bot.retrieve_data( chat_id=chat_id, user_id=user_id, @@ -142,7 +126,7 @@ def add_data(self, **kwargs) -> None: :type kwargs: dict """ - chat_id, user_id, business_connection_id, bot_id, message_thread_id = self._resolve_context() + chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) return self.bot.add_data( chat_id=chat_id, user_id=user_id, @@ -150,4 +134,5 @@ def add_data(self, **kwargs) -> None: bot_id=bot_id, message_thread_id=message_thread_id, **kwargs - ) \ No newline at end of file + ) + \ No newline at end of file diff --git a/telebot/states/sync/middleware.py b/telebot/states/sync/middleware.py index e1c74cc23..b85f795da 100644 --- a/telebot/states/sync/middleware.py +++ b/telebot/states/sync/middleware.py @@ -1,17 +1,19 @@ from telebot.handler_backends import BaseMiddleware from telebot import TeleBot from telebot.states.sync.context import StateContext +from telebot.util import update_types +from telebot import types class StateMiddleware(BaseMiddleware): def __init__(self, bot: TeleBot) -> None: self.update_sensitive = False - self.update_types = ['message', 'edited_message', 'callback_query'] #TODO: support other types + self.update_types = update_types self.bot: TeleBot = bot def pre_process(self, message, data): data['state_context'] = StateContext(message, self.bot) def post_process(self, message, data, exception): - pass \ No newline at end of file + pass diff --git a/tests/test_handler_backends.py b/tests/test_handler_backends.py index bb541bf8a..f57200c1c 100644 --- a/tests/test_handler_backends.py +++ b/tests/test_handler_backends.py @@ -19,7 +19,7 @@ @pytest.fixture() def telegram_bot(): - return telebot.TeleBot('', threaded=False) + return telebot.TeleBot('', threaded=False, token_check=False) @pytest.fixture From 676597cf6c8353b4569d17ddc29e235049f05cf7 Mon Sep 17 00:00:00 2001 From: _run Date: Fri, 12 Jul 2024 17:16:10 +0500 Subject: [PATCH 04/24] added redis support(not fully tested) --- telebot/custom_filters.py | 4 + telebot/storage/redis_storage.py | 300 ++++++++++++++----------------- 2 files changed, 136 insertions(+), 168 deletions(-) diff --git a/telebot/custom_filters.py b/telebot/custom_filters.py index dbba91d9c..e53c669a1 100644 --- a/telebot/custom_filters.py +++ b/telebot/custom_filters.py @@ -407,6 +407,9 @@ def check(self, message, text): chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(message, self.bot._user.id) + if chat_id is None: + chat_id = user_id # May change in future + if isinstance(text, list): new_text = [] for i in text: @@ -423,6 +426,7 @@ def check(self, message, text): bot_id=bot_id, message_thread_id=message_thread_id ) + if user_state == text: return True elif type(text) is list and user_state in text: diff --git a/telebot/storage/redis_storage.py b/telebot/storage/redis_storage.py index 3fac57c46..9c52a3fa1 100644 --- a/telebot/storage/redis_storage.py +++ b/telebot/storage/redis_storage.py @@ -1,183 +1,147 @@ -from telebot.storage.base_storage import StateStorageBase, StateContext import json - -redis_installed = True -try: - from redis import Redis, ConnectionPool - -except: - redis_installed = False +import redis +from telebot.storage.base_storage import StateStorageBase, StateContext +from typing import Optional, Union class StateRedisStorage(StateStorageBase): - """ - This class is for Redis storage. - This will work only for states. - To use it, just pass this class to: - TeleBot(storage=StateRedisStorage()) - """ - def __init__(self, host='localhost', port=6379, db=0, password=None, prefix='telebot_', redis_url=None): - super().__init__() + def __init__(self, host='localhost', port=6379, db=0, password=None, + prefix='telebot', + redis_url=None, + connection_pool: redis.ConnectionPool=None, + separator: Optional[str]=":", + ) -> None: + self.separator = separator + self.prefix = prefix + if not self.prefix: + raise ValueError("Prefix cannot be empty") + if redis_url: - self.redis = ConnectionPool.from_url(redis_url) + self.redis = redis.Redis.from_url(redis_url) + elif connection_pool: + self.redis = redis.Redis(connection_pool=connection_pool) else: - self.redis = ConnectionPool(host=host, port=port, db=db, password=password) - #self.con = Redis(connection_pool=self.redis) -> use this when necessary - # - # {chat_id: {user_id: {'state': None, 'data': {}}, ...}, ...} - self.prefix = prefix - if not redis_installed: - raise Exception("Redis is not installed. Install it via 'pip install redis'") + self.redis = redis.Redis(host=host, port=port, db=db, password=password) - def get_record(self, key): - """ - Function to get record from database. - It has nothing to do with states. - Made for backward compatibility - """ - connection = Redis(connection_pool=self.redis) - result = connection.get(self.prefix+str(key)) - connection.close() - if result: return json.loads(result) - return - - def set_record(self, key, value): - """ - Function to set record to database. - It has nothing to do with states. - Made for backward compatibility - """ - connection = Redis(connection_pool=self.redis) - connection.set(self.prefix+str(key), json.dumps(value)) - connection.close() + + def set_state( + self, chat_id: int, user_id: int, state: str, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None + ) -> bool: + if hasattr(state, "name"): + state = state.name + + _key = self.convert_params_to_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + + def set_state_action(pipe): + pipe.multi() + #pipe.hset(_key, mapping={"state": state, "data": "{}"}) + pipe.hset(_key, "state", state) + + self.redis.transaction(set_state_action, _key) return True - def delete_record(self, key): - """ - Function to delete record from database. - It has nothing to do with states. - Made for backward compatibility - """ - connection = Redis(connection_pool=self.redis) - connection.delete(self.prefix+str(key)) - connection.close() + def get_state( + self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None + ) -> Union[str, None]: + _key = self.convert_params_to_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + state_bytes = self.redis.hget(_key, "state") + return state_bytes.decode('utf-8') if state_bytes else None + + def delete_state( + self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None + ) -> bool: + _key = self.convert_params_to_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + return self.redis.delete(_key) > 0 + + def set_data( + self, chat_id: int, user_id: int, key: str, value: Union[str, int, float, dict], + business_connection_id: Optional[str] = None, message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None + ) -> bool: + _key = self.convert_params_to_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + + def set_data_action(pipe): + pipe.multi() + data = pipe.hget(_key, "data") + data = data.execute()[0] + if data is None: + pipe.hset(_key, "data", json.dumps({key: value})) + else: + data = json.loads(data) + data[key] = value + pipe.hset(_key, "data", json.dumps(data)) + + self.redis.transaction(set_data_action, _key) return True - def set_state(self, chat_id, user_id, state): - """ - Set state for a particular user in a chat. - """ - response = self.get_record(chat_id) - user_id = str(user_id) - if hasattr(state, 'name'): - state = state.name + def get_data( + self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None + ) -> dict: + _key = self.convert_params_to_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + data = self.redis.hget(_key, "data") + return json.loads(data) if data else {} + + def reset_data( + self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None + ) -> bool: + _key = self.convert_params_to_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + + def reset_data_action(pipe): + pipe.multi() + if pipe.exists(_key): + pipe.hset(_key, "data", "{}") + else: + return False + + self.redis.transaction(reset_data_action, _key) + return True - if response: - if user_id in response: - response[user_id]['state'] = state + def get_interactive_data( + self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None + ) -> Optional[dict]: + return StateContext( + self, chat_id=chat_id, user_id=user_id, business_connection_id=business_connection_id, + message_thread_id=message_thread_id, bot_id=bot_id + ) + + def save( + self, chat_id: int, user_id: int, data: dict, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None + ) -> bool: + _key = self.convert_params_to_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, + message_thread_id, bot_id + ) + + def save_action(pipe): + pipe.multi() + if pipe.exists(_key): + pipe.hset(_key, "data", json.dumps(data)) else: - response[user_id] = {'state': state, 'data': {}} - else: - response = {user_id: {'state': state, 'data': {}}} - self.set_record(chat_id, response) + return False + self.redis.transaction(save_action, _key) return True - - def delete_state(self, chat_id, user_id): - """ - Delete state for a particular user in a chat. - """ - response = self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - del response[user_id] - if user_id == str(chat_id): - self.delete_record(chat_id) - return True - else: self.set_record(chat_id, response) - return True - return False - - - def get_value(self, chat_id, user_id, key): - """ - Get value for a data of a user in a chat. - """ - response = self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - if key in response[user_id]['data']: - return response[user_id]['data'][key] - return None - - def get_state(self, chat_id, user_id): - """ - Get state of a user in a chat. - """ - response = self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - return response[user_id]['state'] - - return None - - - def get_data(self, chat_id, user_id): - """ - Get data of particular user in a particular chat. - """ - response = self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - return response[user_id]['data'] - return None - - - def reset_data(self, chat_id, user_id): - """ - Reset data of a user in a chat. - """ - response = self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - response[user_id]['data'] = {} - self.set_record(chat_id, response) - return True - - - - - def set_data(self, chat_id, user_id, key, value): - """ - Set data without interactive data. - """ - response = self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - response[user_id]['data'][key] = value - self.set_record(chat_id, response) - return True - return False - - def get_interactive_data(self, chat_id, user_id): - """ - Get Data in interactive way. - You can use with() with this function. - """ - return StateContext(self, chat_id, user_id) - - def save(self, chat_id, user_id, data): - response = self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - response[user_id]['data'] = data - self.set_record(chat_id, response) - return True - + def __str__(self) -> str: + keys = self.redis.keys(f"{self.prefix}{self.separator}*") + data = {key.decode(): self.redis.hgetall(key) for key in keys} + return f"" From 5bd42715582f3ebee6edd1da0dccd89dca0b1981 Mon Sep 17 00:00:00 2001 From: _run Date: Fri, 12 Jul 2024 17:28:14 +0500 Subject: [PATCH 05/24] Added redis dependency check --- telebot/storage/redis_storage.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/telebot/storage/redis_storage.py b/telebot/storage/redis_storage.py index 9c52a3fa1..d32ab8f8e 100644 --- a/telebot/storage/redis_storage.py +++ b/telebot/storage/redis_storage.py @@ -1,8 +1,13 @@ import json -import redis from telebot.storage.base_storage import StateStorageBase, StateContext from typing import Optional, Union +redis_installed = True +try: + import redis +except ImportError: + redis_installed = False + class StateRedisStorage(StateStorageBase): def __init__(self, host='localhost', port=6379, db=0, password=None, prefix='telebot', @@ -10,6 +15,10 @@ def __init__(self, host='localhost', port=6379, db=0, password=None, connection_pool: redis.ConnectionPool=None, separator: Optional[str]=":", ) -> None: + + if not redis_installed: + raise ImportError("Redis is not installed. Please install it via pip install redis") + self.separator = separator self.prefix = prefix if not self.prefix: From a79fd77cba36073f7d70bcbff5240c0231ee8f3c Mon Sep 17 00:00:00 2001 From: _run Date: Fri, 12 Jul 2024 17:29:56 +0500 Subject: [PATCH 06/24] fix test --- telebot/storage/redis_storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/telebot/storage/redis_storage.py b/telebot/storage/redis_storage.py index d32ab8f8e..39c75a49b 100644 --- a/telebot/storage/redis_storage.py +++ b/telebot/storage/redis_storage.py @@ -12,7 +12,7 @@ class StateRedisStorage(StateStorageBase): def __init__(self, host='localhost', port=6379, db=0, password=None, prefix='telebot', redis_url=None, - connection_pool: redis.ConnectionPool=None, + connection_pool: 'redis.ConnectionPool'=None, separator: Optional[str]=":", ) -> None: From c29bf0e5358da3b21c519007fea64038745cdd2a Mon Sep 17 00:00:00 2001 From: _run Date: Fri, 12 Jul 2024 21:51:34 +0500 Subject: [PATCH 07/24] Added pickle and renamed method --- telebot/storage/base_storage.py | 2 +- telebot/storage/memory_storage.py | 14 +-- telebot/storage/pickle_storage.py | 195 ++++++++++++++++-------------- telebot/storage/redis_storage.py | 14 +-- 4 files changed, 121 insertions(+), 104 deletions(-) diff --git a/telebot/storage/base_storage.py b/telebot/storage/base_storage.py index 36545a77c..2358e93c9 100644 --- a/telebot/storage/base_storage.py +++ b/telebot/storage/base_storage.py @@ -48,7 +48,7 @@ def get_interactive_data(self, chat_id, user_id): def save(self, chat_id, user_id, data): raise NotImplementedError - def convert_params_to_key( + def _get_key( self, chat_id: int, user_id: int, diff --git a/telebot/storage/memory_storage.py b/telebot/storage/memory_storage.py index fbf9eebda..547445d0c 100644 --- a/telebot/storage/memory_storage.py +++ b/telebot/storage/memory_storage.py @@ -21,7 +21,7 @@ def set_state( if hasattr(state, "name"): state = state.name - _key = self.convert_params_to_key( + _key = self._get_key( chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id ) @@ -37,7 +37,7 @@ def get_state( message_thread_id: Optional[int]=None, bot_id: Optional[int]=None ) -> Union[str, None]: - _key = self.convert_params_to_key( + _key = self._get_key( chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id ) @@ -50,7 +50,7 @@ def delete_state( self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, message_thread_id: Optional[int]=None, bot_id: Optional[int]=None ) -> bool: - _key = self.convert_params_to_key( + _key = self._get_key( chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id ) @@ -66,7 +66,7 @@ def set_data( business_connection_id: Optional[str]=None, message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> bool: - _key = self.convert_params_to_key( + _key = self._get_key( chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id ) @@ -81,7 +81,7 @@ def get_data( message_thread_id: Optional[int]=None, bot_id: Optional[int]=None ) -> dict: - _key = self.convert_params_to_key( + _key = self._get_key( chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id ) @@ -92,7 +92,7 @@ def reset_data( message_thread_id: Optional[int]=None, bot_id: Optional[int]=None ) -> bool: - _key = self.convert_params_to_key( + _key = self._get_key( chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id ) @@ -114,7 +114,7 @@ def save( self, chat_id: int, user_id: int, data: dict, business_connection_id: Optional[str]=None, message_thread_id: Optional[int]=None, bot_id: Optional[int]=None ) -> bool: - _key = self.convert_params_to_key( + _key = self._get_key( chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id ) diff --git a/telebot/storage/pickle_storage.py b/telebot/storage/pickle_storage.py index 68c9fbed5..64ed7ce85 100644 --- a/telebot/storage/pickle_storage.py +++ b/telebot/storage/pickle_storage.py @@ -1,34 +1,26 @@ -from telebot.storage.base_storage import StateStorageBase, StateContext import os - import pickle - +import threading +from typing import Optional, Union +from telebot.storage.base_storage import StateStorageBase, StateContext class StatePickleStorage(StateStorageBase): - def __init__(self, file_path="./.state-save/states.pkl") -> None: - super().__init__() + def __init__(self, file_path: str="./.state-save/states.pkl", + prefix='telebot', separator: Optional[str]=":") -> None: self.file_path = file_path + self.prefix = prefix + self.separator = separator + self.lock = threading.Lock() + self.create_dir() - self.data = self.read() - def convert_old_to_new(self): - """ - Use this function to convert old storage to new storage. - This function is for people who was using pickle storage - that was in version <=4.3.1. - """ - # old looks like: - # {1: {'state': 'start', 'data': {'name': 'John'}} - # we should update old version pickle to new. - # new looks like: - # {1: {2: {'state': 'start', 'data': {'name': 'John'}}}} - new_data = {} - for key, value in self.data.items(): - # this returns us id and dict with data and state - new_data[key] = {key: value} # convert this to new - # pass it to global data - self.data = new_data - self.update_data() # update data in file + def _read_from_file(self) -> dict: + with open(self.file_path, 'rb') as f: + return pickle.load(f) + + def _write_to_file(self, data: dict) -> None: + with open(self.file_path, 'wb') as f: + pickle.dump(data, f) def create_dir(self): """ @@ -40,77 +32,102 @@ def create_dir(self): with open(self.file_path,'wb') as file: pickle.dump({}, file) - def read(self): - file = open(self.file_path, 'rb') - data = pickle.load(file) - file.close() - return data - - def update_data(self): - file = open(self.file_path, 'wb+') - pickle.dump(self.data, file, protocol=pickle.HIGHEST_PROTOCOL) - file.close() - - def set_state(self, chat_id, user_id, state): - if hasattr(state, 'name'): - state = state.name - if chat_id in self.data: - if user_id in self.data[chat_id]: - self.data[chat_id][user_id]['state'] = state - self.update_data() - return True + def set_state(self, chat_id: int, user_id: int, state: str, + business_connection_id: Optional[str]=None, message_thread_id: Optional[int]=None, + bot_id: Optional[int]=None) -> bool: + _key = self._get_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + with self.lock: + data = self._read_from_file() + if _key not in data: + data[_key] = {"state": state, "data": {}} else: - self.data[chat_id][user_id] = {'state': state, 'data': {}} - self.update_data() - return True - self.data[chat_id] = {user_id: {'state': state, 'data': {}}} - self.update_data() + data[_key]["state"] = state + self._write_to_file(data) return True - - def delete_state(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - del self.data[chat_id][user_id] - if chat_id == user_id: - del self.data[chat_id] - self.update_data() - return True - - return False - - def get_state(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - return self.data[chat_id][user_id]['state'] - - return None - def get_data(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - return self.data[chat_id][user_id]['data'] + def get_state(self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> Union[str, None]: + _key = self._get_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + with self.lock: + data = self._read_from_file() + return data.get(_key, {}).get("state") - return None - - def reset_data(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - self.data[chat_id][user_id]['data'] = {} - self.update_data() + def delete_state(self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> bool: + _key = self._get_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + with self.lock: + data = self._read_from_file() + if _key in data: + del data[_key] + self._write_to_file(data) return True - return False + return False - def set_data(self, chat_id, user_id, key, value): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - self.data[chat_id][user_id]['data'][key] = value - self.update_data() + def set_data(self, chat_id: int, user_id: int, key: str, value: Union[str, int, float, dict], + business_connection_id: Optional[str]=None, message_thread_id: Optional[int]=None, + bot_id: Optional[int]=None) -> bool: + _key = self._get_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + with self.lock: + data = self._read_from_file() + state_data = data.get(_key, {}) + state_data["data"][key] = value + if _key not in data: + data[_key] = {"state": None, "data": state_data} + else: + data[_key]["data"][key] = value + self._write_to_file(data) + return True + + def get_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> dict: + _key = self._get_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + with self.lock: + data = self._read_from_file() + return data.get(_key, {}).get("data", {}) + + def reset_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> bool: + _key = self._get_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + with self.lock: + data = self._read_from_file() + if _key in data: + data[_key]["data"] = {} + self._write_to_file(data) return True - raise RuntimeError('chat_id {} and user_id {} does not exist'.format(chat_id, user_id)) + return False - def get_interactive_data(self, chat_id, user_id): - return StateContext(self, chat_id, user_id) + def get_interactive_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> Optional[dict]: + return StateContext( + self, chat_id=chat_id, user_id=user_id, business_connection_id=business_connection_id, + message_thread_id=message_thread_id, bot_id=bot_id + ) - def save(self, chat_id, user_id, data): - self.data[chat_id][user_id]['data'] = data - self.update_data() + def save(self, chat_id: int, user_id: int, data: dict, business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> bool: + _key = self._get_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + with self.lock: + data = self._read_from_file() + data[_key]["data"] = data + self._write_to_file(data) + return True + + def __str__(self) -> str: + with self.lock: + with open(self.file_path, 'rb') as f: + data = pickle.load(f) + return f"" diff --git a/telebot/storage/redis_storage.py b/telebot/storage/redis_storage.py index 39c75a49b..b16ea3016 100644 --- a/telebot/storage/redis_storage.py +++ b/telebot/storage/redis_storage.py @@ -40,7 +40,7 @@ def set_state( if hasattr(state, "name"): state = state.name - _key = self.convert_params_to_key( + _key = self._get_key( chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id ) @@ -56,7 +56,7 @@ def get_state( self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, message_thread_id: Optional[int] = None, bot_id: Optional[int] = None ) -> Union[str, None]: - _key = self.convert_params_to_key( + _key = self._get_key( chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id ) state_bytes = self.redis.hget(_key, "state") @@ -66,7 +66,7 @@ def delete_state( self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, message_thread_id: Optional[int] = None, bot_id: Optional[int] = None ) -> bool: - _key = self.convert_params_to_key( + _key = self._get_key( chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id ) return self.redis.delete(_key) > 0 @@ -76,7 +76,7 @@ def set_data( business_connection_id: Optional[str] = None, message_thread_id: Optional[int] = None, bot_id: Optional[int] = None ) -> bool: - _key = self.convert_params_to_key( + _key = self._get_key( chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id ) @@ -98,7 +98,7 @@ def get_data( self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, message_thread_id: Optional[int] = None, bot_id: Optional[int] = None ) -> dict: - _key = self.convert_params_to_key( + _key = self._get_key( chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id ) data = self.redis.hget(_key, "data") @@ -108,7 +108,7 @@ def reset_data( self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, message_thread_id: Optional[int] = None, bot_id: Optional[int] = None ) -> bool: - _key = self.convert_params_to_key( + _key = self._get_key( chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id ) @@ -135,7 +135,7 @@ def save( self, chat_id: int, user_id: int, data: dict, business_connection_id: Optional[str] = None, message_thread_id: Optional[int] = None, bot_id: Optional[int] = None ) -> bool: - _key = self.convert_params_to_key( + _key = self._get_key( chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id ) From 15bced9b67b27b2c54a44f1f060ac552d141d4ca Mon Sep 17 00:00:00 2001 From: _run Date: Wed, 17 Jul 2024 16:31:24 +0500 Subject: [PATCH 08/24] Fixed bugs, renamed StateContext for retrieve_data to StateDataContext to avoid conflicts, added support of statesv2 to async(only memory storage) --- telebot/__init__.py | 9 +- telebot/async_telebot.py | 137 +++++++++++++++--- telebot/asyncio_filters.py | 44 +++--- telebot/asyncio_handler_backends.py | 41 +----- telebot/asyncio_storage/__init__.py | 4 +- telebot/asyncio_storage/base_storage.py | 55 ++++++- telebot/asyncio_storage/memory_storage.py | 169 +++++++++++++++------- telebot/asyncio_storage/pickle_storage.py | 4 +- telebot/asyncio_storage/redis_storage.py | 4 +- telebot/states/__init__.py | 2 +- telebot/states/aio/__init__.py | 7 + telebot/states/aio/context.py | 138 ++++++++++++++++++ telebot/states/aio/middleware.py | 19 +++ telebot/storage/__init__.py | 4 +- telebot/storage/base_storage.py | 2 +- telebot/storage/memory_storage.py | 4 +- telebot/storage/pickle_storage.py | 4 +- telebot/storage/redis_storage.py | 4 +- 18 files changed, 485 insertions(+), 166 deletions(-) create mode 100644 telebot/states/aio/context.py create mode 100644 telebot/states/aio/middleware.py diff --git a/telebot/__init__.py b/telebot/__init__.py index 31370d3e3..08ae9fc6f 100644 --- a/telebot/__init__.py +++ b/telebot/__init__.py @@ -1180,6 +1180,9 @@ def polling(self, non_stop: Optional[bool]=False, skip_pending: Optional[bool]=F if restart_on_change: self._setup_change_detector(path_to_watch) + if not self._user: + self._user = self.get_me() + logger.info('Starting your bot with username: [@%s]', self.user.username) if self.threaded: @@ -6678,7 +6681,7 @@ def set_state(self, user_id: int, state: Union[int, str, State], chat_id: Option chat_id = user_id if bot_id is None: bot_id = self.bot_id - self.current_states.set_state( + return self.current_states.set_state( chat_id=chat_id, user_id=user_id, state=state, bot_id=bot_id, business_connection_id=business_connection_id, message_thread_id=message_thread_id) @@ -6710,7 +6713,7 @@ def reset_data(self, user_id: int, chat_id: Optional[int]=None, chat_id = user_id if bot_id is None: bot_id = self.bot_id - self.current_states.reset_data(chat_id=chat_id, user_id=user_id, bot_id=bot_id, + return self.current_states.reset_data(chat_id=chat_id, user_id=user_id, bot_id=bot_id, business_connection_id=business_connection_id, message_thread_id=message_thread_id) @@ -6731,7 +6734,7 @@ def delete_state(self, user_id: int, chat_id: Optional[int]=None, business_conne chat_id = user_id if bot_id is None: bot_id = self.bot_id - self.current_states.delete_state(chat_id=chat_id, user_id=user_id, bot_id=bot_id, + return self.current_states.delete_state(chat_id=chat_id, user_id=user_id, bot_id=bot_id, business_connection_id=business_connection_id, message_thread_id=message_thread_id) diff --git a/telebot/async_telebot.py b/telebot/async_telebot.py index d8dd91e05..2d443cb5b 100644 --- a/telebot/async_telebot.py +++ b/telebot/async_telebot.py @@ -4,7 +4,7 @@ import logging import re import traceback -from typing import Any, Awaitable, Callable, List, Optional, Union +from typing import Any, Awaitable, Callable, List, Optional, Union, Dict import sys # this imports are used to avoid circular import error @@ -18,7 +18,7 @@ from inspect import signature, iscoroutinefunction -from telebot import util, types, asyncio_helper +from telebot import util, types, asyncio_helper, apihelper # have to use sync import asyncio from telebot import asyncio_filters @@ -117,6 +117,8 @@ class AsyncTeleBot: :param colorful_logs: Outputs colorful logs :type colorful_logs: :obj:`bool`, optional + :param token_check: Check token on start + :type token_check: :obj:`bool`, optional, defaults to True """ def __init__(self, token: str, parse_mode: Optional[str]=None, offset: Optional[int]=None, @@ -126,7 +128,8 @@ def __init__(self, token: str, parse_mode: Optional[str]=None, offset: Optional[ disable_notification: Optional[bool]=None, protect_content: Optional[bool]=None, allow_sending_without_reply: Optional[bool]=None, - colorful_logs: Optional[bool]=False) -> None: + colorful_logs: Optional[bool]=False, + token_check: Optional[bool]=True) -> None: # update-related self.token = token @@ -183,6 +186,14 @@ def __init__(self, token: str, parse_mode: Optional[str]=None, offset: Optional[ self.middlewares = [] self._user = None # set during polling + self.bot_id = None + + if token_check: + result = apihelper.get_me(token) + self._user = types.User.de_json(result) + self.bot_id = self._user.id + + @property def user(self): @@ -424,7 +435,8 @@ async def _process_polling(self, non_stop: bool=False, interval: int=0, timeout: # show warning logger.warning("Setting non_stop to False will stop polling on API and system exceptions.") - self._user = await self.get_me() + if not self._user: + self._user = await self.get_me() logger.info('Starting your bot with username: [@%s]', self.user.username) @@ -7831,7 +7843,10 @@ async def get_forum_topic_icon_stickers(self) -> List[types.Sticker]: """ return await asyncio_helper.get_forum_topic_icon_stickers(self.token) - async def set_state(self, user_id: int, state: Union[State, int, str], chat_id: Optional[int]=None): + + async def set_state(self, user_id: int, state: Union[int, str, State], chat_id: Optional[int]=None, + business_connection_id: Optional[str]=None, message_thread_id: Optional[int]=None, + bot_id: Optional[int]=None) -> None: """ Sets a new state of a user. @@ -7850,13 +7865,29 @@ async def set_state(self, user_id: int, state: Union[State, int, str], chat_id: :param chat_id: Chat's identifier :type chat_id: :obj:`int` + :param bot_id: Bot's identifier + :type bot_id: :obj:`int` + + :param business_connection_id: Business identifier + :type business_connection_id: :obj:`str` + + :param message_thread_id: Identifier of the message thread + :type message_thread_id: :obj:`int` + :return: None """ - if not chat_id: + if chat_id is None: chat_id = user_id - await self.current_states.set_state(chat_id, user_id, state) + if bot_id is None: + bot_id = self.bot_id + return await self.current_states.set_state( + chat_id=chat_id, user_id=user_id, state=state, bot_id=bot_id, + business_connection_id=business_connection_id, message_thread_id=message_thread_id) - async def reset_data(self, user_id: int, chat_id: Optional[int]=None): + + async def reset_data(self, user_id: int, chat_id: Optional[int]=None, + business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> None: """ Reset data for a user in chat. @@ -7866,13 +7897,27 @@ async def reset_data(self, user_id: int, chat_id: Optional[int]=None): :param chat_id: Chat's identifier :type chat_id: :obj:`int` + :param bot_id: Bot's identifier + :type bot_id: :obj:`int` + + :param business_connection_id: Business identifier + :type business_connection_id: :obj:`str` + + :param message_thread_id: Identifier of the message thread + :type message_thread_id: :obj:`int` + :return: None """ if chat_id is None: chat_id = user_id - await self.current_states.reset_data(chat_id, user_id) + if bot_id is None: + bot_id = self.bot_id + return await self.current_states.reset_data(chat_id=chat_id, user_id=user_id, bot_id=bot_id, + business_connection_id=business_connection_id, message_thread_id=message_thread_id) + - async def delete_state(self, user_id: int, chat_id: Optional[int]=None): + async def delete_state(self, user_id: int, chat_id: Optional[int]=None, business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> None: """ Delete the current state of a user. @@ -7884,11 +7929,16 @@ async def delete_state(self, user_id: int, chat_id: Optional[int]=None): :return: None """ - if not chat_id: + if chat_id is None: chat_id = user_id - await self.current_states.delete_state(chat_id, user_id) + if bot_id is None: + bot_id = self.bot_id + return await self.current_states.delete_state(chat_id=chat_id, user_id=user_id, bot_id=bot_id, + business_connection_id=business_connection_id, message_thread_id=message_thread_id) + - def retrieve_data(self, user_id: int, chat_id: Optional[int]=None): + def retrieve_data(self, user_id: int, chat_id: Optional[int]=None, business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> Optional[Dict[str, Any]]: """ Returns context manager with data for a user in chat. @@ -7898,14 +7948,30 @@ def retrieve_data(self, user_id: int, chat_id: Optional[int]=None): :param chat_id: Chat's unique identifier, defaults to user_id :type chat_id: int, optional + :param bot_id: Bot's identifier + :type bot_id: int, optional + + :param business_connection_id: Business identifier + :type business_connection_id: str, optional + + :param message_thread_id: Identifier of the message thread + :type message_thread_id: int, optional + :return: Context manager with data for a user in chat :rtype: Optional[Any] """ - if not chat_id: + if chat_id is None: chat_id = user_id - return self.current_states.get_interactive_data(chat_id, user_id) + if bot_id is None: + bot_id = self.bot_id + return self.current_states.get_interactive_data(chat_id=chat_id, user_id=user_id, bot_id=bot_id, + business_connection_id=business_connection_id, + message_thread_id=message_thread_id) + - async def get_state(self, user_id, chat_id: Optional[int]=None): + async def get_state(self, user_id: int, chat_id: Optional[int]=None, + business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> Union[int, str]: """ Gets current state of a user. Not recommended to use this method. But it is ok for debugging. @@ -7916,14 +7982,31 @@ async def get_state(self, user_id, chat_id: Optional[int]=None): :param chat_id: Chat's identifier :type chat_id: :obj:`int` + :param bot_id: Bot's identifier + :type bot_id: :obj:`int` + + :param business_connection_id: Business identifier + :type business_connection_id: :obj:`str` + + :param message_thread_id: Identifier of the message thread + :type message_thread_id: :obj:`int` + :return: state of a user :rtype: :obj:`int` or :obj:`str` or :class:`telebot.types.State` """ - if not chat_id: + if chat_id is None: chat_id = user_id - return await self.current_states.get_state(chat_id, user_id) + if bot_id is None: + bot_id = self.bot_id + return await self.current_states.get_state(chat_id=chat_id, user_id=user_id, bot_id=bot_id, + business_connection_id=business_connection_id, message_thread_id=message_thread_id) + - async def add_data(self, user_id: int, chat_id: Optional[int]=None, **kwargs): + async def add_data(self, user_id: int, chat_id: Optional[int]=None, + business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, + bot_id: Optional[int]=None, + **kwargs) -> None: """ Add data to states. @@ -7933,10 +8016,22 @@ async def add_data(self, user_id: int, chat_id: Optional[int]=None, **kwargs): :param chat_id: Chat's identifier :type chat_id: :obj:`int` + :param bot_id: Bot's identifier + :type bot_id: :obj:`int` + + :param business_connection_id: Business identifier + :type business_connection_id: :obj:`str` + + :param message_thread_id: Identifier of the message thread + :type message_thread_id: :obj:`int` + :param kwargs: Data to add :return: None """ - if not chat_id: + if chat_id is None: chat_id = user_id + if bot_id is None: + bot_id = self.bot_id for key, value in kwargs.items(): - await self.current_states.set_data(chat_id, user_id, key, value) + await self.current_states.set_data(chat_id=chat_id, user_id=user_id, key=key, value=value, bot_id=bot_id, + business_connection_id=business_connection_id, message_thread_id=message_thread_id) diff --git a/telebot/asyncio_filters.py b/telebot/asyncio_filters.py index da794b77b..389ce1643 100644 --- a/telebot/asyncio_filters.py +++ b/telebot/asyncio_filters.py @@ -3,6 +3,7 @@ from telebot.asyncio_handler_backends import State from telebot import types +from telebot.states import resolve_context class SimpleCustomFilter(ABC): @@ -397,18 +398,11 @@ async def check(self, message, text): :meta private: """ if text == '*': return True + + chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(message, self.bot._user.id) - # needs to work with callbackquery - if isinstance(message, types.Message): - chat_id = message.chat.id - user_id = message.from_user.id - - if isinstance(message, types.CallbackQuery): - - chat_id = message.message.chat.id - user_id = message.from_user.id - message = message.message - + if chat_id is None: + chat_id = user_id # May change in future if isinstance(text, list): new_text = [] @@ -418,21 +412,19 @@ async def check(self, message, text): text = new_text elif isinstance(text, State): text = text.name - - if message.chat.type in ['group', 'supergroup']: - group_state = await self.bot.current_states.get_state(chat_id, user_id) - if group_state == text: - return True - elif type(text) is list and group_state in text: - return True - - - else: - user_state = await self.bot.current_states.get_state(chat_id, user_id) - if user_state == text: - return True - elif type(text) is list and user_state in text: - return True + + user_state = await self.bot.current_states.get_state( + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + bot_id=bot_id, + message_thread_id=message_thread_id + ) + + if user_state == text: + return True + elif type(text) is list and user_state in text: + return True class IsDigitFilter(SimpleCustomFilter): diff --git a/telebot/asyncio_handler_backends.py b/telebot/asyncio_handler_backends.py index 0861a9893..6c96cc2d7 100644 --- a/telebot/asyncio_handler_backends.py +++ b/telebot/asyncio_handler_backends.py @@ -1,6 +1,7 @@ """ File with all middleware classes, states. """ +from telebot.states import State, StatesGroup class BaseMiddleware: @@ -48,46 +49,6 @@ async def post_process(self, message, data, exception): raise NotImplementedError -class State: - """ - Class representing a state. - - .. code-block:: python3 - - class MyStates(StatesGroup): - my_state = State() # returns my_state:State string. - """ - def __init__(self) -> None: - self.name = None - - def __str__(self) -> str: - return self.name - - -class StatesGroup: - """ - Class representing common states. - - .. code-block:: python3 - - class MyStates(StatesGroup): - my_state = State() # returns my_state:State string. - """ - def __init_subclass__(cls) -> None: - state_list = [] - for name, value in cls.__dict__.items(): - if not name.startswith('__') and not callable(value) and isinstance(value, State): - # change value of that variable - value.name = ':'.join((cls.__name__, name)) - value.group = cls - state_list.append(value) - cls._state_list = state_list - - @classmethod - def state_list(self): - return self._state_list - - class SkipHandler: """ Class for skipping handlers. diff --git a/telebot/asyncio_storage/__init__.py b/telebot/asyncio_storage/__init__.py index 892f0af94..1f9d51650 100644 --- a/telebot/asyncio_storage/__init__.py +++ b/telebot/asyncio_storage/__init__.py @@ -1,13 +1,13 @@ from telebot.asyncio_storage.memory_storage import StateMemoryStorage from telebot.asyncio_storage.redis_storage import StateRedisStorage from telebot.asyncio_storage.pickle_storage import StatePickleStorage -from telebot.asyncio_storage.base_storage import StateContext,StateStorageBase +from telebot.asyncio_storage.base_storage import StateDataContext, StateStorageBase __all__ = [ - 'StateStorageBase', 'StateContext', + 'StateStorageBase', 'StateDataContext', 'StateMemoryStorage', 'StateRedisStorage', 'StatePickleStorage' ] \ No newline at end of file diff --git a/telebot/asyncio_storage/base_storage.py b/telebot/asyncio_storage/base_storage.py index 38615c4c2..6d06e7bfc 100644 --- a/telebot/asyncio_storage/base_storage.py +++ b/telebot/asyncio_storage/base_storage.py @@ -1,6 +1,5 @@ import copy - class StateStorageBase: def __init__(self) -> None: pass @@ -42,27 +41,67 @@ async def reset_data(self, chat_id, user_id): async def get_state(self, chat_id, user_id): raise NotImplementedError - + + def get_interactive_data(self, chat_id, user_id): + """ + Should be sync, but should provide a context manager + with __aenter__ and __aexit__ methods. + """ + raise NotImplementedError + async def save(self, chat_id, user_id, data): raise NotImplementedError + + def _get_key( + self, + chat_id: int, + user_id: int, + prefix: str, + separator: str, + business_connection_id: str=None, + message_thread_id: int=None, + bot_id: int=None + ) -> str: + """ + Convert parameters to a key. + """ + params = [prefix] + if bot_id: + params.append(str(bot_id)) + if business_connection_id: + params.append(business_connection_id) + if message_thread_id: + params.append(str(message_thread_id)) + params.append(str(chat_id)) + params.append(str(user_id)) + + return separator.join(params) + + + + -class StateContext: +class StateDataContext: """ Class for data. """ - - def __init__(self, obj, chat_id, user_id): + def __init__(self , obj, chat_id, user_id, business_connection_id=None, message_thread_id=None, bot_id=None, ): self.obj = obj self.data = None self.chat_id = chat_id self.user_id = user_id + self.bot_id = bot_id + self.business_connection_id = business_connection_id + self.message_thread_id = message_thread_id + - async def __aenter__(self): - self.data = copy.deepcopy(await self.obj.get_data(self.chat_id, self.user_id)) + data = await self.obj.get_data(chat_id=self.chat_id, user_id=self.user_id, business_connection_id=self.business_connection_id, + message_thread_id=self.message_thread_id, bot_id=self.bot_id) + self.data = copy.deepcopy(data) return self.data async def __aexit__(self, exc_type, exc_val, exc_tb): - return await self.obj.save(self.chat_id, self.user_id, self.data) \ No newline at end of file + return await self.obj.save(self.chat_id, self.user_id, self.data, self.business_connection_id, self.message_thread_id, self.bot_id) \ No newline at end of file diff --git a/telebot/asyncio_storage/memory_storage.py b/telebot/asyncio_storage/memory_storage.py index 45c2ad914..e65ed74d4 100644 --- a/telebot/asyncio_storage/memory_storage.py +++ b/telebot/asyncio_storage/memory_storage.py @@ -1,66 +1,131 @@ -from telebot.asyncio_storage.base_storage import StateStorageBase, StateContext +from telebot.asyncio_storage.base_storage import StateStorageBase, StateDataContext +from typing import Optional, Union class StateMemoryStorage(StateStorageBase): - def __init__(self) -> None: - self.data = {} - # - # {chat_id: {user_id: {'state': None, 'data': {}}, ...}, ...} - - - async def set_state(self, chat_id, user_id, state): - if hasattr(state, 'name'): + def __init__(self, + separator: Optional[str]=":", + prefix: Optional[str]="telebot" + ) -> None: + self.separator = separator + self.prefix = prefix + if not self.prefix: + raise ValueError("Prefix cannot be empty") + + self.data = {} # key: telebot:bot_id:business_connection_id:message_thread_id:chat_id:user_id + + async def set_state( + self, chat_id: int, user_id: int, state: str, business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None + + ) -> bool: + if hasattr(state, "name"): state = state.name - if chat_id in self.data: - if user_id in self.data[chat_id]: - self.data[chat_id][user_id]['state'] = state - return True - else: - self.data[chat_id][user_id] = {'state': state, 'data': {}} - return True - self.data[chat_id] = {user_id: {'state': state, 'data': {}}} + + _key = self._get_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + + if self.data.get(_key) is None: + self.data[_key] = {"state": state, "data": {}} + else: + self.data[_key]["state"] = state + return True - async def delete_state(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - del self.data[chat_id][user_id] - if chat_id == user_id: - del self.data[chat_id] - - return True + async def get_state( + self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None + ) -> Union[str, None]: - return False + _key = self._get_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + if self.data.get(_key) is None: + return None + + return self.data[_key]["state"] - async def get_state(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - return self.data[chat_id][user_id]['state'] + async def delete_state( + self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None + ) -> bool: + _key = self._get_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + + if self.data.get(_key) is None: + return False + + del self.data[_key] + return True + + + async def set_data( + self, chat_id: int, user_id: int, key: str, value: Union[str, int, float, dict], + business_connection_id: Optional[str]=None, message_thread_id: Optional[int]=None, + bot_id: Optional[int]=None) -> bool: + + _key = self._get_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + + if self.data.get(_key) is None: + return False + self.data[_key]["data"][key] = value + return True - return None - async def get_data(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - return self.data[chat_id][user_id]['data'] + + async def get_data( + self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None + ) -> dict: - return None + _key = self._get_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) - async def reset_data(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - self.data[chat_id][user_id]['data'] = {} - return True - return False + return self.data.get(_key, {}).get("data", None) + + async def reset_data( + self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None + ) -> bool: + + _key = self._get_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) - async def set_data(self, chat_id, user_id, key, value): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - self.data[chat_id][user_id]['data'][key] = value - return True - raise RuntimeError('chat_id {} and user_id {} does not exist'.format(chat_id, user_id)) + if self.data.get(_key) is None: + return False + self.data[_key]["data"] = {} + return True + + def get_interactive_data( + self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None + ) -> Optional[dict]: + return StateDataContext( + self, chat_id=chat_id, user_id=user_id, business_connection_id=business_connection_id, + message_thread_id=message_thread_id, bot_id=bot_id + ) + + async def save( + self, chat_id: int, user_id: int, data: dict, business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None + ) -> bool: + _key = self._get_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, + message_thread_id, bot_id + ) - def get_interactive_data(self, chat_id, user_id): - return StateContext(self, chat_id, user_id) + if self.data.get(_key) is None: + return False + self.data[_key]["data"] = data + return True + + def __str__(self) -> str: + return f"" + + - async def save(self, chat_id, user_id, data): - self.data[chat_id][user_id]['data'] = data \ No newline at end of file diff --git a/telebot/asyncio_storage/pickle_storage.py b/telebot/asyncio_storage/pickle_storage.py index cf446d85b..fcffbb289 100644 --- a/telebot/asyncio_storage/pickle_storage.py +++ b/telebot/asyncio_storage/pickle_storage.py @@ -1,4 +1,4 @@ -from telebot.asyncio_storage.base_storage import StateStorageBase, StateContext +from telebot.asyncio_storage.base_storage import StateStorageBase, StateDataContext import os import pickle @@ -103,7 +103,7 @@ async def set_data(self, chat_id, user_id, key, value): raise RuntimeError('chat_id {} and user_id {} does not exist'.format(chat_id, user_id)) def get_interactive_data(self, chat_id, user_id): - return StateContext(self, chat_id, user_id) + return StateDataContext(self, chat_id, user_id) async def save(self, chat_id, user_id, data): self.data[chat_id][user_id]['data'] = data diff --git a/telebot/asyncio_storage/redis_storage.py b/telebot/asyncio_storage/redis_storage.py index 84db253e5..86fde3716 100644 --- a/telebot/asyncio_storage/redis_storage.py +++ b/telebot/asyncio_storage/redis_storage.py @@ -1,4 +1,4 @@ -from telebot.asyncio_storage.base_storage import StateStorageBase, StateContext +from telebot.asyncio_storage.base_storage import StateStorageBase, StateDataContext import json redis_installed = True @@ -167,7 +167,7 @@ def get_interactive_data(self, chat_id, user_id): Get Data in interactive way. You can use with() with this function. """ - return StateContext(self, chat_id, user_id) + return StateDataContext(self, chat_id, user_id) async def save(self, chat_id, user_id, data): response = await self.get_record(chat_id) diff --git a/telebot/states/__init__.py b/telebot/states/__init__.py index b8efe3bbe..015f135e4 100644 --- a/telebot/states/__init__.py +++ b/telebot/states/__init__.py @@ -16,7 +16,7 @@ def __init__(self) -> None: self.name: str = None self.group: StatesGroup = None def __str__(self) -> str: - return f"<{self.group.__name__}:{self.name}>" + return f"<{self.name}>" class StatesGroup: diff --git a/telebot/states/aio/__init__.py b/telebot/states/aio/__init__.py index e69de29bb..60f9ea2b7 100644 --- a/telebot/states/aio/__init__.py +++ b/telebot/states/aio/__init__.py @@ -0,0 +1,7 @@ +from .context import StateContext +from .middleware import StateMiddleware + +__all__ = [ + 'StateContext', + 'StateMiddleware', +] \ No newline at end of file diff --git a/telebot/states/aio/context.py b/telebot/states/aio/context.py new file mode 100644 index 000000000..431eba93c --- /dev/null +++ b/telebot/states/aio/context.py @@ -0,0 +1,138 @@ +from telebot.states import State, StatesGroup +from telebot.types import CallbackQuery, Message +from telebot.async_telebot import AsyncTeleBot +from telebot.states import resolve_context + +from typing import Union + + + +class StateContext(): + """ + Class representing a state context. + + Passed through a middleware to provide easy way to set states. + + .. code-block:: python3 + + @bot.message_handler(commands=['start']) + async def start_ex(message: types.Message, state_context: StateContext): + await state_context.set(MyStates.name) + await bot.send_message(message.chat.id, 'Hi, write me a name', reply_to_message_id=message.message_id) + # also, state_context.data(), .add_data(), .reset_data(), .delete() methods available. + """ + + def __init__(self, message: Union[Message, CallbackQuery], bot: str) -> None: + self.message: Union[Message, CallbackQuery] = message + self.bot: AsyncTeleBot = bot + self.bot_id = self.bot.bot_id + + async def set(self, state: Union[State, str]) -> None: + """ + Set state for current user. + + :param state: State object or state name. + :type state: Union[State, str] + + .. code-block:: python3 + + @bot.message_handler(commands=['start']) + async def start_ex(message: types.Message, state_context: StateContext): + await state_context.set(MyStates.name) + await bot.send_message(message.chat.id, 'Hi, write me a name', reply_to_message_id=message.message_id) + """ + + chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) + if isinstance(state, State): + state = state.name + return await self.bot.set_state( + chat_id=chat_id, + user_id=user_id, + state=state, + business_connection_id=business_connection_id, + bot_id=bot_id, + message_thread_id=message_thread_id + ) + + async def get(self) -> str: + """ + Get current state for current user. + + :return: Current state name. + :rtype: str + """ + + chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) + return await self.bot.get_state( + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + bot_id=bot_id, + message_thread_id=message_thread_id + ) + + async def reset_data(self) -> None: + """ + Reset data for current user. + State will not be changed. + """ + + chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) + return await self.bot.reset_data( + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + bot_id=bot_id, + message_thread_id=message_thread_id + ) + + async def delete(self) -> None: + """ + Deletes state and data for current user. + """ + chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) + return await self.bot.delete_state( + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + bot_id=bot_id, + message_thread_id=message_thread_id + ) + + def data(self) -> dict: + """ + Get data for current user. + + .. code-block:: python3 + + with state_context.data() as data: + print(data) + """ + + chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) + return self.bot.retrieve_data( + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + bot_id=bot_id, + message_thread_id=message_thread_id + ) + + async def add_data(self, **kwargs) -> None: + """ + Add data for current user. + + :param kwargs: Data to add. + :type kwargs: dict + """ + + chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) + return await self.bot.add_data( + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + bot_id=bot_id, + message_thread_id=message_thread_id, + **kwargs + ) + \ No newline at end of file diff --git a/telebot/states/aio/middleware.py b/telebot/states/aio/middleware.py new file mode 100644 index 000000000..66283dc82 --- /dev/null +++ b/telebot/states/aio/middleware.py @@ -0,0 +1,19 @@ +from telebot.asyncio_handler_backends import BaseMiddleware +from telebot.async_telebot import AsyncTeleBot +from telebot.states.sync.context import StateContext +from telebot.util import update_types +from telebot import types + + +class StateMiddleware(BaseMiddleware): + + def __init__(self, bot: AsyncTeleBot) -> None: + self.update_sensitive = False + self.update_types = update_types + self.bot: AsyncTeleBot = bot + + async def pre_process(self, message, data): + data['state_context'] = StateContext(message, self.bot) + + async def post_process(self, message, data, exception): + pass diff --git a/telebot/storage/__init__.py b/telebot/storage/__init__.py index 59e2b058c..954c1b31d 100644 --- a/telebot/storage/__init__.py +++ b/telebot/storage/__init__.py @@ -1,13 +1,13 @@ from telebot.storage.memory_storage import StateMemoryStorage from telebot.storage.redis_storage import StateRedisStorage from telebot.storage.pickle_storage import StatePickleStorage -from telebot.storage.base_storage import StateContext,StateStorageBase +from telebot.storage.base_storage import StateDataContext,StateStorageBase __all__ = [ - 'StateStorageBase', 'StateContext', + 'StateStorageBase', 'StateDataContext', 'StateMemoryStorage', 'StateRedisStorage', 'StatePickleStorage' ] \ No newline at end of file diff --git a/telebot/storage/base_storage.py b/telebot/storage/base_storage.py index 2358e93c9..25a6c6b04 100644 --- a/telebot/storage/base_storage.py +++ b/telebot/storage/base_storage.py @@ -78,7 +78,7 @@ def _get_key( -class StateContext: +class StateDataContext: """ Class for data. """ diff --git a/telebot/storage/memory_storage.py b/telebot/storage/memory_storage.py index 547445d0c..82142430a 100644 --- a/telebot/storage/memory_storage.py +++ b/telebot/storage/memory_storage.py @@ -1,4 +1,4 @@ -from telebot.storage.base_storage import StateStorageBase, StateContext +from telebot.storage.base_storage import StateStorageBase, StateDataContext from typing import Optional, Union class StateMemoryStorage(StateStorageBase): @@ -105,7 +105,7 @@ def get_interactive_data( self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, message_thread_id: Optional[int]=None, bot_id: Optional[int]=None ) -> Optional[dict]: - return StateContext( + return StateDataContext( self, chat_id=chat_id, user_id=user_id, business_connection_id=business_connection_id, message_thread_id=message_thread_id, bot_id=bot_id ) diff --git a/telebot/storage/pickle_storage.py b/telebot/storage/pickle_storage.py index 64ed7ce85..917491814 100644 --- a/telebot/storage/pickle_storage.py +++ b/telebot/storage/pickle_storage.py @@ -2,7 +2,7 @@ import pickle import threading from typing import Optional, Union -from telebot.storage.base_storage import StateStorageBase, StateContext +from telebot.storage.base_storage import StateStorageBase, StateDataContext class StatePickleStorage(StateStorageBase): def __init__(self, file_path: str="./.state-save/states.pkl", @@ -110,7 +110,7 @@ def reset_data(self, chat_id: int, user_id: int, business_connection_id: Optiona def get_interactive_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> Optional[dict]: - return StateContext( + return StateDataContext( self, chat_id=chat_id, user_id=user_id, business_connection_id=business_connection_id, message_thread_id=message_thread_id, bot_id=bot_id ) diff --git a/telebot/storage/redis_storage.py b/telebot/storage/redis_storage.py index b16ea3016..0aa7c43ea 100644 --- a/telebot/storage/redis_storage.py +++ b/telebot/storage/redis_storage.py @@ -1,5 +1,5 @@ import json -from telebot.storage.base_storage import StateStorageBase, StateContext +from telebot.storage.base_storage import StateStorageBase, StateDataContext from typing import Optional, Union redis_installed = True @@ -126,7 +126,7 @@ def get_interactive_data( self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, message_thread_id: Optional[int] = None, bot_id: Optional[int] = None ) -> Optional[dict]: - return StateContext( + return StateDataContext( self, chat_id=chat_id, user_id=user_id, business_connection_id=business_connection_id, message_thread_id=message_thread_id, bot_id=bot_id ) From ff6485d66570e497038f6a1cc9fa6ac10ee05f18 Mon Sep 17 00:00:00 2001 From: _run Date: Fri, 19 Jul 2024 21:26:33 +0500 Subject: [PATCH 09/24] Async version fully supported, partially tested. --- telebot/asyncio_storage/pickle_storage.py | 206 +++++++++------- telebot/asyncio_storage/redis_storage.py | 274 +++++++++------------- telebot/storage/pickle_storage.py | 5 +- telebot/storage/redis_storage.py | 4 +- 4 files changed, 237 insertions(+), 252 deletions(-) diff --git a/telebot/asyncio_storage/pickle_storage.py b/telebot/asyncio_storage/pickle_storage.py index fcffbb289..0c7da7eb1 100644 --- a/telebot/asyncio_storage/pickle_storage.py +++ b/telebot/asyncio_storage/pickle_storage.py @@ -1,28 +1,43 @@ -from telebot.asyncio_storage.base_storage import StateStorageBase, StateDataContext -import os +try: + import aiofiles +except ImportError: + aiofiles_installed = False + +import os import pickle +import asyncio +from typing import Optional, Union, Callable, Any +from telebot.asyncio_storage.base_storage import StateStorageBase, StateDataContext + +def with_lock(func: Callable) -> Callable: + async def wrapper(self, *args, **kwargs): + async with self.lock: + return await func(self, *args, **kwargs) + return wrapper class StatePickleStorage(StateStorageBase): - def __init__(self, file_path="./.state-save/states.pkl") -> None: + def __init__(self, file_path: str = "./.state-save/states.pkl", + prefix='telebot', separator: Optional[str] = ":") -> None: + + if not aiofiles_installed: + raise ImportError("Please install aiofiles using `pip install aiofiles`") + self.file_path = file_path + self.prefix = prefix + self.separator = separator + self.lock = asyncio.Lock() self.create_dir() - self.data = self.read() - - async def convert_old_to_new(self): - # old looks like: - # {1: {'state': 'start', 'data': {'name': 'John'}} - # we should update old version pickle to new. - # new looks like: - # {1: {2: {'state': 'start', 'data': {'name': 'John'}}}} - new_data = {} - for key, value in self.data.items(): - # this returns us id and dict with data and state - new_data[key] = {key: value} # convert this to new - # pass it to global data - self.data = new_data - self.update_data() # update data in file + + async def _read_from_file(self) -> dict: + async with aiofiles.open(self.file_path, 'rb') as f: + data = await f.read() + return pickle.loads(data) + + async def _write_to_file(self, data: dict) -> None: + async with aiofiles.open(self.file_path, 'wb') as f: + await f.write(pickle.dumps(data)) def create_dir(self): """ @@ -34,77 +49,100 @@ def create_dir(self): with open(self.file_path,'wb') as file: pickle.dump({}, file) - def read(self): - file = open(self.file_path, 'rb') - data = pickle.load(file) - file.close() - return data - - def update_data(self): - file = open(self.file_path, 'wb+') - pickle.dump(self.data, file, protocol=pickle.HIGHEST_PROTOCOL) - file.close() - - async def set_state(self, chat_id, user_id, state): - if hasattr(state, 'name'): - state = state.name - if chat_id in self.data: - if user_id in self.data[chat_id]: - self.data[chat_id][user_id]['state'] = state - self.update_data() - return True - else: - self.data[chat_id][user_id] = {'state': state, 'data': {}} - self.update_data() - return True - self.data[chat_id] = {user_id: {'state': state, 'data': {}}} - self.update_data() + + @with_lock + async def set_state(self, chat_id: int, user_id: int, state: str, + business_connection_id: Optional[str] = None, message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None) -> bool: + _key = self._get_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + data = await self._read_from_file() + if _key not in data: + data[_key] = {"state": state, "data": {}} + else: + data[_key]["state"] = state + await self._write_to_file(data) return True - - async def delete_state(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - del self.data[chat_id][user_id] - if chat_id == user_id: - del self.data[chat_id] - self.update_data() - return True + @with_lock + async def get_state(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> Union[str, None]: + _key = self._get_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + data = await self._read_from_file() + return data.get(_key, {}).get("state") + + @with_lock + async def delete_state(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool: + _key = self._get_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + data = await self._read_from_file() + if _key in data: + del data[_key] + await self._write_to_file(data) + return True return False - - async def get_state(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - return self.data[chat_id][user_id]['state'] - - return None - async def get_data(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - return self.data[chat_id][user_id]['data'] - - return None - - async def reset_data(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - self.data[chat_id][user_id]['data'] = {} - self.update_data() - return True + @with_lock + async def set_data(self, chat_id: int, user_id: int, key: str, value: Union[str, int, float, dict], + business_connection_id: Optional[str] = None, message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None) -> bool: + _key = self._get_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + data = await self._read_from_file() + state_data = data.get(_key, {}) + state_data["data"][key] = value + if _key not in data: + data[_key] = {"state": None, "data": state_data} + else: + data[_key]["data"][key] = value + await self._write_to_file(data) + return True + + @with_lock + async def get_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> dict: + _key = self._get_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + data = await self._read_from_file() + return data.get(_key, {}).get("data", {}) + + @with_lock + async def reset_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool: + _key = self._get_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + data = await self._read_from_file() + if _key in data: + data[_key]["data"] = {} + await self._write_to_file(data) + return True return False - async def set_data(self, chat_id, user_id, key, value): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - self.data[chat_id][user_id]['data'][key] = value - self.update_data() - return True - raise RuntimeError('chat_id {} and user_id {} does not exist'.format(chat_id, user_id)) + def get_interactive_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> Optional[dict]: + return StateDataContext( + self, chat_id=chat_id, user_id=user_id, business_connection_id=business_connection_id, + message_thread_id=message_thread_id, bot_id=bot_id + ) - def get_interactive_data(self, chat_id, user_id): - return StateDataContext(self, chat_id, user_id) + @with_lock + async def save(self, chat_id: int, user_id: int, data: dict, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool: + _key = self._get_key( + chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + ) + data = await self._read_from_file() + data[_key]["data"] = data + await self._write_to_file(data) + return True - async def save(self, chat_id, user_id, data): - self.data[chat_id][user_id]['data'] = data - self.update_data() \ No newline at end of file + def __str__(self) -> str: + return f"StatePickleStorage({self.file_path}, {self.prefix})" diff --git a/telebot/asyncio_storage/redis_storage.py b/telebot/asyncio_storage/redis_storage.py index 86fde3716..a6b19780d 100644 --- a/telebot/asyncio_storage/redis_storage.py +++ b/telebot/asyncio_storage/redis_storage.py @@ -1,179 +1,131 @@ -from telebot.asyncio_storage.base_storage import StateStorageBase, StateDataContext -import json -redis_installed = True -is_actual_aioredis = False try: - import aioredis - is_actual_aioredis = True + import redis + from redis.asyncio import Redis, ConnectionPool except ImportError: - try: - from redis import asyncio as aioredis - except ImportError: - redis_installed = False + redis_installed = False + +import json +from typing import Optional, Union, Callable, Coroutine +import asyncio + +from telebot.asyncio_storage.base_storage import StateStorageBase, StateDataContext + + +def async_with_lock(func: Callable[..., Coroutine]) -> Callable[..., Coroutine]: + async def wrapper(self, *args, **kwargs): + async with self.lock: + return await func(self, *args, **kwargs) + return wrapper +def async_with_pipeline(func: Callable[..., Coroutine]) -> Callable[..., Coroutine]: + async def wrapper(self, *args, **kwargs): + async with self.redis.pipeline() as pipe: + pipe.multi() + result = await func(self, pipe, *args, **kwargs) + await pipe.execute() + return result + return wrapper class StateRedisStorage(StateStorageBase): - """ - This class is for Redis storage. - This will work only for states. - To use it, just pass this class to: - TeleBot(storage=StateRedisStorage()) - """ - def __init__(self, host='localhost', port=6379, db=0, password=None, prefix='telebot_', redis_url=None): + def __init__(self, host='localhost', port=6379, db=0, password=None, + prefix='telebot', + redis_url=None, + connection_pool: 'ConnectionPool'=None, + separator: Optional[str] = ":", + ) -> None: + if not redis_installed: - raise ImportError('AioRedis is not installed. Install it via "pip install aioredis"') + raise ImportError("Please install redis using `pip install redis`") + + self.separator = separator + self.prefix = prefix + if not self.prefix: + raise ValueError("Prefix cannot be empty") - if is_actual_aioredis: - aioredis_version = tuple(map(int, aioredis.__version__.split(".")[0])) - if aioredis_version < (2,): - raise ImportError('Invalid aioredis version. Aioredis version should be >= 2.0.0') if redis_url: - self.redis = aioredis.Redis.from_url(redis_url) + self.redis = redis.asyncio.from_url(redis_url) + elif connection_pool: + self.redis = Redis(connection_pool=connection_pool) else: - self.redis = aioredis.Redis(host=host, port=port, db=db, password=password) + self.redis = Redis(host=host, port=port, db=db, password=password) + + self.lock = asyncio.Lock() + + @async_with_lock + @async_with_pipeline + async def set_state(self, pipe, chat_id: int, user_id: int, state: str, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool: + if hasattr(state, "name"): + state = state.name - self.prefix = prefix - #self.con = Redis(connection_pool=self.redis) -> use this when necessary - # - # {chat_id: {user_id: {'state': None, 'data': {}}, ...}, ...} - - async def get_record(self, key): - """ - Function to get record from database. - It has nothing to do with states. - Made for backward compatibility - """ - result = await self.redis.get(self.prefix+str(key)) - if result: return json.loads(result) - return - - async def set_record(self, key, value): - """ - Function to set record to database. - It has nothing to do with states. - Made for backward compatibility - """ - - await self.redis.set(self.prefix+str(key), json.dumps(value)) + _key = self._get_key(chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id) + await pipe.hset(_key, "state", state) return True - async def delete_record(self, key): - """ - Function to delete record from database. - It has nothing to do with states. - Made for backward compatibility - """ - await self.redis.delete(self.prefix+str(key)) + async def get_state(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> Union[str, None]: + _key = self._get_key(chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id) + state_bytes = await self.redis.hget(_key, "state") + return state_bytes.decode('utf-8') if state_bytes else None + + async def delete_state(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool: + _key = self._get_key(chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id) + result = await self.redis.delete(_key) + return result > 0 + + @async_with_lock + @async_with_pipeline + async def set_data(self, pipe, chat_id: int, user_id: int, key: str, value: Union[str, int, float, dict], + business_connection_id: Optional[str] = None, message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None) -> bool: + _key = self._get_key(chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id) + data = await pipe.hget(_key, "data") + data = await pipe.execute() + data = data[0] + if data is None: + await pipe.hset(_key, "data", json.dumps({key: value})) + else: + data = json.loads(data) + data[key] = value + await pipe.hset(_key, "data", json.dumps(data)) return True - async def set_state(self, chat_id, user_id, state): - """ - Set state for a particular user in a chat. - """ - response = await self.get_record(chat_id) - user_id = str(user_id) - if hasattr(state, 'name'): - state = state.name - if response: - if user_id in response: - response[user_id]['state'] = state - else: - response[user_id] = {'state': state, 'data': {}} + async def get_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> dict: + _key = self._get_key(chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id) + data = await self.redis.hget(_key, "data") + return json.loads(data) if data else {} + + @async_with_lock + @async_with_pipeline + async def reset_data(self, pipe, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool: + _key = self._get_key(chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id) + if await pipe.exists(_key): + await pipe.hset(_key, "data", "{}") else: - response = {user_id: {'state': state, 'data': {}}} - await self.set_record(chat_id, response) + return False + return True + def get_interactive_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> Optional[dict]: + return StateDataContext(self, chat_id=chat_id, user_id=user_id, business_connection_id=business_connection_id, + message_thread_id=message_thread_id, bot_id=bot_id) + + @async_with_lock + @async_with_pipeline + async def save(self, pipe, chat_id: int, user_id: int, data: dict, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool: + _key = self._get_key(chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id) + if await pipe.exists(_key): + await pipe.hset(_key, "data", json.dumps(data)) + else: + return False return True - - async def delete_state(self, chat_id, user_id): - """ - Delete state for a particular user in a chat. - """ - response = await self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - del response[user_id] - if user_id == str(chat_id): - await self.delete_record(chat_id) - return True - else: await self.set_record(chat_id, response) - return True - return False - - async def get_value(self, chat_id, user_id, key): - """ - Get value for a data of a user in a chat. - """ - response = await self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - if key in response[user_id]['data']: - return response[user_id]['data'][key] - return None - - async def get_state(self, chat_id, user_id): - """ - Get state of a user in a chat. - """ - response = await self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - return response[user_id]['state'] - - return None - - async def get_data(self, chat_id, user_id): - """ - Get data of particular user in a particular chat. - """ - response = await self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - return response[user_id]['data'] - return None - - async def reset_data(self, chat_id, user_id): - """ - Reset data of a user in a chat. - """ - response = await self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - response[user_id]['data'] = {} - await self.set_record(chat_id, response) - return True - - async def set_data(self, chat_id, user_id, key, value): - """ - Set data without interactive data. - """ - response = await self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - response[user_id]['data'][key] = value - await self.set_record(chat_id, response) - return True - return False - - def get_interactive_data(self, chat_id, user_id): - """ - Get Data in interactive way. - You can use with() with this function. - """ - return StateDataContext(self, chat_id, user_id) - - async def save(self, chat_id, user_id, data): - response = await self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - response[user_id]['data'] = data - await self.set_record(chat_id, response) - return True + + def __str__(self) -> str: + # include some connection info + return f"StateRedisStorage({self.redis})" diff --git a/telebot/storage/pickle_storage.py b/telebot/storage/pickle_storage.py index 917491814..32a9653c7 100644 --- a/telebot/storage/pickle_storage.py +++ b/telebot/storage/pickle_storage.py @@ -127,7 +127,4 @@ def save(self, chat_id: int, user_id: int, data: dict, business_connection_id: O return True def __str__(self) -> str: - with self.lock: - with open(self.file_path, 'rb') as f: - data = pickle.load(f) - return f"" + return f"StatePickleStorage({self.file_path}, {self.prefix})" diff --git a/telebot/storage/redis_storage.py b/telebot/storage/redis_storage.py index 0aa7c43ea..d41da05c8 100644 --- a/telebot/storage/redis_storage.py +++ b/telebot/storage/redis_storage.py @@ -151,6 +151,4 @@ def save_action(pipe): return True def __str__(self) -> str: - keys = self.redis.keys(f"{self.prefix}{self.separator}*") - data = {key.decode(): self.redis.hgetall(key) for key in keys} - return f"" + return f"StateRedisStorage({self.redis})" From af79db6ad3528d01be167bec02cc6799dfe67526 Mon Sep 17 00:00:00 2001 From: _run Date: Sun, 21 Jul 2024 20:15:57 +0500 Subject: [PATCH 10/24] rewrote pickle sync with lock decorators, few changes to storages to return empty dict on get_data & raise runtimeerror on key not existing --- telebot/asyncio_storage/memory_storage.py | 4 +- telebot/asyncio_storage/pickle_storage.py | 2 +- telebot/asyncio_storage/redis_storage.py | 2 +- telebot/storage/memory_storage.py | 4 +- telebot/storage/pickle_storage.py | 124 ++++++++++++---------- telebot/storage/redis_storage.py | 2 +- 6 files changed, 72 insertions(+), 66 deletions(-) diff --git a/telebot/asyncio_storage/memory_storage.py b/telebot/asyncio_storage/memory_storage.py index e65ed74d4..661cc35e9 100644 --- a/telebot/asyncio_storage/memory_storage.py +++ b/telebot/asyncio_storage/memory_storage.py @@ -71,7 +71,7 @@ async def set_data( ) if self.data.get(_key) is None: - return False + raise RuntimeError(f"MemoryStorage: key {_key} does not exist.") self.data[_key]["data"][key] = value return True @@ -85,7 +85,7 @@ async def get_data( chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id ) - return self.data.get(_key, {}).get("data", None) + return self.data.get(_key, {}).get("data", {}) async def reset_data( self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, diff --git a/telebot/asyncio_storage/pickle_storage.py b/telebot/asyncio_storage/pickle_storage.py index 0c7da7eb1..9a8c9eead 100644 --- a/telebot/asyncio_storage/pickle_storage.py +++ b/telebot/asyncio_storage/pickle_storage.py @@ -98,7 +98,7 @@ async def set_data(self, chat_id: int, user_id: int, key: str, value: Union[str, state_data = data.get(_key, {}) state_data["data"][key] = value if _key not in data: - data[_key] = {"state": None, "data": state_data} + raise RuntimeError(f"StatePickleStorage: key {_key} does not exist.") else: data[_key]["data"][key] = value await self._write_to_file(data) diff --git a/telebot/asyncio_storage/redis_storage.py b/telebot/asyncio_storage/redis_storage.py index a6b19780d..b07bd4159 100644 --- a/telebot/asyncio_storage/redis_storage.py +++ b/telebot/asyncio_storage/redis_storage.py @@ -86,7 +86,7 @@ async def set_data(self, pipe, chat_id: int, user_id: int, key: str, value: Unio data = await pipe.execute() data = data[0] if data is None: - await pipe.hset(_key, "data", json.dumps({key: value})) + raise RuntimeError(f"StateRedisStorage: key {_key} does not exist.") else: data = json.loads(data) data[key] = value diff --git a/telebot/storage/memory_storage.py b/telebot/storage/memory_storage.py index 82142430a..11acdc119 100644 --- a/telebot/storage/memory_storage.py +++ b/telebot/storage/memory_storage.py @@ -71,7 +71,7 @@ def set_data( ) if self.data.get(_key) is None: - return False + raise RuntimeError(f"StateMemoryStorage: key {_key} does not exist.") self.data[_key]["data"][key] = value return True @@ -85,7 +85,7 @@ def get_data( chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id ) - return self.data.get(_key, {}).get("data", None) + return self.data.get(_key, {}).get("data", {}) def reset_data( self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, diff --git a/telebot/storage/pickle_storage.py b/telebot/storage/pickle_storage.py index 32a9653c7..b449c17f4 100644 --- a/telebot/storage/pickle_storage.py +++ b/telebot/storage/pickle_storage.py @@ -1,12 +1,18 @@ import os import pickle import threading -from typing import Optional, Union +from typing import Optional, Union, Callable from telebot.storage.base_storage import StateStorageBase, StateDataContext +def with_lock(func: Callable) -> Callable: + def wrapper(self, *args, **kwargs): + with self.lock: + return func(self, *args, **kwargs) + return wrapper + class StatePickleStorage(StateStorageBase): - def __init__(self, file_path: str="./.state-save/states.pkl", - prefix='telebot', separator: Optional[str]=":") -> None: + def __init__(self, file_path: str = "./.state-save/states.pkl", + prefix='telebot', separator: Optional[str] = ":") -> None: self.file_path = file_path self.prefix = prefix self.separator = separator @@ -32,98 +38,98 @@ def create_dir(self): with open(self.file_path,'wb') as file: pickle.dump({}, file) + @with_lock def set_state(self, chat_id: int, user_id: int, state: str, - business_connection_id: Optional[str]=None, message_thread_id: Optional[int]=None, - bot_id: Optional[int]=None) -> bool: + business_connection_id: Optional[str] = None, message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None) -> bool: _key = self._get_key( chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id ) - with self.lock: - data = self._read_from_file() - if _key not in data: - data[_key] = {"state": state, "data": {}} - else: - data[_key]["state"] = state - self._write_to_file(data) + data = self._read_from_file() + if _key not in data: + data[_key] = {"state": state, "data": {}} + else: + data[_key]["state"] = state + self._write_to_file(data) return True - def get_state(self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, - message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> Union[str, None]: + @with_lock + def get_state(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> Union[str, None]: _key = self._get_key( chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id ) - with self.lock: - data = self._read_from_file() - return data.get(_key, {}).get("state") + data = self._read_from_file() + return data.get(_key, {}).get("state") - def delete_state(self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, - message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> bool: + @with_lock + def delete_state(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool: _key = self._get_key( chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id ) - with self.lock: - data = self._read_from_file() - if _key in data: - del data[_key] - self._write_to_file(data) - return True - return False + data = self._read_from_file() + if _key in data: + del data[_key] + self._write_to_file(data) + return True + return False + @with_lock def set_data(self, chat_id: int, user_id: int, key: str, value: Union[str, int, float, dict], - business_connection_id: Optional[str]=None, message_thread_id: Optional[int]=None, - bot_id: Optional[int]=None) -> bool: + business_connection_id: Optional[str] = None, message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None) -> bool: _key = self._get_key( chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id ) - with self.lock: - data = self._read_from_file() - state_data = data.get(_key, {}) - state_data["data"][key] = value - if _key not in data: - data[_key] = {"state": None, "data": state_data} - else: - data[_key]["data"][key] = value - self._write_to_file(data) + data = self._read_from_file() + state_data = data.get(_key, {}) + state_data["data"][key] = value + + if _key not in data: + raise RuntimeError(f"PickleStorage: key {_key} does not exist.") + + self._write_to_file(data) return True - def get_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, - message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> dict: + @with_lock + def get_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> dict: _key = self._get_key( chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id ) - with self.lock: - data = self._read_from_file() - return data.get(_key, {}).get("data", {}) + data = self._read_from_file() + return data.get(_key, {}).get("data", {}) - def reset_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, - message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> bool: + @with_lock + def reset_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool: _key = self._get_key( chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id ) - with self.lock: - data = self._read_from_file() - if _key in data: - data[_key]["data"] = {} - self._write_to_file(data) - return True - return False + data = self._read_from_file() + if _key in data: + data[_key]["data"] = {} + self._write_to_file(data) + return True + return False - def get_interactive_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, - message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> Optional[dict]: + def get_interactive_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> Optional[dict]: return StateDataContext( self, chat_id=chat_id, user_id=user_id, business_connection_id=business_connection_id, message_thread_id=message_thread_id, bot_id=bot_id ) - def save(self, chat_id: int, user_id: int, data: dict, business_connection_id: Optional[str]=None, - message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> bool: + @with_lock + def save(self, chat_id: int, user_id: int, data: dict, business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool: _key = self._get_key( chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id ) - with self.lock: - data = self._read_from_file() - data[_key]["data"] = data - self._write_to_file(data) + data = self._read_from_file() + data[_key]["data"] = data + self._write_to_file(data) return True def __str__(self) -> str: diff --git a/telebot/storage/redis_storage.py b/telebot/storage/redis_storage.py index d41da05c8..f21d50fe1 100644 --- a/telebot/storage/redis_storage.py +++ b/telebot/storage/redis_storage.py @@ -85,7 +85,7 @@ def set_data_action(pipe): data = pipe.hget(_key, "data") data = data.execute()[0] if data is None: - pipe.hset(_key, "data", json.dumps({key: value})) + raise RuntimeError(f"RedisStorage: key {_key} does not exist.") else: data = json.loads(data) data[key] = value From 1adca1375b6901b0963fd020f697825032b551d5 Mon Sep 17 00:00:00 2001 From: _run Date: Sun, 21 Jul 2024 22:27:08 +0500 Subject: [PATCH 11/24] Improved docstrings, fixed bugs, allow accessing statecontext via state name --- telebot/__init__.py | 52 ++++++++++++++++------- telebot/async_telebot.py | 43 +++++++++++++------ telebot/asyncio_storage/memory_storage.py | 17 ++++++++ telebot/asyncio_storage/pickle_storage.py | 24 +++++++++++ telebot/asyncio_storage/redis_storage.py | 34 +++++++++++++++ telebot/states/aio/context.py | 23 ++++++---- telebot/states/aio/middleware.py | 4 +- telebot/states/sync/context.py | 25 ++++++----- telebot/states/sync/middleware.py | 4 +- telebot/storage/memory_storage.py | 16 +++++++ telebot/storage/pickle_storage.py | 24 +++++++++++ telebot/storage/redis_storage.py | 33 ++++++++++++++ 12 files changed, 248 insertions(+), 51 deletions(-) diff --git a/telebot/__init__.py b/telebot/__init__.py index 08ae9fc6f..07ae33181 100644 --- a/telebot/__init__.py +++ b/telebot/__init__.py @@ -6645,9 +6645,9 @@ def setup_middleware(self, middleware: BaseMiddleware): self.middlewares.append(middleware) - def set_state(self, user_id: int, state: Union[int, str, State], chat_id: Optional[int]=None, + def set_state(self, user_id: int, state: Union[str, State], chat_id: Optional[int]=None, business_connection_id: Optional[str]=None, message_thread_id: Optional[int]=None, - bot_id: Optional[int]=None) -> None: + bot_id: Optional[int]=None) -> bool: """ Sets a new state of a user. @@ -6657,16 +6657,24 @@ def set_state(self, user_id: int, state: Union[int, str, State], chat_id: Option Otherwise, if you only set user_id, chat_id will equal to user_id, this means that state will be set for the user in his private chat with a bot. + .. versionchanged:: 4.22.0 + + Added additional parameters to support topics, business connections, and message threads. + + .. seealso:: + + For more details, visit the `custom_states.py example `_. + :param user_id: User's identifier :type user_id: :obj:`int` - :param state: new state. can be string, integer, or :class:`telebot.types.State` + :param state: new state. can be string, or :class:`telebot.types.State` :type state: :obj:`int` or :obj:`str` or :class:`telebot.types.State` :param chat_id: Chat's identifier :type chat_id: :obj:`int` - :param bot_id: Bot's identifier + :param bot_id: Bot's identifier, defaults to current bot id :type bot_id: :obj:`int` :param business_connection_id: Business identifier @@ -6675,7 +6683,8 @@ def set_state(self, user_id: int, state: Union[int, str, State], chat_id: Option :param message_thread_id: Identifier of the message thread :type message_thread_id: :obj:`int` - :return: None + :return: True on success + :rtype: :obj:`bool` """ if chat_id is None: chat_id = user_id @@ -6688,9 +6697,9 @@ def set_state(self, user_id: int, state: Union[int, str, State], chat_id: Option def reset_data(self, user_id: int, chat_id: Optional[int]=None, business_connection_id: Optional[str]=None, - message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> None: + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> bool: """ - Reset data for a user in chat. + Reset data for a user in chat: sets the 'data' fieldi to an empty dictionary. :param user_id: User's identifier :type user_id: :obj:`int` @@ -6698,7 +6707,7 @@ def reset_data(self, user_id: int, chat_id: Optional[int]=None, :param chat_id: Chat's identifier :type chat_id: :obj:`int` - :param bot_id: Bot's identifier + :param bot_id: Bot's identifier, defaults to current bot id :type bot_id: :obj:`int` :param business_connection_id: Business identifier @@ -6707,7 +6716,8 @@ def reset_data(self, user_id: int, chat_id: Optional[int]=None, :param message_thread_id: Identifier of the message thread :type message_thread_id: :obj:`int` - :return: None + :return: True on success + :rtype: :obj:`bool` """ if chat_id is None: chat_id = user_id @@ -6718,9 +6728,13 @@ def reset_data(self, user_id: int, chat_id: Optional[int]=None, def delete_state(self, user_id: int, chat_id: Optional[int]=None, business_connection_id: Optional[str]=None, - message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> None: + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> bool: """ - Delete the current state of a user. + Fully deletes the storage record of a user in chat. + + .. warning:: + + This does NOT set state to None, but deletes the object from storage. :param user_id: User's identifier :type user_id: :obj:`int` @@ -6728,7 +6742,8 @@ def delete_state(self, user_id: int, chat_id: Optional[int]=None, business_conne :param chat_id: Chat's identifier :type chat_id: :obj:`int` - :return: None + :return: True on success + :rtype: :obj:`bool` """ if chat_id is None: chat_id = user_id @@ -6749,7 +6764,7 @@ def retrieve_data(self, user_id: int, chat_id: Optional[int]=None, business_conn :param chat_id: Chat's unique identifier, defaults to user_id :type chat_id: int, optional - :param bot_id: Bot's identifier + :param bot_id: Bot's identifier, defaults to current bot id :type bot_id: int, optional :param business_connection_id: Business identifier @@ -6772,18 +6787,23 @@ def retrieve_data(self, user_id: int, chat_id: Optional[int]=None, business_conn def get_state(self, user_id: int, chat_id: Optional[int]=None, business_connection_id: Optional[str]=None, - message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> Union[int, str]: + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> str: """ Gets current state of a user. Not recommended to use this method. But it is ok for debugging. + .. warning:: + + Even if you are using :class:`telebot.types.State`, this method will return a string. + When comparing(not recommended), you should compare this string with :class:`telebot.types.State`.name + :param user_id: User's identifier :type user_id: :obj:`int` :param chat_id: Chat's identifier :type chat_id: :obj:`int` - :param bot_id: Bot's identifier + :param bot_id: Bot's identifier, defaults to current bot id :type bot_id: :obj:`int` :param business_connection_id: Business identifier @@ -6817,7 +6837,7 @@ def add_data(self, user_id: int, chat_id: Optional[int]=None, :param chat_id: Chat's identifier :type chat_id: :obj:`int` - :param bot_id: Bot's identifier + :param bot_id: Bot's identifier, defaults to current bot id :type bot_id: :obj:`int` :param business_connection_id: Business identifier diff --git a/telebot/async_telebot.py b/telebot/async_telebot.py index 2d443cb5b..508aa6cb4 100644 --- a/telebot/async_telebot.py +++ b/telebot/async_telebot.py @@ -7846,7 +7846,7 @@ async def get_forum_topic_icon_stickers(self) -> List[types.Sticker]: async def set_state(self, user_id: int, state: Union[int, str, State], chat_id: Optional[int]=None, business_connection_id: Optional[str]=None, message_thread_id: Optional[int]=None, - bot_id: Optional[int]=None) -> None: + bot_id: Optional[int]=None) -> bool: """ Sets a new state of a user. @@ -7856,16 +7856,24 @@ async def set_state(self, user_id: int, state: Union[int, str, State], chat_id: Otherwise, if you only set user_id, chat_id will equal to user_id, this means that state will be set for the user in his private chat with a bot. + .. versionchanged:: 4.22.0 + + Added additional parameters to support topics, business connections, and message threads. + + .. seealso:: + + For more details, visit the `custom_states.py example `_. + :param user_id: User's identifier :type user_id: :obj:`int` - :param state: new state. can be string, integer, or :class:`telebot.types.State` + :param state: new state. can be string, or :class:`telebot.types.State` :type state: :obj:`int` or :obj:`str` or :class:`telebot.types.State` :param chat_id: Chat's identifier :type chat_id: :obj:`int` - :param bot_id: Bot's identifier + :param bot_id: Bot's identifier, defaults to current bot id :type bot_id: :obj:`int` :param business_connection_id: Business identifier @@ -7874,7 +7882,8 @@ async def set_state(self, user_id: int, state: Union[int, str, State], chat_id: :param message_thread_id: Identifier of the message thread :type message_thread_id: :obj:`int` - :return: None + :return: True on success + :rtype: :obj:`bool` """ if chat_id is None: chat_id = user_id @@ -7887,7 +7896,7 @@ async def set_state(self, user_id: int, state: Union[int, str, State], chat_id: async def reset_data(self, user_id: int, chat_id: Optional[int]=None, business_connection_id: Optional[str]=None, - message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> None: + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> bool: """ Reset data for a user in chat. @@ -7897,7 +7906,7 @@ async def reset_data(self, user_id: int, chat_id: Optional[int]=None, :param chat_id: Chat's identifier :type chat_id: :obj:`int` - :param bot_id: Bot's identifier + :param bot_id: Bot's identifier, defaults to current bot id :type bot_id: :obj:`int` :param business_connection_id: Business identifier @@ -7906,7 +7915,8 @@ async def reset_data(self, user_id: int, chat_id: Optional[int]=None, :param message_thread_id: Identifier of the message thread :type message_thread_id: :obj:`int` - :return: None + :return: True on success + :rtype: :obj:`bool` """ if chat_id is None: chat_id = user_id @@ -7917,7 +7927,7 @@ async def reset_data(self, user_id: int, chat_id: Optional[int]=None, async def delete_state(self, user_id: int, chat_id: Optional[int]=None, business_connection_id: Optional[str]=None, - message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> None: + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> bool: """ Delete the current state of a user. @@ -7948,7 +7958,7 @@ def retrieve_data(self, user_id: int, chat_id: Optional[int]=None, business_conn :param chat_id: Chat's unique identifier, defaults to user_id :type chat_id: int, optional - :param bot_id: Bot's identifier + :param bot_id: Bot's identifier, defaults to current bot id :type bot_id: int, optional :param business_connection_id: Business identifier @@ -7958,7 +7968,7 @@ def retrieve_data(self, user_id: int, chat_id: Optional[int]=None, business_conn :type message_thread_id: int, optional :return: Context manager with data for a user in chat - :rtype: Optional[Any] + :rtype: :obj:`dict` """ if chat_id is None: chat_id = user_id @@ -7971,18 +7981,23 @@ def retrieve_data(self, user_id: int, chat_id: Optional[int]=None, business_conn async def get_state(self, user_id: int, chat_id: Optional[int]=None, business_connection_id: Optional[str]=None, - message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> Union[int, str]: + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> str: """ Gets current state of a user. Not recommended to use this method. But it is ok for debugging. + .. warning:: + + Even if you are using :class:`telebot.types.State`, this method will return a string. + When comparing(not recommended), you should compare this string with :class:`telebot.types.State`.name + :param user_id: User's identifier :type user_id: :obj:`int` :param chat_id: Chat's identifier :type chat_id: :obj:`int` - :param bot_id: Bot's identifier + :param bot_id: Bot's identifier, defaults to current bot id :type bot_id: :obj:`int` :param business_connection_id: Business identifier @@ -7992,7 +8007,7 @@ async def get_state(self, user_id: int, chat_id: Optional[int]=None, :type message_thread_id: :obj:`int` :return: state of a user - :rtype: :obj:`int` or :obj:`str` or :class:`telebot.types.State` + :rtype: :obj:`str` """ if chat_id is None: chat_id = user_id @@ -8016,7 +8031,7 @@ async def add_data(self, user_id: int, chat_id: Optional[int]=None, :param chat_id: Chat's identifier :type chat_id: :obj:`int` - :param bot_id: Bot's identifier + :param bot_id: Bot's identifier, defaults to current bot id :type bot_id: :obj:`int` :param business_connection_id: Business identifier diff --git a/telebot/asyncio_storage/memory_storage.py b/telebot/asyncio_storage/memory_storage.py index 661cc35e9..87449cace 100644 --- a/telebot/asyncio_storage/memory_storage.py +++ b/telebot/asyncio_storage/memory_storage.py @@ -2,6 +2,23 @@ from typing import Optional, Union class StateMemoryStorage(StateStorageBase): + """ + Memory storage for states. + + Stores states in memory as a dictionary. + + .. code-block:: python3 + + storage = StateMemoryStorage() + bot = AsyncTeleBot(token, storage=storage) + + :param separator: Separator for keys, default is ":". + :type separator: Optional[str] + + :param prefix: Prefix for keys, default is "telebot". + :type prefix: Optional[str] + """ + def __init__(self, separator: Optional[str]=":", prefix: Optional[str]="telebot" diff --git a/telebot/asyncio_storage/pickle_storage.py b/telebot/asyncio_storage/pickle_storage.py index 9a8c9eead..0506d4703 100644 --- a/telebot/asyncio_storage/pickle_storage.py +++ b/telebot/asyncio_storage/pickle_storage.py @@ -1,4 +1,5 @@ +aiofiles_installed = True try: import aiofiles except ImportError: @@ -18,6 +19,29 @@ async def wrapper(self, *args, **kwargs): return wrapper class StatePickleStorage(StateStorageBase): + """ + State storage based on pickle file. + + .. warning:: + + This storage is not recommended for production use. + Data may be corrupted. If you face a case where states do not work as expected, + try to use another storage. + + .. code-block:: python3 + + storage = StatePickleStorage() + bot = AsyncTeleBot(token, storage=storage) + + :param file_path: Path to file where states will be stored. + :type file_path: str + + :param prefix: Prefix for keys, default is "telebot". + :type prefix: Optional[str] + + :param separator: Separator for keys, default is ":". + :type separator: Optional[str] + """ def __init__(self, file_path: str = "./.state-save/states.pkl", prefix='telebot', separator: Optional[str] = ":") -> None: diff --git a/telebot/asyncio_storage/redis_storage.py b/telebot/asyncio_storage/redis_storage.py index b07bd4159..dea214daa 100644 --- a/telebot/asyncio_storage/redis_storage.py +++ b/telebot/asyncio_storage/redis_storage.py @@ -1,4 +1,5 @@ +redis_installed = True try: import redis from redis.asyncio import Redis, ConnectionPool @@ -28,6 +29,39 @@ async def wrapper(self, *args, **kwargs): return wrapper class StateRedisStorage(StateStorageBase): + """ + State storage based on Redis. + + .. code-block:: python3 + + storage = StateRedisStorage(...) + bot = AsyncTeleBot(token, storage=storage) + + :param host: Redis host, default is "localhost". + :type host: str + + :param port: Redis port, default is 6379. + :type port: int + + :param db: Redis database, default is 0. + :type db: int + + :param password: Redis password, default is None. + :type password: Optional[str] + + :param prefix: Prefix for keys, default is "telebot". + :type prefix: Optional[str] + + :param redis_url: Redis URL, default is None. + :type redis_url: Optional[str] + + :param connection_pool: Redis connection pool, default is None. + :type connection_pool: Optional[ConnectionPool] + + :param separator: Separator for keys, default is ":". + :type separator: Optional[str] + + """ def __init__(self, host='localhost', port=6379, db=0, password=None, prefix='telebot', redis_url=None, diff --git a/telebot/states/aio/context.py b/telebot/states/aio/context.py index 431eba93c..956ca6837 100644 --- a/telebot/states/aio/context.py +++ b/telebot/states/aio/context.py @@ -27,7 +27,7 @@ def __init__(self, message: Union[Message, CallbackQuery], bot: str) -> None: self.bot: AsyncTeleBot = bot self.bot_id = self.bot.bot_id - async def set(self, state: Union[State, str]) -> None: + async def set(self, state: Union[State, str]) -> bool: """ Set state for current user. @@ -71,14 +71,16 @@ async def get(self) -> str: message_thread_id=message_thread_id ) - async def reset_data(self) -> None: - """ - Reset data for current user. - State will not be changed. + async def delete(self) -> bool: """ + Deletes state and data for current user. + .. warning:: + + This method deletes state and associated data for current user. + """ chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) - return await self.bot.reset_data( + return await self.bot.delete_state( chat_id=chat_id, user_id=user_id, business_connection_id=business_connection_id, @@ -86,12 +88,14 @@ async def reset_data(self) -> None: message_thread_id=message_thread_id ) - async def delete(self) -> None: + async def reset_data(self) -> bool: """ - Deletes state and data for current user. + Reset data for current user. + State will not be changed. """ + chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) - return await self.bot.delete_state( + return await self.bot.reset_data( chat_id=chat_id, user_id=user_id, business_connection_id=business_connection_id, @@ -107,6 +111,7 @@ def data(self) -> dict: with state_context.data() as data: print(data) + data['name'] = 'John' """ chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) diff --git a/telebot/states/aio/middleware.py b/telebot/states/aio/middleware.py index 66283dc82..546b29067 100644 --- a/telebot/states/aio/middleware.py +++ b/telebot/states/aio/middleware.py @@ -13,7 +13,9 @@ def __init__(self, bot: AsyncTeleBot) -> None: self.bot: AsyncTeleBot = bot async def pre_process(self, message, data): - data['state_context'] = StateContext(message, self.bot) + state_context = StateContext(message, self.bot) + data['state_context'] = state_context + data['state'] = state_context # 2 ways to access state context async def post_process(self, message, data, exception): pass diff --git a/telebot/states/sync/context.py b/telebot/states/sync/context.py index c8009007a..c0611f6bb 100644 --- a/telebot/states/sync/context.py +++ b/telebot/states/sync/context.py @@ -27,7 +27,7 @@ def __init__(self, message: Union[Message, CallbackQuery], bot: str) -> None: self.bot: TeleBot = bot self.bot_id = self.bot.bot_id - def set(self, state: Union[State, str]) -> None: + def set(self, state: Union[State, str]) -> bool: """ Set state for current user. @@ -71,27 +71,31 @@ def get(self) -> str: message_thread_id=message_thread_id ) - def reset_data(self) -> None: - """ - Reset data for current user. - State will not be changed. + def delete(self) -> bool: """ + Deletes state and data for current user. + .. warning:: + + This method deletes state and associated data for current user. + """ chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) - return self.bot.reset_data( + return self.bot.delete_state( chat_id=chat_id, user_id=user_id, business_connection_id=business_connection_id, bot_id=bot_id, message_thread_id=message_thread_id ) - - def delete(self) -> None: + + def reset_data(self) -> bool: """ - Deletes state and data for current user. + Reset data for current user. + State will not be changed. """ + chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) - return self.bot.delete_state( + return self.bot.reset_data( chat_id=chat_id, user_id=user_id, business_connection_id=business_connection_id, @@ -107,6 +111,7 @@ def data(self) -> dict: with state_context.data() as data: print(data) + data['name'] = 'John' """ chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) diff --git a/telebot/states/sync/middleware.py b/telebot/states/sync/middleware.py index b85f795da..4a252158b 100644 --- a/telebot/states/sync/middleware.py +++ b/telebot/states/sync/middleware.py @@ -13,7 +13,9 @@ def __init__(self, bot: TeleBot) -> None: self.bot: TeleBot = bot def pre_process(self, message, data): - data['state_context'] = StateContext(message, self.bot) + state_context = StateContext(message, self.bot) + data['state_context'] = state_context + data['state'] = state_context # 2 ways to access state context def post_process(self, message, data, exception): pass diff --git a/telebot/storage/memory_storage.py b/telebot/storage/memory_storage.py index 11acdc119..dc88bcebc 100644 --- a/telebot/storage/memory_storage.py +++ b/telebot/storage/memory_storage.py @@ -2,6 +2,22 @@ from typing import Optional, Union class StateMemoryStorage(StateStorageBase): + """ + Memory storage for states. + + Stores states in memory as a dictionary. + + .. code-block:: python3 + + storage = StateMemoryStorage() + bot = TeleBot(token, storage=storage) + + :param separator: Separator for keys, default is ":". + :type separator: Optional[str] + + :param prefix: Prefix for keys, default is "telebot". + :type prefix: Optional[str] + """ def __init__(self, separator: Optional[str]=":", prefix: Optional[str]="telebot" diff --git a/telebot/storage/pickle_storage.py b/telebot/storage/pickle_storage.py index b449c17f4..fc8dc3289 100644 --- a/telebot/storage/pickle_storage.py +++ b/telebot/storage/pickle_storage.py @@ -11,6 +11,30 @@ def wrapper(self, *args, **kwargs): return wrapper class StatePickleStorage(StateStorageBase): + """ + State storage based on pickle file. + + .. warning:: + + This storage is not recommended for production use. + Data may be corrupted. If you face a case where states do not work as expected, + try to use another storage. + + .. code-block:: python3 + + storage = StatePickleStorage() + bot = TeleBot(token, storage=storage) + + :param file_path: Path to file where states will be stored. + :type file_path: str + + :param prefix: Prefix for keys, default is "telebot". + :type prefix: Optional[str] + + :param separator: Separator for keys, default is ":". + :type separator: Optional[str] + """ + def __init__(self, file_path: str = "./.state-save/states.pkl", prefix='telebot', separator: Optional[str] = ":") -> None: self.file_path = file_path diff --git a/telebot/storage/redis_storage.py b/telebot/storage/redis_storage.py index f21d50fe1..146ec2f63 100644 --- a/telebot/storage/redis_storage.py +++ b/telebot/storage/redis_storage.py @@ -9,6 +9,39 @@ redis_installed = False class StateRedisStorage(StateStorageBase): + """ + State storage based on Redis. + + .. code-block:: python3 + + storage = StateRedisStorage(...) + bot = TeleBot(token, storage=storage) + + :param host: Redis host, default is "localhost". + :type host: str + + :param port: Redis port, default is 6379. + :type port: int + + :param db: Redis database, default is 0. + :type db: int + + :param password: Redis password, default is None. + :type password: Optional[str] + + :param prefix: Prefix for keys, default is "telebot". + :type prefix: Optional[str] + + :param redis_url: Redis URL, default is None. + :type redis_url: Optional[str] + + :param connection_pool: Redis connection pool, default is None. + :type connection_pool: Optional[ConnectionPool] + + :param separator: Separator for keys, default is ":". + :type separator: Optional[str] + + """ def __init__(self, host='localhost', port=6379, db=0, password=None, prefix='telebot', redis_url=None, From adae7a9ba4b137a3ed46db48a8a52cf0b089b692 Mon Sep 17 00:00:00 2001 From: _run Date: Sun, 21 Jul 2024 22:33:08 +0500 Subject: [PATCH 12/24] Removed aioredis from optional dependencies, breaking change made for states="*" --- pyproject.toml | 1 - telebot/asyncio_filters.py | 6 +++++- telebot/custom_filters.py | 6 +++++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0095c3f21..d5289361b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,6 @@ Issues = "https://github.com/eternnoir/pyTelegramBotAPI/issues" json = ["ujson"] PIL = ["Pillow"] redis = ["redis>=3.4.1"] -aioredis = ["aioredis"] aiohttp = ["aiohttp"] fastapi = ["fastapi"] uvicorn = ["uvicorn"] diff --git a/telebot/asyncio_filters.py b/telebot/asyncio_filters.py index 389ce1643..f8212f934 100644 --- a/telebot/asyncio_filters.py +++ b/telebot/asyncio_filters.py @@ -397,7 +397,6 @@ async def check(self, message, text): """ :meta private: """ - if text == '*': return True chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(message, self.bot._user.id) @@ -421,10 +420,15 @@ async def check(self, message, text): message_thread_id=message_thread_id ) + # CHANGED BEHAVIOUR + if text == "*" and user_state is not None: + return True + if user_state == text: return True elif type(text) is list and user_state in text: return True + return False class IsDigitFilter(SimpleCustomFilter): diff --git a/telebot/custom_filters.py b/telebot/custom_filters.py index e53c669a1..2e91e55ea 100644 --- a/telebot/custom_filters.py +++ b/telebot/custom_filters.py @@ -403,7 +403,6 @@ def check(self, message, text): """ :meta private: """ - if text == '*': return True chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(message, self.bot._user.id) @@ -427,10 +426,15 @@ def check(self, message, text): message_thread_id=message_thread_id ) + # CHANGED BEHAVIOUR + if text == "*" and user_state is not None: + return True + if user_state == text: return True elif type(text) is list and user_state in text: return True + return False class IsDigitFilter(SimpleCustomFilter): """ From d33db8532f1505d240dc53c79c9eeb5995a7fb95 Mon Sep 17 00:00:00 2001 From: _run Date: Sun, 21 Jul 2024 22:55:57 +0500 Subject: [PATCH 13/24] Updated examples --- .../asynchronous_telebot/custom_states.py | 169 +++++++++------- examples/custom_states.py | 186 +++++++++--------- 2 files changed, 191 insertions(+), 164 deletions(-) diff --git a/examples/asynchronous_telebot/custom_states.py b/examples/asynchronous_telebot/custom_states.py index aefad9809..eefc6ebbb 100644 --- a/examples/asynchronous_telebot/custom_states.py +++ b/examples/asynchronous_telebot/custom_states.py @@ -1,91 +1,112 @@ -from telebot import asyncio_filters -from telebot.async_telebot import AsyncTeleBot - -# list of storages, you can use any storage +from telebot import async_telebot +from telebot import asyncio_filters, types +from telebot.states import State, StatesGroup +from telebot.states.aio.context import StateContext from telebot.asyncio_storage import StateMemoryStorage -# new feature for states. -from telebot.asyncio_handler_backends import State, StatesGroup - -# default state storage is statememorystorage -bot = AsyncTeleBot('TOKEN', state_storage=StateMemoryStorage()) - +# Initialize the bot +state_storage = StateMemoryStorage() # don't use this in production; switch to redis +bot = async_telebot.AsyncTeleBot("TOKEN", state_storage=state_storage) -# Just create different statesgroup +# Define states class MyStates(StatesGroup): - name = State() # statesgroup should contain states - surname = State() + name = State() age = State() + color = State() + hobby = State() - - -# set_state -> sets a new state -# delete_state -> delets state if exists -# get_state -> returns state if exists - - +# Start command handler @bot.message_handler(commands=['start']) -async def start_ex(message): - """ - Start command. Here we are starting state - """ - await bot.set_state(message.from_user.id, MyStates.name, message.chat.id) - await bot.send_message(message.chat.id, 'Hi, write me a name') - - - -@bot.message_handler(state="*", commands='cancel') -async def any_state(message): - """ - Cancel state - """ - await bot.send_message(message.chat.id, "Your state was cancelled.") - await bot.delete_state(message.from_user.id, message.chat.id) +async def start_ex(message: types.Message, state: StateContext): + await state.set(MyStates.name) + await bot.send_message(message.chat.id, 'Hello! What is your first name?', reply_to_message_id=message.message_id) + +# Cancel command handler +@bot.message_handler(state="*", commands=['cancel']) +async def any_state(message: types.Message, state: StateContext): + await state.delete() + await bot.send_message(message.chat.id, 'Your information has been cleared. Type /start to begin again.', reply_to_message_id=message.message_id) +# Handler for name input @bot.message_handler(state=MyStates.name) -async def name_get(message): - """ - State 1. Will process when user's state is MyStates.name. - """ - await bot.send_message(message.chat.id, f'Now write me a surname') - await bot.set_state(message.from_user.id, MyStates.surname, message.chat.id) - async with bot.retrieve_data(message.from_user.id, message.chat.id) as data: - data['name'] = message.text - - -@bot.message_handler(state=MyStates.surname) -async def ask_age(message): - """ - State 2. Will process when user's state is MyStates.surname. - """ - await bot.send_message(message.chat.id, "What is your age?") - await bot.set_state(message.from_user.id, MyStates.age, message.chat.id) - async with bot.retrieve_data(message.from_user.id, message.chat.id) as data: - data['surname'] = message.text - -# result +async def name_get(message: types.Message, state: StateContext): + await state.set(MyStates.age) + await bot.send_message(message.chat.id, "How old are you?", reply_to_message_id=message.message_id) + await state.add_data(name=message.text) + +# Handler for age input @bot.message_handler(state=MyStates.age, is_digit=True) -async def ready_for_answer(message): - """ - State 3. Will process when user's state is MyStates.age. - """ - async with bot.retrieve_data(message.from_user.id, message.chat.id) as data: - await bot.send_message(message.chat.id, "Ready, take a look:\nName: {name}\nSurname: {surname}\nAge: {age}".format(name=data['name'], surname=data['surname'], age=message.text), parse_mode="html") - await bot.delete_state(message.from_user.id, message.chat.id) - -#incorrect number +async def ask_color(message: types.Message, state: StateContext): + await state.set(MyStates.color) + await state.add_data(age=message.text) + + # Define reply keyboard for color selection + keyboard = types.ReplyKeyboardMarkup(row_width=2) + colors = ["Red", "Green", "Blue", "Yellow", "Purple", "Orange", "Other"] + buttons = [types.KeyboardButton(color) for color in colors] + keyboard.add(*buttons) + + await bot.send_message(message.chat.id, "What is your favorite color? Choose from the options below.", reply_markup=keyboard, reply_to_message_id=message.message_id) + +# Handler for color input +@bot.message_handler(state=MyStates.color) +async def ask_hobby(message: types.Message, state: StateContext): + await state.set(MyStates.hobby) + await state.add_data(color=message.text) + + # Define reply keyboard for hobby selection + keyboard = types.ReplyKeyboardMarkup(row_width=2) + hobbies = ["Reading", "Traveling", "Gaming", "Cooking"] + buttons = [types.KeyboardButton(hobby) for hobby in hobbies] + keyboard.add(*buttons) + + await bot.send_message(message.chat.id, "What is one of your hobbies? Choose from the options below.", reply_markup=keyboard, reply_to_message_id=message.message_id) + +# Handler for hobby input; use filters to ease validation +@bot.message_handler(state=MyStates.hobby, text=['Reading', 'Traveling', 'Gaming', 'Cooking']) +async def finish(message: types.Message, state: StateContext): + async with state.data() as data: + name = data.get('name') + age = data.get('age') + color = data.get('color') + hobby = message.text # Get the hobby from the message text + + # Provide a fun fact based on color + color_facts = { + "Red": "Red is often associated with excitement and passion.", + "Green": "Green is the color of nature and tranquility.", + "Blue": "Blue is known for its calming and serene effects.", + "Yellow": "Yellow is a cheerful color often associated with happiness.", + "Purple": "Purple signifies royalty and luxury.", + "Orange": "Orange is a vibrant color that stimulates enthusiasm.", + "Other": "Colors have various meanings depending on context." + } + color_fact = color_facts.get(color, "Colors have diverse meanings, and yours is unique!") + + msg = (f"Thank you for sharing! Here is a summary of your information:\n" + f"First Name: {name}\n" + f"Age: {age}\n" + f"Favorite Color: {color}\n" + f"Fun Fact about your color: {color_fact}\n" + f"Favorite Hobby: {hobby}") + + await bot.send_message(message.chat.id, msg, parse_mode="html", reply_to_message_id=message.message_id) + await state.delete() + +# Handler for incorrect age input @bot.message_handler(state=MyStates.age, is_digit=False) -async def age_incorrect(message): - """ - Will process for wrong input when state is MyState.age - """ - await bot.send_message(message.chat.id, 'Looks like you are submitting a string in the field age. Please enter a number') - -# register filters +async def age_incorrect(message: types.Message): + await bot.send_message(message.chat.id, 'Please enter a valid number for age.', reply_to_message_id=message.message_id) +# Add custom filters bot.add_custom_filter(asyncio_filters.StateFilter(bot)) bot.add_custom_filter(asyncio_filters.IsDigitFilter()) +bot.add_custom_filter(asyncio_filters.TextMatchFilter()) +# necessary for state parameter in handlers. +from telebot.states.aio.middleware import StateMiddleware +bot.setup_middleware(StateMiddleware(bot)) +# Start polling import asyncio -asyncio.run(bot.polling()) \ No newline at end of file +asyncio.run(bot.polling()) diff --git a/examples/custom_states.py b/examples/custom_states.py index a02761286..a488664b2 100644 --- a/examples/custom_states.py +++ b/examples/custom_states.py @@ -1,106 +1,112 @@ -import telebot # telebot - -from telebot import custom_filters -from telebot.handler_backends import State, StatesGroup #States - -# States storage +import telebot +from telebot import custom_filters, types +from telebot.states import State, StatesGroup +from telebot.states.sync.context import StateContext from telebot.storage import StateMemoryStorage +# Initialize the bot +state_storage = StateMemoryStorage() # don't use this in production; switch to redis +bot = telebot.TeleBot("TOKEN", state_storage=state_storage, + use_class_middlewares=True) -# Starting from version 4.4.0+, we support storages. -# StateRedisStorage -> Redis-based storage. -# StatePickleStorage -> Pickle-based storage. -# For redis, you will need to install redis. -# Pass host, db, password, or anything else, -# if you need to change config for redis. -# Pickle requires path. Default path is in folder .state-saves. -# If you were using older version of pytba for pickle, -# you need to migrate from old pickle to new by using -# StatePickleStorage().convert_old_to_new() - - - -# Now, you can pass storage to bot. -state_storage = StateMemoryStorage() # you can init here another storage - -bot = telebot.TeleBot("TOKEN", -state_storage=state_storage) - - -# States group. +# Define states class MyStates(StatesGroup): - # Just name variables differently - name = State() # creating instances of State class is enough from now - surname = State() + name = State() age = State() + color = State() + hobby = State() - - - +# Start command handler @bot.message_handler(commands=['start']) -def start_ex(message): - """ - Start command. Here we are starting state - """ - bot.set_state(message.from_user.id, MyStates.name, message.chat.id) - bot.send_message(message.chat.id, 'Hi, write me a name') - - -# Any state +def start_ex(message: types.Message, state: StateContext): + state.set(MyStates.name) + bot.send_message(message.chat.id, 'Hello! What is your first name?', reply_to_message_id=message.message_id) + +# Cancel command handler @bot.message_handler(state="*", commands=['cancel']) -def any_state(message): - """ - Cancel state - """ - bot.send_message(message.chat.id, "Your state was cancelled.") - bot.delete_state(message.from_user.id, message.chat.id) +def any_state(message: types.Message, state: StateContext): + state.delete() + bot.send_message(message.chat.id, 'Your information has been cleared. Type /start to begin again.', reply_to_message_id=message.message_id) +# Handler for name input @bot.message_handler(state=MyStates.name) -def name_get(message): - """ - State 1. Will process when user's state is MyStates.name. - """ - bot.send_message(message.chat.id, 'Now write me a surname') - bot.set_state(message.from_user.id, MyStates.surname, message.chat.id) - with bot.retrieve_data(message.from_user.id, message.chat.id) as data: - data['name'] = message.text - - -@bot.message_handler(state=MyStates.surname) -def ask_age(message): - """ - State 2. Will process when user's state is MyStates.surname. - """ - bot.send_message(message.chat.id, "What is your age?") - bot.set_state(message.from_user.id, MyStates.age, message.chat.id) - with bot.retrieve_data(message.from_user.id, message.chat.id) as data: - data['surname'] = message.text - -# result +def name_get(message: types.Message, state: StateContext): + state.set(MyStates.age) + bot.send_message(message.chat.id, "How old are you?", reply_to_message_id=message.message_id) + state.add_data(name=message.text) + +# Handler for age input @bot.message_handler(state=MyStates.age, is_digit=True) -def ready_for_answer(message): - """ - State 3. Will process when user's state is MyStates.age. - """ - with bot.retrieve_data(message.from_user.id, message.chat.id) as data: - msg = ("Ready, take a look:\n" - f"Name: {data['name']}\n" - f"Surname: {data['surname']}\n" - f"Age: {message.text}") - bot.send_message(message.chat.id, msg, parse_mode="html") - bot.delete_state(message.from_user.id, message.chat.id) - -#incorrect number +def ask_color(message: types.Message, state: StateContext): + state.set(MyStates.color) + state.add_data(age=message.text) + + # Define reply keyboard for color selection + keyboard = types.ReplyKeyboardMarkup(row_width=2) + colors = ["Red", "Green", "Blue", "Yellow", "Purple", "Orange", "Other"] + buttons = [types.KeyboardButton(color) for color in colors] + keyboard.add(*buttons) + + bot.send_message(message.chat.id, "What is your favorite color? Choose from the options below.", reply_markup=keyboard, reply_to_message_id=message.message_id) + +# Handler for color input +@bot.message_handler(state=MyStates.color) +def ask_hobby(message: types.Message, state: StateContext): + state.set(MyStates.hobby) + state.add_data(color=message.text) + + # Define reply keyboard for hobby selection + keyboard = types.ReplyKeyboardMarkup(row_width=2) + hobbies = ["Reading", "Traveling", "Gaming", "Cooking"] + buttons = [types.KeyboardButton(hobby) for hobby in hobbies] + keyboard.add(*buttons) + + bot.send_message(message.chat.id, "What is one of your hobbies? Choose from the options below.", reply_markup=keyboard, reply_to_message_id=message.message_id) + +# Handler for hobby input +@bot.message_handler(state=MyStates.hobby, text=['Reading', 'Traveling', 'Gaming', 'Cooking']) +def finish(message: types.Message, state: StateContext): + with state.data() as data: + name = data.get('name') + age = data.get('age') + color = data.get('color') + hobby = message.text # Get the hobby from the message text + + # Provide a fun fact based on color + color_facts = { + "Red": "Red is often associated with excitement and passion.", + "Green": "Green is the color of nature and tranquility.", + "Blue": "Blue is known for its calming and serene effects.", + "Yellow": "Yellow is a cheerful color often associated with happiness.", + "Purple": "Purple signifies royalty and luxury.", + "Orange": "Orange is a vibrant color that stimulates enthusiasm.", + "Other": "Colors have various meanings depending on context." + } + color_fact = color_facts.get(color, "Colors have diverse meanings, and yours is unique!") + + msg = (f"Thank you for sharing! Here is a summary of your information:\n" + f"First Name: {name}\n" + f"Age: {age}\n" + f"Favorite Color: {color}\n" + f"Fun Fact about your color: {color_fact}\n" + f"Favorite Hobby: {hobby}") + + bot.send_message(message.chat.id, msg, parse_mode="html", reply_to_message_id=message.message_id) + state.delete() + +# Handler for incorrect age input @bot.message_handler(state=MyStates.age, is_digit=False) -def age_incorrect(message): - """ - Wrong response for MyStates.age - """ - bot.send_message(message.chat.id, 'Looks like you are submitting a string in the field age. Please enter a number') - -# register filters +def age_incorrect(message: types.Message): + bot.send_message(message.chat.id, 'Please enter a valid number for age.', reply_to_message_id=message.message_id) +# Add custom filters bot.add_custom_filter(custom_filters.StateFilter(bot)) bot.add_custom_filter(custom_filters.IsDigitFilter()) +bot.add_custom_filter(custom_filters.TextMatchFilter()) + +# necessary for state parameter in handlers. +from telebot.states.sync.middleware import StateMiddleware +bot.setup_middleware(StateMiddleware(bot)) -bot.infinity_polling(skip_pending=True) +# Start polling +bot.infinity_polling() From 536ffa21f419ce91323ebe0723dc8bc26f70d145 Mon Sep 17 00:00:00 2001 From: _run Date: Thu, 25 Jul 2024 18:09:19 +0500 Subject: [PATCH 14/24] Added migrate_format to migrate from old format of states to new. Only redis as it should be used for production. --- telebot/asyncio_storage/redis_storage.py | 41 ++++++++++++++++++++++++ telebot/storage/redis_storage.py | 41 ++++++++++++++++++++++++ 2 files changed, 82 insertions(+) diff --git a/telebot/asyncio_storage/redis_storage.py b/telebot/asyncio_storage/redis_storage.py index dea214daa..5c0c1bb0a 100644 --- a/telebot/asyncio_storage/redis_storage.py +++ b/telebot/asyncio_storage/redis_storage.py @@ -159,6 +159,47 @@ async def save(self, pipe, chat_id: int, user_id: int, data: dict, business_conn else: return False return True + + def migrate_format(self, bot_id: int, + prefix: Optional[str]="telebot_"): + """ + Migrate from old to new format of keys. + Run this function once to migrate all redis existing keys to new format. + + Starting from version 4.22.0, the format of keys has been changed: + :value + - Old format: {prefix}chat_id: {user_id: {'state': None, 'data': {}}, ...} + - New format: + {prefix}{separator}{bot_id}{separator}{business_connection_id}{separator}{message_thread_id}{separator}{chat_id}{separator}{user_id}: {'state': ..., 'data': {}} + + This function will help you to migrate from the old format to the new one in order to avoid data loss. + + :param bot_id: Bot ID; To get it, call a getMe request and grab the id from the response. + :type bot_id: int + + :param prefix: Prefix for keys, default is "telebot_"(old default value) + :type prefix: Optional[str] + """ + keys = self.redis.keys(f"{prefix}*") + + for key in keys: + old_key = key.decode('utf-8') + # old: {prefix}chat_id: {user_id: {'state': None, 'data': {}}, ...} + value = self.redis.get(old_key) + value = json.loads(value) + + chat_id = old_key[len(prefix):] + user_id = list(value.keys())[0] + state = value[user_id]['state'] + state_data = value[user_id]['data'] + + # set new format + new_key = self._get_key(int(chat_id), int(user_id), self.prefix, self.separator, bot_id=bot_id) + self.redis.hset(new_key, "state", state) + self.redis.hset(new_key, "data", json.dumps(state_data)) + + # delete old key + self.redis.delete(old_key) def __str__(self) -> str: # include some connection info diff --git a/telebot/storage/redis_storage.py b/telebot/storage/redis_storage.py index 146ec2f63..3a81ac033 100644 --- a/telebot/storage/redis_storage.py +++ b/telebot/storage/redis_storage.py @@ -182,6 +182,47 @@ def save_action(pipe): self.redis.transaction(save_action, _key) return True + + def migrate_format(self, bot_id: int, + prefix: Optional[str]="telebot_"): + """ + Migrate from old to new format of keys. + Run this function once to migrate all redis existing keys to new format. + + Starting from version 4.22.0, the format of keys has been changed: + :value + - Old format: {prefix}chat_id: {user_id: {'state': None, 'data': {}}, ...} + - New format: + {prefix}{separator}{bot_id}{separator}{business_connection_id}{separator}{message_thread_id}{separator}{chat_id}{separator}{user_id}: {'state': ..., 'data': {}} + + This function will help you to migrate from the old format to the new one in order to avoid data loss. + + :param bot_id: Bot ID; To get it, call a getMe request and grab the id from the response. + :type bot_id: int + + :param prefix: Prefix for keys, default is "telebot_"(old default value) + :type prefix: Optional[str] + """ + keys = self.redis.keys(f"{prefix}*") + + for key in keys: + old_key = key.decode('utf-8') + # old: {prefix}chat_id: {user_id: {'state': None, 'data': {}}, ...} + value = self.redis.get(old_key) + value = json.loads(value) + + chat_id = old_key[len(prefix):] + user_id = list(value.keys())[0] + state = value[user_id]['state'] + state_data = value[user_id]['data'] + + # set new format + new_key = self._get_key(int(chat_id), int(user_id), self.prefix, self.separator, bot_id=bot_id) + self.redis.hset(new_key, "state", state) + self.redis.hset(new_key, "data", json.dumps(state_data)) + + # delete old key + self.redis.delete(old_key) def __str__(self) -> str: return f"StateRedisStorage({self.redis})" From d2485bf42827795a20f40cd3c3ff96e52df86f74 Mon Sep 17 00:00:00 2001 From: _run Date: Fri, 26 Jul 2024 11:57:52 +0500 Subject: [PATCH 15/24] Another approach to bot id --- telebot/__init__.py | 22 +++++++++++++--------- telebot/async_telebot.py | 20 +++++++++++--------- telebot/util.py | 18 +++++++++++++++++- tests/test_handler_backends.py | 2 +- 4 files changed, 42 insertions(+), 20 deletions(-) diff --git a/telebot/__init__.py b/telebot/__init__.py index 07ae33181..d0ab64d5a 100644 --- a/telebot/__init__.py +++ b/telebot/__init__.py @@ -152,9 +152,14 @@ class TeleBot: :param allow_sending_without_reply: Default value for allow_sending_without_reply, defaults to None :type allow_sending_without_reply: :obj:`bool`, optional - :param colorful_logs: Outputs colorful logs :type colorful_logs: :obj:`bool`, optional + + :param validate_token: Validate token, defaults to True; + :type validate_token: :obj:`bool`, optional + + :raises ImportError: If coloredlogs module is not installed and colorful_logs is True + :raises ValueError: If token is invalid """ def __init__( @@ -169,7 +174,7 @@ def __init__( protect_content: Optional[bool]=None, allow_sending_without_reply: Optional[bool]=None, colorful_logs: Optional[bool]=False, - token_check: Optional[bool]=True + validate_token: Optional[bool]=True ): # update-related @@ -186,11 +191,12 @@ def __init__( self.allow_sending_without_reply = allow_sending_without_reply self.webhook_listener = None self._user = None + self.bot_id: int = None - # token check - if token_check: - self._user = self.get_me() - self.bot_id = self._user.id + if validate_token: + util.validate_token(self.token) + + self.bot_id = util.extract_bot_id(self.token) # subject to change in future, unspecified # logs-related if colorful_logs: @@ -286,9 +292,7 @@ def __init__( self.threaded = threaded if self.threaded: self.worker_pool = util.ThreadPool(self, num_threads=num_threads) - - - + @property def user(self) -> types.User: """ diff --git a/telebot/async_telebot.py b/telebot/async_telebot.py index 508aa6cb4..d60a4a090 100644 --- a/telebot/async_telebot.py +++ b/telebot/async_telebot.py @@ -117,8 +117,11 @@ class AsyncTeleBot: :param colorful_logs: Outputs colorful logs :type colorful_logs: :obj:`bool`, optional - :param token_check: Check token on start - :type token_check: :obj:`bool`, optional, defaults to True + :param validate_token: Validate token, defaults to True; + :type validate_token: :obj:`bool`, optional + + :raises ImportError: If coloredlogs module is not installed and colorful_logs is True + :raises ValueError: If token is invalid """ def __init__(self, token: str, parse_mode: Optional[str]=None, offset: Optional[int]=None, @@ -129,7 +132,7 @@ def __init__(self, token: str, parse_mode: Optional[str]=None, offset: Optional[ protect_content: Optional[bool]=None, allow_sending_without_reply: Optional[bool]=None, colorful_logs: Optional[bool]=False, - token_check: Optional[bool]=True) -> None: + validate_token: Optional[bool]=True) -> None: # update-related self.token = token @@ -186,15 +189,14 @@ def __init__(self, token: str, parse_mode: Optional[str]=None, offset: Optional[ self.middlewares = [] self._user = None # set during polling - self.bot_id = None + self.bot_id: int = None - if token_check: - result = apihelper.get_me(token) - self._user = types.User.de_json(result) - self.bot_id = self._user.id + if validate_token: + util.validate_token(self.token) + + self.bot_id: int = util.extract_bot_id(self.token) # subject to change, unspecified - @property def user(self): return self._user diff --git a/telebot/util.py b/telebot/util.py index 295c0d1aa..c8ef526c5 100644 --- a/telebot/util.py +++ b/telebot/util.py @@ -686,6 +686,22 @@ def validate_web_app_data(token: str, raw_init_data: str): return hmac.new(secret_key.digest(), data_check_string.encode(), sha256).hexdigest() == init_data_hash +def validate_token(token) -> bool: + if any(char.isspace() for char in token): + raise ValueError('Token must not contain spaces') + + if ':' not in token: + raise ValueError('Token must contain a colon') + + if len(token.split(':')) != 2: + raise ValueError('Token must contain exactly 2 parts separated by a colon') + + return True + +def extract_bot_id(token) -> str: + return int(token.split(':')[0]) + + __all__ = ( "content_type_media", "content_type_service", "update_types", "WorkerThread", "AsyncTask", "CustomRequestResponse", @@ -696,5 +712,5 @@ def validate_web_app_data(token: str, raw_init_data: str): "split_string", "smart_split", "escape", "user_link", "quick_markup", "antiflood", "parse_web_app_data", "validate_web_app_data", "or_set", "or_clear", "orify", "OrEvent", "per_thread", - "webhook_google_functions" + "webhook_google_functions", "validate_token", "extract_bot_id" ) diff --git a/tests/test_handler_backends.py b/tests/test_handler_backends.py index f57200c1c..ad6ccaaad 100644 --- a/tests/test_handler_backends.py +++ b/tests/test_handler_backends.py @@ -19,7 +19,7 @@ @pytest.fixture() def telegram_bot(): - return telebot.TeleBot('', threaded=False, token_check=False) + return telebot.TeleBot('', threaded=False, validate_token=False) @pytest.fixture From dd0dfa98aa9a977827f62b078232941f981e372a Mon Sep 17 00:00:00 2001 From: _run Date: Fri, 26 Jul 2024 11:59:26 +0500 Subject: [PATCH 16/24] fix tests --- tests/test_handler_backends.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_handler_backends.py b/tests/test_handler_backends.py index ad6ccaaad..b74d6fad0 100644 --- a/tests/test_handler_backends.py +++ b/tests/test_handler_backends.py @@ -19,7 +19,7 @@ @pytest.fixture() def telegram_bot(): - return telebot.TeleBot('', threaded=False, validate_token=False) + return telebot.TeleBot('1234:test', threaded=False) @pytest.fixture From 7f99176c034cd5e356f009dd2664a03cf67edf25 Mon Sep 17 00:00:00 2001 From: _run Date: Fri, 26 Jul 2024 13:00:39 +0500 Subject: [PATCH 17/24] fixed redis data bug on set_state --- telebot/asyncio_storage/redis_storage.py | 7 +++++++ telebot/storage/redis_storage.py | 12 ++++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/telebot/asyncio_storage/redis_storage.py b/telebot/asyncio_storage/redis_storage.py index 5c0c1bb0a..69b18d5a1 100644 --- a/telebot/asyncio_storage/redis_storage.py +++ b/telebot/asyncio_storage/redis_storage.py @@ -95,7 +95,14 @@ async def set_state(self, pipe, chat_id: int, user_id: int, state: str, state = state.name _key = self._get_key(chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id) + pipe.hget(_key, "data") + result = await pipe.execute() + data = result[0] + if data is None: + pipe.hset(_key, "data", json.dumps({})) + await pipe.hset(_key, "state", state) + return True async def get_state(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, diff --git a/telebot/storage/redis_storage.py b/telebot/storage/redis_storage.py index 3a81ac033..1d1413fa2 100644 --- a/telebot/storage/redis_storage.py +++ b/telebot/storage/redis_storage.py @@ -79,9 +79,17 @@ def set_state( def set_state_action(pipe): pipe.multi() - #pipe.hset(_key, mapping={"state": state, "data": "{}"}) + + data = pipe.hget(_key, "data") + result = pipe.execute() + data = result[0] + if data is None: + # If data is None, set it to an empty dictionary + data = {} + pipe.hset(_key, "data", json.dumps(data)) + pipe.hset(_key, "state", state) - + self.redis.transaction(set_state_action, _key) return True From 3fec98c070f91b098766bef4968b283e26ce1556 Mon Sep 17 00:00:00 2001 From: _run Date: Sat, 27 Jul 2024 15:07:45 +0500 Subject: [PATCH 18/24] Code cleanup --- telebot/asyncio_storage/__init__.py | 12 +- telebot/asyncio_storage/base_storage.py | 66 ++++--- telebot/asyncio_storage/memory_storage.py | 183 +++++++++++++----- telebot/asyncio_storage/pickle_storage.py | 174 +++++++++++++---- telebot/asyncio_storage/redis_storage.py | 226 +++++++++++++++++----- telebot/states/__init__.py | 96 ++++++--- telebot/states/aio/__init__.py | 6 +- telebot/states/aio/context.py | 56 +++--- telebot/states/aio/middleware.py | 4 +- telebot/storage/__init__.py | 14 +- telebot/storage/base_storage.py | 66 ++++--- telebot/storage/memory_storage.py | 184 +++++++++++++----- telebot/storage/pickle_storage.py | 173 +++++++++++++---- telebot/storage/redis_storage.py | 186 +++++++++++++----- 14 files changed, 1054 insertions(+), 392 deletions(-) diff --git a/telebot/asyncio_storage/__init__.py b/telebot/asyncio_storage/__init__.py index 1f9d51650..82a8c817a 100644 --- a/telebot/asyncio_storage/__init__.py +++ b/telebot/asyncio_storage/__init__.py @@ -4,10 +4,10 @@ from telebot.asyncio_storage.base_storage import StateDataContext, StateStorageBase - - - __all__ = [ - 'StateStorageBase', 'StateDataContext', - 'StateMemoryStorage', 'StateRedisStorage', 'StatePickleStorage' -] \ No newline at end of file + "StateStorageBase", + "StateDataContext", + "StateMemoryStorage", + "StateRedisStorage", + "StatePickleStorage", +] diff --git a/telebot/asyncio_storage/base_storage.py b/telebot/asyncio_storage/base_storage.py index 6d06e7bfc..20d0b0587 100644 --- a/telebot/asyncio_storage/base_storage.py +++ b/telebot/asyncio_storage/base_storage.py @@ -1,5 +1,6 @@ import copy + class StateStorageBase: def __init__(self) -> None: pass @@ -15,30 +16,30 @@ async def get_data(self, chat_id, user_id): Get data for a user in a particular chat. """ raise NotImplementedError - + async def set_state(self, chat_id, user_id, state): """ Set state for a particular user. - ! Note that you should create a - record if it does not exist, and + ! Note that you should create a + record if it does not exist, and if a record with state already exists, you need to update a record. """ raise NotImplementedError - + async def delete_state(self, chat_id, user_id): """ Delete state for a particular user. """ raise NotImplementedError - + async def reset_data(self, chat_id, user_id): """ Reset data for a particular user in a chat. """ raise NotImplementedError - + async def get_state(self, chat_id, user_id): raise NotImplementedError @@ -51,16 +52,16 @@ def get_interactive_data(self, chat_id, user_id): async def save(self, chat_id, user_id, data): raise NotImplementedError - + def _get_key( - self, - chat_id: int, - user_id: int, - prefix: str, - separator: str, - business_connection_id: str=None, - message_thread_id: int=None, - bot_id: int=None + self, + chat_id: int, + user_id: int, + prefix: str, + separator: str, + business_connection_id: str = None, + message_thread_id: int = None, + bot_id: int = None, ) -> str: """ Convert parameters to a key. @@ -78,15 +79,20 @@ def _get_key( return separator.join(params) - - - - class StateDataContext: """ Class for data. """ - def __init__(self , obj, chat_id, user_id, business_connection_id=None, message_thread_id=None, bot_id=None, ): + + def __init__( + self, + obj, + chat_id, + user_id, + business_connection_id=None, + message_thread_id=None, + bot_id=None, + ): self.obj = obj self.data = None self.chat_id = chat_id @@ -95,13 +101,23 @@ def __init__(self , obj, chat_id, user_id, business_connection_id=None, message_ self.business_connection_id = business_connection_id self.message_thread_id = message_thread_id - - async def __aenter__(self): - data = await self.obj.get_data(chat_id=self.chat_id, user_id=self.user_id, business_connection_id=self.business_connection_id, - message_thread_id=self.message_thread_id, bot_id=self.bot_id) + data = await self.obj.get_data( + chat_id=self.chat_id, + user_id=self.user_id, + business_connection_id=self.business_connection_id, + message_thread_id=self.message_thread_id, + bot_id=self.bot_id, + ) self.data = copy.deepcopy(data) return self.data async def __aexit__(self, exc_type, exc_val, exc_tb): - return await self.obj.save(self.chat_id, self.user_id, self.data, self.business_connection_id, self.message_thread_id, self.bot_id) \ No newline at end of file + return await self.obj.save( + self.chat_id, + self.user_id, + self.data, + self.business_connection_id, + self.message_thread_id, + self.bot_id, + ) diff --git a/telebot/asyncio_storage/memory_storage.py b/telebot/asyncio_storage/memory_storage.py index 87449cace..e17d5fe8c 100644 --- a/telebot/asyncio_storage/memory_storage.py +++ b/telebot/asyncio_storage/memory_storage.py @@ -1,6 +1,7 @@ from telebot.asyncio_storage.base_storage import StateStorageBase, StateDataContext from typing import Optional, Union + class StateMemoryStorage(StateStorageBase): """ Memory storage for states. @@ -19,72 +20,114 @@ class StateMemoryStorage(StateStorageBase): :type prefix: Optional[str] """ - def __init__(self, - separator: Optional[str]=":", - prefix: Optional[str]="telebot" - ) -> None: + def __init__( + self, separator: Optional[str] = ":", prefix: Optional[str] = "telebot" + ) -> None: self.separator = separator self.prefix = prefix if not self.prefix: raise ValueError("Prefix cannot be empty") - - self.data = {} # key: telebot:bot_id:business_connection_id:message_thread_id:chat_id:user_id - async def set_state( - self, chat_id: int, user_id: int, state: str, business_connection_id: Optional[str]=None, - message_thread_id: Optional[int]=None, bot_id: Optional[int]=None + self.data = ( + {} + ) # key: telebot:bot_id:business_connection_id:message_thread_id:chat_id:user_id + async def set_state( + self, + chat_id: int, + user_id: int, + state: str, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, ) -> bool: if hasattr(state, "name"): state = state.name _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) if self.data.get(_key) is None: self.data[_key] = {"state": state, "data": {}} else: self.data[_key]["state"] = state - + return True - + async def get_state( - self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, - message_thread_id: Optional[int]=None, bot_id: Optional[int]=None + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, ) -> Union[str, None]: _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) if self.data.get(_key) is None: return None - + return self.data[_key]["state"] - + async def delete_state( - self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, - message_thread_id: Optional[int]=None, bot_id: Optional[int]=None + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, ) -> bool: _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) - + if self.data.get(_key) is None: return False - + del self.data[_key] return True - - + async def set_data( - self, chat_id: int, user_id: int, key: str, value: Union[str, int, float, dict], - business_connection_id: Optional[str]=None, message_thread_id: Optional[int]=None, - bot_id: Optional[int]=None) -> bool: - + self, + chat_id: int, + user_id: int, + key: str, + value: Union[str, int, float, dict], + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) if self.data.get(_key) is None: @@ -92,57 +135,91 @@ async def set_data( self.data[_key]["data"][key] = value return True - async def get_data( - self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, - message_thread_id: Optional[int]=None, bot_id: Optional[int]=None + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, ) -> dict: - + _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) return self.data.get(_key, {}).get("data", {}) - + async def reset_data( - self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, - message_thread_id: Optional[int]=None, bot_id: Optional[int]=None + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, ) -> bool: - + _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) if self.data.get(_key) is None: return False self.data[_key]["data"] = {} return True - + def get_interactive_data( - self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, - message_thread_id: Optional[int]=None, bot_id: Optional[int]=None + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, ) -> Optional[dict]: return StateDataContext( - self, chat_id=chat_id, user_id=user_id, business_connection_id=business_connection_id, - message_thread_id=message_thread_id, bot_id=bot_id + self, + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + message_thread_id=message_thread_id, + bot_id=bot_id, ) - + async def save( - self, chat_id: int, user_id: int, data: dict, business_connection_id: Optional[str]=None, - message_thread_id: Optional[int]=None, bot_id: Optional[int]=None + self, + chat_id: int, + user_id: int, + data: dict, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, ) -> bool: _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, - message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) if self.data.get(_key) is None: return False self.data[_key]["data"] = data return True - + def __str__(self) -> str: return f"" - - - diff --git a/telebot/asyncio_storage/pickle_storage.py b/telebot/asyncio_storage/pickle_storage.py index 0506d4703..723672034 100644 --- a/telebot/asyncio_storage/pickle_storage.py +++ b/telebot/asyncio_storage/pickle_storage.py @@ -1,4 +1,3 @@ - aiofiles_installed = True try: import aiofiles @@ -12,12 +11,15 @@ from telebot.asyncio_storage.base_storage import StateStorageBase, StateDataContext + def with_lock(func: Callable) -> Callable: async def wrapper(self, *args, **kwargs): async with self.lock: return await func(self, *args, **kwargs) + return wrapper + class StatePickleStorage(StateStorageBase): """ State storage based on pickle file. @@ -27,7 +29,7 @@ class StatePickleStorage(StateStorageBase): This storage is not recommended for production use. Data may be corrupted. If you face a case where states do not work as expected, try to use another storage. - + .. code-block:: python3 storage = StatePickleStorage() @@ -42,9 +44,14 @@ class StatePickleStorage(StateStorageBase): :param separator: Separator for keys, default is ":". :type separator: Optional[str] """ - def __init__(self, file_path: str = "./.state-save/states.pkl", - prefix='telebot', separator: Optional[str] = ":") -> None: - + + def __init__( + self, + file_path: str = "./.state-save/states.pkl", + prefix="telebot", + separator: Optional[str] = ":", + ) -> None: + if not aiofiles_installed: raise ImportError("Please install aiofiles using `pip install aiofiles`") @@ -55,12 +62,12 @@ def __init__(self, file_path: str = "./.state-save/states.pkl", self.create_dir() async def _read_from_file(self) -> dict: - async with aiofiles.open(self.file_path, 'rb') as f: + async with aiofiles.open(self.file_path, "rb") as f: data = await f.read() return pickle.loads(data) async def _write_to_file(self, data: dict) -> None: - async with aiofiles.open(self.file_path, 'wb') as f: + async with aiofiles.open(self.file_path, "wb") as f: await f.write(pickle.dumps(data)) def create_dir(self): @@ -70,16 +77,27 @@ def create_dir(self): dirs, filename = os.path.split(self.file_path) os.makedirs(dirs, exist_ok=True) if not os.path.isfile(self.file_path): - with open(self.file_path,'wb') as file: + with open(self.file_path, "wb") as file: pickle.dump({}, file) - @with_lock - async def set_state(self, chat_id: int, user_id: int, state: str, - business_connection_id: Optional[str] = None, message_thread_id: Optional[int] = None, - bot_id: Optional[int] = None) -> bool: + async def set_state( + self, + chat_id: int, + user_id: int, + state: str, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) data = await self._read_from_file() if _key not in data: @@ -90,19 +108,43 @@ async def set_state(self, chat_id: int, user_id: int, state: str, return True @with_lock - async def get_state(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, - message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> Union[str, None]: + async def get_state( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> Union[str, None]: _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) data = await self._read_from_file() return data.get(_key, {}).get("state") @with_lock - async def delete_state(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, - message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool: + async def delete_state( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) data = await self._read_from_file() if _key in data: @@ -112,11 +154,24 @@ async def delete_state(self, chat_id: int, user_id: int, business_connection_id: return False @with_lock - async def set_data(self, chat_id: int, user_id: int, key: str, value: Union[str, int, float, dict], - business_connection_id: Optional[str] = None, message_thread_id: Optional[int] = None, - bot_id: Optional[int] = None) -> bool: + async def set_data( + self, + chat_id: int, + user_id: int, + key: str, + value: Union[str, int, float, dict], + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) data = await self._read_from_file() state_data = data.get(_key, {}) @@ -129,19 +184,43 @@ async def set_data(self, chat_id: int, user_id: int, key: str, value: Union[str, return True @with_lock - async def get_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, - message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> dict: + async def get_data( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> dict: _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) data = await self._read_from_file() return data.get(_key, {}).get("data", {}) @with_lock - async def reset_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, - message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool: + async def reset_data( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) data = await self._read_from_file() if _key in data: @@ -150,18 +229,41 @@ async def reset_data(self, chat_id: int, user_id: int, business_connection_id: O return True return False - def get_interactive_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, - message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> Optional[dict]: + def get_interactive_data( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> Optional[dict]: return StateDataContext( - self, chat_id=chat_id, user_id=user_id, business_connection_id=business_connection_id, - message_thread_id=message_thread_id, bot_id=bot_id + self, + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + message_thread_id=message_thread_id, + bot_id=bot_id, ) @with_lock - async def save(self, chat_id: int, user_id: int, data: dict, business_connection_id: Optional[str] = None, - message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool: + async def save( + self, + chat_id: int, + user_id: int, + data: dict, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) data = await self._read_from_file() data[_key]["data"] = data diff --git a/telebot/asyncio_storage/redis_storage.py b/telebot/asyncio_storage/redis_storage.py index 69b18d5a1..eccff5c30 100644 --- a/telebot/asyncio_storage/redis_storage.py +++ b/telebot/asyncio_storage/redis_storage.py @@ -1,4 +1,3 @@ - redis_installed = True try: import redis @@ -17,8 +16,10 @@ def async_with_lock(func: Callable[..., Coroutine]) -> Callable[..., Coroutine]: async def wrapper(self, *args, **kwargs): async with self.lock: return await func(self, *args, **kwargs) + return wrapper + def async_with_pipeline(func: Callable[..., Coroutine]) -> Callable[..., Coroutine]: async def wrapper(self, *args, **kwargs): async with self.redis.pipeline() as pipe: @@ -26,8 +27,10 @@ async def wrapper(self, *args, **kwargs): result = await func(self, pipe, *args, **kwargs) await pipe.execute() return result + return wrapper + class StateRedisStorage(StateStorageBase): """ State storage based on Redis. @@ -62,16 +65,22 @@ class StateRedisStorage(StateStorageBase): :type separator: Optional[str] """ - def __init__(self, host='localhost', port=6379, db=0, password=None, - prefix='telebot', - redis_url=None, - connection_pool: 'ConnectionPool'=None, - separator: Optional[str] = ":", - ) -> None: - + + def __init__( + self, + host="localhost", + port=6379, + db=0, + password=None, + prefix="telebot", + redis_url=None, + connection_pool: "ConnectionPool" = None, + separator: Optional[str] = ":", + ) -> None: + if not redis_installed: raise ImportError("Please install redis using `pip install redis`") - + self.separator = separator self.prefix = prefix if not self.prefix: @@ -83,46 +92,105 @@ def __init__(self, host='localhost', port=6379, db=0, password=None, self.redis = Redis(connection_pool=connection_pool) else: self.redis = Redis(host=host, port=port, db=db, password=password) - + self.lock = asyncio.Lock() @async_with_lock @async_with_pipeline - async def set_state(self, pipe, chat_id: int, user_id: int, state: str, - business_connection_id: Optional[str] = None, - message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool: + async def set_state( + self, + pipe, + chat_id: int, + user_id: int, + state: str, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: if hasattr(state, "name"): state = state.name - _key = self._get_key(chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id) + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) pipe.hget(_key, "data") result = await pipe.execute() data = result[0] if data is None: pipe.hset(_key, "data", json.dumps({})) - + await pipe.hset(_key, "state", state) - + return True - async def get_state(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, - message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> Union[str, None]: - _key = self._get_key(chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id) + async def get_state( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> Union[str, None]: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) state_bytes = await self.redis.hget(_key, "state") - return state_bytes.decode('utf-8') if state_bytes else None - - async def delete_state(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, - message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool: - _key = self._get_key(chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id) + return state_bytes.decode("utf-8") if state_bytes else None + + async def delete_state( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) result = await self.redis.delete(_key) return result > 0 @async_with_lock @async_with_pipeline - async def set_data(self, pipe, chat_id: int, user_id: int, key: str, value: Union[str, int, float, dict], - business_connection_id: Optional[str] = None, message_thread_id: Optional[int] = None, - bot_id: Optional[int] = None) -> bool: - _key = self._get_key(chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id) + async def set_data( + self, + pipe, + chat_id: int, + user_id: int, + key: str, + value: Union[str, int, float, dict], + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) data = await pipe.hget(_key, "data") data = await pipe.execute() data = data[0] @@ -134,41 +202,97 @@ async def set_data(self, pipe, chat_id: int, user_id: int, key: str, value: Unio await pipe.hset(_key, "data", json.dumps(data)) return True - async def get_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, - message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> dict: - _key = self._get_key(chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id) + async def get_data( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> dict: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) data = await self.redis.hget(_key, "data") return json.loads(data) if data else {} @async_with_lock @async_with_pipeline - async def reset_data(self, pipe, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, - message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool: - _key = self._get_key(chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id) + async def reset_data( + self, + pipe, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) if await pipe.exists(_key): await pipe.hset(_key, "data", "{}") else: return False return True - def get_interactive_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, - message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> Optional[dict]: - return StateDataContext(self, chat_id=chat_id, user_id=user_id, business_connection_id=business_connection_id, - message_thread_id=message_thread_id, bot_id=bot_id) + def get_interactive_data( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> Optional[dict]: + return StateDataContext( + self, + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + message_thread_id=message_thread_id, + bot_id=bot_id, + ) @async_with_lock @async_with_pipeline - async def save(self, pipe, chat_id: int, user_id: int, data: dict, business_connection_id: Optional[str] = None, - message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool: - _key = self._get_key(chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id) + async def save( + self, + pipe, + chat_id: int, + user_id: int, + data: dict, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) if await pipe.exists(_key): await pipe.hset(_key, "data", json.dumps(data)) else: return False return True - - def migrate_format(self, bot_id: int, - prefix: Optional[str]="telebot_"): + + def migrate_format(self, bot_id: int, prefix: Optional[str] = "telebot_"): """ Migrate from old to new format of keys. Run this function once to migrate all redis existing keys to new format. @@ -188,20 +312,22 @@ def migrate_format(self, bot_id: int, :type prefix: Optional[str] """ keys = self.redis.keys(f"{prefix}*") - + for key in keys: - old_key = key.decode('utf-8') + old_key = key.decode("utf-8") # old: {prefix}chat_id: {user_id: {'state': None, 'data': {}}, ...} value = self.redis.get(old_key) value = json.loads(value) - chat_id = old_key[len(prefix):] + chat_id = old_key[len(prefix) :] user_id = list(value.keys())[0] - state = value[user_id]['state'] - state_data = value[user_id]['data'] + state = value[user_id]["state"] + state_data = value[user_id]["data"] # set new format - new_key = self._get_key(int(chat_id), int(user_id), self.prefix, self.separator, bot_id=bot_id) + new_key = self._get_key( + int(chat_id), int(user_id), self.prefix, self.separator, bot_id=bot_id + ) self.redis.hset(new_key, "state", state) self.redis.hset(new_key, "data", json.dumps(state_data)) diff --git a/telebot/states/__init__.py b/telebot/states/__init__.py index 015f135e4..2491608ee 100644 --- a/telebot/states/__init__.py +++ b/telebot/states/__init__.py @@ -1,8 +1,10 @@ """ Contains classes for states and state groups. """ + from telebot import types + class State: """ Class representing a state. @@ -12,9 +14,11 @@ class State: class MyStates(StatesGroup): my_state = State() # returns my_state:State string. """ + def __init__(self) -> None: self.name: str = None self.group: StatesGroup = None + def __str__(self) -> str: return f"<{self.name}>" @@ -28,12 +32,17 @@ class StatesGroup: class MyStates(StatesGroup): my_state = State() # returns my_state:State string. """ + def __init_subclass__(cls) -> None: state_list = [] for name, value in cls.__dict__.items(): - if not name.startswith('__') and not callable(value) and isinstance(value, State): + if ( + not name.startswith("__") + and not callable(value) + and isinstance(value, State) + ): # change value of that variable - value.name = ':'.join((cls.__name__, name)) + value.name = ":".join((cls.__name__, name)) value.group = cls state_list.append(value) cls._state_list = state_list @@ -42,41 +51,78 @@ def __init_subclass__(cls) -> None: def state_list(self): return self._state_list + def resolve_context(message, bot_id: int) -> tuple: # chat_id, user_id, business_connection_id, bot_id, message_thread_id - + # message, edited_message, channel_post, edited_channel_post, business_message, edited_business_message if isinstance(message, types.Message): - return (message.chat.id, message.from_user.id, message.business_connection_id, bot_id, - message.message_thread_id if message.is_topic_message else None) - elif isinstance(message, types.CallbackQuery): # callback_query - return (message.message.chat.id, message.from_user.id, message.message.business_connection_id, bot_id, - message.message.message_thread_id if message.message.is_topic_message else None) - elif isinstance(message, types.BusinessConnection): # business_connection + return ( + message.chat.id, + message.from_user.id, + message.business_connection_id, + bot_id, + message.message_thread_id if message.is_topic_message else None, + ) + elif isinstance(message, types.CallbackQuery): # callback_query + return ( + message.message.chat.id, + message.from_user.id, + message.message.business_connection_id, + bot_id, + ( + message.message.message_thread_id + if message.message.is_topic_message + else None + ), + ) + elif isinstance(message, types.BusinessConnection): # business_connection return (message.user_chat_id, message.user.id, message.id, bot_id, None) - elif isinstance(message, types.BusinessMessagesDeleted): # deleted_business_messages - return (message.chat.id, message.chat.id, message.business_connection_id, bot_id, None) - elif isinstance(message, types.MessageReactionUpdated): # message_reaction + elif isinstance( + message, types.BusinessMessagesDeleted + ): # deleted_business_messages + return ( + message.chat.id, + message.chat.id, + message.business_connection_id, + bot_id, + None, + ) + elif isinstance(message, types.MessageReactionUpdated): # message_reaction return (message.chat.id, message.user.id, None, bot_id, None) - elif isinstance(message, types.MessageReactionCountUpdated): # message_reaction_count + elif isinstance( + message, types.MessageReactionCountUpdated + ): # message_reaction_count return (message.chat.id, None, None, bot_id, None) - elif isinstance(message, types.InlineQuery): # inline_query + elif isinstance(message, types.InlineQuery): # inline_query return (None, message.from_user.id, None, bot_id, None) - elif isinstance(message, types.ChosenInlineResult): # chosen_inline_result + elif isinstance(message, types.ChosenInlineResult): # chosen_inline_result return (None, message.from_user.id, None, bot_id, None) - elif isinstance(message, types.ShippingQuery): # shipping_query + elif isinstance(message, types.ShippingQuery): # shipping_query return (None, message.from_user.id, None, bot_id, None) - elif isinstance(message, types.PreCheckoutQuery): # pre_checkout_query + elif isinstance(message, types.PreCheckoutQuery): # pre_checkout_query return (None, message.from_user.id, None, bot_id, None) - elif isinstance(message, types.PollAnswer): # poll_answer + elif isinstance(message, types.PollAnswer): # poll_answer return (None, message.user.id, None, bot_id, None) - elif isinstance(message, types.ChatMemberUpdated): # chat_member # my_chat_member + elif isinstance(message, types.ChatMemberUpdated): # chat_member # my_chat_member return (message.chat.id, message.from_user.id, None, bot_id, None) - elif isinstance(message, types.ChatJoinRequest): # chat_join_request + elif isinstance(message, types.ChatJoinRequest): # chat_join_request return (message.chat.id, message.from_user.id, None, bot_id, None) - elif isinstance(message, types.ChatBoostRemoved): # removed_chat_boost - return (message.chat.id, message.source.user.id if message.source else None, None, bot_id, None) - elif isinstance(message, types.ChatBoostUpdated): # chat_boost - return (message.chat.id, message.boost.source.user.id if message.boost.source else None, None, bot_id, None) + elif isinstance(message, types.ChatBoostRemoved): # removed_chat_boost + return ( + message.chat.id, + message.source.user.id if message.source else None, + None, + bot_id, + None, + ) + elif isinstance(message, types.ChatBoostUpdated): # chat_boost + return ( + message.chat.id, + message.boost.source.user.id if message.boost.source else None, + None, + bot_id, + None, + ) else: - pass # not yet supported :( \ No newline at end of file + pass # not yet supported :( diff --git a/telebot/states/aio/__init__.py b/telebot/states/aio/__init__.py index 60f9ea2b7..14ef19536 100644 --- a/telebot/states/aio/__init__.py +++ b/telebot/states/aio/__init__.py @@ -2,6 +2,6 @@ from .middleware import StateMiddleware __all__ = [ - 'StateContext', - 'StateMiddleware', -] \ No newline at end of file + "StateContext", + "StateMiddleware", +] diff --git a/telebot/states/aio/context.py b/telebot/states/aio/context.py index 956ca6837..4c9ad61a3 100644 --- a/telebot/states/aio/context.py +++ b/telebot/states/aio/context.py @@ -6,8 +6,7 @@ from typing import Union - -class StateContext(): +class StateContext: """ Class representing a state context. @@ -21,7 +20,7 @@ async def start_ex(message: types.Message, state_context: StateContext): await bot.send_message(message.chat.id, 'Hi, write me a name', reply_to_message_id=message.message_id) # also, state_context.data(), .add_data(), .reset_data(), .delete() methods available. """ - + def __init__(self, message: Union[Message, CallbackQuery], bot: str) -> None: self.message: Union[Message, CallbackQuery] = message self.bot: AsyncTeleBot = bot @@ -42,7 +41,9 @@ async def start_ex(message: types.Message, state_context: StateContext): await bot.send_message(message.chat.id, 'Hi, write me a name', reply_to_message_id=message.message_id) """ - chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) + chat_id, user_id, business_connection_id, bot_id, message_thread_id = ( + resolve_context(self.message, self.bot.bot_id) + ) if isinstance(state, State): state = state.name return await self.bot.set_state( @@ -51,9 +52,9 @@ async def start_ex(message: types.Message, state_context: StateContext): state=state, business_connection_id=business_connection_id, bot_id=bot_id, - message_thread_id=message_thread_id + message_thread_id=message_thread_id, ) - + async def get(self) -> str: """ Get current state for current user. @@ -62,67 +63,75 @@ async def get(self) -> str: :rtype: str """ - chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) + chat_id, user_id, business_connection_id, bot_id, message_thread_id = ( + resolve_context(self.message, self.bot.bot_id) + ) return await self.bot.get_state( chat_id=chat_id, user_id=user_id, business_connection_id=business_connection_id, bot_id=bot_id, - message_thread_id=message_thread_id + message_thread_id=message_thread_id, ) - + async def delete(self) -> bool: """ Deletes state and data for current user. .. warning:: - + This method deletes state and associated data for current user. """ - chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) + chat_id, user_id, business_connection_id, bot_id, message_thread_id = ( + resolve_context(self.message, self.bot.bot_id) + ) return await self.bot.delete_state( chat_id=chat_id, user_id=user_id, business_connection_id=business_connection_id, bot_id=bot_id, - message_thread_id=message_thread_id + message_thread_id=message_thread_id, ) - + async def reset_data(self) -> bool: """ - Reset data for current user. + Reset data for current user. State will not be changed. """ - chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) + chat_id, user_id, business_connection_id, bot_id, message_thread_id = ( + resolve_context(self.message, self.bot.bot_id) + ) return await self.bot.reset_data( chat_id=chat_id, user_id=user_id, business_connection_id=business_connection_id, bot_id=bot_id, - message_thread_id=message_thread_id + message_thread_id=message_thread_id, ) - + def data(self) -> dict: """ Get data for current user. .. code-block:: python3 - + with state_context.data() as data: print(data) data['name'] = 'John' """ - chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) + chat_id, user_id, business_connection_id, bot_id, message_thread_id = ( + resolve_context(self.message, self.bot.bot_id) + ) return self.bot.retrieve_data( chat_id=chat_id, user_id=user_id, business_connection_id=business_connection_id, bot_id=bot_id, - message_thread_id=message_thread_id + message_thread_id=message_thread_id, ) - + async def add_data(self, **kwargs) -> None: """ Add data for current user. @@ -131,7 +140,9 @@ async def add_data(self, **kwargs) -> None: :type kwargs: dict """ - chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) + chat_id, user_id, business_connection_id, bot_id, message_thread_id = ( + resolve_context(self.message, self.bot.bot_id) + ) return await self.bot.add_data( chat_id=chat_id, user_id=user_id, @@ -140,4 +151,3 @@ async def add_data(self, **kwargs) -> None: message_thread_id=message_thread_id, **kwargs ) - \ No newline at end of file diff --git a/telebot/states/aio/middleware.py b/telebot/states/aio/middleware.py index 546b29067..675b7b462 100644 --- a/telebot/states/aio/middleware.py +++ b/telebot/states/aio/middleware.py @@ -14,8 +14,8 @@ def __init__(self, bot: AsyncTeleBot) -> None: async def pre_process(self, message, data): state_context = StateContext(message, self.bot) - data['state_context'] = state_context - data['state'] = state_context # 2 ways to access state context + data["state_context"] = state_context + data["state"] = state_context # 2 ways to access state context async def post_process(self, message, data, exception): pass diff --git a/telebot/storage/__init__.py b/telebot/storage/__init__.py index 954c1b31d..bb3dd7f07 100644 --- a/telebot/storage/__init__.py +++ b/telebot/storage/__init__.py @@ -1,13 +1,13 @@ from telebot.storage.memory_storage import StateMemoryStorage from telebot.storage.redis_storage import StateRedisStorage from telebot.storage.pickle_storage import StatePickleStorage -from telebot.storage.base_storage import StateDataContext,StateStorageBase - - - +from telebot.storage.base_storage import StateDataContext, StateStorageBase __all__ = [ - 'StateStorageBase', 'StateDataContext', - 'StateMemoryStorage', 'StateRedisStorage', 'StatePickleStorage' -] \ No newline at end of file + "StateStorageBase", + "StateDataContext", + "StateMemoryStorage", + "StateRedisStorage", + "StatePickleStorage", +] diff --git a/telebot/storage/base_storage.py b/telebot/storage/base_storage.py index 25a6c6b04..9956c7ab9 100644 --- a/telebot/storage/base_storage.py +++ b/telebot/storage/base_storage.py @@ -1,5 +1,6 @@ import copy + class StateStorageBase: def __init__(self) -> None: pass @@ -15,30 +16,30 @@ def get_data(self, chat_id, user_id): Get data for a user in a particular chat. """ raise NotImplementedError - + def set_state(self, chat_id, user_id, state): """ Set state for a particular user. - ! Note that you should create a - record if it does not exist, and + ! Note that you should create a + record if it does not exist, and if a record with state already exists, you need to update a record. """ raise NotImplementedError - + def delete_state(self, chat_id, user_id): """ Delete state for a particular user. """ raise NotImplementedError - + def reset_data(self, chat_id, user_id): """ Reset data for a particular user in a chat. """ raise NotImplementedError - + def get_state(self, chat_id, user_id): raise NotImplementedError @@ -47,16 +48,16 @@ def get_interactive_data(self, chat_id, user_id): def save(self, chat_id, user_id, data): raise NotImplementedError - + def _get_key( - self, - chat_id: int, - user_id: int, - prefix: str, - separator: str, - business_connection_id: str=None, - message_thread_id: int=None, - bot_id: int=None + self, + chat_id: int, + user_id: int, + prefix: str, + separator: str, + business_connection_id: str = None, + message_thread_id: int = None, + bot_id: int = None, ) -> str: """ Convert parameters to a key. @@ -74,18 +75,28 @@ def _get_key( return separator.join(params) - - - - class StateDataContext: """ Class for data. """ - def __init__(self , obj, chat_id, user_id, business_connection_id=None, message_thread_id=None, bot_id=None, ): + + def __init__( + self, + obj, + chat_id, + user_id, + business_connection_id=None, + message_thread_id=None, + bot_id=None, + ): self.obj = obj - res = obj.get_data(chat_id=chat_id, user_id=user_id, business_connection_id=business_connection_id, - message_thread_id=message_thread_id, bot_id=bot_id) + res = obj.get_data( + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + message_thread_id=message_thread_id, + bot_id=bot_id, + ) self.data = copy.deepcopy(res) self.chat_id = chat_id self.user_id = user_id @@ -93,10 +104,15 @@ def __init__(self , obj, chat_id, user_id, business_connection_id=None, message_ self.business_connection_id = business_connection_id self.message_thread_id = message_thread_id - - def __enter__(self): return self.data def __exit__(self, exc_type, exc_val, exc_tb): - return self.obj.save(self.chat_id, self.user_id, self.data, self.business_connection_id, self.message_thread_id, self.bot_id) \ No newline at end of file + return self.obj.save( + self.chat_id, + self.user_id, + self.data, + self.business_connection_id, + self.message_thread_id, + self.bot_id, + ) diff --git a/telebot/storage/memory_storage.py b/telebot/storage/memory_storage.py index dc88bcebc..21266632c 100644 --- a/telebot/storage/memory_storage.py +++ b/telebot/storage/memory_storage.py @@ -1,6 +1,7 @@ from telebot.storage.base_storage import StateStorageBase, StateDataContext from typing import Optional, Union + class StateMemoryStorage(StateStorageBase): """ Memory storage for states. @@ -18,72 +19,115 @@ class StateMemoryStorage(StateStorageBase): :param prefix: Prefix for keys, default is "telebot". :type prefix: Optional[str] """ - def __init__(self, - separator: Optional[str]=":", - prefix: Optional[str]="telebot" - ) -> None: + + def __init__( + self, separator: Optional[str] = ":", prefix: Optional[str] = "telebot" + ) -> None: self.separator = separator self.prefix = prefix if not self.prefix: raise ValueError("Prefix cannot be empty") - - self.data = {} # key: telebot:bot_id:business_connection_id:message_thread_id:chat_id:user_id - def set_state( - self, chat_id: int, user_id: int, state: str, business_connection_id: Optional[str]=None, - message_thread_id: Optional[int]=None, bot_id: Optional[int]=None + self.data = ( + {} + ) # key: telebot:bot_id:business_connection_id:message_thread_id:chat_id:user_id + def set_state( + self, + chat_id: int, + user_id: int, + state: str, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, ) -> bool: if hasattr(state, "name"): state = state.name _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) if self.data.get(_key) is None: self.data[_key] = {"state": state, "data": {}} else: self.data[_key]["state"] = state - + return True - + def get_state( - self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, - message_thread_id: Optional[int]=None, bot_id: Optional[int]=None + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, ) -> Union[str, None]: _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) if self.data.get(_key) is None: return None - + return self.data[_key]["state"] - + def delete_state( - self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, - message_thread_id: Optional[int]=None, bot_id: Optional[int]=None + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, ) -> bool: _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) - + if self.data.get(_key) is None: return False - + del self.data[_key] return True - - + def set_data( - self, chat_id: int, user_id: int, key: str, value: Union[str, int, float, dict], - business_connection_id: Optional[str]=None, message_thread_id: Optional[int]=None, - bot_id: Optional[int]=None) -> bool: - + self, + chat_id: int, + user_id: int, + key: str, + value: Union[str, int, float, dict], + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) if self.data.get(_key) is None: @@ -91,57 +135,91 @@ def set_data( self.data[_key]["data"][key] = value return True - def get_data( - self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, - message_thread_id: Optional[int]=None, bot_id: Optional[int]=None + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, ) -> dict: - + _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) return self.data.get(_key, {}).get("data", {}) - + def reset_data( - self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, - message_thread_id: Optional[int]=None, bot_id: Optional[int]=None + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, ) -> bool: - + _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) if self.data.get(_key) is None: return False self.data[_key]["data"] = {} return True - + def get_interactive_data( - self, chat_id: int, user_id: int, business_connection_id: Optional[str]=None, - message_thread_id: Optional[int]=None, bot_id: Optional[int]=None + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, ) -> Optional[dict]: return StateDataContext( - self, chat_id=chat_id, user_id=user_id, business_connection_id=business_connection_id, - message_thread_id=message_thread_id, bot_id=bot_id + self, + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + message_thread_id=message_thread_id, + bot_id=bot_id, ) - + def save( - self, chat_id: int, user_id: int, data: dict, business_connection_id: Optional[str]=None, - message_thread_id: Optional[int]=None, bot_id: Optional[int]=None + self, + chat_id: int, + user_id: int, + data: dict, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, ) -> bool: _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, - message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) if self.data.get(_key) is None: return False self.data[_key]["data"] = data return True - + def __str__(self) -> str: return f"" - - - diff --git a/telebot/storage/pickle_storage.py b/telebot/storage/pickle_storage.py index fc8dc3289..e11a05043 100644 --- a/telebot/storage/pickle_storage.py +++ b/telebot/storage/pickle_storage.py @@ -4,12 +4,15 @@ from typing import Optional, Union, Callable from telebot.storage.base_storage import StateStorageBase, StateDataContext + def with_lock(func: Callable) -> Callable: def wrapper(self, *args, **kwargs): with self.lock: return func(self, *args, **kwargs) + return wrapper + class StatePickleStorage(StateStorageBase): """ State storage based on pickle file. @@ -19,7 +22,7 @@ class StatePickleStorage(StateStorageBase): This storage is not recommended for production use. Data may be corrupted. If you face a case where states do not work as expected, try to use another storage. - + .. code-block:: python3 storage = StatePickleStorage() @@ -35,8 +38,12 @@ class StatePickleStorage(StateStorageBase): :type separator: Optional[str] """ - def __init__(self, file_path: str = "./.state-save/states.pkl", - prefix='telebot', separator: Optional[str] = ":") -> None: + def __init__( + self, + file_path: str = "./.state-save/states.pkl", + prefix="telebot", + separator: Optional[str] = ":", + ) -> None: self.file_path = file_path self.prefix = prefix self.separator = separator @@ -45,11 +52,11 @@ def __init__(self, file_path: str = "./.state-save/states.pkl", self.create_dir() def _read_from_file(self) -> dict: - with open(self.file_path, 'rb') as f: + with open(self.file_path, "rb") as f: return pickle.load(f) def _write_to_file(self, data: dict) -> None: - with open(self.file_path, 'wb') as f: + with open(self.file_path, "wb") as f: pickle.dump(data, f) def create_dir(self): @@ -59,15 +66,27 @@ def create_dir(self): dirs, filename = os.path.split(self.file_path) os.makedirs(dirs, exist_ok=True) if not os.path.isfile(self.file_path): - with open(self.file_path,'wb') as file: + with open(self.file_path, "wb") as file: pickle.dump({}, file) @with_lock - def set_state(self, chat_id: int, user_id: int, state: str, - business_connection_id: Optional[str] = None, message_thread_id: Optional[int] = None, - bot_id: Optional[int] = None) -> bool: + def set_state( + self, + chat_id: int, + user_id: int, + state: str, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) data = self._read_from_file() if _key not in data: @@ -78,19 +97,43 @@ def set_state(self, chat_id: int, user_id: int, state: str, return True @with_lock - def get_state(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, - message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> Union[str, None]: + def get_state( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> Union[str, None]: _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) data = self._read_from_file() return data.get(_key, {}).get("state") - + @with_lock - def delete_state(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, - message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool: + def delete_state( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) data = self._read_from_file() if _key in data: @@ -100,11 +143,24 @@ def delete_state(self, chat_id: int, user_id: int, business_connection_id: Optio return False @with_lock - def set_data(self, chat_id: int, user_id: int, key: str, value: Union[str, int, float, dict], - business_connection_id: Optional[str] = None, message_thread_id: Optional[int] = None, - bot_id: Optional[int] = None) -> bool: + def set_data( + self, + chat_id: int, + user_id: int, + key: str, + value: Union[str, int, float, dict], + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) data = self._read_from_file() state_data = data.get(_key, {}) @@ -117,19 +173,43 @@ def set_data(self, chat_id: int, user_id: int, key: str, value: Union[str, int, return True @with_lock - def get_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, - message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> dict: + def get_data( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> dict: _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) data = self._read_from_file() return data.get(_key, {}).get("data", {}) @with_lock - def reset_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, - message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool: + def reset_data( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) data = self._read_from_file() if _key in data: @@ -138,23 +218,46 @@ def reset_data(self, chat_id: int, user_id: int, business_connection_id: Optiona return True return False - def get_interactive_data(self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, - message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> Optional[dict]: + def get_interactive_data( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> Optional[dict]: return StateDataContext( - self, chat_id=chat_id, user_id=user_id, business_connection_id=business_connection_id, - message_thread_id=message_thread_id, bot_id=bot_id + self, + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + message_thread_id=message_thread_id, + bot_id=bot_id, ) @with_lock - def save(self, chat_id: int, user_id: int, data: dict, business_connection_id: Optional[str] = None, - message_thread_id: Optional[int] = None, bot_id: Optional[int] = None) -> bool: + def save( + self, + chat_id: int, + user_id: int, + data: dict, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) data = self._read_from_file() data[_key]["data"] = data self._write_to_file(data) return True - + def __str__(self) -> str: return f"StatePickleStorage({self.file_path}, {self.prefix})" diff --git a/telebot/storage/redis_storage.py b/telebot/storage/redis_storage.py index 1d1413fa2..19f5c2cd5 100644 --- a/telebot/storage/redis_storage.py +++ b/telebot/storage/redis_storage.py @@ -8,6 +8,7 @@ except ImportError: redis_installed = False + class StateRedisStorage(StateStorageBase): """ State storage based on Redis. @@ -42,39 +43,56 @@ class StateRedisStorage(StateStorageBase): :type separator: Optional[str] """ - def __init__(self, host='localhost', port=6379, db=0, password=None, - prefix='telebot', - redis_url=None, - connection_pool: 'redis.ConnectionPool'=None, - separator: Optional[str]=":", - ) -> None: - + + def __init__( + self, + host="localhost", + port=6379, + db=0, + password=None, + prefix="telebot", + redis_url=None, + connection_pool: "redis.ConnectionPool" = None, + separator: Optional[str] = ":", + ) -> None: + if not redis_installed: - raise ImportError("Redis is not installed. Please install it via pip install redis") + raise ImportError( + "Redis is not installed. Please install it via pip install redis" + ) self.separator = separator self.prefix = prefix if not self.prefix: raise ValueError("Prefix cannot be empty") - + if redis_url: self.redis = redis.Redis.from_url(redis_url) elif connection_pool: self.redis = redis.Redis(connection_pool=connection_pool) else: self.redis = redis.Redis(host=host, port=port, db=db, password=password) - def set_state( - self, chat_id: int, user_id: int, state: str, - business_connection_id: Optional[str] = None, - message_thread_id: Optional[int] = None, bot_id: Optional[int] = None + self, + chat_id: int, + user_id: int, + state: str, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, ) -> bool: if hasattr(state, "name"): state = state.name _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) def set_state_action(pipe): @@ -89,36 +107,67 @@ def set_state_action(pipe): pipe.hset(_key, "data", json.dumps(data)) pipe.hset(_key, "state", state) - + self.redis.transaction(set_state_action, _key) return True def get_state( - self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, - message_thread_id: Optional[int] = None, bot_id: Optional[int] = None + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, ) -> Union[str, None]: _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) state_bytes = self.redis.hget(_key, "state") - return state_bytes.decode('utf-8') if state_bytes else None + return state_bytes.decode("utf-8") if state_bytes else None def delete_state( - self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, - message_thread_id: Optional[int] = None, bot_id: Optional[int] = None + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, ) -> bool: _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) return self.redis.delete(_key) > 0 def set_data( - self, chat_id: int, user_id: int, key: str, value: Union[str, int, float, dict], - business_connection_id: Optional[str] = None, message_thread_id: Optional[int] = None, - bot_id: Optional[int] = None + self, + chat_id: int, + user_id: int, + key: str, + value: Union[str, int, float, dict], + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, ) -> bool: _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) def set_data_action(pipe): @@ -136,21 +185,41 @@ def set_data_action(pipe): return True def get_data( - self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, - message_thread_id: Optional[int] = None, bot_id: Optional[int] = None + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, ) -> dict: _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) data = self.redis.hget(_key, "data") return json.loads(data) if data else {} def reset_data( - self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, - message_thread_id: Optional[int] = None, bot_id: Optional[int] = None + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, ) -> bool: _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) def reset_data_action(pipe): @@ -164,21 +233,39 @@ def reset_data_action(pipe): return True def get_interactive_data( - self, chat_id: int, user_id: int, business_connection_id: Optional[str] = None, - message_thread_id: Optional[int] = None, bot_id: Optional[int] = None + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, ) -> Optional[dict]: return StateDataContext( - self, chat_id=chat_id, user_id=user_id, business_connection_id=business_connection_id, - message_thread_id=message_thread_id, bot_id=bot_id + self, + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + message_thread_id=message_thread_id, + bot_id=bot_id, ) def save( - self, chat_id: int, user_id: int, data: dict, business_connection_id: Optional[str] = None, - message_thread_id: Optional[int] = None, bot_id: Optional[int] = None + self, + chat_id: int, + user_id: int, + data: dict, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, ) -> bool: _key = self._get_key( - chat_id, user_id, self.prefix, self.separator, business_connection_id, - message_thread_id, bot_id + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, ) def save_action(pipe): @@ -190,9 +277,8 @@ def save_action(pipe): self.redis.transaction(save_action, _key) return True - - def migrate_format(self, bot_id: int, - prefix: Optional[str]="telebot_"): + + def migrate_format(self, bot_id: int, prefix: Optional[str] = "telebot_"): """ Migrate from old to new format of keys. Run this function once to migrate all redis existing keys to new format. @@ -212,20 +298,22 @@ def migrate_format(self, bot_id: int, :type prefix: Optional[str] """ keys = self.redis.keys(f"{prefix}*") - + for key in keys: - old_key = key.decode('utf-8') + old_key = key.decode("utf-8") # old: {prefix}chat_id: {user_id: {'state': None, 'data': {}}, ...} value = self.redis.get(old_key) value = json.loads(value) - chat_id = old_key[len(prefix):] + chat_id = old_key[len(prefix) :] user_id = list(value.keys())[0] - state = value[user_id]['state'] - state_data = value[user_id]['data'] + state = value[user_id]["state"] + state_data = value[user_id]["data"] # set new format - new_key = self._get_key(int(chat_id), int(user_id), self.prefix, self.separator, bot_id=bot_id) + new_key = self._get_key( + int(chat_id), int(user_id), self.prefix, self.separator, bot_id=bot_id + ) self.redis.hset(new_key, "state", state) self.redis.hset(new_key, "data", json.dumps(state_data)) From dbfa514fc389d527ed368cf4630e40b957d88d6e Mon Sep 17 00:00:00 2001 From: _run Date: Sat, 27 Jul 2024 15:12:27 +0500 Subject: [PATCH 19/24] code cleanuop & renamed aio -> asyncio --- .../asynchronous_telebot/custom_states.py | 98 +++++++++++++------ examples/custom_states.py | 91 ++++++++++++----- telebot/states/{aio => asyncio}/__init__.py | 0 telebot/states/{aio => asyncio}/context.py | 0 telebot/states/{aio => asyncio}/middleware.py | 0 5 files changed, 136 insertions(+), 53 deletions(-) rename telebot/states/{aio => asyncio}/__init__.py (100%) rename telebot/states/{aio => asyncio}/context.py (100%) rename telebot/states/{aio => asyncio}/middleware.py (100%) diff --git a/examples/asynchronous_telebot/custom_states.py b/examples/asynchronous_telebot/custom_states.py index eefc6ebbb..f0111e27e 100644 --- a/examples/asynchronous_telebot/custom_states.py +++ b/examples/asynchronous_telebot/custom_states.py @@ -1,13 +1,13 @@ -from telebot import async_telebot -from telebot import asyncio_filters, types -from telebot.states import State, StatesGroup -from telebot.states.aio.context import StateContext +from telebot import async_telebot, asyncio_filters, types from telebot.asyncio_storage import StateMemoryStorage +from telebot.states import State, StatesGroup +from telebot.states.asyncio.context import StateContext # Initialize the bot state_storage = StateMemoryStorage() # don't use this in production; switch to redis bot = async_telebot.AsyncTeleBot("TOKEN", state_storage=state_storage) + # Define states class MyStates(StatesGroup): name = State() @@ -15,25 +15,39 @@ class MyStates(StatesGroup): color = State() hobby = State() + # Start command handler -@bot.message_handler(commands=['start']) +@bot.message_handler(commands=["start"]) async def start_ex(message: types.Message, state: StateContext): await state.set(MyStates.name) - await bot.send_message(message.chat.id, 'Hello! What is your first name?', reply_to_message_id=message.message_id) + await bot.send_message( + message.chat.id, + "Hello! What is your first name?", + reply_to_message_id=message.message_id, + ) + # Cancel command handler -@bot.message_handler(state="*", commands=['cancel']) +@bot.message_handler(state="*", commands=["cancel"]) async def any_state(message: types.Message, state: StateContext): await state.delete() - await bot.send_message(message.chat.id, 'Your information has been cleared. Type /start to begin again.', reply_to_message_id=message.message_id) + await bot.send_message( + message.chat.id, + "Your information has been cleared. Type /start to begin again.", + reply_to_message_id=message.message_id, + ) + # Handler for name input @bot.message_handler(state=MyStates.name) async def name_get(message: types.Message, state: StateContext): await state.set(MyStates.age) - await bot.send_message(message.chat.id, "How old are you?", reply_to_message_id=message.message_id) + await bot.send_message( + message.chat.id, "How old are you?", reply_to_message_id=message.message_id + ) await state.add_data(name=message.text) + # Handler for age input @bot.message_handler(state=MyStates.age, is_digit=True) async def ask_color(message: types.Message, state: StateContext): @@ -46,7 +60,13 @@ async def ask_color(message: types.Message, state: StateContext): buttons = [types.KeyboardButton(color) for color in colors] keyboard.add(*buttons) - await bot.send_message(message.chat.id, "What is your favorite color? Choose from the options below.", reply_markup=keyboard, reply_to_message_id=message.message_id) + await bot.send_message( + message.chat.id, + "What is your favorite color? Choose from the options below.", + reply_markup=keyboard, + reply_to_message_id=message.message_id, + ) + # Handler for color input @bot.message_handler(state=MyStates.color) @@ -60,15 +80,23 @@ async def ask_hobby(message: types.Message, state: StateContext): buttons = [types.KeyboardButton(hobby) for hobby in hobbies] keyboard.add(*buttons) - await bot.send_message(message.chat.id, "What is one of your hobbies? Choose from the options below.", reply_markup=keyboard, reply_to_message_id=message.message_id) + await bot.send_message( + message.chat.id, + "What is one of your hobbies? Choose from the options below.", + reply_markup=keyboard, + reply_to_message_id=message.message_id, + ) + # Handler for hobby input; use filters to ease validation -@bot.message_handler(state=MyStates.hobby, text=['Reading', 'Traveling', 'Gaming', 'Cooking']) +@bot.message_handler( + state=MyStates.hobby, text=["Reading", "Traveling", "Gaming", "Cooking"] +) async def finish(message: types.Message, state: StateContext): async with state.data() as data: - name = data.get('name') - age = data.get('age') - color = data.get('color') + name = data.get("name") + age = data.get("age") + color = data.get("color") hobby = message.text # Get the hobby from the message text # Provide a fun fact based on color @@ -79,24 +107,36 @@ async def finish(message: types.Message, state: StateContext): "Yellow": "Yellow is a cheerful color often associated with happiness.", "Purple": "Purple signifies royalty and luxury.", "Orange": "Orange is a vibrant color that stimulates enthusiasm.", - "Other": "Colors have various meanings depending on context." + "Other": "Colors have various meanings depending on context.", } - color_fact = color_facts.get(color, "Colors have diverse meanings, and yours is unique!") - - msg = (f"Thank you for sharing! Here is a summary of your information:\n" - f"First Name: {name}\n" - f"Age: {age}\n" - f"Favorite Color: {color}\n" - f"Fun Fact about your color: {color_fact}\n" - f"Favorite Hobby: {hobby}") - - await bot.send_message(message.chat.id, msg, parse_mode="html", reply_to_message_id=message.message_id) + color_fact = color_facts.get( + color, "Colors have diverse meanings, and yours is unique!" + ) + + msg = ( + f"Thank you for sharing! Here is a summary of your information:\n" + f"First Name: {name}\n" + f"Age: {age}\n" + f"Favorite Color: {color}\n" + f"Fun Fact about your color: {color_fact}\n" + f"Favorite Hobby: {hobby}" + ) + + await bot.send_message( + message.chat.id, msg, parse_mode="html", reply_to_message_id=message.message_id + ) await state.delete() + # Handler for incorrect age input @bot.message_handler(state=MyStates.age, is_digit=False) async def age_incorrect(message: types.Message): - await bot.send_message(message.chat.id, 'Please enter a valid number for age.', reply_to_message_id=message.message_id) + await bot.send_message( + message.chat.id, + "Please enter a valid number for age.", + reply_to_message_id=message.message_id, + ) + # Add custom filters bot.add_custom_filter(asyncio_filters.StateFilter(bot)) @@ -104,9 +144,11 @@ async def age_incorrect(message: types.Message): bot.add_custom_filter(asyncio_filters.TextMatchFilter()) # necessary for state parameter in handlers. -from telebot.states.aio.middleware import StateMiddleware +from telebot.states.asyncio.middleware import StateMiddleware + bot.setup_middleware(StateMiddleware(bot)) # Start polling import asyncio + asyncio.run(bot.polling()) diff --git a/examples/custom_states.py b/examples/custom_states.py index a488664b2..131dd1d42 100644 --- a/examples/custom_states.py +++ b/examples/custom_states.py @@ -6,8 +6,8 @@ # Initialize the bot state_storage = StateMemoryStorage() # don't use this in production; switch to redis -bot = telebot.TeleBot("TOKEN", state_storage=state_storage, - use_class_middlewares=True) +bot = telebot.TeleBot("TOKEN", state_storage=state_storage, use_class_middlewares=True) + # Define states class MyStates(StatesGroup): @@ -16,25 +16,39 @@ class MyStates(StatesGroup): color = State() hobby = State() + # Start command handler -@bot.message_handler(commands=['start']) +@bot.message_handler(commands=["start"]) def start_ex(message: types.Message, state: StateContext): state.set(MyStates.name) - bot.send_message(message.chat.id, 'Hello! What is your first name?', reply_to_message_id=message.message_id) + bot.send_message( + message.chat.id, + "Hello! What is your first name?", + reply_to_message_id=message.message_id, + ) + # Cancel command handler -@bot.message_handler(state="*", commands=['cancel']) +@bot.message_handler(state="*", commands=["cancel"]) def any_state(message: types.Message, state: StateContext): state.delete() - bot.send_message(message.chat.id, 'Your information has been cleared. Type /start to begin again.', reply_to_message_id=message.message_id) + bot.send_message( + message.chat.id, + "Your information has been cleared. Type /start to begin again.", + reply_to_message_id=message.message_id, + ) + # Handler for name input @bot.message_handler(state=MyStates.name) def name_get(message: types.Message, state: StateContext): state.set(MyStates.age) - bot.send_message(message.chat.id, "How old are you?", reply_to_message_id=message.message_id) + bot.send_message( + message.chat.id, "How old are you?", reply_to_message_id=message.message_id + ) state.add_data(name=message.text) + # Handler for age input @bot.message_handler(state=MyStates.age, is_digit=True) def ask_color(message: types.Message, state: StateContext): @@ -47,7 +61,13 @@ def ask_color(message: types.Message, state: StateContext): buttons = [types.KeyboardButton(color) for color in colors] keyboard.add(*buttons) - bot.send_message(message.chat.id, "What is your favorite color? Choose from the options below.", reply_markup=keyboard, reply_to_message_id=message.message_id) + bot.send_message( + message.chat.id, + "What is your favorite color? Choose from the options below.", + reply_markup=keyboard, + reply_to_message_id=message.message_id, + ) + # Handler for color input @bot.message_handler(state=MyStates.color) @@ -61,15 +81,23 @@ def ask_hobby(message: types.Message, state: StateContext): buttons = [types.KeyboardButton(hobby) for hobby in hobbies] keyboard.add(*buttons) - bot.send_message(message.chat.id, "What is one of your hobbies? Choose from the options below.", reply_markup=keyboard, reply_to_message_id=message.message_id) + bot.send_message( + message.chat.id, + "What is one of your hobbies? Choose from the options below.", + reply_markup=keyboard, + reply_to_message_id=message.message_id, + ) + # Handler for hobby input -@bot.message_handler(state=MyStates.hobby, text=['Reading', 'Traveling', 'Gaming', 'Cooking']) +@bot.message_handler( + state=MyStates.hobby, text=["Reading", "Traveling", "Gaming", "Cooking"] +) def finish(message: types.Message, state: StateContext): with state.data() as data: - name = data.get('name') - age = data.get('age') - color = data.get('color') + name = data.get("name") + age = data.get("age") + color = data.get("color") hobby = message.text # Get the hobby from the message text # Provide a fun fact based on color @@ -80,24 +108,36 @@ def finish(message: types.Message, state: StateContext): "Yellow": "Yellow is a cheerful color often associated with happiness.", "Purple": "Purple signifies royalty and luxury.", "Orange": "Orange is a vibrant color that stimulates enthusiasm.", - "Other": "Colors have various meanings depending on context." + "Other": "Colors have various meanings depending on context.", } - color_fact = color_facts.get(color, "Colors have diverse meanings, and yours is unique!") - - msg = (f"Thank you for sharing! Here is a summary of your information:\n" - f"First Name: {name}\n" - f"Age: {age}\n" - f"Favorite Color: {color}\n" - f"Fun Fact about your color: {color_fact}\n" - f"Favorite Hobby: {hobby}") - - bot.send_message(message.chat.id, msg, parse_mode="html", reply_to_message_id=message.message_id) + color_fact = color_facts.get( + color, "Colors have diverse meanings, and yours is unique!" + ) + + msg = ( + f"Thank you for sharing! Here is a summary of your information:\n" + f"First Name: {name}\n" + f"Age: {age}\n" + f"Favorite Color: {color}\n" + f"Fun Fact about your color: {color_fact}\n" + f"Favorite Hobby: {hobby}" + ) + + bot.send_message( + message.chat.id, msg, parse_mode="html", reply_to_message_id=message.message_id + ) state.delete() + # Handler for incorrect age input @bot.message_handler(state=MyStates.age, is_digit=False) def age_incorrect(message: types.Message): - bot.send_message(message.chat.id, 'Please enter a valid number for age.', reply_to_message_id=message.message_id) + bot.send_message( + message.chat.id, + "Please enter a valid number for age.", + reply_to_message_id=message.message_id, + ) + # Add custom filters bot.add_custom_filter(custom_filters.StateFilter(bot)) @@ -106,6 +146,7 @@ def age_incorrect(message: types.Message): # necessary for state parameter in handlers. from telebot.states.sync.middleware import StateMiddleware + bot.setup_middleware(StateMiddleware(bot)) # Start polling diff --git a/telebot/states/aio/__init__.py b/telebot/states/asyncio/__init__.py similarity index 100% rename from telebot/states/aio/__init__.py rename to telebot/states/asyncio/__init__.py diff --git a/telebot/states/aio/context.py b/telebot/states/asyncio/context.py similarity index 100% rename from telebot/states/aio/context.py rename to telebot/states/asyncio/context.py diff --git a/telebot/states/aio/middleware.py b/telebot/states/asyncio/middleware.py similarity index 100% rename from telebot/states/aio/middleware.py rename to telebot/states/asyncio/middleware.py From 2dbf19004c74a7f8e73f37a460484fdec023ccb3 Mon Sep 17 00:00:00 2001 From: _run Date: Sun, 28 Jul 2024 14:26:58 +0500 Subject: [PATCH 20/24] Remove apihelper following the validate_token not using getMe --- telebot/async_telebot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/telebot/async_telebot.py b/telebot/async_telebot.py index d60a4a090..705012f77 100644 --- a/telebot/async_telebot.py +++ b/telebot/async_telebot.py @@ -18,7 +18,7 @@ from inspect import signature, iscoroutinefunction -from telebot import util, types, asyncio_helper, apihelper # have to use sync +from telebot import util, types, asyncio_helper import asyncio from telebot import asyncio_filters From b10e8d749440cc596fdba05b1c5d1a33e4b86196 Mon Sep 17 00:00:00 2001 From: _run Date: Sun, 28 Jul 2024 14:48:12 +0500 Subject: [PATCH 21/24] Reverted changes regarding self._user, fixed validate_token=False causing error when extracting a bot id --- telebot/__init__.py | 6 +----- telebot/async_telebot.py | 6 ++---- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/telebot/__init__.py b/telebot/__init__.py index d0ab64d5a..19f85f2d2 100644 --- a/telebot/__init__.py +++ b/telebot/__init__.py @@ -195,8 +195,7 @@ def __init__( if validate_token: util.validate_token(self.token) - - self.bot_id = util.extract_bot_id(self.token) # subject to change in future, unspecified + self.bot_id = util.extract_bot_id(self.token) # subject to change in future, unspecified # logs-related if colorful_logs: @@ -1184,9 +1183,6 @@ def polling(self, non_stop: Optional[bool]=False, skip_pending: Optional[bool]=F if restart_on_change: self._setup_change_detector(path_to_watch) - if not self._user: - self._user = self.get_me() - logger.info('Starting your bot with username: [@%s]', self.user.username) if self.threaded: diff --git a/telebot/async_telebot.py b/telebot/async_telebot.py index 705012f77..5302000ff 100644 --- a/telebot/async_telebot.py +++ b/telebot/async_telebot.py @@ -193,8 +193,7 @@ def __init__(self, token: str, parse_mode: Optional[str]=None, offset: Optional[ if validate_token: util.validate_token(self.token) - - self.bot_id: int = util.extract_bot_id(self.token) # subject to change, unspecified + self.bot_id: int = util.extract_bot_id(self.token) # subject to change, unspecified @property @@ -437,8 +436,7 @@ async def _process_polling(self, non_stop: bool=False, interval: int=0, timeout: # show warning logger.warning("Setting non_stop to False will stop polling on API and system exceptions.") - if not self._user: - self._user = await self.get_me() + self._user = await self.get_me() logger.info('Starting your bot with username: [@%s]', self.user.username) From 6108e358134177d07d90f5fe5907217140c51a40 Mon Sep 17 00:00:00 2001 From: _run Date: Sun, 28 Jul 2024 14:49:00 +0500 Subject: [PATCH 22/24] fix docstring --- telebot/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/telebot/__init__.py b/telebot/__init__.py index 19f85f2d2..74c08a4b8 100644 --- a/telebot/__init__.py +++ b/telebot/__init__.py @@ -6699,7 +6699,7 @@ def reset_data(self, user_id: int, chat_id: Optional[int]=None, business_connection_id: Optional[str]=None, message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> bool: """ - Reset data for a user in chat: sets the 'data' fieldi to an empty dictionary. + Reset data for a user in chat: sets the 'data' field to an empty dictionary. :param user_id: User's identifier :type user_id: :obj:`int` From 30ebe756ac98010380a83114d84525ae968f0086 Mon Sep 17 00:00:00 2001 From: _run Date: Tue, 30 Jul 2024 15:48:26 +0500 Subject: [PATCH 23/24] make extract_bot_id return None in case validation fails --- telebot/__init__.py | 4 ++-- telebot/async_telebot.py | 4 ++-- telebot/util.py | 6 +++++- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/telebot/__init__.py b/telebot/__init__.py index 74c08a4b8..fc7f3cdc3 100644 --- a/telebot/__init__.py +++ b/telebot/__init__.py @@ -191,11 +191,11 @@ def __init__( self.allow_sending_without_reply = allow_sending_without_reply self.webhook_listener = None self._user = None - self.bot_id: int = None if validate_token: util.validate_token(self.token) - self.bot_id = util.extract_bot_id(self.token) # subject to change in future, unspecified + + self.bot_id: Union[int, None] = util.extract_bot_id(self.token) # subject to change in future, unspecified # logs-related if colorful_logs: diff --git a/telebot/async_telebot.py b/telebot/async_telebot.py index 5302000ff..ec782ff91 100644 --- a/telebot/async_telebot.py +++ b/telebot/async_telebot.py @@ -189,11 +189,11 @@ def __init__(self, token: str, parse_mode: Optional[str]=None, offset: Optional[ self.middlewares = [] self._user = None # set during polling - self.bot_id: int = None if validate_token: util.validate_token(self.token) - self.bot_id: int = util.extract_bot_id(self.token) # subject to change, unspecified + + self.bot_id: Union[int, None] = util.extract_bot_id(self.token) # subject to change, unspecified @property diff --git a/telebot/util.py b/telebot/util.py index c8ef526c5..0448893e1 100644 --- a/telebot/util.py +++ b/telebot/util.py @@ -698,7 +698,11 @@ def validate_token(token) -> bool: return True -def extract_bot_id(token) -> str: +def extract_bot_id(token) -> Union[int, None]: + try: + validate_token(token) + except ValueError: + return None return int(token.split(':')[0]) From d44ebce95916f845b0493fdf395cd3f8b70ee1d0 Mon Sep 17 00:00:00 2001 From: _run Date: Sun, 11 Aug 2024 16:57:42 +0500 Subject: [PATCH 24/24] Fix versions --- telebot/__init__.py | 2 +- telebot/async_telebot.py | 2 +- telebot/asyncio_storage/redis_storage.py | 2 +- telebot/storage/redis_storage.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/telebot/__init__.py b/telebot/__init__.py index fc7f3cdc3..79db7c8d9 100644 --- a/telebot/__init__.py +++ b/telebot/__init__.py @@ -6657,7 +6657,7 @@ def set_state(self, user_id: int, state: Union[str, State], chat_id: Optional[in Otherwise, if you only set user_id, chat_id will equal to user_id, this means that state will be set for the user in his private chat with a bot. - .. versionchanged:: 4.22.0 + .. versionchanged:: 4.23.0 Added additional parameters to support topics, business connections, and message threads. diff --git a/telebot/async_telebot.py b/telebot/async_telebot.py index ec782ff91..92bf2e127 100644 --- a/telebot/async_telebot.py +++ b/telebot/async_telebot.py @@ -7856,7 +7856,7 @@ async def set_state(self, user_id: int, state: Union[int, str, State], chat_id: Otherwise, if you only set user_id, chat_id will equal to user_id, this means that state will be set for the user in his private chat with a bot. - .. versionchanged:: 4.22.0 + .. versionchanged:: 4.23.0 Added additional parameters to support topics, business connections, and message threads. diff --git a/telebot/asyncio_storage/redis_storage.py b/telebot/asyncio_storage/redis_storage.py index eccff5c30..08e255809 100644 --- a/telebot/asyncio_storage/redis_storage.py +++ b/telebot/asyncio_storage/redis_storage.py @@ -297,7 +297,7 @@ def migrate_format(self, bot_id: int, prefix: Optional[str] = "telebot_"): Migrate from old to new format of keys. Run this function once to migrate all redis existing keys to new format. - Starting from version 4.22.0, the format of keys has been changed: + Starting from version 4.23.0, the format of keys has been changed: :value - Old format: {prefix}chat_id: {user_id: {'state': None, 'data': {}}, ...} - New format: diff --git a/telebot/storage/redis_storage.py b/telebot/storage/redis_storage.py index 19f5c2cd5..c9935ac5e 100644 --- a/telebot/storage/redis_storage.py +++ b/telebot/storage/redis_storage.py @@ -283,7 +283,7 @@ def migrate_format(self, bot_id: int, prefix: Optional[str] = "telebot_"): Migrate from old to new format of keys. Run this function once to migrate all redis existing keys to new format. - Starting from version 4.22.0, the format of keys has been changed: + Starting from version 4.23.0, the format of keys has been changed: :value - Old format: {prefix}chat_id: {user_id: {'state': None, 'data': {}}, ...} - New format: