diff --git a/examples/asynchronous_telebot/custom_states.py b/examples/asynchronous_telebot/custom_states.py index aefad9809..f0111e27e 100644 --- a/examples/asynchronous_telebot/custom_states.py +++ b/examples/asynchronous_telebot/custom_states.py @@ -1,91 +1,154 @@ -from telebot import asyncio_filters -from telebot.async_telebot import AsyncTeleBot - -# list of storages, you can use any storage +from telebot import async_telebot, asyncio_filters, types from telebot.asyncio_storage import StateMemoryStorage +from telebot.states import State, StatesGroup +from telebot.states.asyncio.context import StateContext -# new feature for states. -from telebot.asyncio_handler_backends import State, StatesGroup - -# default state storage is statememorystorage -bot = AsyncTeleBot('TOKEN', state_storage=StateMemoryStorage()) +# Initialize the bot +state_storage = StateMemoryStorage() # don't use this in production; switch to redis +bot = async_telebot.AsyncTeleBot("TOKEN", state_storage=state_storage) -# Just create different statesgroup +# Define states class MyStates(StatesGroup): - name = State() # statesgroup should contain states - surname = State() + name = State() age = State() + color = State() + hobby = State() +# Start command handler +@bot.message_handler(commands=["start"]) +async def start_ex(message: types.Message, state: StateContext): + await state.set(MyStates.name) + await bot.send_message( + message.chat.id, + "Hello! What is your first name?", + reply_to_message_id=message.message_id, + ) -# set_state -> sets a new state -# delete_state -> delets state if exists -# get_state -> returns state if exists +# Cancel command handler +@bot.message_handler(state="*", commands=["cancel"]) +async def any_state(message: types.Message, state: StateContext): + await state.delete() + await bot.send_message( + message.chat.id, + "Your information has been cleared. Type /start to begin again.", + reply_to_message_id=message.message_id, + ) -@bot.message_handler(commands=['start']) -async def start_ex(message): - """ - Start command. Here we are starting state - """ - await bot.set_state(message.from_user.id, MyStates.name, message.chat.id) - await bot.send_message(message.chat.id, 'Hi, write me a name') - - - -@bot.message_handler(state="*", commands='cancel') -async def any_state(message): - """ - Cancel state - """ - await bot.send_message(message.chat.id, "Your state was cancelled.") - await bot.delete_state(message.from_user.id, message.chat.id) +# Handler for name input @bot.message_handler(state=MyStates.name) -async def name_get(message): - """ - State 1. Will process when user's state is MyStates.name. - """ - await bot.send_message(message.chat.id, f'Now write me a surname') - await bot.set_state(message.from_user.id, MyStates.surname, message.chat.id) - async with bot.retrieve_data(message.from_user.id, message.chat.id) as data: - data['name'] = message.text - - -@bot.message_handler(state=MyStates.surname) -async def ask_age(message): - """ - State 2. Will process when user's state is MyStates.surname. - """ - await bot.send_message(message.chat.id, "What is your age?") - await bot.set_state(message.from_user.id, MyStates.age, message.chat.id) - async with bot.retrieve_data(message.from_user.id, message.chat.id) as data: - data['surname'] = message.text - -# result +async def name_get(message: types.Message, state: StateContext): + await state.set(MyStates.age) + await bot.send_message( + message.chat.id, "How old are you?", reply_to_message_id=message.message_id + ) + await state.add_data(name=message.text) + + +# Handler for age input @bot.message_handler(state=MyStates.age, is_digit=True) -async def ready_for_answer(message): - """ - State 3. Will process when user's state is MyStates.age. - """ - async with bot.retrieve_data(message.from_user.id, message.chat.id) as data: - await bot.send_message(message.chat.id, "Ready, take a look:\nName: {name}\nSurname: {surname}\nAge: {age}".format(name=data['name'], surname=data['surname'], age=message.text), parse_mode="html") - await bot.delete_state(message.from_user.id, message.chat.id) - -#incorrect number +async def ask_color(message: types.Message, state: StateContext): + await state.set(MyStates.color) + await state.add_data(age=message.text) + + # Define reply keyboard for color selection + keyboard = types.ReplyKeyboardMarkup(row_width=2) + colors = ["Red", "Green", "Blue", "Yellow", "Purple", "Orange", "Other"] + buttons = [types.KeyboardButton(color) for color in colors] + keyboard.add(*buttons) + + await bot.send_message( + message.chat.id, + "What is your favorite color? Choose from the options below.", + reply_markup=keyboard, + reply_to_message_id=message.message_id, + ) + + +# Handler for color input +@bot.message_handler(state=MyStates.color) +async def ask_hobby(message: types.Message, state: StateContext): + await state.set(MyStates.hobby) + await state.add_data(color=message.text) + + # Define reply keyboard for hobby selection + keyboard = types.ReplyKeyboardMarkup(row_width=2) + hobbies = ["Reading", "Traveling", "Gaming", "Cooking"] + buttons = [types.KeyboardButton(hobby) for hobby in hobbies] + keyboard.add(*buttons) + + await bot.send_message( + message.chat.id, + "What is one of your hobbies? Choose from the options below.", + reply_markup=keyboard, + reply_to_message_id=message.message_id, + ) + + +# Handler for hobby input; use filters to ease validation +@bot.message_handler( + state=MyStates.hobby, text=["Reading", "Traveling", "Gaming", "Cooking"] +) +async def finish(message: types.Message, state: StateContext): + async with state.data() as data: + name = data.get("name") + age = data.get("age") + color = data.get("color") + hobby = message.text # Get the hobby from the message text + + # Provide a fun fact based on color + color_facts = { + "Red": "Red is often associated with excitement and passion.", + "Green": "Green is the color of nature and tranquility.", + "Blue": "Blue is known for its calming and serene effects.", + "Yellow": "Yellow is a cheerful color often associated with happiness.", + "Purple": "Purple signifies royalty and luxury.", + "Orange": "Orange is a vibrant color that stimulates enthusiasm.", + "Other": "Colors have various meanings depending on context.", + } + color_fact = color_facts.get( + color, "Colors have diverse meanings, and yours is unique!" + ) + + msg = ( + f"Thank you for sharing! Here is a summary of your information:\n" + f"First Name: {name}\n" + f"Age: {age}\n" + f"Favorite Color: {color}\n" + f"Fun Fact about your color: {color_fact}\n" + f"Favorite Hobby: {hobby}" + ) + + await bot.send_message( + message.chat.id, msg, parse_mode="html", reply_to_message_id=message.message_id + ) + await state.delete() + + +# Handler for incorrect age input @bot.message_handler(state=MyStates.age, is_digit=False) -async def age_incorrect(message): - """ - Will process for wrong input when state is MyState.age - """ - await bot.send_message(message.chat.id, 'Looks like you are submitting a string in the field age. Please enter a number') +async def age_incorrect(message: types.Message): + await bot.send_message( + message.chat.id, + "Please enter a valid number for age.", + reply_to_message_id=message.message_id, + ) -# register filters +# Add custom filters bot.add_custom_filter(asyncio_filters.StateFilter(bot)) bot.add_custom_filter(asyncio_filters.IsDigitFilter()) +bot.add_custom_filter(asyncio_filters.TextMatchFilter()) + +# necessary for state parameter in handlers. +from telebot.states.asyncio.middleware import StateMiddleware +bot.setup_middleware(StateMiddleware(bot)) +# Start polling import asyncio -asyncio.run(bot.polling()) \ No newline at end of file + +asyncio.run(bot.polling()) diff --git a/examples/custom_states.py b/examples/custom_states.py index a02761286..131dd1d42 100644 --- a/examples/custom_states.py +++ b/examples/custom_states.py @@ -1,106 +1,153 @@ -import telebot # telebot - -from telebot import custom_filters -from telebot.handler_backends import State, StatesGroup #States - -# States storage +import telebot +from telebot import custom_filters, types +from telebot.states import State, StatesGroup +from telebot.states.sync.context import StateContext from telebot.storage import StateMemoryStorage +# Initialize the bot +state_storage = StateMemoryStorage() # don't use this in production; switch to redis +bot = telebot.TeleBot("TOKEN", state_storage=state_storage, use_class_middlewares=True) -# Starting from version 4.4.0+, we support storages. -# StateRedisStorage -> Redis-based storage. -# StatePickleStorage -> Pickle-based storage. -# For redis, you will need to install redis. -# Pass host, db, password, or anything else, -# if you need to change config for redis. -# Pickle requires path. Default path is in folder .state-saves. -# If you were using older version of pytba for pickle, -# you need to migrate from old pickle to new by using -# StatePickleStorage().convert_old_to_new() - - - -# Now, you can pass storage to bot. -state_storage = StateMemoryStorage() # you can init here another storage - -bot = telebot.TeleBot("TOKEN", -state_storage=state_storage) - -# States group. +# Define states class MyStates(StatesGroup): - # Just name variables differently - name = State() # creating instances of State class is enough from now - surname = State() + name = State() age = State() + color = State() + hobby = State() +# Start command handler +@bot.message_handler(commands=["start"]) +def start_ex(message: types.Message, state: StateContext): + state.set(MyStates.name) + bot.send_message( + message.chat.id, + "Hello! What is your first name?", + reply_to_message_id=message.message_id, + ) -@bot.message_handler(commands=['start']) -def start_ex(message): - """ - Start command. Here we are starting state - """ - bot.set_state(message.from_user.id, MyStates.name, message.chat.id) - bot.send_message(message.chat.id, 'Hi, write me a name') - +# Cancel command handler +@bot.message_handler(state="*", commands=["cancel"]) +def any_state(message: types.Message, state: StateContext): + state.delete() + bot.send_message( + message.chat.id, + "Your information has been cleared. Type /start to begin again.", + reply_to_message_id=message.message_id, + ) -# Any state -@bot.message_handler(state="*", commands=['cancel']) -def any_state(message): - """ - Cancel state - """ - bot.send_message(message.chat.id, "Your state was cancelled.") - bot.delete_state(message.from_user.id, message.chat.id) +# Handler for name input @bot.message_handler(state=MyStates.name) -def name_get(message): - """ - State 1. Will process when user's state is MyStates.name. - """ - bot.send_message(message.chat.id, 'Now write me a surname') - bot.set_state(message.from_user.id, MyStates.surname, message.chat.id) - with bot.retrieve_data(message.from_user.id, message.chat.id) as data: - data['name'] = message.text - - -@bot.message_handler(state=MyStates.surname) -def ask_age(message): - """ - State 2. Will process when user's state is MyStates.surname. - """ - bot.send_message(message.chat.id, "What is your age?") - bot.set_state(message.from_user.id, MyStates.age, message.chat.id) - with bot.retrieve_data(message.from_user.id, message.chat.id) as data: - data['surname'] = message.text - -# result +def name_get(message: types.Message, state: StateContext): + state.set(MyStates.age) + bot.send_message( + message.chat.id, "How old are you?", reply_to_message_id=message.message_id + ) + state.add_data(name=message.text) + + +# Handler for age input @bot.message_handler(state=MyStates.age, is_digit=True) -def ready_for_answer(message): - """ - State 3. Will process when user's state is MyStates.age. - """ - with bot.retrieve_data(message.from_user.id, message.chat.id) as data: - msg = ("Ready, take a look:\n" - f"Name: {data['name']}\n" - f"Surname: {data['surname']}\n" - f"Age: {message.text}") - bot.send_message(message.chat.id, msg, parse_mode="html") - bot.delete_state(message.from_user.id, message.chat.id) - -#incorrect number +def ask_color(message: types.Message, state: StateContext): + state.set(MyStates.color) + state.add_data(age=message.text) + + # Define reply keyboard for color selection + keyboard = types.ReplyKeyboardMarkup(row_width=2) + colors = ["Red", "Green", "Blue", "Yellow", "Purple", "Orange", "Other"] + buttons = [types.KeyboardButton(color) for color in colors] + keyboard.add(*buttons) + + bot.send_message( + message.chat.id, + "What is your favorite color? Choose from the options below.", + reply_markup=keyboard, + reply_to_message_id=message.message_id, + ) + + +# Handler for color input +@bot.message_handler(state=MyStates.color) +def ask_hobby(message: types.Message, state: StateContext): + state.set(MyStates.hobby) + state.add_data(color=message.text) + + # Define reply keyboard for hobby selection + keyboard = types.ReplyKeyboardMarkup(row_width=2) + hobbies = ["Reading", "Traveling", "Gaming", "Cooking"] + buttons = [types.KeyboardButton(hobby) for hobby in hobbies] + keyboard.add(*buttons) + + bot.send_message( + message.chat.id, + "What is one of your hobbies? Choose from the options below.", + reply_markup=keyboard, + reply_to_message_id=message.message_id, + ) + + +# Handler for hobby input +@bot.message_handler( + state=MyStates.hobby, text=["Reading", "Traveling", "Gaming", "Cooking"] +) +def finish(message: types.Message, state: StateContext): + with state.data() as data: + name = data.get("name") + age = data.get("age") + color = data.get("color") + hobby = message.text # Get the hobby from the message text + + # Provide a fun fact based on color + color_facts = { + "Red": "Red is often associated with excitement and passion.", + "Green": "Green is the color of nature and tranquility.", + "Blue": "Blue is known for its calming and serene effects.", + "Yellow": "Yellow is a cheerful color often associated with happiness.", + "Purple": "Purple signifies royalty and luxury.", + "Orange": "Orange is a vibrant color that stimulates enthusiasm.", + "Other": "Colors have various meanings depending on context.", + } + color_fact = color_facts.get( + color, "Colors have diverse meanings, and yours is unique!" + ) + + msg = ( + f"Thank you for sharing! Here is a summary of your information:\n" + f"First Name: {name}\n" + f"Age: {age}\n" + f"Favorite Color: {color}\n" + f"Fun Fact about your color: {color_fact}\n" + f"Favorite Hobby: {hobby}" + ) + + bot.send_message( + message.chat.id, msg, parse_mode="html", reply_to_message_id=message.message_id + ) + state.delete() + + +# Handler for incorrect age input @bot.message_handler(state=MyStates.age, is_digit=False) -def age_incorrect(message): - """ - Wrong response for MyStates.age - """ - bot.send_message(message.chat.id, 'Looks like you are submitting a string in the field age. Please enter a number') +def age_incorrect(message: types.Message): + bot.send_message( + message.chat.id, + "Please enter a valid number for age.", + reply_to_message_id=message.message_id, + ) -# register filters +# Add custom filters bot.add_custom_filter(custom_filters.StateFilter(bot)) bot.add_custom_filter(custom_filters.IsDigitFilter()) +bot.add_custom_filter(custom_filters.TextMatchFilter()) + +# necessary for state parameter in handlers. +from telebot.states.sync.middleware import StateMiddleware + +bot.setup_middleware(StateMiddleware(bot)) -bot.infinity_polling(skip_pending=True) +# Start polling +bot.infinity_polling() diff --git a/pyproject.toml b/pyproject.toml index a45a8dd3b..9ef5d86f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,6 @@ Issues = "https://github.com/eternnoir/pyTelegramBotAPI/issues" json = ["ujson"] PIL = ["Pillow"] redis = ["redis>=3.4.1"] -aioredis = ["aioredis"] aiohttp = ["aiohttp"] fastapi = ["fastapi"] uvicorn = ["uvicorn"] diff --git a/telebot/__init__.py b/telebot/__init__.py index 97c3ea23c..2c26b3b75 100644 --- a/telebot/__init__.py +++ b/telebot/__init__.py @@ -7,7 +7,7 @@ import threading import time import traceback -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, List, Optional, Union, Dict # these imports are used to avoid circular import error import telebot.util @@ -153,9 +153,14 @@ class TeleBot: :param allow_sending_without_reply: Default value for allow_sending_without_reply, defaults to None :type allow_sending_without_reply: :obj:`bool`, optional - :param colorful_logs: Outputs colorful logs :type colorful_logs: :obj:`bool`, optional + + :param validate_token: Validate token, defaults to True; + :type validate_token: :obj:`bool`, optional + + :raises ImportError: If coloredlogs module is not installed and colorful_logs is True + :raises ValueError: If token is invalid """ def __init__( @@ -169,7 +174,8 @@ def __init__( disable_notification: Optional[bool]=None, protect_content: Optional[bool]=None, allow_sending_without_reply: Optional[bool]=None, - colorful_logs: Optional[bool]=False + colorful_logs: Optional[bool]=False, + validate_token: Optional[bool]=True ): # update-related @@ -187,6 +193,11 @@ def __init__( self.webhook_listener = None self._user = None + if validate_token: + util.validate_token(self.token) + + self.bot_id: Union[int, None] = util.extract_bot_id(self.token) # subject to change in future, unspecified + # logs-related if colorful_logs: try: @@ -281,7 +292,7 @@ def __init__( self.threaded = threaded if self.threaded: self.worker_pool = util.ThreadPool(self, num_threads=num_threads) - + @property def user(self) -> types.User: """ @@ -6642,7 +6653,9 @@ def setup_middleware(self, middleware: BaseMiddleware): self.middlewares.append(middleware) - def set_state(self, user_id: int, state: Union[int, str, State], chat_id: Optional[int]=None) -> None: + def set_state(self, user_id: int, state: Union[str, State], chat_id: Optional[int]=None, + business_connection_id: Optional[str]=None, message_thread_id: Optional[int]=None, + bot_id: Optional[int]=None) -> bool: """ Sets a new state of a user. @@ -6652,25 +6665,49 @@ def set_state(self, user_id: int, state: Union[int, str, State], chat_id: Option Otherwise, if you only set user_id, chat_id will equal to user_id, this means that state will be set for the user in his private chat with a bot. + .. versionchanged:: 4.23.0 + + Added additional parameters to support topics, business connections, and message threads. + + .. seealso:: + + For more details, visit the `custom_states.py example `_. + :param user_id: User's identifier :type user_id: :obj:`int` - :param state: new state. can be string, integer, or :class:`telebot.types.State` + :param state: new state. can be string, or :class:`telebot.types.State` :type state: :obj:`int` or :obj:`str` or :class:`telebot.types.State` :param chat_id: Chat's identifier :type chat_id: :obj:`int` - :return: None + :param bot_id: Bot's identifier, defaults to current bot id + :type bot_id: :obj:`int` + + :param business_connection_id: Business identifier + :type business_connection_id: :obj:`str` + + :param message_thread_id: Identifier of the message thread + :type message_thread_id: :obj:`int` + + :return: True on success + :rtype: :obj:`bool` """ if chat_id is None: chat_id = user_id - self.current_states.set_state(chat_id, user_id, state) + if bot_id is None: + bot_id = self.bot_id + return self.current_states.set_state( + chat_id=chat_id, user_id=user_id, state=state, bot_id=bot_id, + business_connection_id=business_connection_id, message_thread_id=message_thread_id) - def reset_data(self, user_id: int, chat_id: Optional[int]=None): + def reset_data(self, user_id: int, chat_id: Optional[int]=None, + business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> bool: """ - Reset data for a user in chat. + Reset data for a user in chat: sets the 'data' field to an empty dictionary. :param user_id: User's identifier :type user_id: :obj:`int` @@ -6678,16 +6715,34 @@ def reset_data(self, user_id: int, chat_id: Optional[int]=None): :param chat_id: Chat's identifier :type chat_id: :obj:`int` - :return: None + :param bot_id: Bot's identifier, defaults to current bot id + :type bot_id: :obj:`int` + + :param business_connection_id: Business identifier + :type business_connection_id: :obj:`str` + + :param message_thread_id: Identifier of the message thread + :type message_thread_id: :obj:`int` + + :return: True on success + :rtype: :obj:`bool` """ if chat_id is None: chat_id = user_id - self.current_states.reset_data(chat_id, user_id) + if bot_id is None: + bot_id = self.bot_id + return self.current_states.reset_data(chat_id=chat_id, user_id=user_id, bot_id=bot_id, + business_connection_id=business_connection_id, message_thread_id=message_thread_id) - def delete_state(self, user_id: int, chat_id: Optional[int]=None) -> None: + def delete_state(self, user_id: int, chat_id: Optional[int]=None, business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> bool: """ - Delete the current state of a user. + Fully deletes the storage record of a user in chat. + + .. warning:: + + This does NOT set state to None, but deletes the object from storage. :param user_id: User's identifier :type user_id: :obj:`int` @@ -6695,14 +6750,19 @@ def delete_state(self, user_id: int, chat_id: Optional[int]=None) -> None: :param chat_id: Chat's identifier :type chat_id: :obj:`int` - :return: None + :return: True on success + :rtype: :obj:`bool` """ if chat_id is None: chat_id = user_id - self.current_states.delete_state(chat_id, user_id) + if bot_id is None: + bot_id = self.bot_id + return self.current_states.delete_state(chat_id=chat_id, user_id=user_id, bot_id=bot_id, + business_connection_id=business_connection_id, message_thread_id=message_thread_id) - def retrieve_data(self, user_id: int, chat_id: Optional[int]=None) -> Optional[Any]: + def retrieve_data(self, user_id: int, chat_id: Optional[int]=None, business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> Optional[Dict[str, Any]]: """ Returns context manager with data for a user in chat. @@ -6712,34 +6772,70 @@ def retrieve_data(self, user_id: int, chat_id: Optional[int]=None) -> Optional[A :param chat_id: Chat's unique identifier, defaults to user_id :type chat_id: int, optional + :param bot_id: Bot's identifier, defaults to current bot id + :type bot_id: int, optional + + :param business_connection_id: Business identifier + :type business_connection_id: str, optional + + :param message_thread_id: Identifier of the message thread + :type message_thread_id: int, optional + :return: Context manager with data for a user in chat :rtype: Optional[Any] """ if chat_id is None: chat_id = user_id - return self.current_states.get_interactive_data(chat_id, user_id) + if bot_id is None: + bot_id = self.bot_id + return self.current_states.get_interactive_data(chat_id=chat_id, user_id=user_id, bot_id=bot_id, + business_connection_id=business_connection_id, + message_thread_id=message_thread_id) - def get_state(self, user_id: int, chat_id: Optional[int]=None) -> Optional[Union[int, str, State]]: + def get_state(self, user_id: int, chat_id: Optional[int]=None, + business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> str: """ Gets current state of a user. Not recommended to use this method. But it is ok for debugging. + .. warning:: + + Even if you are using :class:`telebot.types.State`, this method will return a string. + When comparing(not recommended), you should compare this string with :class:`telebot.types.State`.name + :param user_id: User's identifier :type user_id: :obj:`int` :param chat_id: Chat's identifier :type chat_id: :obj:`int` + :param bot_id: Bot's identifier, defaults to current bot id + :type bot_id: :obj:`int` + + :param business_connection_id: Business identifier + :type business_connection_id: :obj:`str` + + :param message_thread_id: Identifier of the message thread + :type message_thread_id: :obj:`int` + :return: state of a user :rtype: :obj:`int` or :obj:`str` or :class:`telebot.types.State` """ if chat_id is None: chat_id = user_id - return self.current_states.get_state(chat_id, user_id) + if bot_id is None: + bot_id = self.bot_id + return self.current_states.get_state(chat_id=chat_id, user_id=user_id, bot_id=bot_id, + business_connection_id=business_connection_id, message_thread_id=message_thread_id) - def add_data(self, user_id: int, chat_id: Optional[int]=None, **kwargs): + def add_data(self, user_id: int, chat_id: Optional[int]=None, + business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, + bot_id: Optional[int]=None, + **kwargs) -> None: """ Add data to states. @@ -6749,13 +6845,25 @@ def add_data(self, user_id: int, chat_id: Optional[int]=None, **kwargs): :param chat_id: Chat's identifier :type chat_id: :obj:`int` + :param bot_id: Bot's identifier, defaults to current bot id + :type bot_id: :obj:`int` + + :param business_connection_id: Business identifier + :type business_connection_id: :obj:`str` + + :param message_thread_id: Identifier of the message thread + :type message_thread_id: :obj:`int` + :param kwargs: Data to add :return: None """ if chat_id is None: chat_id = user_id + if bot_id is None: + bot_id = self.bot_id for key, value in kwargs.items(): - self.current_states.set_data(chat_id, user_id, key, value) + self.current_states.set_data(chat_id=chat_id, user_id=user_id, key=key, value=value, bot_id=bot_id, + business_connection_id=business_connection_id, message_thread_id=message_thread_id) def register_next_step_handler_by_chat_id( diff --git a/telebot/async_telebot.py b/telebot/async_telebot.py index 66ba987fb..4e92f50d3 100644 --- a/telebot/async_telebot.py +++ b/telebot/async_telebot.py @@ -4,7 +4,7 @@ import logging import re import traceback -from typing import Any, Awaitable, Callable, List, Optional, Union +from typing import Any, Awaitable, Callable, List, Optional, Union, Dict import sys # this imports are used to avoid circular import error @@ -117,6 +117,11 @@ class AsyncTeleBot: :param colorful_logs: Outputs colorful logs :type colorful_logs: :obj:`bool`, optional + :param validate_token: Validate token, defaults to True; + :type validate_token: :obj:`bool`, optional + + :raises ImportError: If coloredlogs module is not installed and colorful_logs is True + :raises ValueError: If token is invalid """ def __init__(self, token: str, parse_mode: Optional[str]=None, offset: Optional[int]=None, @@ -126,7 +131,8 @@ def __init__(self, token: str, parse_mode: Optional[str]=None, offset: Optional[ disable_notification: Optional[bool]=None, protect_content: Optional[bool]=None, allow_sending_without_reply: Optional[bool]=None, - colorful_logs: Optional[bool]=False) -> None: + colorful_logs: Optional[bool]=False, + validate_token: Optional[bool]=True) -> None: # update-related self.token = token @@ -184,6 +190,12 @@ def __init__(self, token: str, parse_mode: Optional[str]=None, offset: Optional[ self._user = None # set during polling + if validate_token: + util.validate_token(self.token) + + self.bot_id: Union[int, None] = util.extract_bot_id(self.token) # subject to change, unspecified + + @property def user(self): return self._user @@ -7837,7 +7849,10 @@ async def get_forum_topic_icon_stickers(self) -> List[types.Sticker]: """ return await asyncio_helper.get_forum_topic_icon_stickers(self.token) - async def set_state(self, user_id: int, state: Union[State, int, str], chat_id: Optional[int]=None): + + async def set_state(self, user_id: int, state: Union[int, str, State], chat_id: Optional[int]=None, + business_connection_id: Optional[str]=None, message_thread_id: Optional[int]=None, + bot_id: Optional[int]=None) -> bool: """ Sets a new state of a user. @@ -7847,22 +7862,47 @@ async def set_state(self, user_id: int, state: Union[State, int, str], chat_id: Otherwise, if you only set user_id, chat_id will equal to user_id, this means that state will be set for the user in his private chat with a bot. + .. versionchanged:: 4.23.0 + + Added additional parameters to support topics, business connections, and message threads. + + .. seealso:: + + For more details, visit the `custom_states.py example `_. + :param user_id: User's identifier :type user_id: :obj:`int` - :param state: new state. can be string, integer, or :class:`telebot.types.State` + :param state: new state. can be string, or :class:`telebot.types.State` :type state: :obj:`int` or :obj:`str` or :class:`telebot.types.State` :param chat_id: Chat's identifier :type chat_id: :obj:`int` - :return: None + :param bot_id: Bot's identifier, defaults to current bot id + :type bot_id: :obj:`int` + + :param business_connection_id: Business identifier + :type business_connection_id: :obj:`str` + + :param message_thread_id: Identifier of the message thread + :type message_thread_id: :obj:`int` + + :return: True on success + :rtype: :obj:`bool` """ - if not chat_id: + if chat_id is None: chat_id = user_id - await self.current_states.set_state(chat_id, user_id, state) + if bot_id is None: + bot_id = self.bot_id + return await self.current_states.set_state( + chat_id=chat_id, user_id=user_id, state=state, bot_id=bot_id, + business_connection_id=business_connection_id, message_thread_id=message_thread_id) + - async def reset_data(self, user_id: int, chat_id: Optional[int]=None): + async def reset_data(self, user_id: int, chat_id: Optional[int]=None, + business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> bool: """ Reset data for a user in chat. @@ -7872,13 +7912,28 @@ async def reset_data(self, user_id: int, chat_id: Optional[int]=None): :param chat_id: Chat's identifier :type chat_id: :obj:`int` - :return: None + :param bot_id: Bot's identifier, defaults to current bot id + :type bot_id: :obj:`int` + + :param business_connection_id: Business identifier + :type business_connection_id: :obj:`str` + + :param message_thread_id: Identifier of the message thread + :type message_thread_id: :obj:`int` + + :return: True on success + :rtype: :obj:`bool` """ if chat_id is None: chat_id = user_id - await self.current_states.reset_data(chat_id, user_id) + if bot_id is None: + bot_id = self.bot_id + return await self.current_states.reset_data(chat_id=chat_id, user_id=user_id, bot_id=bot_id, + business_connection_id=business_connection_id, message_thread_id=message_thread_id) + - async def delete_state(self, user_id: int, chat_id: Optional[int]=None): + async def delete_state(self, user_id: int, chat_id: Optional[int]=None, business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> bool: """ Delete the current state of a user. @@ -7890,11 +7945,16 @@ async def delete_state(self, user_id: int, chat_id: Optional[int]=None): :return: None """ - if not chat_id: + if chat_id is None: chat_id = user_id - await self.current_states.delete_state(chat_id, user_id) + if bot_id is None: + bot_id = self.bot_id + return await self.current_states.delete_state(chat_id=chat_id, user_id=user_id, bot_id=bot_id, + business_connection_id=business_connection_id, message_thread_id=message_thread_id) + - def retrieve_data(self, user_id: int, chat_id: Optional[int]=None): + def retrieve_data(self, user_id: int, chat_id: Optional[int]=None, business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> Optional[Dict[str, Any]]: """ Returns context manager with data for a user in chat. @@ -7904,32 +7964,70 @@ def retrieve_data(self, user_id: int, chat_id: Optional[int]=None): :param chat_id: Chat's unique identifier, defaults to user_id :type chat_id: int, optional + :param bot_id: Bot's identifier, defaults to current bot id + :type bot_id: int, optional + + :param business_connection_id: Business identifier + :type business_connection_id: str, optional + + :param message_thread_id: Identifier of the message thread + :type message_thread_id: int, optional + :return: Context manager with data for a user in chat - :rtype: Optional[Any] + :rtype: :obj:`dict` """ - if not chat_id: + if chat_id is None: chat_id = user_id - return self.current_states.get_interactive_data(chat_id, user_id) + if bot_id is None: + bot_id = self.bot_id + return self.current_states.get_interactive_data(chat_id=chat_id, user_id=user_id, bot_id=bot_id, + business_connection_id=business_connection_id, + message_thread_id=message_thread_id) + - async def get_state(self, user_id, chat_id: Optional[int]=None): + async def get_state(self, user_id: int, chat_id: Optional[int]=None, + business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> str: """ Gets current state of a user. Not recommended to use this method. But it is ok for debugging. + .. warning:: + + Even if you are using :class:`telebot.types.State`, this method will return a string. + When comparing(not recommended), you should compare this string with :class:`telebot.types.State`.name + :param user_id: User's identifier :type user_id: :obj:`int` :param chat_id: Chat's identifier :type chat_id: :obj:`int` + :param bot_id: Bot's identifier, defaults to current bot id + :type bot_id: :obj:`int` + + :param business_connection_id: Business identifier + :type business_connection_id: :obj:`str` + + :param message_thread_id: Identifier of the message thread + :type message_thread_id: :obj:`int` + :return: state of a user - :rtype: :obj:`int` or :obj:`str` or :class:`telebot.types.State` + :rtype: :obj:`str` """ - if not chat_id: + if chat_id is None: chat_id = user_id - return await self.current_states.get_state(chat_id, user_id) + if bot_id is None: + bot_id = self.bot_id + return await self.current_states.get_state(chat_id=chat_id, user_id=user_id, bot_id=bot_id, + business_connection_id=business_connection_id, message_thread_id=message_thread_id) + - async def add_data(self, user_id: int, chat_id: Optional[int]=None, **kwargs): + async def add_data(self, user_id: int, chat_id: Optional[int]=None, + business_connection_id: Optional[str]=None, + message_thread_id: Optional[int]=None, + bot_id: Optional[int]=None, + **kwargs) -> None: """ Add data to states. @@ -7939,10 +8037,22 @@ async def add_data(self, user_id: int, chat_id: Optional[int]=None, **kwargs): :param chat_id: Chat's identifier :type chat_id: :obj:`int` + :param bot_id: Bot's identifier, defaults to current bot id + :type bot_id: :obj:`int` + + :param business_connection_id: Business identifier + :type business_connection_id: :obj:`str` + + :param message_thread_id: Identifier of the message thread + :type message_thread_id: :obj:`int` + :param kwargs: Data to add :return: None """ - if not chat_id: + if chat_id is None: chat_id = user_id + if bot_id is None: + bot_id = self.bot_id for key, value in kwargs.items(): - await self.current_states.set_data(chat_id, user_id, key, value) + await self.current_states.set_data(chat_id=chat_id, user_id=user_id, key=key, value=value, bot_id=bot_id, + business_connection_id=business_connection_id, message_thread_id=message_thread_id) diff --git a/telebot/asyncio_filters.py b/telebot/asyncio_filters.py index f4e594a7a..20abe5464 100644 --- a/telebot/asyncio_filters.py +++ b/telebot/asyncio_filters.py @@ -3,6 +3,7 @@ from telebot.asyncio_handler_backends import State from telebot import types +from telebot.states import resolve_context class SimpleCustomFilter(ABC): @@ -396,19 +397,11 @@ async def check(self, message, text): """ :meta private: """ - if text == '*': return True - - # needs to work with callbackquery - if isinstance(message, types.Message): - chat_id = message.chat.id - user_id = message.from_user.id - - if isinstance(message, types.CallbackQuery): - - chat_id = message.message.chat.id - user_id = message.from_user.id - message = message.message + + chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(message, self.bot._user.id) + if chat_id is None: + chat_id = user_id # May change in future if isinstance(text, list): new_text = [] @@ -418,21 +411,24 @@ async def check(self, message, text): text = new_text elif isinstance(text, State): text = text.name - - if message.chat.type in ['group', 'supergroup']: - group_state = await self.bot.current_states.get_state(chat_id, user_id) - if group_state == text: - return True - elif type(text) is list and group_state in text: - return True - - - else: - user_state = await self.bot.current_states.get_state(chat_id, user_id) - if user_state == text: - return True - elif type(text) is list and user_state in text: - return True + + user_state = await self.bot.current_states.get_state( + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + bot_id=bot_id, + message_thread_id=message_thread_id + ) + + # CHANGED BEHAVIOUR + if text == "*" and user_state is not None: + return True + + if user_state == text: + return True + elif type(text) is list and user_state in text: + return True + return False class IsDigitFilter(SimpleCustomFilter): diff --git a/telebot/asyncio_handler_backends.py b/telebot/asyncio_handler_backends.py index 0861a9893..6c96cc2d7 100644 --- a/telebot/asyncio_handler_backends.py +++ b/telebot/asyncio_handler_backends.py @@ -1,6 +1,7 @@ """ File with all middleware classes, states. """ +from telebot.states import State, StatesGroup class BaseMiddleware: @@ -48,46 +49,6 @@ async def post_process(self, message, data, exception): raise NotImplementedError -class State: - """ - Class representing a state. - - .. code-block:: python3 - - class MyStates(StatesGroup): - my_state = State() # returns my_state:State string. - """ - def __init__(self) -> None: - self.name = None - - def __str__(self) -> str: - return self.name - - -class StatesGroup: - """ - Class representing common states. - - .. code-block:: python3 - - class MyStates(StatesGroup): - my_state = State() # returns my_state:State string. - """ - def __init_subclass__(cls) -> None: - state_list = [] - for name, value in cls.__dict__.items(): - if not name.startswith('__') and not callable(value) and isinstance(value, State): - # change value of that variable - value.name = ':'.join((cls.__name__, name)) - value.group = cls - state_list.append(value) - cls._state_list = state_list - - @classmethod - def state_list(self): - return self._state_list - - class SkipHandler: """ Class for skipping handlers. diff --git a/telebot/asyncio_storage/__init__.py b/telebot/asyncio_storage/__init__.py index 892f0af94..82a8c817a 100644 --- a/telebot/asyncio_storage/__init__.py +++ b/telebot/asyncio_storage/__init__.py @@ -1,13 +1,13 @@ from telebot.asyncio_storage.memory_storage import StateMemoryStorage from telebot.asyncio_storage.redis_storage import StateRedisStorage from telebot.asyncio_storage.pickle_storage import StatePickleStorage -from telebot.asyncio_storage.base_storage import StateContext,StateStorageBase - - - +from telebot.asyncio_storage.base_storage import StateDataContext, StateStorageBase __all__ = [ - 'StateStorageBase', 'StateContext', - 'StateMemoryStorage', 'StateRedisStorage', 'StatePickleStorage' -] \ No newline at end of file + "StateStorageBase", + "StateDataContext", + "StateMemoryStorage", + "StateRedisStorage", + "StatePickleStorage", +] diff --git a/telebot/asyncio_storage/base_storage.py b/telebot/asyncio_storage/base_storage.py index 38615c4c2..20d0b0587 100644 --- a/telebot/asyncio_storage/base_storage.py +++ b/telebot/asyncio_storage/base_storage.py @@ -16,53 +16,108 @@ async def get_data(self, chat_id, user_id): Get data for a user in a particular chat. """ raise NotImplementedError - + async def set_state(self, chat_id, user_id, state): """ Set state for a particular user. - ! Note that you should create a - record if it does not exist, and + ! Note that you should create a + record if it does not exist, and if a record with state already exists, you need to update a record. """ raise NotImplementedError - + async def delete_state(self, chat_id, user_id): """ Delete state for a particular user. """ raise NotImplementedError - + async def reset_data(self, chat_id, user_id): """ Reset data for a particular user in a chat. """ raise NotImplementedError - + async def get_state(self, chat_id, user_id): raise NotImplementedError - + + def get_interactive_data(self, chat_id, user_id): + """ + Should be sync, but should provide a context manager + with __aenter__ and __aexit__ methods. + """ + raise NotImplementedError + async def save(self, chat_id, user_id, data): raise NotImplementedError + def _get_key( + self, + chat_id: int, + user_id: int, + prefix: str, + separator: str, + business_connection_id: str = None, + message_thread_id: int = None, + bot_id: int = None, + ) -> str: + """ + Convert parameters to a key. + """ + params = [prefix] + if bot_id: + params.append(str(bot_id)) + if business_connection_id: + params.append(business_connection_id) + if message_thread_id: + params.append(str(message_thread_id)) + params.append(str(chat_id)) + params.append(str(user_id)) -class StateContext: + return separator.join(params) + + +class StateDataContext: """ Class for data. """ - def __init__(self, obj, chat_id, user_id): + def __init__( + self, + obj, + chat_id, + user_id, + business_connection_id=None, + message_thread_id=None, + bot_id=None, + ): self.obj = obj self.data = None self.chat_id = chat_id self.user_id = user_id - - + self.bot_id = bot_id + self.business_connection_id = business_connection_id + self.message_thread_id = message_thread_id async def __aenter__(self): - self.data = copy.deepcopy(await self.obj.get_data(self.chat_id, self.user_id)) + data = await self.obj.get_data( + chat_id=self.chat_id, + user_id=self.user_id, + business_connection_id=self.business_connection_id, + message_thread_id=self.message_thread_id, + bot_id=self.bot_id, + ) + self.data = copy.deepcopy(data) return self.data async def __aexit__(self, exc_type, exc_val, exc_tb): - return await self.obj.save(self.chat_id, self.user_id, self.data) \ No newline at end of file + return await self.obj.save( + self.chat_id, + self.user_id, + self.data, + self.business_connection_id, + self.message_thread_id, + self.bot_id, + ) diff --git a/telebot/asyncio_storage/memory_storage.py b/telebot/asyncio_storage/memory_storage.py index 45c2ad914..e17d5fe8c 100644 --- a/telebot/asyncio_storage/memory_storage.py +++ b/telebot/asyncio_storage/memory_storage.py @@ -1,66 +1,225 @@ -from telebot.asyncio_storage.base_storage import StateStorageBase, StateContext +from telebot.asyncio_storage.base_storage import StateStorageBase, StateDataContext +from typing import Optional, Union + class StateMemoryStorage(StateStorageBase): - def __init__(self) -> None: - self.data = {} - # - # {chat_id: {user_id: {'state': None, 'data': {}}, ...}, ...} - - - async def set_state(self, chat_id, user_id, state): - if hasattr(state, 'name'): + """ + Memory storage for states. + + Stores states in memory as a dictionary. + + .. code-block:: python3 + + storage = StateMemoryStorage() + bot = AsyncTeleBot(token, storage=storage) + + :param separator: Separator for keys, default is ":". + :type separator: Optional[str] + + :param prefix: Prefix for keys, default is "telebot". + :type prefix: Optional[str] + """ + + def __init__( + self, separator: Optional[str] = ":", prefix: Optional[str] = "telebot" + ) -> None: + self.separator = separator + self.prefix = prefix + if not self.prefix: + raise ValueError("Prefix cannot be empty") + + self.data = ( + {} + ) # key: telebot:bot_id:business_connection_id:message_thread_id:chat_id:user_id + + async def set_state( + self, + chat_id: int, + user_id: int, + state: str, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + if hasattr(state, "name"): state = state.name - if chat_id in self.data: - if user_id in self.data[chat_id]: - self.data[chat_id][user_id]['state'] = state - return True - else: - self.data[chat_id][user_id] = {'state': state, 'data': {}} - return True - self.data[chat_id] = {user_id: {'state': state, 'data': {}}} + + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + + if self.data.get(_key) is None: + self.data[_key] = {"state": state, "data": {}} + else: + self.data[_key]["state"] = state + + return True + + async def get_state( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> Union[str, None]: + + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + + if self.data.get(_key) is None: + return None + + return self.data[_key]["state"] + + async def delete_state( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + + if self.data.get(_key) is None: + return False + + del self.data[_key] + return True + + async def set_data( + self, + chat_id: int, + user_id: int, + key: str, + value: Union[str, int, float, dict], + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + + if self.data.get(_key) is None: + raise RuntimeError(f"MemoryStorage: key {_key} does not exist.") + self.data[_key]["data"][key] = value + return True + + async def get_data( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> dict: + + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + + return self.data.get(_key, {}).get("data", {}) + + async def reset_data( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + + if self.data.get(_key) is None: + return False + self.data[_key]["data"] = {} return True - - async def delete_state(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - del self.data[chat_id][user_id] - if chat_id == user_id: - del self.data[chat_id] - - return True - - return False - - - async def get_state(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - return self.data[chat_id][user_id]['state'] - - return None - async def get_data(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - return self.data[chat_id][user_id]['data'] - - return None - - async def reset_data(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - self.data[chat_id][user_id]['data'] = {} - return True - return False - - async def set_data(self, chat_id, user_id, key, value): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - self.data[chat_id][user_id]['data'][key] = value - return True - raise RuntimeError('chat_id {} and user_id {} does not exist'.format(chat_id, user_id)) - - def get_interactive_data(self, chat_id, user_id): - return StateContext(self, chat_id, user_id) - - async def save(self, chat_id, user_id, data): - self.data[chat_id][user_id]['data'] = data \ No newline at end of file + + def get_interactive_data( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> Optional[dict]: + return StateDataContext( + self, + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + message_thread_id=message_thread_id, + bot_id=bot_id, + ) + + async def save( + self, + chat_id: int, + user_id: int, + data: dict, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + + if self.data.get(_key) is None: + return False + self.data[_key]["data"] = data + return True + + def __str__(self) -> str: + return f"" diff --git a/telebot/asyncio_storage/pickle_storage.py b/telebot/asyncio_storage/pickle_storage.py index cf446d85b..723672034 100644 --- a/telebot/asyncio_storage/pickle_storage.py +++ b/telebot/asyncio_storage/pickle_storage.py @@ -1,28 +1,74 @@ -from telebot.asyncio_storage.base_storage import StateStorageBase, StateContext -import os +aiofiles_installed = True +try: + import aiofiles +except ImportError: + aiofiles_installed = False +import os import pickle +import asyncio +from typing import Optional, Union, Callable, Any + +from telebot.asyncio_storage.base_storage import StateStorageBase, StateDataContext + + +def with_lock(func: Callable) -> Callable: + async def wrapper(self, *args, **kwargs): + async with self.lock: + return await func(self, *args, **kwargs) + + return wrapper class StatePickleStorage(StateStorageBase): - def __init__(self, file_path="./.state-save/states.pkl") -> None: + """ + State storage based on pickle file. + + .. warning:: + + This storage is not recommended for production use. + Data may be corrupted. If you face a case where states do not work as expected, + try to use another storage. + + .. code-block:: python3 + + storage = StatePickleStorage() + bot = AsyncTeleBot(token, storage=storage) + + :param file_path: Path to file where states will be stored. + :type file_path: str + + :param prefix: Prefix for keys, default is "telebot". + :type prefix: Optional[str] + + :param separator: Separator for keys, default is ":". + :type separator: Optional[str] + """ + + def __init__( + self, + file_path: str = "./.state-save/states.pkl", + prefix="telebot", + separator: Optional[str] = ":", + ) -> None: + + if not aiofiles_installed: + raise ImportError("Please install aiofiles using `pip install aiofiles`") + self.file_path = file_path + self.prefix = prefix + self.separator = separator + self.lock = asyncio.Lock() self.create_dir() - self.data = self.read() - - async def convert_old_to_new(self): - # old looks like: - # {1: {'state': 'start', 'data': {'name': 'John'}} - # we should update old version pickle to new. - # new looks like: - # {1: {2: {'state': 'start', 'data': {'name': 'John'}}}} - new_data = {} - for key, value in self.data.items(): - # this returns us id and dict with data and state - new_data[key] = {key: value} # convert this to new - # pass it to global data - self.data = new_data - self.update_data() # update data in file + + async def _read_from_file(self) -> dict: + async with aiofiles.open(self.file_path, "rb") as f: + data = await f.read() + return pickle.loads(data) + + async def _write_to_file(self, data: dict) -> None: + async with aiofiles.open(self.file_path, "wb") as f: + await f.write(pickle.dumps(data)) def create_dir(self): """ @@ -31,80 +77,198 @@ def create_dir(self): dirs, filename = os.path.split(self.file_path) os.makedirs(dirs, exist_ok=True) if not os.path.isfile(self.file_path): - with open(self.file_path,'wb') as file: + with open(self.file_path, "wb") as file: pickle.dump({}, file) - def read(self): - file = open(self.file_path, 'rb') - data = pickle.load(file) - file.close() - return data - - def update_data(self): - file = open(self.file_path, 'wb+') - pickle.dump(self.data, file, protocol=pickle.HIGHEST_PROTOCOL) - file.close() - - async def set_state(self, chat_id, user_id, state): - if hasattr(state, 'name'): - state = state.name - if chat_id in self.data: - if user_id in self.data[chat_id]: - self.data[chat_id][user_id]['state'] = state - self.update_data() - return True - else: - self.data[chat_id][user_id] = {'state': state, 'data': {}} - self.update_data() - return True - self.data[chat_id] = {user_id: {'state': state, 'data': {}}} - self.update_data() + @with_lock + async def set_state( + self, + chat_id: int, + user_id: int, + state: str, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + data = await self._read_from_file() + if _key not in data: + data[_key] = {"state": state, "data": {}} + else: + data[_key]["state"] = state + await self._write_to_file(data) return True - - async def delete_state(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - del self.data[chat_id][user_id] - if chat_id == user_id: - del self.data[chat_id] - self.update_data() - return True + @with_lock + async def get_state( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> Union[str, None]: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + data = await self._read_from_file() + return data.get(_key, {}).get("state") + + @with_lock + async def delete_state( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + data = await self._read_from_file() + if _key in data: + del data[_key] + await self._write_to_file(data) + return True return False - - async def get_state(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - return self.data[chat_id][user_id]['state'] - - return None - async def get_data(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - return self.data[chat_id][user_id]['data'] - - return None - - async def reset_data(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - self.data[chat_id][user_id]['data'] = {} - self.update_data() - return True + @with_lock + async def set_data( + self, + chat_id: int, + user_id: int, + key: str, + value: Union[str, int, float, dict], + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + data = await self._read_from_file() + state_data = data.get(_key, {}) + state_data["data"][key] = value + if _key not in data: + raise RuntimeError(f"StatePickleStorage: key {_key} does not exist.") + else: + data[_key]["data"][key] = value + await self._write_to_file(data) + return True + + @with_lock + async def get_data( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> dict: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + data = await self._read_from_file() + return data.get(_key, {}).get("data", {}) + + @with_lock + async def reset_data( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + data = await self._read_from_file() + if _key in data: + data[_key]["data"] = {} + await self._write_to_file(data) + return True return False - async def set_data(self, chat_id, user_id, key, value): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - self.data[chat_id][user_id]['data'][key] = value - self.update_data() - return True - raise RuntimeError('chat_id {} and user_id {} does not exist'.format(chat_id, user_id)) + def get_interactive_data( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> Optional[dict]: + return StateDataContext( + self, + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + message_thread_id=message_thread_id, + bot_id=bot_id, + ) - def get_interactive_data(self, chat_id, user_id): - return StateContext(self, chat_id, user_id) + @with_lock + async def save( + self, + chat_id: int, + user_id: int, + data: dict, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + data = await self._read_from_file() + data[_key]["data"] = data + await self._write_to_file(data) + return True - async def save(self, chat_id, user_id, data): - self.data[chat_id][user_id]['data'] = data - self.update_data() \ No newline at end of file + def __str__(self) -> str: + return f"StatePickleStorage({self.file_path}, {self.prefix})" diff --git a/telebot/asyncio_storage/redis_storage.py b/telebot/asyncio_storage/redis_storage.py index 84db253e5..08e255809 100644 --- a/telebot/asyncio_storage/redis_storage.py +++ b/telebot/asyncio_storage/redis_storage.py @@ -1,179 +1,339 @@ -from telebot.asyncio_storage.base_storage import StateStorageBase, StateContext -import json - redis_installed = True -is_actual_aioredis = False try: - import aioredis - is_actual_aioredis = True + import redis + from redis.asyncio import Redis, ConnectionPool except ImportError: - try: - from redis import asyncio as aioredis - except ImportError: - redis_installed = False + redis_installed = False + +import json +from typing import Optional, Union, Callable, Coroutine +import asyncio + +from telebot.asyncio_storage.base_storage import StateStorageBase, StateDataContext + + +def async_with_lock(func: Callable[..., Coroutine]) -> Callable[..., Coroutine]: + async def wrapper(self, *args, **kwargs): + async with self.lock: + return await func(self, *args, **kwargs) + + return wrapper + + +def async_with_pipeline(func: Callable[..., Coroutine]) -> Callable[..., Coroutine]: + async def wrapper(self, *args, **kwargs): + async with self.redis.pipeline() as pipe: + pipe.multi() + result = await func(self, pipe, *args, **kwargs) + await pipe.execute() + return result + + return wrapper class StateRedisStorage(StateStorageBase): """ - This class is for Redis storage. - This will work only for states. - To use it, just pass this class to: - TeleBot(storage=StateRedisStorage()) + State storage based on Redis. + + .. code-block:: python3 + + storage = StateRedisStorage(...) + bot = AsyncTeleBot(token, storage=storage) + + :param host: Redis host, default is "localhost". + :type host: str + + :param port: Redis port, default is 6379. + :type port: int + + :param db: Redis database, default is 0. + :type db: int + + :param password: Redis password, default is None. + :type password: Optional[str] + + :param prefix: Prefix for keys, default is "telebot". + :type prefix: Optional[str] + + :param redis_url: Redis URL, default is None. + :type redis_url: Optional[str] + + :param connection_pool: Redis connection pool, default is None. + :type connection_pool: Optional[ConnectionPool] + + :param separator: Separator for keys, default is ":". + :type separator: Optional[str] + """ - def __init__(self, host='localhost', port=6379, db=0, password=None, prefix='telebot_', redis_url=None): + + def __init__( + self, + host="localhost", + port=6379, + db=0, + password=None, + prefix="telebot", + redis_url=None, + connection_pool: "ConnectionPool" = None, + separator: Optional[str] = ":", + ) -> None: + if not redis_installed: - raise ImportError('AioRedis is not installed. Install it via "pip install aioredis"') + raise ImportError("Please install redis using `pip install redis`") + + self.separator = separator + self.prefix = prefix + if not self.prefix: + raise ValueError("Prefix cannot be empty") - if is_actual_aioredis: - aioredis_version = tuple(map(int, aioredis.__version__.split(".")[0])) - if aioredis_version < (2,): - raise ImportError('Invalid aioredis version. Aioredis version should be >= 2.0.0') if redis_url: - self.redis = aioredis.Redis.from_url(redis_url) + self.redis = redis.asyncio.from_url(redis_url) + elif connection_pool: + self.redis = Redis(connection_pool=connection_pool) else: - self.redis = aioredis.Redis(host=host, port=port, db=db, password=password) + self.redis = Redis(host=host, port=port, db=db, password=password) - self.prefix = prefix - #self.con = Redis(connection_pool=self.redis) -> use this when necessary - # - # {chat_id: {user_id: {'state': None, 'data': {}}, ...}, ...} - - async def get_record(self, key): - """ - Function to get record from database. - It has nothing to do with states. - Made for backward compatibility - """ - result = await self.redis.get(self.prefix+str(key)) - if result: return json.loads(result) - return + self.lock = asyncio.Lock() + + @async_with_lock + @async_with_pipeline + async def set_state( + self, + pipe, + chat_id: int, + user_id: int, + state: str, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + if hasattr(state, "name"): + state = state.name + + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + pipe.hget(_key, "data") + result = await pipe.execute() + data = result[0] + if data is None: + pipe.hset(_key, "data", json.dumps({})) + + await pipe.hset(_key, "state", state) - async def set_record(self, key, value): - """ - Function to set record to database. - It has nothing to do with states. - Made for backward compatibility - """ - - await self.redis.set(self.prefix+str(key), json.dumps(value)) return True - async def delete_record(self, key): - """ - Function to delete record from database. - It has nothing to do with states. - Made for backward compatibility - """ - await self.redis.delete(self.prefix+str(key)) + async def get_state( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> Union[str, None]: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + state_bytes = await self.redis.hget(_key, "state") + return state_bytes.decode("utf-8") if state_bytes else None + + async def delete_state( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + result = await self.redis.delete(_key) + return result > 0 + + @async_with_lock + @async_with_pipeline + async def set_data( + self, + pipe, + chat_id: int, + user_id: int, + key: str, + value: Union[str, int, float, dict], + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + data = await pipe.hget(_key, "data") + data = await pipe.execute() + data = data[0] + if data is None: + raise RuntimeError(f"StateRedisStorage: key {_key} does not exist.") + else: + data = json.loads(data) + data[key] = value + await pipe.hset(_key, "data", json.dumps(data)) return True - async def set_state(self, chat_id, user_id, state): - """ - Set state for a particular user in a chat. - """ - response = await self.get_record(chat_id) - user_id = str(user_id) - if hasattr(state, 'name'): - state = state.name - if response: - if user_id in response: - response[user_id]['state'] = state - else: - response[user_id] = {'state': state, 'data': {}} + async def get_data( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> dict: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + data = await self.redis.hget(_key, "data") + return json.loads(data) if data else {} + + @async_with_lock + @async_with_pipeline + async def reset_data( + self, + pipe, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + if await pipe.exists(_key): + await pipe.hset(_key, "data", "{}") else: - response = {user_id: {'state': state, 'data': {}}} - await self.set_record(chat_id, response) + return False + return True + def get_interactive_data( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> Optional[dict]: + return StateDataContext( + self, + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + message_thread_id=message_thread_id, + bot_id=bot_id, + ) + + @async_with_lock + @async_with_pipeline + async def save( + self, + pipe, + chat_id: int, + user_id: int, + data: dict, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + if await pipe.exists(_key): + await pipe.hset(_key, "data", json.dumps(data)) + else: + return False return True - - async def delete_state(self, chat_id, user_id): - """ - Delete state for a particular user in a chat. - """ - response = await self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - del response[user_id] - if user_id == str(chat_id): - await self.delete_record(chat_id) - return True - else: await self.set_record(chat_id, response) - return True - return False - - async def get_value(self, chat_id, user_id, key): - """ - Get value for a data of a user in a chat. - """ - response = await self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - if key in response[user_id]['data']: - return response[user_id]['data'][key] - return None - - async def get_state(self, chat_id, user_id): - """ - Get state of a user in a chat. + + def migrate_format(self, bot_id: int, prefix: Optional[str] = "telebot_"): """ - response = await self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - return response[user_id]['state'] + Migrate from old to new format of keys. + Run this function once to migrate all redis existing keys to new format. - return None + Starting from version 4.23.0, the format of keys has been changed: + :value + - Old format: {prefix}chat_id: {user_id: {'state': None, 'data': {}}, ...} + - New format: + {prefix}{separator}{bot_id}{separator}{business_connection_id}{separator}{message_thread_id}{separator}{chat_id}{separator}{user_id}: {'state': ..., 'data': {}} - async def get_data(self, chat_id, user_id): - """ - Get data of particular user in a particular chat. - """ - response = await self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - return response[user_id]['data'] - return None - - async def reset_data(self, chat_id, user_id): - """ - Reset data of a user in a chat. - """ - response = await self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - response[user_id]['data'] = {} - await self.set_record(chat_id, response) - return True - - async def set_data(self, chat_id, user_id, key, value): - """ - Set data without interactive data. - """ - response = await self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - response[user_id]['data'][key] = value - await self.set_record(chat_id, response) - return True - return False - - def get_interactive_data(self, chat_id, user_id): - """ - Get Data in interactive way. - You can use with() with this function. + This function will help you to migrate from the old format to the new one in order to avoid data loss. + + :param bot_id: Bot ID; To get it, call a getMe request and grab the id from the response. + :type bot_id: int + + :param prefix: Prefix for keys, default is "telebot_"(old default value) + :type prefix: Optional[str] """ - return StateContext(self, chat_id, user_id) - - async def save(self, chat_id, user_id, data): - response = await self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - response[user_id]['data'] = data - await self.set_record(chat_id, response) - return True + keys = self.redis.keys(f"{prefix}*") + + for key in keys: + old_key = key.decode("utf-8") + # old: {prefix}chat_id: {user_id: {'state': None, 'data': {}}, ...} + value = self.redis.get(old_key) + value = json.loads(value) + + chat_id = old_key[len(prefix) :] + user_id = list(value.keys())[0] + state = value[user_id]["state"] + state_data = value[user_id]["data"] + + # set new format + new_key = self._get_key( + int(chat_id), int(user_id), self.prefix, self.separator, bot_id=bot_id + ) + self.redis.hset(new_key, "state", state) + self.redis.hset(new_key, "data", json.dumps(state_data)) + + # delete old key + self.redis.delete(old_key) + + def __str__(self) -> str: + # include some connection info + return f"StateRedisStorage({self.redis})" diff --git a/telebot/custom_filters.py b/telebot/custom_filters.py index a1c0aba14..2e91e55ea 100644 --- a/telebot/custom_filters.py +++ b/telebot/custom_filters.py @@ -4,6 +4,7 @@ from telebot import types +from telebot.states import resolve_context class SimpleCustomFilter(ABC): """ @@ -402,21 +403,11 @@ def check(self, message, text): """ :meta private: """ - if text == '*': return True - # needs to work with callbackquery - if isinstance(message, types.Message): - chat_id = message.chat.id - user_id = message.from_user.id + chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(message, self.bot._user.id) - if isinstance(message, types.CallbackQuery): - - chat_id = message.message.chat.id - user_id = message.from_user.id - message = message.message - - - + if chat_id is None: + chat_id = user_id # May change in future if isinstance(text, list): new_text = [] @@ -427,21 +418,23 @@ def check(self, message, text): elif isinstance(text, State): text = text.name - if message.chat.type in ['group', 'supergroup']: - group_state = self.bot.current_states.get_state(chat_id, user_id) - if group_state == text: - return True - elif isinstance(text, list) and group_state in text: - return True - - - else: - user_state = self.bot.current_states.get_state(chat_id, user_id) - if user_state == text: - return True - elif isinstance(text, list) and user_state in text: - return True - + user_state = self.bot.current_states.get_state( + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + bot_id=bot_id, + message_thread_id=message_thread_id + ) + + # CHANGED BEHAVIOUR + if text == "*" and user_state is not None: + return True + + if user_state == text: + return True + elif type(text) is list and user_state in text: + return True + return False class IsDigitFilter(SimpleCustomFilter): """ diff --git a/telebot/handler_backends.py b/telebot/handler_backends.py index b95861e0b..12f8c89bf 100644 --- a/telebot/handler_backends.py +++ b/telebot/handler_backends.py @@ -9,6 +9,8 @@ except: redis_installed = False +# backward compatibility +from telebot.states import State, StatesGroup class HandlerBackend(object): """ @@ -160,45 +162,6 @@ def get_handlers(self, handler_group_id): return handlers -class State: - """ - Class representing a state. - - .. code-block:: python3 - - class MyStates(StatesGroup): - my_state = State() # returns my_state:State string. - """ - def __init__(self) -> None: - self.name = None - def __str__(self) -> str: - return self.name - - -class StatesGroup: - """ - Class representing common states. - - .. code-block:: python3 - - class MyStates(StatesGroup): - my_state = State() # returns my_state:State string. - """ - def __init_subclass__(cls) -> None: - state_list = [] - for name, value in cls.__dict__.items(): - if not name.startswith('__') and not callable(value) and isinstance(value, State): - # change value of that variable - value.name = ':'.join((cls.__name__, name)) - value.group = cls - state_list.append(value) - cls._state_list = state_list - - @classmethod - def state_list(self): - return self._state_list - - class BaseMiddleware: """ Base class for middleware. @@ -292,4 +255,4 @@ def start2(message): """ def __init__(self) -> None: - pass + pass \ No newline at end of file diff --git a/telebot/states/__init__.py b/telebot/states/__init__.py new file mode 100644 index 000000000..2491608ee --- /dev/null +++ b/telebot/states/__init__.py @@ -0,0 +1,128 @@ +""" +Contains classes for states and state groups. +""" + +from telebot import types + + +class State: + """ + Class representing a state. + + .. code-block:: python3 + + class MyStates(StatesGroup): + my_state = State() # returns my_state:State string. + """ + + def __init__(self) -> None: + self.name: str = None + self.group: StatesGroup = None + + def __str__(self) -> str: + return f"<{self.name}>" + + +class StatesGroup: + """ + Class representing common states. + + .. code-block:: python3 + + class MyStates(StatesGroup): + my_state = State() # returns my_state:State string. + """ + + def __init_subclass__(cls) -> None: + state_list = [] + for name, value in cls.__dict__.items(): + if ( + not name.startswith("__") + and not callable(value) + and isinstance(value, State) + ): + # change value of that variable + value.name = ":".join((cls.__name__, name)) + value.group = cls + state_list.append(value) + cls._state_list = state_list + + @classmethod + def state_list(self): + return self._state_list + + +def resolve_context(message, bot_id: int) -> tuple: + # chat_id, user_id, business_connection_id, bot_id, message_thread_id + + # message, edited_message, channel_post, edited_channel_post, business_message, edited_business_message + if isinstance(message, types.Message): + return ( + message.chat.id, + message.from_user.id, + message.business_connection_id, + bot_id, + message.message_thread_id if message.is_topic_message else None, + ) + elif isinstance(message, types.CallbackQuery): # callback_query + return ( + message.message.chat.id, + message.from_user.id, + message.message.business_connection_id, + bot_id, + ( + message.message.message_thread_id + if message.message.is_topic_message + else None + ), + ) + elif isinstance(message, types.BusinessConnection): # business_connection + return (message.user_chat_id, message.user.id, message.id, bot_id, None) + elif isinstance( + message, types.BusinessMessagesDeleted + ): # deleted_business_messages + return ( + message.chat.id, + message.chat.id, + message.business_connection_id, + bot_id, + None, + ) + elif isinstance(message, types.MessageReactionUpdated): # message_reaction + return (message.chat.id, message.user.id, None, bot_id, None) + elif isinstance( + message, types.MessageReactionCountUpdated + ): # message_reaction_count + return (message.chat.id, None, None, bot_id, None) + elif isinstance(message, types.InlineQuery): # inline_query + return (None, message.from_user.id, None, bot_id, None) + elif isinstance(message, types.ChosenInlineResult): # chosen_inline_result + return (None, message.from_user.id, None, bot_id, None) + elif isinstance(message, types.ShippingQuery): # shipping_query + return (None, message.from_user.id, None, bot_id, None) + elif isinstance(message, types.PreCheckoutQuery): # pre_checkout_query + return (None, message.from_user.id, None, bot_id, None) + elif isinstance(message, types.PollAnswer): # poll_answer + return (None, message.user.id, None, bot_id, None) + elif isinstance(message, types.ChatMemberUpdated): # chat_member # my_chat_member + return (message.chat.id, message.from_user.id, None, bot_id, None) + elif isinstance(message, types.ChatJoinRequest): # chat_join_request + return (message.chat.id, message.from_user.id, None, bot_id, None) + elif isinstance(message, types.ChatBoostRemoved): # removed_chat_boost + return ( + message.chat.id, + message.source.user.id if message.source else None, + None, + bot_id, + None, + ) + elif isinstance(message, types.ChatBoostUpdated): # chat_boost + return ( + message.chat.id, + message.boost.source.user.id if message.boost.source else None, + None, + bot_id, + None, + ) + else: + pass # not yet supported :( diff --git a/telebot/states/asyncio/__init__.py b/telebot/states/asyncio/__init__.py new file mode 100644 index 000000000..14ef19536 --- /dev/null +++ b/telebot/states/asyncio/__init__.py @@ -0,0 +1,7 @@ +from .context import StateContext +from .middleware import StateMiddleware + +__all__ = [ + "StateContext", + "StateMiddleware", +] diff --git a/telebot/states/asyncio/context.py b/telebot/states/asyncio/context.py new file mode 100644 index 000000000..4c9ad61a3 --- /dev/null +++ b/telebot/states/asyncio/context.py @@ -0,0 +1,153 @@ +from telebot.states import State, StatesGroup +from telebot.types import CallbackQuery, Message +from telebot.async_telebot import AsyncTeleBot +from telebot.states import resolve_context + +from typing import Union + + +class StateContext: + """ + Class representing a state context. + + Passed through a middleware to provide easy way to set states. + + .. code-block:: python3 + + @bot.message_handler(commands=['start']) + async def start_ex(message: types.Message, state_context: StateContext): + await state_context.set(MyStates.name) + await bot.send_message(message.chat.id, 'Hi, write me a name', reply_to_message_id=message.message_id) + # also, state_context.data(), .add_data(), .reset_data(), .delete() methods available. + """ + + def __init__(self, message: Union[Message, CallbackQuery], bot: str) -> None: + self.message: Union[Message, CallbackQuery] = message + self.bot: AsyncTeleBot = bot + self.bot_id = self.bot.bot_id + + async def set(self, state: Union[State, str]) -> bool: + """ + Set state for current user. + + :param state: State object or state name. + :type state: Union[State, str] + + .. code-block:: python3 + + @bot.message_handler(commands=['start']) + async def start_ex(message: types.Message, state_context: StateContext): + await state_context.set(MyStates.name) + await bot.send_message(message.chat.id, 'Hi, write me a name', reply_to_message_id=message.message_id) + """ + + chat_id, user_id, business_connection_id, bot_id, message_thread_id = ( + resolve_context(self.message, self.bot.bot_id) + ) + if isinstance(state, State): + state = state.name + return await self.bot.set_state( + chat_id=chat_id, + user_id=user_id, + state=state, + business_connection_id=business_connection_id, + bot_id=bot_id, + message_thread_id=message_thread_id, + ) + + async def get(self) -> str: + """ + Get current state for current user. + + :return: Current state name. + :rtype: str + """ + + chat_id, user_id, business_connection_id, bot_id, message_thread_id = ( + resolve_context(self.message, self.bot.bot_id) + ) + return await self.bot.get_state( + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + bot_id=bot_id, + message_thread_id=message_thread_id, + ) + + async def delete(self) -> bool: + """ + Deletes state and data for current user. + + .. warning:: + + This method deletes state and associated data for current user. + """ + chat_id, user_id, business_connection_id, bot_id, message_thread_id = ( + resolve_context(self.message, self.bot.bot_id) + ) + return await self.bot.delete_state( + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + bot_id=bot_id, + message_thread_id=message_thread_id, + ) + + async def reset_data(self) -> bool: + """ + Reset data for current user. + State will not be changed. + """ + + chat_id, user_id, business_connection_id, bot_id, message_thread_id = ( + resolve_context(self.message, self.bot.bot_id) + ) + return await self.bot.reset_data( + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + bot_id=bot_id, + message_thread_id=message_thread_id, + ) + + def data(self) -> dict: + """ + Get data for current user. + + .. code-block:: python3 + + with state_context.data() as data: + print(data) + data['name'] = 'John' + """ + + chat_id, user_id, business_connection_id, bot_id, message_thread_id = ( + resolve_context(self.message, self.bot.bot_id) + ) + return self.bot.retrieve_data( + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + bot_id=bot_id, + message_thread_id=message_thread_id, + ) + + async def add_data(self, **kwargs) -> None: + """ + Add data for current user. + + :param kwargs: Data to add. + :type kwargs: dict + """ + + chat_id, user_id, business_connection_id, bot_id, message_thread_id = ( + resolve_context(self.message, self.bot.bot_id) + ) + return await self.bot.add_data( + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + bot_id=bot_id, + message_thread_id=message_thread_id, + **kwargs + ) diff --git a/telebot/states/asyncio/middleware.py b/telebot/states/asyncio/middleware.py new file mode 100644 index 000000000..675b7b462 --- /dev/null +++ b/telebot/states/asyncio/middleware.py @@ -0,0 +1,21 @@ +from telebot.asyncio_handler_backends import BaseMiddleware +from telebot.async_telebot import AsyncTeleBot +from telebot.states.sync.context import StateContext +from telebot.util import update_types +from telebot import types + + +class StateMiddleware(BaseMiddleware): + + def __init__(self, bot: AsyncTeleBot) -> None: + self.update_sensitive = False + self.update_types = update_types + self.bot: AsyncTeleBot = bot + + async def pre_process(self, message, data): + state_context = StateContext(message, self.bot) + data["state_context"] = state_context + data["state"] = state_context # 2 ways to access state context + + async def post_process(self, message, data, exception): + pass diff --git a/telebot/states/sync/__init__.py b/telebot/states/sync/__init__.py new file mode 100644 index 000000000..60f9ea2b7 --- /dev/null +++ b/telebot/states/sync/__init__.py @@ -0,0 +1,7 @@ +from .context import StateContext +from .middleware import StateMiddleware + +__all__ = [ + 'StateContext', + 'StateMiddleware', +] \ No newline at end of file diff --git a/telebot/states/sync/context.py b/telebot/states/sync/context.py new file mode 100644 index 000000000..c0611f6bb --- /dev/null +++ b/telebot/states/sync/context.py @@ -0,0 +1,143 @@ +from telebot.states import State, StatesGroup +from telebot.types import CallbackQuery, Message +from telebot import TeleBot, types +from telebot.states import resolve_context + +from typing import Union + + + +class StateContext(): + """ + Class representing a state context. + + Passed through a middleware to provide easy way to set states. + + .. code-block:: python3 + + @bot.message_handler(commands=['start']) + def start_ex(message: types.Message, state_context: StateContext): + state_context.set(MyStates.name) + bot.send_message(message.chat.id, 'Hi, write me a name', reply_to_message_id=message.message_id) + # also, state_context.data(), .add_data(), .reset_data(), .delete() methods available. + """ + + def __init__(self, message: Union[Message, CallbackQuery], bot: str) -> None: + self.message: Union[Message, CallbackQuery] = message + self.bot: TeleBot = bot + self.bot_id = self.bot.bot_id + + def set(self, state: Union[State, str]) -> bool: + """ + Set state for current user. + + :param state: State object or state name. + :type state: Union[State, str] + + .. code-block:: python3 + + @bot.message_handler(commands=['start']) + def start_ex(message: types.Message, state_context: StateContext): + state_context.set(MyStates.name) + bot.send_message(message.chat.id, 'Hi, write me a name', reply_to_message_id=message.message_id) + """ + + chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) + if isinstance(state, State): + state = state.name + return self.bot.set_state( + chat_id=chat_id, + user_id=user_id, + state=state, + business_connection_id=business_connection_id, + bot_id=bot_id, + message_thread_id=message_thread_id + ) + + def get(self) -> str: + """ + Get current state for current user. + + :return: Current state name. + :rtype: str + """ + + chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) + return self.bot.get_state( + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + bot_id=bot_id, + message_thread_id=message_thread_id + ) + + def delete(self) -> bool: + """ + Deletes state and data for current user. + + .. warning:: + + This method deletes state and associated data for current user. + """ + chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) + return self.bot.delete_state( + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + bot_id=bot_id, + message_thread_id=message_thread_id + ) + + def reset_data(self) -> bool: + """ + Reset data for current user. + State will not be changed. + """ + + chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) + return self.bot.reset_data( + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + bot_id=bot_id, + message_thread_id=message_thread_id + ) + + def data(self) -> dict: + """ + Get data for current user. + + .. code-block:: python3 + + with state_context.data() as data: + print(data) + data['name'] = 'John' + """ + + chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) + return self.bot.retrieve_data( + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + bot_id=bot_id, + message_thread_id=message_thread_id + ) + + def add_data(self, **kwargs) -> None: + """ + Add data for current user. + + :param kwargs: Data to add. + :type kwargs: dict + """ + + chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id) + return self.bot.add_data( + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + bot_id=bot_id, + message_thread_id=message_thread_id, + **kwargs + ) + \ No newline at end of file diff --git a/telebot/states/sync/middleware.py b/telebot/states/sync/middleware.py new file mode 100644 index 000000000..4a252158b --- /dev/null +++ b/telebot/states/sync/middleware.py @@ -0,0 +1,21 @@ +from telebot.handler_backends import BaseMiddleware +from telebot import TeleBot +from telebot.states.sync.context import StateContext +from telebot.util import update_types +from telebot import types + + +class StateMiddleware(BaseMiddleware): + + def __init__(self, bot: TeleBot) -> None: + self.update_sensitive = False + self.update_types = update_types + self.bot: TeleBot = bot + + def pre_process(self, message, data): + state_context = StateContext(message, self.bot) + data['state_context'] = state_context + data['state'] = state_context # 2 ways to access state context + + def post_process(self, message, data, exception): + pass diff --git a/telebot/storage/__init__.py b/telebot/storage/__init__.py index 59e2b058c..bb3dd7f07 100644 --- a/telebot/storage/__init__.py +++ b/telebot/storage/__init__.py @@ -1,13 +1,13 @@ from telebot.storage.memory_storage import StateMemoryStorage from telebot.storage.redis_storage import StateRedisStorage from telebot.storage.pickle_storage import StatePickleStorage -from telebot.storage.base_storage import StateContext,StateStorageBase - - - +from telebot.storage.base_storage import StateDataContext, StateStorageBase __all__ = [ - 'StateStorageBase', 'StateContext', - 'StateMemoryStorage', 'StateRedisStorage', 'StatePickleStorage' -] \ No newline at end of file + "StateStorageBase", + "StateDataContext", + "StateMemoryStorage", + "StateRedisStorage", + "StatePickleStorage", +] diff --git a/telebot/storage/base_storage.py b/telebot/storage/base_storage.py index 92b31ba85..9956c7ab9 100644 --- a/telebot/storage/base_storage.py +++ b/telebot/storage/base_storage.py @@ -1,5 +1,6 @@ import copy + class StateStorageBase: def __init__(self) -> None: pass @@ -15,30 +16,30 @@ def get_data(self, chat_id, user_id): Get data for a user in a particular chat. """ raise NotImplementedError - + def set_state(self, chat_id, user_id, state): """ Set state for a particular user. - ! Note that you should create a - record if it does not exist, and + ! Note that you should create a + record if it does not exist, and if a record with state already exists, you need to update a record. """ raise NotImplementedError - + def delete_state(self, chat_id, user_id): """ Delete state for a particular user. """ raise NotImplementedError - + def reset_data(self, chat_id, user_id): """ Reset data for a particular user in a chat. """ raise NotImplementedError - + def get_state(self, chat_id, user_id): raise NotImplementedError @@ -48,21 +49,70 @@ def get_interactive_data(self, chat_id, user_id): def save(self, chat_id, user_id, data): raise NotImplementedError + def _get_key( + self, + chat_id: int, + user_id: int, + prefix: str, + separator: str, + business_connection_id: str = None, + message_thread_id: int = None, + bot_id: int = None, + ) -> str: + """ + Convert parameters to a key. + """ + params = [prefix] + if bot_id: + params.append(str(bot_id)) + if business_connection_id: + params.append(business_connection_id) + if message_thread_id: + params.append(str(message_thread_id)) + params.append(str(chat_id)) + params.append(str(user_id)) + + return separator.join(params) -class StateContext: +class StateDataContext: """ Class for data. """ - def __init__(self , obj, chat_id, user_id) -> None: + + def __init__( + self, + obj, + chat_id, + user_id, + business_connection_id=None, + message_thread_id=None, + bot_id=None, + ): self.obj = obj - self.data = copy.deepcopy(obj.get_data(chat_id, user_id)) + res = obj.get_data( + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + message_thread_id=message_thread_id, + bot_id=bot_id, + ) + self.data = copy.deepcopy(res) self.chat_id = chat_id self.user_id = user_id - + self.bot_id = bot_id + self.business_connection_id = business_connection_id + self.message_thread_id = message_thread_id def __enter__(self): return self.data def __exit__(self, exc_type, exc_val, exc_tb): - return self.obj.save(self.chat_id, self.user_id, self.data) \ No newline at end of file + return self.obj.save( + self.chat_id, + self.user_id, + self.data, + self.business_connection_id, + self.message_thread_id, + self.bot_id, + ) diff --git a/telebot/storage/memory_storage.py b/telebot/storage/memory_storage.py index 7d71c7ccd..21266632c 100644 --- a/telebot/storage/memory_storage.py +++ b/telebot/storage/memory_storage.py @@ -1,69 +1,225 @@ -from telebot.storage.base_storage import StateStorageBase, StateContext +from telebot.storage.base_storage import StateStorageBase, StateDataContext +from typing import Optional, Union class StateMemoryStorage(StateStorageBase): - def __init__(self) -> None: - super().__init__() - self.data = {} - # - # {chat_id: {user_id: {'state': None, 'data': {}}, ...}, ...} - - - def set_state(self, chat_id, user_id, state): - if hasattr(state, 'name'): + """ + Memory storage for states. + + Stores states in memory as a dictionary. + + .. code-block:: python3 + + storage = StateMemoryStorage() + bot = TeleBot(token, storage=storage) + + :param separator: Separator for keys, default is ":". + :type separator: Optional[str] + + :param prefix: Prefix for keys, default is "telebot". + :type prefix: Optional[str] + """ + + def __init__( + self, separator: Optional[str] = ":", prefix: Optional[str] = "telebot" + ) -> None: + self.separator = separator + self.prefix = prefix + if not self.prefix: + raise ValueError("Prefix cannot be empty") + + self.data = ( + {} + ) # key: telebot:bot_id:business_connection_id:message_thread_id:chat_id:user_id + + def set_state( + self, + chat_id: int, + user_id: int, + state: str, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + if hasattr(state, "name"): state = state.name - if chat_id in self.data: - if user_id in self.data[chat_id]: - self.data[chat_id][user_id]['state'] = state - return True - else: - self.data[chat_id][user_id] = {'state': state, 'data': {}} - return True - self.data[chat_id] = {user_id: {'state': state, 'data': {}}} + + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + + if self.data.get(_key) is None: + self.data[_key] = {"state": state, "data": {}} + else: + self.data[_key]["state"] = state + return True - - def delete_state(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - del self.data[chat_id][user_id] - if chat_id == user_id: - del self.data[chat_id] - - return True - - return False - - - def get_state(self, chat_id, user_id): - - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - return self.data[chat_id][user_id]['state'] - - return None - def get_data(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - return self.data[chat_id][user_id]['data'] - - return None - - def reset_data(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - self.data[chat_id][user_id]['data'] = {} - return True - return False - - def set_data(self, chat_id, user_id, key, value): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - self.data[chat_id][user_id]['data'][key] = value - return True - raise RuntimeError('chat_id {} and user_id {} does not exist'.format(chat_id, user_id)) - - def get_interactive_data(self, chat_id, user_id): - return StateContext(self, chat_id, user_id) - - def save(self, chat_id, user_id, data): - self.data[chat_id][user_id]['data'] = data \ No newline at end of file + + def get_state( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> Union[str, None]: + + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + + if self.data.get(_key) is None: + return None + + return self.data[_key]["state"] + + def delete_state( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + + if self.data.get(_key) is None: + return False + + del self.data[_key] + return True + + def set_data( + self, + chat_id: int, + user_id: int, + key: str, + value: Union[str, int, float, dict], + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + + if self.data.get(_key) is None: + raise RuntimeError(f"StateMemoryStorage: key {_key} does not exist.") + self.data[_key]["data"][key] = value + return True + + def get_data( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> dict: + + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + + return self.data.get(_key, {}).get("data", {}) + + def reset_data( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + + if self.data.get(_key) is None: + return False + self.data[_key]["data"] = {} + return True + + def get_interactive_data( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> Optional[dict]: + return StateDataContext( + self, + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + message_thread_id=message_thread_id, + bot_id=bot_id, + ) + + def save( + self, + chat_id: int, + user_id: int, + data: dict, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + + if self.data.get(_key) is None: + return False + self.data[_key]["data"] = data + return True + + def __str__(self) -> str: + return f"" diff --git a/telebot/storage/pickle_storage.py b/telebot/storage/pickle_storage.py index 68c9fbed5..e11a05043 100644 --- a/telebot/storage/pickle_storage.py +++ b/telebot/storage/pickle_storage.py @@ -1,34 +1,63 @@ -from telebot.storage.base_storage import StateStorageBase, StateContext import os - import pickle +import threading +from typing import Optional, Union, Callable +from telebot.storage.base_storage import StateStorageBase, StateDataContext + + +def with_lock(func: Callable) -> Callable: + def wrapper(self, *args, **kwargs): + with self.lock: + return func(self, *args, **kwargs) + + return wrapper class StatePickleStorage(StateStorageBase): - def __init__(self, file_path="./.state-save/states.pkl") -> None: - super().__init__() + """ + State storage based on pickle file. + + .. warning:: + + This storage is not recommended for production use. + Data may be corrupted. If you face a case where states do not work as expected, + try to use another storage. + + .. code-block:: python3 + + storage = StatePickleStorage() + bot = TeleBot(token, storage=storage) + + :param file_path: Path to file where states will be stored. + :type file_path: str + + :param prefix: Prefix for keys, default is "telebot". + :type prefix: Optional[str] + + :param separator: Separator for keys, default is ":". + :type separator: Optional[str] + """ + + def __init__( + self, + file_path: str = "./.state-save/states.pkl", + prefix="telebot", + separator: Optional[str] = ":", + ) -> None: self.file_path = file_path + self.prefix = prefix + self.separator = separator + self.lock = threading.Lock() + self.create_dir() - self.data = self.read() - def convert_old_to_new(self): - """ - Use this function to convert old storage to new storage. - This function is for people who was using pickle storage - that was in version <=4.3.1. - """ - # old looks like: - # {1: {'state': 'start', 'data': {'name': 'John'}} - # we should update old version pickle to new. - # new looks like: - # {1: {2: {'state': 'start', 'data': {'name': 'John'}}}} - new_data = {} - for key, value in self.data.items(): - # this returns us id and dict with data and state - new_data[key] = {key: value} # convert this to new - # pass it to global data - self.data = new_data - self.update_data() # update data in file + def _read_from_file(self) -> dict: + with open(self.file_path, "rb") as f: + return pickle.load(f) + + def _write_to_file(self, data: dict) -> None: + with open(self.file_path, "wb") as f: + pickle.dump(data, f) def create_dir(self): """ @@ -37,80 +66,198 @@ def create_dir(self): dirs, filename = os.path.split(self.file_path) os.makedirs(dirs, exist_ok=True) if not os.path.isfile(self.file_path): - with open(self.file_path,'wb') as file: + with open(self.file_path, "wb") as file: pickle.dump({}, file) - def read(self): - file = open(self.file_path, 'rb') - data = pickle.load(file) - file.close() - return data - - def update_data(self): - file = open(self.file_path, 'wb+') - pickle.dump(self.data, file, protocol=pickle.HIGHEST_PROTOCOL) - file.close() - - def set_state(self, chat_id, user_id, state): - if hasattr(state, 'name'): - state = state.name - if chat_id in self.data: - if user_id in self.data[chat_id]: - self.data[chat_id][user_id]['state'] = state - self.update_data() - return True - else: - self.data[chat_id][user_id] = {'state': state, 'data': {}} - self.update_data() - return True - self.data[chat_id] = {user_id: {'state': state, 'data': {}}} - self.update_data() + @with_lock + def set_state( + self, + chat_id: int, + user_id: int, + state: str, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + data = self._read_from_file() + if _key not in data: + data[_key] = {"state": state, "data": {}} + else: + data[_key]["state"] = state + self._write_to_file(data) return True - - def delete_state(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - del self.data[chat_id][user_id] - if chat_id == user_id: - del self.data[chat_id] - self.update_data() - return True + @with_lock + def get_state( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> Union[str, None]: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + data = self._read_from_file() + return data.get(_key, {}).get("state") + + @with_lock + def delete_state( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + data = self._read_from_file() + if _key in data: + del data[_key] + self._write_to_file(data) + return True return False - - def get_state(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - return self.data[chat_id][user_id]['state'] - - return None - def get_data(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - return self.data[chat_id][user_id]['data'] - - return None - - def reset_data(self, chat_id, user_id): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - self.data[chat_id][user_id]['data'] = {} - self.update_data() - return True + @with_lock + def set_data( + self, + chat_id: int, + user_id: int, + key: str, + value: Union[str, int, float, dict], + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + data = self._read_from_file() + state_data = data.get(_key, {}) + state_data["data"][key] = value + + if _key not in data: + raise RuntimeError(f"PickleStorage: key {_key} does not exist.") + + self._write_to_file(data) + return True + + @with_lock + def get_data( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> dict: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + data = self._read_from_file() + return data.get(_key, {}).get("data", {}) + + @with_lock + def reset_data( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + data = self._read_from_file() + if _key in data: + data[_key]["data"] = {} + self._write_to_file(data) + return True return False - def set_data(self, chat_id, user_id, key, value): - if self.data.get(chat_id): - if self.data[chat_id].get(user_id): - self.data[chat_id][user_id]['data'][key] = value - self.update_data() - return True - raise RuntimeError('chat_id {} and user_id {} does not exist'.format(chat_id, user_id)) + def get_interactive_data( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> Optional[dict]: + return StateDataContext( + self, + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + message_thread_id=message_thread_id, + bot_id=bot_id, + ) - def get_interactive_data(self, chat_id, user_id): - return StateContext(self, chat_id, user_id) + @with_lock + def save( + self, + chat_id: int, + user_id: int, + data: dict, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + data = self._read_from_file() + data[_key]["data"] = data + self._write_to_file(data) + return True - def save(self, chat_id, user_id, data): - self.data[chat_id][user_id]['data'] = data - self.update_data() + def __str__(self) -> str: + return f"StatePickleStorage({self.file_path}, {self.prefix})" diff --git a/telebot/storage/redis_storage.py b/telebot/storage/redis_storage.py index 3fac57c46..c9935ac5e 100644 --- a/telebot/storage/redis_storage.py +++ b/telebot/storage/redis_storage.py @@ -1,183 +1,324 @@ -from telebot.storage.base_storage import StateStorageBase, StateContext import json +from telebot.storage.base_storage import StateStorageBase, StateDataContext +from typing import Optional, Union redis_installed = True try: - from redis import Redis, ConnectionPool - -except: + import redis +except ImportError: redis_installed = False + class StateRedisStorage(StateStorageBase): """ - This class is for Redis storage. - This will work only for states. - To use it, just pass this class to: - TeleBot(storage=StateRedisStorage()) + State storage based on Redis. + + .. code-block:: python3 + + storage = StateRedisStorage(...) + bot = TeleBot(token, storage=storage) + + :param host: Redis host, default is "localhost". + :type host: str + + :param port: Redis port, default is 6379. + :type port: int + + :param db: Redis database, default is 0. + :type db: int + + :param password: Redis password, default is None. + :type password: Optional[str] + + :param prefix: Prefix for keys, default is "telebot". + :type prefix: Optional[str] + + :param redis_url: Redis URL, default is None. + :type redis_url: Optional[str] + + :param connection_pool: Redis connection pool, default is None. + :type connection_pool: Optional[ConnectionPool] + + :param separator: Separator for keys, default is ":". + :type separator: Optional[str] + """ - def __init__(self, host='localhost', port=6379, db=0, password=None, prefix='telebot_', redis_url=None): - super().__init__() + + def __init__( + self, + host="localhost", + port=6379, + db=0, + password=None, + prefix="telebot", + redis_url=None, + connection_pool: "redis.ConnectionPool" = None, + separator: Optional[str] = ":", + ) -> None: + + if not redis_installed: + raise ImportError( + "Redis is not installed. Please install it via pip install redis" + ) + + self.separator = separator + self.prefix = prefix + if not self.prefix: + raise ValueError("Prefix cannot be empty") + if redis_url: - self.redis = ConnectionPool.from_url(redis_url) + self.redis = redis.Redis.from_url(redis_url) + elif connection_pool: + self.redis = redis.Redis(connection_pool=connection_pool) else: - self.redis = ConnectionPool(host=host, port=port, db=db, password=password) - #self.con = Redis(connection_pool=self.redis) -> use this when necessary - # - # {chat_id: {user_id: {'state': None, 'data': {}}, ...}, ...} - self.prefix = prefix - if not redis_installed: - raise Exception("Redis is not installed. Install it via 'pip install redis'") - - def get_record(self, key): - """ - Function to get record from database. - It has nothing to do with states. - Made for backward compatibility - """ - connection = Redis(connection_pool=self.redis) - result = connection.get(self.prefix+str(key)) - connection.close() - if result: return json.loads(result) - return + self.redis = redis.Redis(host=host, port=port, db=db, password=password) - def set_record(self, key, value): - """ - Function to set record to database. - It has nothing to do with states. - Made for backward compatibility - """ - connection = Redis(connection_pool=self.redis) - connection.set(self.prefix+str(key), json.dumps(value)) - connection.close() + def set_state( + self, + chat_id: int, + user_id: int, + state: str, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + if hasattr(state, "name"): + state = state.name + + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + + def set_state_action(pipe): + pipe.multi() + + data = pipe.hget(_key, "data") + result = pipe.execute() + data = result[0] + if data is None: + # If data is None, set it to an empty dictionary + data = {} + pipe.hset(_key, "data", json.dumps(data)) + + pipe.hset(_key, "state", state) + + self.redis.transaction(set_state_action, _key) return True - def delete_record(self, key): - """ - Function to delete record from database. - It has nothing to do with states. - Made for backward compatibility - """ - connection = Redis(connection_pool=self.redis) - connection.delete(self.prefix+str(key)) - connection.close() + def get_state( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> Union[str, None]: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + state_bytes = self.redis.hget(_key, "state") + return state_bytes.decode("utf-8") if state_bytes else None + + def delete_state( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + return self.redis.delete(_key) > 0 + + def set_data( + self, + chat_id: int, + user_id: int, + key: str, + value: Union[str, int, float, dict], + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + + def set_data_action(pipe): + pipe.multi() + data = pipe.hget(_key, "data") + data = data.execute()[0] + if data is None: + raise RuntimeError(f"RedisStorage: key {_key} does not exist.") + else: + data = json.loads(data) + data[key] = value + pipe.hset(_key, "data", json.dumps(data)) + + self.redis.transaction(set_data_action, _key) return True - def set_state(self, chat_id, user_id, state): - """ - Set state for a particular user in a chat. - """ - response = self.get_record(chat_id) - user_id = str(user_id) - if hasattr(state, 'name'): - state = state.name + def get_data( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> dict: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) + data = self.redis.hget(_key, "data") + return json.loads(data) if data else {} + + def reset_data( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) - if response: - if user_id in response: - response[user_id]['state'] = state + def reset_data_action(pipe): + pipe.multi() + if pipe.exists(_key): + pipe.hset(_key, "data", "{}") else: - response[user_id] = {'state': state, 'data': {}} - else: - response = {user_id: {'state': state, 'data': {}}} - self.set_record(chat_id, response) + return False + self.redis.transaction(reset_data_action, _key) return True - - def delete_state(self, chat_id, user_id): - """ - Delete state for a particular user in a chat. - """ - response = self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - del response[user_id] - if user_id == str(chat_id): - self.delete_record(chat_id) - return True - else: self.set_record(chat_id, response) - return True - return False - - - def get_value(self, chat_id, user_id, key): - """ - Get value for a data of a user in a chat. - """ - response = self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - if key in response[user_id]['data']: - return response[user_id]['data'][key] - return None - - - def get_state(self, chat_id, user_id): - """ - Get state of a user in a chat. - """ - response = self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - return response[user_id]['state'] - return None + def get_interactive_data( + self, + chat_id: int, + user_id: int, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> Optional[dict]: + return StateDataContext( + self, + chat_id=chat_id, + user_id=user_id, + business_connection_id=business_connection_id, + message_thread_id=message_thread_id, + bot_id=bot_id, + ) + def save( + self, + chat_id: int, + user_id: int, + data: dict, + business_connection_id: Optional[str] = None, + message_thread_id: Optional[int] = None, + bot_id: Optional[int] = None, + ) -> bool: + _key = self._get_key( + chat_id, + user_id, + self.prefix, + self.separator, + business_connection_id, + message_thread_id, + bot_id, + ) - def get_data(self, chat_id, user_id): - """ - Get data of particular user in a particular chat. - """ - response = self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - return response[user_id]['data'] - return None + def save_action(pipe): + pipe.multi() + if pipe.exists(_key): + pipe.hset(_key, "data", json.dumps(data)) + else: + return False + self.redis.transaction(save_action, _key) + return True - def reset_data(self, chat_id, user_id): - """ - Reset data of a user in a chat. + def migrate_format(self, bot_id: int, prefix: Optional[str] = "telebot_"): """ - response = self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - response[user_id]['data'] = {} - self.set_record(chat_id, response) - return True + Migrate from old to new format of keys. + Run this function once to migrate all redis existing keys to new format. - + Starting from version 4.23.0, the format of keys has been changed: + :value + - Old format: {prefix}chat_id: {user_id: {'state': None, 'data': {}}, ...} + - New format: + {prefix}{separator}{bot_id}{separator}{business_connection_id}{separator}{message_thread_id}{separator}{chat_id}{separator}{user_id}: {'state': ..., 'data': {}} + This function will help you to migrate from the old format to the new one in order to avoid data loss. - def set_data(self, chat_id, user_id, key, value): - """ - Set data without interactive data. - """ - response = self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - response[user_id]['data'][key] = value - self.set_record(chat_id, response) - return True - return False - - def get_interactive_data(self, chat_id, user_id): - """ - Get Data in interactive way. - You can use with() with this function. + :param bot_id: Bot ID; To get it, call a getMe request and grab the id from the response. + :type bot_id: int + + :param prefix: Prefix for keys, default is "telebot_"(old default value) + :type prefix: Optional[str] """ - return StateContext(self, chat_id, user_id) - - def save(self, chat_id, user_id, data): - response = self.get_record(chat_id) - user_id = str(user_id) - if response: - if user_id in response: - response[user_id]['data'] = data - self.set_record(chat_id, response) - return True - + keys = self.redis.keys(f"{prefix}*") + + for key in keys: + old_key = key.decode("utf-8") + # old: {prefix}chat_id: {user_id: {'state': None, 'data': {}}, ...} + value = self.redis.get(old_key) + value = json.loads(value) + + chat_id = old_key[len(prefix) :] + user_id = list(value.keys())[0] + state = value[user_id]["state"] + state_data = value[user_id]["data"] + + # set new format + new_key = self._get_key( + int(chat_id), int(user_id), self.prefix, self.separator, bot_id=bot_id + ) + self.redis.hset(new_key, "state", state) + self.redis.hset(new_key, "data", json.dumps(state_data)) + + # delete old key + self.redis.delete(old_key) + + def __str__(self) -> str: + return f"StateRedisStorage({self.redis})" diff --git a/telebot/util.py b/telebot/util.py index 295c0d1aa..0448893e1 100644 --- a/telebot/util.py +++ b/telebot/util.py @@ -686,6 +686,26 @@ def validate_web_app_data(token: str, raw_init_data: str): return hmac.new(secret_key.digest(), data_check_string.encode(), sha256).hexdigest() == init_data_hash +def validate_token(token) -> bool: + if any(char.isspace() for char in token): + raise ValueError('Token must not contain spaces') + + if ':' not in token: + raise ValueError('Token must contain a colon') + + if len(token.split(':')) != 2: + raise ValueError('Token must contain exactly 2 parts separated by a colon') + + return True + +def extract_bot_id(token) -> Union[int, None]: + try: + validate_token(token) + except ValueError: + return None + return int(token.split(':')[0]) + + __all__ = ( "content_type_media", "content_type_service", "update_types", "WorkerThread", "AsyncTask", "CustomRequestResponse", @@ -696,5 +716,5 @@ def validate_web_app_data(token: str, raw_init_data: str): "split_string", "smart_split", "escape", "user_link", "quick_markup", "antiflood", "parse_web_app_data", "validate_web_app_data", "or_set", "or_clear", "orify", "OrEvent", "per_thread", - "webhook_google_functions" + "webhook_google_functions", "validate_token", "extract_bot_id" ) diff --git a/tests/test_handler_backends.py b/tests/test_handler_backends.py index bb541bf8a..b74d6fad0 100644 --- a/tests/test_handler_backends.py +++ b/tests/test_handler_backends.py @@ -19,7 +19,7 @@ @pytest.fixture() def telegram_bot(): - return telebot.TeleBot('', threaded=False) + return telebot.TeleBot('1234:test', threaded=False) @pytest.fixture