diff --git a/examples/asynchronous_telebot/custom_states.py b/examples/asynchronous_telebot/custom_states.py
index aefad9809..f0111e27e 100644
--- a/examples/asynchronous_telebot/custom_states.py
+++ b/examples/asynchronous_telebot/custom_states.py
@@ -1,91 +1,154 @@
-from telebot import asyncio_filters
-from telebot.async_telebot import AsyncTeleBot
-
-# list of storages, you can use any storage
+from telebot import async_telebot, asyncio_filters, types
from telebot.asyncio_storage import StateMemoryStorage
+from telebot.states import State, StatesGroup
+from telebot.states.asyncio.context import StateContext
-# new feature for states.
-from telebot.asyncio_handler_backends import State, StatesGroup
-
-# default state storage is statememorystorage
-bot = AsyncTeleBot('TOKEN', state_storage=StateMemoryStorage())
+# Initialize the bot
+state_storage = StateMemoryStorage() # don't use this in production; switch to redis
+bot = async_telebot.AsyncTeleBot("TOKEN", state_storage=state_storage)
-# Just create different statesgroup
+# Define states
class MyStates(StatesGroup):
- name = State() # statesgroup should contain states
- surname = State()
+ name = State()
age = State()
+ color = State()
+ hobby = State()
+# Start command handler
+@bot.message_handler(commands=["start"])
+async def start_ex(message: types.Message, state: StateContext):
+ await state.set(MyStates.name)
+ await bot.send_message(
+ message.chat.id,
+ "Hello! What is your first name?",
+ reply_to_message_id=message.message_id,
+ )
-# set_state -> sets a new state
-# delete_state -> delets state if exists
-# get_state -> returns state if exists
+# Cancel command handler
+@bot.message_handler(state="*", commands=["cancel"])
+async def any_state(message: types.Message, state: StateContext):
+ await state.delete()
+ await bot.send_message(
+ message.chat.id,
+ "Your information has been cleared. Type /start to begin again.",
+ reply_to_message_id=message.message_id,
+ )
-@bot.message_handler(commands=['start'])
-async def start_ex(message):
- """
- Start command. Here we are starting state
- """
- await bot.set_state(message.from_user.id, MyStates.name, message.chat.id)
- await bot.send_message(message.chat.id, 'Hi, write me a name')
-
-
-
-@bot.message_handler(state="*", commands='cancel')
-async def any_state(message):
- """
- Cancel state
- """
- await bot.send_message(message.chat.id, "Your state was cancelled.")
- await bot.delete_state(message.from_user.id, message.chat.id)
+# Handler for name input
@bot.message_handler(state=MyStates.name)
-async def name_get(message):
- """
- State 1. Will process when user's state is MyStates.name.
- """
- await bot.send_message(message.chat.id, f'Now write me a surname')
- await bot.set_state(message.from_user.id, MyStates.surname, message.chat.id)
- async with bot.retrieve_data(message.from_user.id, message.chat.id) as data:
- data['name'] = message.text
-
-
-@bot.message_handler(state=MyStates.surname)
-async def ask_age(message):
- """
- State 2. Will process when user's state is MyStates.surname.
- """
- await bot.send_message(message.chat.id, "What is your age?")
- await bot.set_state(message.from_user.id, MyStates.age, message.chat.id)
- async with bot.retrieve_data(message.from_user.id, message.chat.id) as data:
- data['surname'] = message.text
-
-# result
+async def name_get(message: types.Message, state: StateContext):
+ await state.set(MyStates.age)
+ await bot.send_message(
+ message.chat.id, "How old are you?", reply_to_message_id=message.message_id
+ )
+ await state.add_data(name=message.text)
+
+
+# Handler for age input
@bot.message_handler(state=MyStates.age, is_digit=True)
-async def ready_for_answer(message):
- """
- State 3. Will process when user's state is MyStates.age.
- """
- async with bot.retrieve_data(message.from_user.id, message.chat.id) as data:
- await bot.send_message(message.chat.id, "Ready, take a look:\nName: {name}\nSurname: {surname}\nAge: {age}".format(name=data['name'], surname=data['surname'], age=message.text), parse_mode="html")
- await bot.delete_state(message.from_user.id, message.chat.id)
-
-#incorrect number
+async def ask_color(message: types.Message, state: StateContext):
+ await state.set(MyStates.color)
+ await state.add_data(age=message.text)
+
+ # Define reply keyboard for color selection
+ keyboard = types.ReplyKeyboardMarkup(row_width=2)
+ colors = ["Red", "Green", "Blue", "Yellow", "Purple", "Orange", "Other"]
+ buttons = [types.KeyboardButton(color) for color in colors]
+ keyboard.add(*buttons)
+
+ await bot.send_message(
+ message.chat.id,
+ "What is your favorite color? Choose from the options below.",
+ reply_markup=keyboard,
+ reply_to_message_id=message.message_id,
+ )
+
+
+# Handler for color input
+@bot.message_handler(state=MyStates.color)
+async def ask_hobby(message: types.Message, state: StateContext):
+ await state.set(MyStates.hobby)
+ await state.add_data(color=message.text)
+
+ # Define reply keyboard for hobby selection
+ keyboard = types.ReplyKeyboardMarkup(row_width=2)
+ hobbies = ["Reading", "Traveling", "Gaming", "Cooking"]
+ buttons = [types.KeyboardButton(hobby) for hobby in hobbies]
+ keyboard.add(*buttons)
+
+ await bot.send_message(
+ message.chat.id,
+ "What is one of your hobbies? Choose from the options below.",
+ reply_markup=keyboard,
+ reply_to_message_id=message.message_id,
+ )
+
+
+# Handler for hobby input; use filters to ease validation
+@bot.message_handler(
+ state=MyStates.hobby, text=["Reading", "Traveling", "Gaming", "Cooking"]
+)
+async def finish(message: types.Message, state: StateContext):
+ async with state.data() as data:
+ name = data.get("name")
+ age = data.get("age")
+ color = data.get("color")
+ hobby = message.text # Get the hobby from the message text
+
+ # Provide a fun fact based on color
+ color_facts = {
+ "Red": "Red is often associated with excitement and passion.",
+ "Green": "Green is the color of nature and tranquility.",
+ "Blue": "Blue is known for its calming and serene effects.",
+ "Yellow": "Yellow is a cheerful color often associated with happiness.",
+ "Purple": "Purple signifies royalty and luxury.",
+ "Orange": "Orange is a vibrant color that stimulates enthusiasm.",
+ "Other": "Colors have various meanings depending on context.",
+ }
+ color_fact = color_facts.get(
+ color, "Colors have diverse meanings, and yours is unique!"
+ )
+
+ msg = (
+ f"Thank you for sharing! Here is a summary of your information:\n"
+ f"First Name: {name}\n"
+ f"Age: {age}\n"
+ f"Favorite Color: {color}\n"
+ f"Fun Fact about your color: {color_fact}\n"
+ f"Favorite Hobby: {hobby}"
+ )
+
+ await bot.send_message(
+ message.chat.id, msg, parse_mode="html", reply_to_message_id=message.message_id
+ )
+ await state.delete()
+
+
+# Handler for incorrect age input
@bot.message_handler(state=MyStates.age, is_digit=False)
-async def age_incorrect(message):
- """
- Will process for wrong input when state is MyState.age
- """
- await bot.send_message(message.chat.id, 'Looks like you are submitting a string in the field age. Please enter a number')
+async def age_incorrect(message: types.Message):
+ await bot.send_message(
+ message.chat.id,
+ "Please enter a valid number for age.",
+ reply_to_message_id=message.message_id,
+ )
-# register filters
+# Add custom filters
bot.add_custom_filter(asyncio_filters.StateFilter(bot))
bot.add_custom_filter(asyncio_filters.IsDigitFilter())
+bot.add_custom_filter(asyncio_filters.TextMatchFilter())
+
+# necessary for state parameter in handlers.
+from telebot.states.asyncio.middleware import StateMiddleware
+bot.setup_middleware(StateMiddleware(bot))
+# Start polling
import asyncio
-asyncio.run(bot.polling())
\ No newline at end of file
+
+asyncio.run(bot.polling())
diff --git a/examples/custom_states.py b/examples/custom_states.py
index a02761286..131dd1d42 100644
--- a/examples/custom_states.py
+++ b/examples/custom_states.py
@@ -1,106 +1,153 @@
-import telebot # telebot
-
-from telebot import custom_filters
-from telebot.handler_backends import State, StatesGroup #States
-
-# States storage
+import telebot
+from telebot import custom_filters, types
+from telebot.states import State, StatesGroup
+from telebot.states.sync.context import StateContext
from telebot.storage import StateMemoryStorage
+# Initialize the bot
+state_storage = StateMemoryStorage() # don't use this in production; switch to redis
+bot = telebot.TeleBot("TOKEN", state_storage=state_storage, use_class_middlewares=True)
-# Starting from version 4.4.0+, we support storages.
-# StateRedisStorage -> Redis-based storage.
-# StatePickleStorage -> Pickle-based storage.
-# For redis, you will need to install redis.
-# Pass host, db, password, or anything else,
-# if you need to change config for redis.
-# Pickle requires path. Default path is in folder .state-saves.
-# If you were using older version of pytba for pickle,
-# you need to migrate from old pickle to new by using
-# StatePickleStorage().convert_old_to_new()
-
-
-
-# Now, you can pass storage to bot.
-state_storage = StateMemoryStorage() # you can init here another storage
-
-bot = telebot.TeleBot("TOKEN",
-state_storage=state_storage)
-
-# States group.
+# Define states
class MyStates(StatesGroup):
- # Just name variables differently
- name = State() # creating instances of State class is enough from now
- surname = State()
+ name = State()
age = State()
+ color = State()
+ hobby = State()
+# Start command handler
+@bot.message_handler(commands=["start"])
+def start_ex(message: types.Message, state: StateContext):
+ state.set(MyStates.name)
+ bot.send_message(
+ message.chat.id,
+ "Hello! What is your first name?",
+ reply_to_message_id=message.message_id,
+ )
-@bot.message_handler(commands=['start'])
-def start_ex(message):
- """
- Start command. Here we are starting state
- """
- bot.set_state(message.from_user.id, MyStates.name, message.chat.id)
- bot.send_message(message.chat.id, 'Hi, write me a name')
-
+# Cancel command handler
+@bot.message_handler(state="*", commands=["cancel"])
+def any_state(message: types.Message, state: StateContext):
+ state.delete()
+ bot.send_message(
+ message.chat.id,
+ "Your information has been cleared. Type /start to begin again.",
+ reply_to_message_id=message.message_id,
+ )
-# Any state
-@bot.message_handler(state="*", commands=['cancel'])
-def any_state(message):
- """
- Cancel state
- """
- bot.send_message(message.chat.id, "Your state was cancelled.")
- bot.delete_state(message.from_user.id, message.chat.id)
+# Handler for name input
@bot.message_handler(state=MyStates.name)
-def name_get(message):
- """
- State 1. Will process when user's state is MyStates.name.
- """
- bot.send_message(message.chat.id, 'Now write me a surname')
- bot.set_state(message.from_user.id, MyStates.surname, message.chat.id)
- with bot.retrieve_data(message.from_user.id, message.chat.id) as data:
- data['name'] = message.text
-
-
-@bot.message_handler(state=MyStates.surname)
-def ask_age(message):
- """
- State 2. Will process when user's state is MyStates.surname.
- """
- bot.send_message(message.chat.id, "What is your age?")
- bot.set_state(message.from_user.id, MyStates.age, message.chat.id)
- with bot.retrieve_data(message.from_user.id, message.chat.id) as data:
- data['surname'] = message.text
-
-# result
+def name_get(message: types.Message, state: StateContext):
+ state.set(MyStates.age)
+ bot.send_message(
+ message.chat.id, "How old are you?", reply_to_message_id=message.message_id
+ )
+ state.add_data(name=message.text)
+
+
+# Handler for age input
@bot.message_handler(state=MyStates.age, is_digit=True)
-def ready_for_answer(message):
- """
- State 3. Will process when user's state is MyStates.age.
- """
- with bot.retrieve_data(message.from_user.id, message.chat.id) as data:
- msg = ("Ready, take a look:\n"
- f"Name: {data['name']}\n"
- f"Surname: {data['surname']}\n"
- f"Age: {message.text}")
- bot.send_message(message.chat.id, msg, parse_mode="html")
- bot.delete_state(message.from_user.id, message.chat.id)
-
-#incorrect number
+def ask_color(message: types.Message, state: StateContext):
+ state.set(MyStates.color)
+ state.add_data(age=message.text)
+
+ # Define reply keyboard for color selection
+ keyboard = types.ReplyKeyboardMarkup(row_width=2)
+ colors = ["Red", "Green", "Blue", "Yellow", "Purple", "Orange", "Other"]
+ buttons = [types.KeyboardButton(color) for color in colors]
+ keyboard.add(*buttons)
+
+ bot.send_message(
+ message.chat.id,
+ "What is your favorite color? Choose from the options below.",
+ reply_markup=keyboard,
+ reply_to_message_id=message.message_id,
+ )
+
+
+# Handler for color input
+@bot.message_handler(state=MyStates.color)
+def ask_hobby(message: types.Message, state: StateContext):
+ state.set(MyStates.hobby)
+ state.add_data(color=message.text)
+
+ # Define reply keyboard for hobby selection
+ keyboard = types.ReplyKeyboardMarkup(row_width=2)
+ hobbies = ["Reading", "Traveling", "Gaming", "Cooking"]
+ buttons = [types.KeyboardButton(hobby) for hobby in hobbies]
+ keyboard.add(*buttons)
+
+ bot.send_message(
+ message.chat.id,
+ "What is one of your hobbies? Choose from the options below.",
+ reply_markup=keyboard,
+ reply_to_message_id=message.message_id,
+ )
+
+
+# Handler for hobby input
+@bot.message_handler(
+ state=MyStates.hobby, text=["Reading", "Traveling", "Gaming", "Cooking"]
+)
+def finish(message: types.Message, state: StateContext):
+ with state.data() as data:
+ name = data.get("name")
+ age = data.get("age")
+ color = data.get("color")
+ hobby = message.text # Get the hobby from the message text
+
+ # Provide a fun fact based on color
+ color_facts = {
+ "Red": "Red is often associated with excitement and passion.",
+ "Green": "Green is the color of nature and tranquility.",
+ "Blue": "Blue is known for its calming and serene effects.",
+ "Yellow": "Yellow is a cheerful color often associated with happiness.",
+ "Purple": "Purple signifies royalty and luxury.",
+ "Orange": "Orange is a vibrant color that stimulates enthusiasm.",
+ "Other": "Colors have various meanings depending on context.",
+ }
+ color_fact = color_facts.get(
+ color, "Colors have diverse meanings, and yours is unique!"
+ )
+
+ msg = (
+ f"Thank you for sharing! Here is a summary of your information:\n"
+ f"First Name: {name}\n"
+ f"Age: {age}\n"
+ f"Favorite Color: {color}\n"
+ f"Fun Fact about your color: {color_fact}\n"
+ f"Favorite Hobby: {hobby}"
+ )
+
+ bot.send_message(
+ message.chat.id, msg, parse_mode="html", reply_to_message_id=message.message_id
+ )
+ state.delete()
+
+
+# Handler for incorrect age input
@bot.message_handler(state=MyStates.age, is_digit=False)
-def age_incorrect(message):
- """
- Wrong response for MyStates.age
- """
- bot.send_message(message.chat.id, 'Looks like you are submitting a string in the field age. Please enter a number')
+def age_incorrect(message: types.Message):
+ bot.send_message(
+ message.chat.id,
+ "Please enter a valid number for age.",
+ reply_to_message_id=message.message_id,
+ )
-# register filters
+# Add custom filters
bot.add_custom_filter(custom_filters.StateFilter(bot))
bot.add_custom_filter(custom_filters.IsDigitFilter())
+bot.add_custom_filter(custom_filters.TextMatchFilter())
+
+# necessary for state parameter in handlers.
+from telebot.states.sync.middleware import StateMiddleware
+
+bot.setup_middleware(StateMiddleware(bot))
-bot.infinity_polling(skip_pending=True)
+# Start polling
+bot.infinity_polling()
diff --git a/pyproject.toml b/pyproject.toml
index a45a8dd3b..9ef5d86f9 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -35,7 +35,6 @@ Issues = "https://github.com/eternnoir/pyTelegramBotAPI/issues"
json = ["ujson"]
PIL = ["Pillow"]
redis = ["redis>=3.4.1"]
-aioredis = ["aioredis"]
aiohttp = ["aiohttp"]
fastapi = ["fastapi"]
uvicorn = ["uvicorn"]
diff --git a/telebot/__init__.py b/telebot/__init__.py
index 97c3ea23c..2c26b3b75 100644
--- a/telebot/__init__.py
+++ b/telebot/__init__.py
@@ -7,7 +7,7 @@
import threading
import time
import traceback
-from typing import Any, Callable, List, Optional, Union
+from typing import Any, Callable, List, Optional, Union, Dict
# these imports are used to avoid circular import error
import telebot.util
@@ -153,9 +153,14 @@ class TeleBot:
:param allow_sending_without_reply: Default value for allow_sending_without_reply, defaults to None
:type allow_sending_without_reply: :obj:`bool`, optional
-
:param colorful_logs: Outputs colorful logs
:type colorful_logs: :obj:`bool`, optional
+
+ :param validate_token: Validate token, defaults to True;
+ :type validate_token: :obj:`bool`, optional
+
+ :raises ImportError: If coloredlogs module is not installed and colorful_logs is True
+ :raises ValueError: If token is invalid
"""
def __init__(
@@ -169,7 +174,8 @@ def __init__(
disable_notification: Optional[bool]=None,
protect_content: Optional[bool]=None,
allow_sending_without_reply: Optional[bool]=None,
- colorful_logs: Optional[bool]=False
+ colorful_logs: Optional[bool]=False,
+ validate_token: Optional[bool]=True
):
# update-related
@@ -187,6 +193,11 @@ def __init__(
self.webhook_listener = None
self._user = None
+ if validate_token:
+ util.validate_token(self.token)
+
+ self.bot_id: Union[int, None] = util.extract_bot_id(self.token) # subject to change in future, unspecified
+
# logs-related
if colorful_logs:
try:
@@ -281,7 +292,7 @@ def __init__(
self.threaded = threaded
if self.threaded:
self.worker_pool = util.ThreadPool(self, num_threads=num_threads)
-
+
@property
def user(self) -> types.User:
"""
@@ -6642,7 +6653,9 @@ def setup_middleware(self, middleware: BaseMiddleware):
self.middlewares.append(middleware)
- def set_state(self, user_id: int, state: Union[int, str, State], chat_id: Optional[int]=None) -> None:
+ def set_state(self, user_id: int, state: Union[str, State], chat_id: Optional[int]=None,
+ business_connection_id: Optional[str]=None, message_thread_id: Optional[int]=None,
+ bot_id: Optional[int]=None) -> bool:
"""
Sets a new state of a user.
@@ -6652,25 +6665,49 @@ def set_state(self, user_id: int, state: Union[int, str, State], chat_id: Option
Otherwise, if you only set user_id, chat_id will equal to user_id, this means that
state will be set for the user in his private chat with a bot.
+ .. versionchanged:: 4.23.0
+
+ Added additional parameters to support topics, business connections, and message threads.
+
+ .. seealso::
+
+ For more details, visit the `custom_states.py example `_.
+
:param user_id: User's identifier
:type user_id: :obj:`int`
- :param state: new state. can be string, integer, or :class:`telebot.types.State`
+ :param state: new state. can be string, or :class:`telebot.types.State`
:type state: :obj:`int` or :obj:`str` or :class:`telebot.types.State`
:param chat_id: Chat's identifier
:type chat_id: :obj:`int`
- :return: None
+ :param bot_id: Bot's identifier, defaults to current bot id
+ :type bot_id: :obj:`int`
+
+ :param business_connection_id: Business identifier
+ :type business_connection_id: :obj:`str`
+
+ :param message_thread_id: Identifier of the message thread
+ :type message_thread_id: :obj:`int`
+
+ :return: True on success
+ :rtype: :obj:`bool`
"""
if chat_id is None:
chat_id = user_id
- self.current_states.set_state(chat_id, user_id, state)
+ if bot_id is None:
+ bot_id = self.bot_id
+ return self.current_states.set_state(
+ chat_id=chat_id, user_id=user_id, state=state, bot_id=bot_id,
+ business_connection_id=business_connection_id, message_thread_id=message_thread_id)
- def reset_data(self, user_id: int, chat_id: Optional[int]=None):
+ def reset_data(self, user_id: int, chat_id: Optional[int]=None,
+ business_connection_id: Optional[str]=None,
+ message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> bool:
"""
- Reset data for a user in chat.
+ Reset data for a user in chat: sets the 'data' field to an empty dictionary.
:param user_id: User's identifier
:type user_id: :obj:`int`
@@ -6678,16 +6715,34 @@ def reset_data(self, user_id: int, chat_id: Optional[int]=None):
:param chat_id: Chat's identifier
:type chat_id: :obj:`int`
- :return: None
+ :param bot_id: Bot's identifier, defaults to current bot id
+ :type bot_id: :obj:`int`
+
+ :param business_connection_id: Business identifier
+ :type business_connection_id: :obj:`str`
+
+ :param message_thread_id: Identifier of the message thread
+ :type message_thread_id: :obj:`int`
+
+ :return: True on success
+ :rtype: :obj:`bool`
"""
if chat_id is None:
chat_id = user_id
- self.current_states.reset_data(chat_id, user_id)
+ if bot_id is None:
+ bot_id = self.bot_id
+ return self.current_states.reset_data(chat_id=chat_id, user_id=user_id, bot_id=bot_id,
+ business_connection_id=business_connection_id, message_thread_id=message_thread_id)
- def delete_state(self, user_id: int, chat_id: Optional[int]=None) -> None:
+ def delete_state(self, user_id: int, chat_id: Optional[int]=None, business_connection_id: Optional[str]=None,
+ message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> bool:
"""
- Delete the current state of a user.
+ Fully deletes the storage record of a user in chat.
+
+ .. warning::
+
+ This does NOT set state to None, but deletes the object from storage.
:param user_id: User's identifier
:type user_id: :obj:`int`
@@ -6695,14 +6750,19 @@ def delete_state(self, user_id: int, chat_id: Optional[int]=None) -> None:
:param chat_id: Chat's identifier
:type chat_id: :obj:`int`
- :return: None
+ :return: True on success
+ :rtype: :obj:`bool`
"""
if chat_id is None:
chat_id = user_id
- self.current_states.delete_state(chat_id, user_id)
+ if bot_id is None:
+ bot_id = self.bot_id
+ return self.current_states.delete_state(chat_id=chat_id, user_id=user_id, bot_id=bot_id,
+ business_connection_id=business_connection_id, message_thread_id=message_thread_id)
- def retrieve_data(self, user_id: int, chat_id: Optional[int]=None) -> Optional[Any]:
+ def retrieve_data(self, user_id: int, chat_id: Optional[int]=None, business_connection_id: Optional[str]=None,
+ message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> Optional[Dict[str, Any]]:
"""
Returns context manager with data for a user in chat.
@@ -6712,34 +6772,70 @@ def retrieve_data(self, user_id: int, chat_id: Optional[int]=None) -> Optional[A
:param chat_id: Chat's unique identifier, defaults to user_id
:type chat_id: int, optional
+ :param bot_id: Bot's identifier, defaults to current bot id
+ :type bot_id: int, optional
+
+ :param business_connection_id: Business identifier
+ :type business_connection_id: str, optional
+
+ :param message_thread_id: Identifier of the message thread
+ :type message_thread_id: int, optional
+
:return: Context manager with data for a user in chat
:rtype: Optional[Any]
"""
if chat_id is None:
chat_id = user_id
- return self.current_states.get_interactive_data(chat_id, user_id)
+ if bot_id is None:
+ bot_id = self.bot_id
+ return self.current_states.get_interactive_data(chat_id=chat_id, user_id=user_id, bot_id=bot_id,
+ business_connection_id=business_connection_id,
+ message_thread_id=message_thread_id)
- def get_state(self, user_id: int, chat_id: Optional[int]=None) -> Optional[Union[int, str, State]]:
+ def get_state(self, user_id: int, chat_id: Optional[int]=None,
+ business_connection_id: Optional[str]=None,
+ message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> str:
"""
Gets current state of a user.
Not recommended to use this method. But it is ok for debugging.
+ .. warning::
+
+ Even if you are using :class:`telebot.types.State`, this method will return a string.
+ When comparing(not recommended), you should compare this string with :class:`telebot.types.State`.name
+
:param user_id: User's identifier
:type user_id: :obj:`int`
:param chat_id: Chat's identifier
:type chat_id: :obj:`int`
+ :param bot_id: Bot's identifier, defaults to current bot id
+ :type bot_id: :obj:`int`
+
+ :param business_connection_id: Business identifier
+ :type business_connection_id: :obj:`str`
+
+ :param message_thread_id: Identifier of the message thread
+ :type message_thread_id: :obj:`int`
+
:return: state of a user
:rtype: :obj:`int` or :obj:`str` or :class:`telebot.types.State`
"""
if chat_id is None:
chat_id = user_id
- return self.current_states.get_state(chat_id, user_id)
+ if bot_id is None:
+ bot_id = self.bot_id
+ return self.current_states.get_state(chat_id=chat_id, user_id=user_id, bot_id=bot_id,
+ business_connection_id=business_connection_id, message_thread_id=message_thread_id)
- def add_data(self, user_id: int, chat_id: Optional[int]=None, **kwargs):
+ def add_data(self, user_id: int, chat_id: Optional[int]=None,
+ business_connection_id: Optional[str]=None,
+ message_thread_id: Optional[int]=None,
+ bot_id: Optional[int]=None,
+ **kwargs) -> None:
"""
Add data to states.
@@ -6749,13 +6845,25 @@ def add_data(self, user_id: int, chat_id: Optional[int]=None, **kwargs):
:param chat_id: Chat's identifier
:type chat_id: :obj:`int`
+ :param bot_id: Bot's identifier, defaults to current bot id
+ :type bot_id: :obj:`int`
+
+ :param business_connection_id: Business identifier
+ :type business_connection_id: :obj:`str`
+
+ :param message_thread_id: Identifier of the message thread
+ :type message_thread_id: :obj:`int`
+
:param kwargs: Data to add
:return: None
"""
if chat_id is None:
chat_id = user_id
+ if bot_id is None:
+ bot_id = self.bot_id
for key, value in kwargs.items():
- self.current_states.set_data(chat_id, user_id, key, value)
+ self.current_states.set_data(chat_id=chat_id, user_id=user_id, key=key, value=value, bot_id=bot_id,
+ business_connection_id=business_connection_id, message_thread_id=message_thread_id)
def register_next_step_handler_by_chat_id(
diff --git a/telebot/async_telebot.py b/telebot/async_telebot.py
index 66ba987fb..4e92f50d3 100644
--- a/telebot/async_telebot.py
+++ b/telebot/async_telebot.py
@@ -4,7 +4,7 @@
import logging
import re
import traceback
-from typing import Any, Awaitable, Callable, List, Optional, Union
+from typing import Any, Awaitable, Callable, List, Optional, Union, Dict
import sys
# this imports are used to avoid circular import error
@@ -117,6 +117,11 @@ class AsyncTeleBot:
:param colorful_logs: Outputs colorful logs
:type colorful_logs: :obj:`bool`, optional
+ :param validate_token: Validate token, defaults to True;
+ :type validate_token: :obj:`bool`, optional
+
+ :raises ImportError: If coloredlogs module is not installed and colorful_logs is True
+ :raises ValueError: If token is invalid
"""
def __init__(self, token: str, parse_mode: Optional[str]=None, offset: Optional[int]=None,
@@ -126,7 +131,8 @@ def __init__(self, token: str, parse_mode: Optional[str]=None, offset: Optional[
disable_notification: Optional[bool]=None,
protect_content: Optional[bool]=None,
allow_sending_without_reply: Optional[bool]=None,
- colorful_logs: Optional[bool]=False) -> None:
+ colorful_logs: Optional[bool]=False,
+ validate_token: Optional[bool]=True) -> None:
# update-related
self.token = token
@@ -184,6 +190,12 @@ def __init__(self, token: str, parse_mode: Optional[str]=None, offset: Optional[
self._user = None # set during polling
+ if validate_token:
+ util.validate_token(self.token)
+
+ self.bot_id: Union[int, None] = util.extract_bot_id(self.token) # subject to change, unspecified
+
+
@property
def user(self):
return self._user
@@ -7837,7 +7849,10 @@ async def get_forum_topic_icon_stickers(self) -> List[types.Sticker]:
"""
return await asyncio_helper.get_forum_topic_icon_stickers(self.token)
- async def set_state(self, user_id: int, state: Union[State, int, str], chat_id: Optional[int]=None):
+
+ async def set_state(self, user_id: int, state: Union[int, str, State], chat_id: Optional[int]=None,
+ business_connection_id: Optional[str]=None, message_thread_id: Optional[int]=None,
+ bot_id: Optional[int]=None) -> bool:
"""
Sets a new state of a user.
@@ -7847,22 +7862,47 @@ async def set_state(self, user_id: int, state: Union[State, int, str], chat_id:
Otherwise, if you only set user_id, chat_id will equal to user_id, this means that
state will be set for the user in his private chat with a bot.
+ .. versionchanged:: 4.23.0
+
+ Added additional parameters to support topics, business connections, and message threads.
+
+ .. seealso::
+
+ For more details, visit the `custom_states.py example `_.
+
:param user_id: User's identifier
:type user_id: :obj:`int`
- :param state: new state. can be string, integer, or :class:`telebot.types.State`
+ :param state: new state. can be string, or :class:`telebot.types.State`
:type state: :obj:`int` or :obj:`str` or :class:`telebot.types.State`
:param chat_id: Chat's identifier
:type chat_id: :obj:`int`
- :return: None
+ :param bot_id: Bot's identifier, defaults to current bot id
+ :type bot_id: :obj:`int`
+
+ :param business_connection_id: Business identifier
+ :type business_connection_id: :obj:`str`
+
+ :param message_thread_id: Identifier of the message thread
+ :type message_thread_id: :obj:`int`
+
+ :return: True on success
+ :rtype: :obj:`bool`
"""
- if not chat_id:
+ if chat_id is None:
chat_id = user_id
- await self.current_states.set_state(chat_id, user_id, state)
+ if bot_id is None:
+ bot_id = self.bot_id
+ return await self.current_states.set_state(
+ chat_id=chat_id, user_id=user_id, state=state, bot_id=bot_id,
+ business_connection_id=business_connection_id, message_thread_id=message_thread_id)
+
- async def reset_data(self, user_id: int, chat_id: Optional[int]=None):
+ async def reset_data(self, user_id: int, chat_id: Optional[int]=None,
+ business_connection_id: Optional[str]=None,
+ message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> bool:
"""
Reset data for a user in chat.
@@ -7872,13 +7912,28 @@ async def reset_data(self, user_id: int, chat_id: Optional[int]=None):
:param chat_id: Chat's identifier
:type chat_id: :obj:`int`
- :return: None
+ :param bot_id: Bot's identifier, defaults to current bot id
+ :type bot_id: :obj:`int`
+
+ :param business_connection_id: Business identifier
+ :type business_connection_id: :obj:`str`
+
+ :param message_thread_id: Identifier of the message thread
+ :type message_thread_id: :obj:`int`
+
+ :return: True on success
+ :rtype: :obj:`bool`
"""
if chat_id is None:
chat_id = user_id
- await self.current_states.reset_data(chat_id, user_id)
+ if bot_id is None:
+ bot_id = self.bot_id
+ return await self.current_states.reset_data(chat_id=chat_id, user_id=user_id, bot_id=bot_id,
+ business_connection_id=business_connection_id, message_thread_id=message_thread_id)
+
- async def delete_state(self, user_id: int, chat_id: Optional[int]=None):
+ async def delete_state(self, user_id: int, chat_id: Optional[int]=None, business_connection_id: Optional[str]=None,
+ message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> bool:
"""
Delete the current state of a user.
@@ -7890,11 +7945,16 @@ async def delete_state(self, user_id: int, chat_id: Optional[int]=None):
:return: None
"""
- if not chat_id:
+ if chat_id is None:
chat_id = user_id
- await self.current_states.delete_state(chat_id, user_id)
+ if bot_id is None:
+ bot_id = self.bot_id
+ return await self.current_states.delete_state(chat_id=chat_id, user_id=user_id, bot_id=bot_id,
+ business_connection_id=business_connection_id, message_thread_id=message_thread_id)
+
- def retrieve_data(self, user_id: int, chat_id: Optional[int]=None):
+ def retrieve_data(self, user_id: int, chat_id: Optional[int]=None, business_connection_id: Optional[str]=None,
+ message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> Optional[Dict[str, Any]]:
"""
Returns context manager with data for a user in chat.
@@ -7904,32 +7964,70 @@ def retrieve_data(self, user_id: int, chat_id: Optional[int]=None):
:param chat_id: Chat's unique identifier, defaults to user_id
:type chat_id: int, optional
+ :param bot_id: Bot's identifier, defaults to current bot id
+ :type bot_id: int, optional
+
+ :param business_connection_id: Business identifier
+ :type business_connection_id: str, optional
+
+ :param message_thread_id: Identifier of the message thread
+ :type message_thread_id: int, optional
+
:return: Context manager with data for a user in chat
- :rtype: Optional[Any]
+ :rtype: :obj:`dict`
"""
- if not chat_id:
+ if chat_id is None:
chat_id = user_id
- return self.current_states.get_interactive_data(chat_id, user_id)
+ if bot_id is None:
+ bot_id = self.bot_id
+ return self.current_states.get_interactive_data(chat_id=chat_id, user_id=user_id, bot_id=bot_id,
+ business_connection_id=business_connection_id,
+ message_thread_id=message_thread_id)
+
- async def get_state(self, user_id, chat_id: Optional[int]=None):
+ async def get_state(self, user_id: int, chat_id: Optional[int]=None,
+ business_connection_id: Optional[str]=None,
+ message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> str:
"""
Gets current state of a user.
Not recommended to use this method. But it is ok for debugging.
+ .. warning::
+
+ Even if you are using :class:`telebot.types.State`, this method will return a string.
+ When comparing(not recommended), you should compare this string with :class:`telebot.types.State`.name
+
:param user_id: User's identifier
:type user_id: :obj:`int`
:param chat_id: Chat's identifier
:type chat_id: :obj:`int`
+ :param bot_id: Bot's identifier, defaults to current bot id
+ :type bot_id: :obj:`int`
+
+ :param business_connection_id: Business identifier
+ :type business_connection_id: :obj:`str`
+
+ :param message_thread_id: Identifier of the message thread
+ :type message_thread_id: :obj:`int`
+
:return: state of a user
- :rtype: :obj:`int` or :obj:`str` or :class:`telebot.types.State`
+ :rtype: :obj:`str`
"""
- if not chat_id:
+ if chat_id is None:
chat_id = user_id
- return await self.current_states.get_state(chat_id, user_id)
+ if bot_id is None:
+ bot_id = self.bot_id
+ return await self.current_states.get_state(chat_id=chat_id, user_id=user_id, bot_id=bot_id,
+ business_connection_id=business_connection_id, message_thread_id=message_thread_id)
+
- async def add_data(self, user_id: int, chat_id: Optional[int]=None, **kwargs):
+ async def add_data(self, user_id: int, chat_id: Optional[int]=None,
+ business_connection_id: Optional[str]=None,
+ message_thread_id: Optional[int]=None,
+ bot_id: Optional[int]=None,
+ **kwargs) -> None:
"""
Add data to states.
@@ -7939,10 +8037,22 @@ async def add_data(self, user_id: int, chat_id: Optional[int]=None, **kwargs):
:param chat_id: Chat's identifier
:type chat_id: :obj:`int`
+ :param bot_id: Bot's identifier, defaults to current bot id
+ :type bot_id: :obj:`int`
+
+ :param business_connection_id: Business identifier
+ :type business_connection_id: :obj:`str`
+
+ :param message_thread_id: Identifier of the message thread
+ :type message_thread_id: :obj:`int`
+
:param kwargs: Data to add
:return: None
"""
- if not chat_id:
+ if chat_id is None:
chat_id = user_id
+ if bot_id is None:
+ bot_id = self.bot_id
for key, value in kwargs.items():
- await self.current_states.set_data(chat_id, user_id, key, value)
+ await self.current_states.set_data(chat_id=chat_id, user_id=user_id, key=key, value=value, bot_id=bot_id,
+ business_connection_id=business_connection_id, message_thread_id=message_thread_id)
diff --git a/telebot/asyncio_filters.py b/telebot/asyncio_filters.py
index f4e594a7a..20abe5464 100644
--- a/telebot/asyncio_filters.py
+++ b/telebot/asyncio_filters.py
@@ -3,6 +3,7 @@
from telebot.asyncio_handler_backends import State
from telebot import types
+from telebot.states import resolve_context
class SimpleCustomFilter(ABC):
@@ -396,19 +397,11 @@ async def check(self, message, text):
"""
:meta private:
"""
- if text == '*': return True
-
- # needs to work with callbackquery
- if isinstance(message, types.Message):
- chat_id = message.chat.id
- user_id = message.from_user.id
-
- if isinstance(message, types.CallbackQuery):
-
- chat_id = message.message.chat.id
- user_id = message.from_user.id
- message = message.message
+
+ chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(message, self.bot._user.id)
+ if chat_id is None:
+ chat_id = user_id # May change in future
if isinstance(text, list):
new_text = []
@@ -418,21 +411,24 @@ async def check(self, message, text):
text = new_text
elif isinstance(text, State):
text = text.name
-
- if message.chat.type in ['group', 'supergroup']:
- group_state = await self.bot.current_states.get_state(chat_id, user_id)
- if group_state == text:
- return True
- elif type(text) is list and group_state in text:
- return True
-
-
- else:
- user_state = await self.bot.current_states.get_state(chat_id, user_id)
- if user_state == text:
- return True
- elif type(text) is list and user_state in text:
- return True
+
+ user_state = await self.bot.current_states.get_state(
+ chat_id=chat_id,
+ user_id=user_id,
+ business_connection_id=business_connection_id,
+ bot_id=bot_id,
+ message_thread_id=message_thread_id
+ )
+
+ # CHANGED BEHAVIOUR
+ if text == "*" and user_state is not None:
+ return True
+
+ if user_state == text:
+ return True
+ elif type(text) is list and user_state in text:
+ return True
+ return False
class IsDigitFilter(SimpleCustomFilter):
diff --git a/telebot/asyncio_handler_backends.py b/telebot/asyncio_handler_backends.py
index 0861a9893..6c96cc2d7 100644
--- a/telebot/asyncio_handler_backends.py
+++ b/telebot/asyncio_handler_backends.py
@@ -1,6 +1,7 @@
"""
File with all middleware classes, states.
"""
+from telebot.states import State, StatesGroup
class BaseMiddleware:
@@ -48,46 +49,6 @@ async def post_process(self, message, data, exception):
raise NotImplementedError
-class State:
- """
- Class representing a state.
-
- .. code-block:: python3
-
- class MyStates(StatesGroup):
- my_state = State() # returns my_state:State string.
- """
- def __init__(self) -> None:
- self.name = None
-
- def __str__(self) -> str:
- return self.name
-
-
-class StatesGroup:
- """
- Class representing common states.
-
- .. code-block:: python3
-
- class MyStates(StatesGroup):
- my_state = State() # returns my_state:State string.
- """
- def __init_subclass__(cls) -> None:
- state_list = []
- for name, value in cls.__dict__.items():
- if not name.startswith('__') and not callable(value) and isinstance(value, State):
- # change value of that variable
- value.name = ':'.join((cls.__name__, name))
- value.group = cls
- state_list.append(value)
- cls._state_list = state_list
-
- @classmethod
- def state_list(self):
- return self._state_list
-
-
class SkipHandler:
"""
Class for skipping handlers.
diff --git a/telebot/asyncio_storage/__init__.py b/telebot/asyncio_storage/__init__.py
index 892f0af94..82a8c817a 100644
--- a/telebot/asyncio_storage/__init__.py
+++ b/telebot/asyncio_storage/__init__.py
@@ -1,13 +1,13 @@
from telebot.asyncio_storage.memory_storage import StateMemoryStorage
from telebot.asyncio_storage.redis_storage import StateRedisStorage
from telebot.asyncio_storage.pickle_storage import StatePickleStorage
-from telebot.asyncio_storage.base_storage import StateContext,StateStorageBase
-
-
-
+from telebot.asyncio_storage.base_storage import StateDataContext, StateStorageBase
__all__ = [
- 'StateStorageBase', 'StateContext',
- 'StateMemoryStorage', 'StateRedisStorage', 'StatePickleStorage'
-]
\ No newline at end of file
+ "StateStorageBase",
+ "StateDataContext",
+ "StateMemoryStorage",
+ "StateRedisStorage",
+ "StatePickleStorage",
+]
diff --git a/telebot/asyncio_storage/base_storage.py b/telebot/asyncio_storage/base_storage.py
index 38615c4c2..20d0b0587 100644
--- a/telebot/asyncio_storage/base_storage.py
+++ b/telebot/asyncio_storage/base_storage.py
@@ -16,53 +16,108 @@ async def get_data(self, chat_id, user_id):
Get data for a user in a particular chat.
"""
raise NotImplementedError
-
+
async def set_state(self, chat_id, user_id, state):
"""
Set state for a particular user.
- ! Note that you should create a
- record if it does not exist, and
+ ! Note that you should create a
+ record if it does not exist, and
if a record with state already exists,
you need to update a record.
"""
raise NotImplementedError
-
+
async def delete_state(self, chat_id, user_id):
"""
Delete state for a particular user.
"""
raise NotImplementedError
-
+
async def reset_data(self, chat_id, user_id):
"""
Reset data for a particular user in a chat.
"""
raise NotImplementedError
-
+
async def get_state(self, chat_id, user_id):
raise NotImplementedError
-
+
+ def get_interactive_data(self, chat_id, user_id):
+ """
+ Should be sync, but should provide a context manager
+ with __aenter__ and __aexit__ methods.
+ """
+ raise NotImplementedError
+
async def save(self, chat_id, user_id, data):
raise NotImplementedError
+ def _get_key(
+ self,
+ chat_id: int,
+ user_id: int,
+ prefix: str,
+ separator: str,
+ business_connection_id: str = None,
+ message_thread_id: int = None,
+ bot_id: int = None,
+ ) -> str:
+ """
+ Convert parameters to a key.
+ """
+ params = [prefix]
+ if bot_id:
+ params.append(str(bot_id))
+ if business_connection_id:
+ params.append(business_connection_id)
+ if message_thread_id:
+ params.append(str(message_thread_id))
+ params.append(str(chat_id))
+ params.append(str(user_id))
-class StateContext:
+ return separator.join(params)
+
+
+class StateDataContext:
"""
Class for data.
"""
- def __init__(self, obj, chat_id, user_id):
+ def __init__(
+ self,
+ obj,
+ chat_id,
+ user_id,
+ business_connection_id=None,
+ message_thread_id=None,
+ bot_id=None,
+ ):
self.obj = obj
self.data = None
self.chat_id = chat_id
self.user_id = user_id
-
-
+ self.bot_id = bot_id
+ self.business_connection_id = business_connection_id
+ self.message_thread_id = message_thread_id
async def __aenter__(self):
- self.data = copy.deepcopy(await self.obj.get_data(self.chat_id, self.user_id))
+ data = await self.obj.get_data(
+ chat_id=self.chat_id,
+ user_id=self.user_id,
+ business_connection_id=self.business_connection_id,
+ message_thread_id=self.message_thread_id,
+ bot_id=self.bot_id,
+ )
+ self.data = copy.deepcopy(data)
return self.data
async def __aexit__(self, exc_type, exc_val, exc_tb):
- return await self.obj.save(self.chat_id, self.user_id, self.data)
\ No newline at end of file
+ return await self.obj.save(
+ self.chat_id,
+ self.user_id,
+ self.data,
+ self.business_connection_id,
+ self.message_thread_id,
+ self.bot_id,
+ )
diff --git a/telebot/asyncio_storage/memory_storage.py b/telebot/asyncio_storage/memory_storage.py
index 45c2ad914..e17d5fe8c 100644
--- a/telebot/asyncio_storage/memory_storage.py
+++ b/telebot/asyncio_storage/memory_storage.py
@@ -1,66 +1,225 @@
-from telebot.asyncio_storage.base_storage import StateStorageBase, StateContext
+from telebot.asyncio_storage.base_storage import StateStorageBase, StateDataContext
+from typing import Optional, Union
+
class StateMemoryStorage(StateStorageBase):
- def __init__(self) -> None:
- self.data = {}
- #
- # {chat_id: {user_id: {'state': None, 'data': {}}, ...}, ...}
-
-
- async def set_state(self, chat_id, user_id, state):
- if hasattr(state, 'name'):
+ """
+ Memory storage for states.
+
+ Stores states in memory as a dictionary.
+
+ .. code-block:: python3
+
+ storage = StateMemoryStorage()
+ bot = AsyncTeleBot(token, storage=storage)
+
+ :param separator: Separator for keys, default is ":".
+ :type separator: Optional[str]
+
+ :param prefix: Prefix for keys, default is "telebot".
+ :type prefix: Optional[str]
+ """
+
+ def __init__(
+ self, separator: Optional[str] = ":", prefix: Optional[str] = "telebot"
+ ) -> None:
+ self.separator = separator
+ self.prefix = prefix
+ if not self.prefix:
+ raise ValueError("Prefix cannot be empty")
+
+ self.data = (
+ {}
+ ) # key: telebot:bot_id:business_connection_id:message_thread_id:chat_id:user_id
+
+ async def set_state(
+ self,
+ chat_id: int,
+ user_id: int,
+ state: str,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> bool:
+ if hasattr(state, "name"):
state = state.name
- if chat_id in self.data:
- if user_id in self.data[chat_id]:
- self.data[chat_id][user_id]['state'] = state
- return True
- else:
- self.data[chat_id][user_id] = {'state': state, 'data': {}}
- return True
- self.data[chat_id] = {user_id: {'state': state, 'data': {}}}
+
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+
+ if self.data.get(_key) is None:
+ self.data[_key] = {"state": state, "data": {}}
+ else:
+ self.data[_key]["state"] = state
+
+ return True
+
+ async def get_state(
+ self,
+ chat_id: int,
+ user_id: int,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> Union[str, None]:
+
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+
+ if self.data.get(_key) is None:
+ return None
+
+ return self.data[_key]["state"]
+
+ async def delete_state(
+ self,
+ chat_id: int,
+ user_id: int,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> bool:
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+
+ if self.data.get(_key) is None:
+ return False
+
+ del self.data[_key]
+ return True
+
+ async def set_data(
+ self,
+ chat_id: int,
+ user_id: int,
+ key: str,
+ value: Union[str, int, float, dict],
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> bool:
+
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+
+ if self.data.get(_key) is None:
+ raise RuntimeError(f"MemoryStorage: key {_key} does not exist.")
+ self.data[_key]["data"][key] = value
+ return True
+
+ async def get_data(
+ self,
+ chat_id: int,
+ user_id: int,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> dict:
+
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+
+ return self.data.get(_key, {}).get("data", {})
+
+ async def reset_data(
+ self,
+ chat_id: int,
+ user_id: int,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> bool:
+
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+
+ if self.data.get(_key) is None:
+ return False
+ self.data[_key]["data"] = {}
return True
-
- async def delete_state(self, chat_id, user_id):
- if self.data.get(chat_id):
- if self.data[chat_id].get(user_id):
- del self.data[chat_id][user_id]
- if chat_id == user_id:
- del self.data[chat_id]
-
- return True
-
- return False
-
-
- async def get_state(self, chat_id, user_id):
- if self.data.get(chat_id):
- if self.data[chat_id].get(user_id):
- return self.data[chat_id][user_id]['state']
-
- return None
- async def get_data(self, chat_id, user_id):
- if self.data.get(chat_id):
- if self.data[chat_id].get(user_id):
- return self.data[chat_id][user_id]['data']
-
- return None
-
- async def reset_data(self, chat_id, user_id):
- if self.data.get(chat_id):
- if self.data[chat_id].get(user_id):
- self.data[chat_id][user_id]['data'] = {}
- return True
- return False
-
- async def set_data(self, chat_id, user_id, key, value):
- if self.data.get(chat_id):
- if self.data[chat_id].get(user_id):
- self.data[chat_id][user_id]['data'][key] = value
- return True
- raise RuntimeError('chat_id {} and user_id {} does not exist'.format(chat_id, user_id))
-
- def get_interactive_data(self, chat_id, user_id):
- return StateContext(self, chat_id, user_id)
-
- async def save(self, chat_id, user_id, data):
- self.data[chat_id][user_id]['data'] = data
\ No newline at end of file
+
+ def get_interactive_data(
+ self,
+ chat_id: int,
+ user_id: int,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> Optional[dict]:
+ return StateDataContext(
+ self,
+ chat_id=chat_id,
+ user_id=user_id,
+ business_connection_id=business_connection_id,
+ message_thread_id=message_thread_id,
+ bot_id=bot_id,
+ )
+
+ async def save(
+ self,
+ chat_id: int,
+ user_id: int,
+ data: dict,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> bool:
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+
+ if self.data.get(_key) is None:
+ return False
+ self.data[_key]["data"] = data
+ return True
+
+ def __str__(self) -> str:
+ return f""
diff --git a/telebot/asyncio_storage/pickle_storage.py b/telebot/asyncio_storage/pickle_storage.py
index cf446d85b..723672034 100644
--- a/telebot/asyncio_storage/pickle_storage.py
+++ b/telebot/asyncio_storage/pickle_storage.py
@@ -1,28 +1,74 @@
-from telebot.asyncio_storage.base_storage import StateStorageBase, StateContext
-import os
+aiofiles_installed = True
+try:
+ import aiofiles
+except ImportError:
+ aiofiles_installed = False
+import os
import pickle
+import asyncio
+from typing import Optional, Union, Callable, Any
+
+from telebot.asyncio_storage.base_storage import StateStorageBase, StateDataContext
+
+
+def with_lock(func: Callable) -> Callable:
+ async def wrapper(self, *args, **kwargs):
+ async with self.lock:
+ return await func(self, *args, **kwargs)
+
+ return wrapper
class StatePickleStorage(StateStorageBase):
- def __init__(self, file_path="./.state-save/states.pkl") -> None:
+ """
+ State storage based on pickle file.
+
+ .. warning::
+
+ This storage is not recommended for production use.
+ Data may be corrupted. If you face a case where states do not work as expected,
+ try to use another storage.
+
+ .. code-block:: python3
+
+ storage = StatePickleStorage()
+ bot = AsyncTeleBot(token, storage=storage)
+
+ :param file_path: Path to file where states will be stored.
+ :type file_path: str
+
+ :param prefix: Prefix for keys, default is "telebot".
+ :type prefix: Optional[str]
+
+ :param separator: Separator for keys, default is ":".
+ :type separator: Optional[str]
+ """
+
+ def __init__(
+ self,
+ file_path: str = "./.state-save/states.pkl",
+ prefix="telebot",
+ separator: Optional[str] = ":",
+ ) -> None:
+
+ if not aiofiles_installed:
+ raise ImportError("Please install aiofiles using `pip install aiofiles`")
+
self.file_path = file_path
+ self.prefix = prefix
+ self.separator = separator
+ self.lock = asyncio.Lock()
self.create_dir()
- self.data = self.read()
-
- async def convert_old_to_new(self):
- # old looks like:
- # {1: {'state': 'start', 'data': {'name': 'John'}}
- # we should update old version pickle to new.
- # new looks like:
- # {1: {2: {'state': 'start', 'data': {'name': 'John'}}}}
- new_data = {}
- for key, value in self.data.items():
- # this returns us id and dict with data and state
- new_data[key] = {key: value} # convert this to new
- # pass it to global data
- self.data = new_data
- self.update_data() # update data in file
+
+ async def _read_from_file(self) -> dict:
+ async with aiofiles.open(self.file_path, "rb") as f:
+ data = await f.read()
+ return pickle.loads(data)
+
+ async def _write_to_file(self, data: dict) -> None:
+ async with aiofiles.open(self.file_path, "wb") as f:
+ await f.write(pickle.dumps(data))
def create_dir(self):
"""
@@ -31,80 +77,198 @@ def create_dir(self):
dirs, filename = os.path.split(self.file_path)
os.makedirs(dirs, exist_ok=True)
if not os.path.isfile(self.file_path):
- with open(self.file_path,'wb') as file:
+ with open(self.file_path, "wb") as file:
pickle.dump({}, file)
- def read(self):
- file = open(self.file_path, 'rb')
- data = pickle.load(file)
- file.close()
- return data
-
- def update_data(self):
- file = open(self.file_path, 'wb+')
- pickle.dump(self.data, file, protocol=pickle.HIGHEST_PROTOCOL)
- file.close()
-
- async def set_state(self, chat_id, user_id, state):
- if hasattr(state, 'name'):
- state = state.name
- if chat_id in self.data:
- if user_id in self.data[chat_id]:
- self.data[chat_id][user_id]['state'] = state
- self.update_data()
- return True
- else:
- self.data[chat_id][user_id] = {'state': state, 'data': {}}
- self.update_data()
- return True
- self.data[chat_id] = {user_id: {'state': state, 'data': {}}}
- self.update_data()
+ @with_lock
+ async def set_state(
+ self,
+ chat_id: int,
+ user_id: int,
+ state: str,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> bool:
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+ data = await self._read_from_file()
+ if _key not in data:
+ data[_key] = {"state": state, "data": {}}
+ else:
+ data[_key]["state"] = state
+ await self._write_to_file(data)
return True
-
- async def delete_state(self, chat_id, user_id):
- if self.data.get(chat_id):
- if self.data[chat_id].get(user_id):
- del self.data[chat_id][user_id]
- if chat_id == user_id:
- del self.data[chat_id]
- self.update_data()
- return True
+ @with_lock
+ async def get_state(
+ self,
+ chat_id: int,
+ user_id: int,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> Union[str, None]:
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+ data = await self._read_from_file()
+ return data.get(_key, {}).get("state")
+
+ @with_lock
+ async def delete_state(
+ self,
+ chat_id: int,
+ user_id: int,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> bool:
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+ data = await self._read_from_file()
+ if _key in data:
+ del data[_key]
+ await self._write_to_file(data)
+ return True
return False
-
- async def get_state(self, chat_id, user_id):
- if self.data.get(chat_id):
- if self.data[chat_id].get(user_id):
- return self.data[chat_id][user_id]['state']
-
- return None
- async def get_data(self, chat_id, user_id):
- if self.data.get(chat_id):
- if self.data[chat_id].get(user_id):
- return self.data[chat_id][user_id]['data']
-
- return None
-
- async def reset_data(self, chat_id, user_id):
- if self.data.get(chat_id):
- if self.data[chat_id].get(user_id):
- self.data[chat_id][user_id]['data'] = {}
- self.update_data()
- return True
+ @with_lock
+ async def set_data(
+ self,
+ chat_id: int,
+ user_id: int,
+ key: str,
+ value: Union[str, int, float, dict],
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> bool:
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+ data = await self._read_from_file()
+ state_data = data.get(_key, {})
+ state_data["data"][key] = value
+ if _key not in data:
+ raise RuntimeError(f"StatePickleStorage: key {_key} does not exist.")
+ else:
+ data[_key]["data"][key] = value
+ await self._write_to_file(data)
+ return True
+
+ @with_lock
+ async def get_data(
+ self,
+ chat_id: int,
+ user_id: int,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> dict:
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+ data = await self._read_from_file()
+ return data.get(_key, {}).get("data", {})
+
+ @with_lock
+ async def reset_data(
+ self,
+ chat_id: int,
+ user_id: int,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> bool:
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+ data = await self._read_from_file()
+ if _key in data:
+ data[_key]["data"] = {}
+ await self._write_to_file(data)
+ return True
return False
- async def set_data(self, chat_id, user_id, key, value):
- if self.data.get(chat_id):
- if self.data[chat_id].get(user_id):
- self.data[chat_id][user_id]['data'][key] = value
- self.update_data()
- return True
- raise RuntimeError('chat_id {} and user_id {} does not exist'.format(chat_id, user_id))
+ def get_interactive_data(
+ self,
+ chat_id: int,
+ user_id: int,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> Optional[dict]:
+ return StateDataContext(
+ self,
+ chat_id=chat_id,
+ user_id=user_id,
+ business_connection_id=business_connection_id,
+ message_thread_id=message_thread_id,
+ bot_id=bot_id,
+ )
- def get_interactive_data(self, chat_id, user_id):
- return StateContext(self, chat_id, user_id)
+ @with_lock
+ async def save(
+ self,
+ chat_id: int,
+ user_id: int,
+ data: dict,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> bool:
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+ data = await self._read_from_file()
+ data[_key]["data"] = data
+ await self._write_to_file(data)
+ return True
- async def save(self, chat_id, user_id, data):
- self.data[chat_id][user_id]['data'] = data
- self.update_data()
\ No newline at end of file
+ def __str__(self) -> str:
+ return f"StatePickleStorage({self.file_path}, {self.prefix})"
diff --git a/telebot/asyncio_storage/redis_storage.py b/telebot/asyncio_storage/redis_storage.py
index 84db253e5..08e255809 100644
--- a/telebot/asyncio_storage/redis_storage.py
+++ b/telebot/asyncio_storage/redis_storage.py
@@ -1,179 +1,339 @@
-from telebot.asyncio_storage.base_storage import StateStorageBase, StateContext
-import json
-
redis_installed = True
-is_actual_aioredis = False
try:
- import aioredis
- is_actual_aioredis = True
+ import redis
+ from redis.asyncio import Redis, ConnectionPool
except ImportError:
- try:
- from redis import asyncio as aioredis
- except ImportError:
- redis_installed = False
+ redis_installed = False
+
+import json
+from typing import Optional, Union, Callable, Coroutine
+import asyncio
+
+from telebot.asyncio_storage.base_storage import StateStorageBase, StateDataContext
+
+
+def async_with_lock(func: Callable[..., Coroutine]) -> Callable[..., Coroutine]:
+ async def wrapper(self, *args, **kwargs):
+ async with self.lock:
+ return await func(self, *args, **kwargs)
+
+ return wrapper
+
+
+def async_with_pipeline(func: Callable[..., Coroutine]) -> Callable[..., Coroutine]:
+ async def wrapper(self, *args, **kwargs):
+ async with self.redis.pipeline() as pipe:
+ pipe.multi()
+ result = await func(self, pipe, *args, **kwargs)
+ await pipe.execute()
+ return result
+
+ return wrapper
class StateRedisStorage(StateStorageBase):
"""
- This class is for Redis storage.
- This will work only for states.
- To use it, just pass this class to:
- TeleBot(storage=StateRedisStorage())
+ State storage based on Redis.
+
+ .. code-block:: python3
+
+ storage = StateRedisStorage(...)
+ bot = AsyncTeleBot(token, storage=storage)
+
+ :param host: Redis host, default is "localhost".
+ :type host: str
+
+ :param port: Redis port, default is 6379.
+ :type port: int
+
+ :param db: Redis database, default is 0.
+ :type db: int
+
+ :param password: Redis password, default is None.
+ :type password: Optional[str]
+
+ :param prefix: Prefix for keys, default is "telebot".
+ :type prefix: Optional[str]
+
+ :param redis_url: Redis URL, default is None.
+ :type redis_url: Optional[str]
+
+ :param connection_pool: Redis connection pool, default is None.
+ :type connection_pool: Optional[ConnectionPool]
+
+ :param separator: Separator for keys, default is ":".
+ :type separator: Optional[str]
+
"""
- def __init__(self, host='localhost', port=6379, db=0, password=None, prefix='telebot_', redis_url=None):
+
+ def __init__(
+ self,
+ host="localhost",
+ port=6379,
+ db=0,
+ password=None,
+ prefix="telebot",
+ redis_url=None,
+ connection_pool: "ConnectionPool" = None,
+ separator: Optional[str] = ":",
+ ) -> None:
+
if not redis_installed:
- raise ImportError('AioRedis is not installed. Install it via "pip install aioredis"')
+ raise ImportError("Please install redis using `pip install redis`")
+
+ self.separator = separator
+ self.prefix = prefix
+ if not self.prefix:
+ raise ValueError("Prefix cannot be empty")
- if is_actual_aioredis:
- aioredis_version = tuple(map(int, aioredis.__version__.split(".")[0]))
- if aioredis_version < (2,):
- raise ImportError('Invalid aioredis version. Aioredis version should be >= 2.0.0')
if redis_url:
- self.redis = aioredis.Redis.from_url(redis_url)
+ self.redis = redis.asyncio.from_url(redis_url)
+ elif connection_pool:
+ self.redis = Redis(connection_pool=connection_pool)
else:
- self.redis = aioredis.Redis(host=host, port=port, db=db, password=password)
+ self.redis = Redis(host=host, port=port, db=db, password=password)
- self.prefix = prefix
- #self.con = Redis(connection_pool=self.redis) -> use this when necessary
- #
- # {chat_id: {user_id: {'state': None, 'data': {}}, ...}, ...}
-
- async def get_record(self, key):
- """
- Function to get record from database.
- It has nothing to do with states.
- Made for backward compatibility
- """
- result = await self.redis.get(self.prefix+str(key))
- if result: return json.loads(result)
- return
+ self.lock = asyncio.Lock()
+
+ @async_with_lock
+ @async_with_pipeline
+ async def set_state(
+ self,
+ pipe,
+ chat_id: int,
+ user_id: int,
+ state: str,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> bool:
+ if hasattr(state, "name"):
+ state = state.name
+
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+ pipe.hget(_key, "data")
+ result = await pipe.execute()
+ data = result[0]
+ if data is None:
+ pipe.hset(_key, "data", json.dumps({}))
+
+ await pipe.hset(_key, "state", state)
- async def set_record(self, key, value):
- """
- Function to set record to database.
- It has nothing to do with states.
- Made for backward compatibility
- """
-
- await self.redis.set(self.prefix+str(key), json.dumps(value))
return True
- async def delete_record(self, key):
- """
- Function to delete record from database.
- It has nothing to do with states.
- Made for backward compatibility
- """
- await self.redis.delete(self.prefix+str(key))
+ async def get_state(
+ self,
+ chat_id: int,
+ user_id: int,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> Union[str, None]:
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+ state_bytes = await self.redis.hget(_key, "state")
+ return state_bytes.decode("utf-8") if state_bytes else None
+
+ async def delete_state(
+ self,
+ chat_id: int,
+ user_id: int,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> bool:
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+ result = await self.redis.delete(_key)
+ return result > 0
+
+ @async_with_lock
+ @async_with_pipeline
+ async def set_data(
+ self,
+ pipe,
+ chat_id: int,
+ user_id: int,
+ key: str,
+ value: Union[str, int, float, dict],
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> bool:
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+ data = await pipe.hget(_key, "data")
+ data = await pipe.execute()
+ data = data[0]
+ if data is None:
+ raise RuntimeError(f"StateRedisStorage: key {_key} does not exist.")
+ else:
+ data = json.loads(data)
+ data[key] = value
+ await pipe.hset(_key, "data", json.dumps(data))
return True
- async def set_state(self, chat_id, user_id, state):
- """
- Set state for a particular user in a chat.
- """
- response = await self.get_record(chat_id)
- user_id = str(user_id)
- if hasattr(state, 'name'):
- state = state.name
- if response:
- if user_id in response:
- response[user_id]['state'] = state
- else:
- response[user_id] = {'state': state, 'data': {}}
+ async def get_data(
+ self,
+ chat_id: int,
+ user_id: int,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> dict:
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+ data = await self.redis.hget(_key, "data")
+ return json.loads(data) if data else {}
+
+ @async_with_lock
+ @async_with_pipeline
+ async def reset_data(
+ self,
+ pipe,
+ chat_id: int,
+ user_id: int,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> bool:
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+ if await pipe.exists(_key):
+ await pipe.hset(_key, "data", "{}")
else:
- response = {user_id: {'state': state, 'data': {}}}
- await self.set_record(chat_id, response)
+ return False
+ return True
+ def get_interactive_data(
+ self,
+ chat_id: int,
+ user_id: int,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> Optional[dict]:
+ return StateDataContext(
+ self,
+ chat_id=chat_id,
+ user_id=user_id,
+ business_connection_id=business_connection_id,
+ message_thread_id=message_thread_id,
+ bot_id=bot_id,
+ )
+
+ @async_with_lock
+ @async_with_pipeline
+ async def save(
+ self,
+ pipe,
+ chat_id: int,
+ user_id: int,
+ data: dict,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> bool:
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+ if await pipe.exists(_key):
+ await pipe.hset(_key, "data", json.dumps(data))
+ else:
+ return False
return True
-
- async def delete_state(self, chat_id, user_id):
- """
- Delete state for a particular user in a chat.
- """
- response = await self.get_record(chat_id)
- user_id = str(user_id)
- if response:
- if user_id in response:
- del response[user_id]
- if user_id == str(chat_id):
- await self.delete_record(chat_id)
- return True
- else: await self.set_record(chat_id, response)
- return True
- return False
-
- async def get_value(self, chat_id, user_id, key):
- """
- Get value for a data of a user in a chat.
- """
- response = await self.get_record(chat_id)
- user_id = str(user_id)
- if response:
- if user_id in response:
- if key in response[user_id]['data']:
- return response[user_id]['data'][key]
- return None
-
- async def get_state(self, chat_id, user_id):
- """
- Get state of a user in a chat.
+
+ def migrate_format(self, bot_id: int, prefix: Optional[str] = "telebot_"):
"""
- response = await self.get_record(chat_id)
- user_id = str(user_id)
- if response:
- if user_id in response:
- return response[user_id]['state']
+ Migrate from old to new format of keys.
+ Run this function once to migrate all redis existing keys to new format.
- return None
+ Starting from version 4.23.0, the format of keys has been changed:
+ :value
+ - Old format: {prefix}chat_id: {user_id: {'state': None, 'data': {}}, ...}
+ - New format:
+ {prefix}{separator}{bot_id}{separator}{business_connection_id}{separator}{message_thread_id}{separator}{chat_id}{separator}{user_id}: {'state': ..., 'data': {}}
- async def get_data(self, chat_id, user_id):
- """
- Get data of particular user in a particular chat.
- """
- response = await self.get_record(chat_id)
- user_id = str(user_id)
- if response:
- if user_id in response:
- return response[user_id]['data']
- return None
-
- async def reset_data(self, chat_id, user_id):
- """
- Reset data of a user in a chat.
- """
- response = await self.get_record(chat_id)
- user_id = str(user_id)
- if response:
- if user_id in response:
- response[user_id]['data'] = {}
- await self.set_record(chat_id, response)
- return True
-
- async def set_data(self, chat_id, user_id, key, value):
- """
- Set data without interactive data.
- """
- response = await self.get_record(chat_id)
- user_id = str(user_id)
- if response:
- if user_id in response:
- response[user_id]['data'][key] = value
- await self.set_record(chat_id, response)
- return True
- return False
-
- def get_interactive_data(self, chat_id, user_id):
- """
- Get Data in interactive way.
- You can use with() with this function.
+ This function will help you to migrate from the old format to the new one in order to avoid data loss.
+
+ :param bot_id: Bot ID; To get it, call a getMe request and grab the id from the response.
+ :type bot_id: int
+
+ :param prefix: Prefix for keys, default is "telebot_"(old default value)
+ :type prefix: Optional[str]
"""
- return StateContext(self, chat_id, user_id)
-
- async def save(self, chat_id, user_id, data):
- response = await self.get_record(chat_id)
- user_id = str(user_id)
- if response:
- if user_id in response:
- response[user_id]['data'] = data
- await self.set_record(chat_id, response)
- return True
+ keys = self.redis.keys(f"{prefix}*")
+
+ for key in keys:
+ old_key = key.decode("utf-8")
+ # old: {prefix}chat_id: {user_id: {'state': None, 'data': {}}, ...}
+ value = self.redis.get(old_key)
+ value = json.loads(value)
+
+ chat_id = old_key[len(prefix) :]
+ user_id = list(value.keys())[0]
+ state = value[user_id]["state"]
+ state_data = value[user_id]["data"]
+
+ # set new format
+ new_key = self._get_key(
+ int(chat_id), int(user_id), self.prefix, self.separator, bot_id=bot_id
+ )
+ self.redis.hset(new_key, "state", state)
+ self.redis.hset(new_key, "data", json.dumps(state_data))
+
+ # delete old key
+ self.redis.delete(old_key)
+
+ def __str__(self) -> str:
+ # include some connection info
+ return f"StateRedisStorage({self.redis})"
diff --git a/telebot/custom_filters.py b/telebot/custom_filters.py
index a1c0aba14..2e91e55ea 100644
--- a/telebot/custom_filters.py
+++ b/telebot/custom_filters.py
@@ -4,6 +4,7 @@
from telebot import types
+from telebot.states import resolve_context
class SimpleCustomFilter(ABC):
"""
@@ -402,21 +403,11 @@ def check(self, message, text):
"""
:meta private:
"""
- if text == '*': return True
- # needs to work with callbackquery
- if isinstance(message, types.Message):
- chat_id = message.chat.id
- user_id = message.from_user.id
+ chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(message, self.bot._user.id)
- if isinstance(message, types.CallbackQuery):
-
- chat_id = message.message.chat.id
- user_id = message.from_user.id
- message = message.message
-
-
-
+ if chat_id is None:
+ chat_id = user_id # May change in future
if isinstance(text, list):
new_text = []
@@ -427,21 +418,23 @@ def check(self, message, text):
elif isinstance(text, State):
text = text.name
- if message.chat.type in ['group', 'supergroup']:
- group_state = self.bot.current_states.get_state(chat_id, user_id)
- if group_state == text:
- return True
- elif isinstance(text, list) and group_state in text:
- return True
-
-
- else:
- user_state = self.bot.current_states.get_state(chat_id, user_id)
- if user_state == text:
- return True
- elif isinstance(text, list) and user_state in text:
- return True
-
+ user_state = self.bot.current_states.get_state(
+ chat_id=chat_id,
+ user_id=user_id,
+ business_connection_id=business_connection_id,
+ bot_id=bot_id,
+ message_thread_id=message_thread_id
+ )
+
+ # CHANGED BEHAVIOUR
+ if text == "*" and user_state is not None:
+ return True
+
+ if user_state == text:
+ return True
+ elif type(text) is list and user_state in text:
+ return True
+ return False
class IsDigitFilter(SimpleCustomFilter):
"""
diff --git a/telebot/handler_backends.py b/telebot/handler_backends.py
index b95861e0b..12f8c89bf 100644
--- a/telebot/handler_backends.py
+++ b/telebot/handler_backends.py
@@ -9,6 +9,8 @@
except:
redis_installed = False
+# backward compatibility
+from telebot.states import State, StatesGroup
class HandlerBackend(object):
"""
@@ -160,45 +162,6 @@ def get_handlers(self, handler_group_id):
return handlers
-class State:
- """
- Class representing a state.
-
- .. code-block:: python3
-
- class MyStates(StatesGroup):
- my_state = State() # returns my_state:State string.
- """
- def __init__(self) -> None:
- self.name = None
- def __str__(self) -> str:
- return self.name
-
-
-class StatesGroup:
- """
- Class representing common states.
-
- .. code-block:: python3
-
- class MyStates(StatesGroup):
- my_state = State() # returns my_state:State string.
- """
- def __init_subclass__(cls) -> None:
- state_list = []
- for name, value in cls.__dict__.items():
- if not name.startswith('__') and not callable(value) and isinstance(value, State):
- # change value of that variable
- value.name = ':'.join((cls.__name__, name))
- value.group = cls
- state_list.append(value)
- cls._state_list = state_list
-
- @classmethod
- def state_list(self):
- return self._state_list
-
-
class BaseMiddleware:
"""
Base class for middleware.
@@ -292,4 +255,4 @@ def start2(message):
"""
def __init__(self) -> None:
- pass
+ pass
\ No newline at end of file
diff --git a/telebot/states/__init__.py b/telebot/states/__init__.py
new file mode 100644
index 000000000..2491608ee
--- /dev/null
+++ b/telebot/states/__init__.py
@@ -0,0 +1,128 @@
+"""
+Contains classes for states and state groups.
+"""
+
+from telebot import types
+
+
+class State:
+ """
+ Class representing a state.
+
+ .. code-block:: python3
+
+ class MyStates(StatesGroup):
+ my_state = State() # returns my_state:State string.
+ """
+
+ def __init__(self) -> None:
+ self.name: str = None
+ self.group: StatesGroup = None
+
+ def __str__(self) -> str:
+ return f"<{self.name}>"
+
+
+class StatesGroup:
+ """
+ Class representing common states.
+
+ .. code-block:: python3
+
+ class MyStates(StatesGroup):
+ my_state = State() # returns my_state:State string.
+ """
+
+ def __init_subclass__(cls) -> None:
+ state_list = []
+ for name, value in cls.__dict__.items():
+ if (
+ not name.startswith("__")
+ and not callable(value)
+ and isinstance(value, State)
+ ):
+ # change value of that variable
+ value.name = ":".join((cls.__name__, name))
+ value.group = cls
+ state_list.append(value)
+ cls._state_list = state_list
+
+ @classmethod
+ def state_list(self):
+ return self._state_list
+
+
+def resolve_context(message, bot_id: int) -> tuple:
+ # chat_id, user_id, business_connection_id, bot_id, message_thread_id
+
+ # message, edited_message, channel_post, edited_channel_post, business_message, edited_business_message
+ if isinstance(message, types.Message):
+ return (
+ message.chat.id,
+ message.from_user.id,
+ message.business_connection_id,
+ bot_id,
+ message.message_thread_id if message.is_topic_message else None,
+ )
+ elif isinstance(message, types.CallbackQuery): # callback_query
+ return (
+ message.message.chat.id,
+ message.from_user.id,
+ message.message.business_connection_id,
+ bot_id,
+ (
+ message.message.message_thread_id
+ if message.message.is_topic_message
+ else None
+ ),
+ )
+ elif isinstance(message, types.BusinessConnection): # business_connection
+ return (message.user_chat_id, message.user.id, message.id, bot_id, None)
+ elif isinstance(
+ message, types.BusinessMessagesDeleted
+ ): # deleted_business_messages
+ return (
+ message.chat.id,
+ message.chat.id,
+ message.business_connection_id,
+ bot_id,
+ None,
+ )
+ elif isinstance(message, types.MessageReactionUpdated): # message_reaction
+ return (message.chat.id, message.user.id, None, bot_id, None)
+ elif isinstance(
+ message, types.MessageReactionCountUpdated
+ ): # message_reaction_count
+ return (message.chat.id, None, None, bot_id, None)
+ elif isinstance(message, types.InlineQuery): # inline_query
+ return (None, message.from_user.id, None, bot_id, None)
+ elif isinstance(message, types.ChosenInlineResult): # chosen_inline_result
+ return (None, message.from_user.id, None, bot_id, None)
+ elif isinstance(message, types.ShippingQuery): # shipping_query
+ return (None, message.from_user.id, None, bot_id, None)
+ elif isinstance(message, types.PreCheckoutQuery): # pre_checkout_query
+ return (None, message.from_user.id, None, bot_id, None)
+ elif isinstance(message, types.PollAnswer): # poll_answer
+ return (None, message.user.id, None, bot_id, None)
+ elif isinstance(message, types.ChatMemberUpdated): # chat_member # my_chat_member
+ return (message.chat.id, message.from_user.id, None, bot_id, None)
+ elif isinstance(message, types.ChatJoinRequest): # chat_join_request
+ return (message.chat.id, message.from_user.id, None, bot_id, None)
+ elif isinstance(message, types.ChatBoostRemoved): # removed_chat_boost
+ return (
+ message.chat.id,
+ message.source.user.id if message.source else None,
+ None,
+ bot_id,
+ None,
+ )
+ elif isinstance(message, types.ChatBoostUpdated): # chat_boost
+ return (
+ message.chat.id,
+ message.boost.source.user.id if message.boost.source else None,
+ None,
+ bot_id,
+ None,
+ )
+ else:
+ pass # not yet supported :(
diff --git a/telebot/states/asyncio/__init__.py b/telebot/states/asyncio/__init__.py
new file mode 100644
index 000000000..14ef19536
--- /dev/null
+++ b/telebot/states/asyncio/__init__.py
@@ -0,0 +1,7 @@
+from .context import StateContext
+from .middleware import StateMiddleware
+
+__all__ = [
+ "StateContext",
+ "StateMiddleware",
+]
diff --git a/telebot/states/asyncio/context.py b/telebot/states/asyncio/context.py
new file mode 100644
index 000000000..4c9ad61a3
--- /dev/null
+++ b/telebot/states/asyncio/context.py
@@ -0,0 +1,153 @@
+from telebot.states import State, StatesGroup
+from telebot.types import CallbackQuery, Message
+from telebot.async_telebot import AsyncTeleBot
+from telebot.states import resolve_context
+
+from typing import Union
+
+
+class StateContext:
+ """
+ Class representing a state context.
+
+ Passed through a middleware to provide easy way to set states.
+
+ .. code-block:: python3
+
+ @bot.message_handler(commands=['start'])
+ async def start_ex(message: types.Message, state_context: StateContext):
+ await state_context.set(MyStates.name)
+ await bot.send_message(message.chat.id, 'Hi, write me a name', reply_to_message_id=message.message_id)
+ # also, state_context.data(), .add_data(), .reset_data(), .delete() methods available.
+ """
+
+ def __init__(self, message: Union[Message, CallbackQuery], bot: str) -> None:
+ self.message: Union[Message, CallbackQuery] = message
+ self.bot: AsyncTeleBot = bot
+ self.bot_id = self.bot.bot_id
+
+ async def set(self, state: Union[State, str]) -> bool:
+ """
+ Set state for current user.
+
+ :param state: State object or state name.
+ :type state: Union[State, str]
+
+ .. code-block:: python3
+
+ @bot.message_handler(commands=['start'])
+ async def start_ex(message: types.Message, state_context: StateContext):
+ await state_context.set(MyStates.name)
+ await bot.send_message(message.chat.id, 'Hi, write me a name', reply_to_message_id=message.message_id)
+ """
+
+ chat_id, user_id, business_connection_id, bot_id, message_thread_id = (
+ resolve_context(self.message, self.bot.bot_id)
+ )
+ if isinstance(state, State):
+ state = state.name
+ return await self.bot.set_state(
+ chat_id=chat_id,
+ user_id=user_id,
+ state=state,
+ business_connection_id=business_connection_id,
+ bot_id=bot_id,
+ message_thread_id=message_thread_id,
+ )
+
+ async def get(self) -> str:
+ """
+ Get current state for current user.
+
+ :return: Current state name.
+ :rtype: str
+ """
+
+ chat_id, user_id, business_connection_id, bot_id, message_thread_id = (
+ resolve_context(self.message, self.bot.bot_id)
+ )
+ return await self.bot.get_state(
+ chat_id=chat_id,
+ user_id=user_id,
+ business_connection_id=business_connection_id,
+ bot_id=bot_id,
+ message_thread_id=message_thread_id,
+ )
+
+ async def delete(self) -> bool:
+ """
+ Deletes state and data for current user.
+
+ .. warning::
+
+ This method deletes state and associated data for current user.
+ """
+ chat_id, user_id, business_connection_id, bot_id, message_thread_id = (
+ resolve_context(self.message, self.bot.bot_id)
+ )
+ return await self.bot.delete_state(
+ chat_id=chat_id,
+ user_id=user_id,
+ business_connection_id=business_connection_id,
+ bot_id=bot_id,
+ message_thread_id=message_thread_id,
+ )
+
+ async def reset_data(self) -> bool:
+ """
+ Reset data for current user.
+ State will not be changed.
+ """
+
+ chat_id, user_id, business_connection_id, bot_id, message_thread_id = (
+ resolve_context(self.message, self.bot.bot_id)
+ )
+ return await self.bot.reset_data(
+ chat_id=chat_id,
+ user_id=user_id,
+ business_connection_id=business_connection_id,
+ bot_id=bot_id,
+ message_thread_id=message_thread_id,
+ )
+
+ def data(self) -> dict:
+ """
+ Get data for current user.
+
+ .. code-block:: python3
+
+ with state_context.data() as data:
+ print(data)
+ data['name'] = 'John'
+ """
+
+ chat_id, user_id, business_connection_id, bot_id, message_thread_id = (
+ resolve_context(self.message, self.bot.bot_id)
+ )
+ return self.bot.retrieve_data(
+ chat_id=chat_id,
+ user_id=user_id,
+ business_connection_id=business_connection_id,
+ bot_id=bot_id,
+ message_thread_id=message_thread_id,
+ )
+
+ async def add_data(self, **kwargs) -> None:
+ """
+ Add data for current user.
+
+ :param kwargs: Data to add.
+ :type kwargs: dict
+ """
+
+ chat_id, user_id, business_connection_id, bot_id, message_thread_id = (
+ resolve_context(self.message, self.bot.bot_id)
+ )
+ return await self.bot.add_data(
+ chat_id=chat_id,
+ user_id=user_id,
+ business_connection_id=business_connection_id,
+ bot_id=bot_id,
+ message_thread_id=message_thread_id,
+ **kwargs
+ )
diff --git a/telebot/states/asyncio/middleware.py b/telebot/states/asyncio/middleware.py
new file mode 100644
index 000000000..675b7b462
--- /dev/null
+++ b/telebot/states/asyncio/middleware.py
@@ -0,0 +1,21 @@
+from telebot.asyncio_handler_backends import BaseMiddleware
+from telebot.async_telebot import AsyncTeleBot
+from telebot.states.sync.context import StateContext
+from telebot.util import update_types
+from telebot import types
+
+
+class StateMiddleware(BaseMiddleware):
+
+ def __init__(self, bot: AsyncTeleBot) -> None:
+ self.update_sensitive = False
+ self.update_types = update_types
+ self.bot: AsyncTeleBot = bot
+
+ async def pre_process(self, message, data):
+ state_context = StateContext(message, self.bot)
+ data["state_context"] = state_context
+ data["state"] = state_context # 2 ways to access state context
+
+ async def post_process(self, message, data, exception):
+ pass
diff --git a/telebot/states/sync/__init__.py b/telebot/states/sync/__init__.py
new file mode 100644
index 000000000..60f9ea2b7
--- /dev/null
+++ b/telebot/states/sync/__init__.py
@@ -0,0 +1,7 @@
+from .context import StateContext
+from .middleware import StateMiddleware
+
+__all__ = [
+ 'StateContext',
+ 'StateMiddleware',
+]
\ No newline at end of file
diff --git a/telebot/states/sync/context.py b/telebot/states/sync/context.py
new file mode 100644
index 000000000..c0611f6bb
--- /dev/null
+++ b/telebot/states/sync/context.py
@@ -0,0 +1,143 @@
+from telebot.states import State, StatesGroup
+from telebot.types import CallbackQuery, Message
+from telebot import TeleBot, types
+from telebot.states import resolve_context
+
+from typing import Union
+
+
+
+class StateContext():
+ """
+ Class representing a state context.
+
+ Passed through a middleware to provide easy way to set states.
+
+ .. code-block:: python3
+
+ @bot.message_handler(commands=['start'])
+ def start_ex(message: types.Message, state_context: StateContext):
+ state_context.set(MyStates.name)
+ bot.send_message(message.chat.id, 'Hi, write me a name', reply_to_message_id=message.message_id)
+ # also, state_context.data(), .add_data(), .reset_data(), .delete() methods available.
+ """
+
+ def __init__(self, message: Union[Message, CallbackQuery], bot: str) -> None:
+ self.message: Union[Message, CallbackQuery] = message
+ self.bot: TeleBot = bot
+ self.bot_id = self.bot.bot_id
+
+ def set(self, state: Union[State, str]) -> bool:
+ """
+ Set state for current user.
+
+ :param state: State object or state name.
+ :type state: Union[State, str]
+
+ .. code-block:: python3
+
+ @bot.message_handler(commands=['start'])
+ def start_ex(message: types.Message, state_context: StateContext):
+ state_context.set(MyStates.name)
+ bot.send_message(message.chat.id, 'Hi, write me a name', reply_to_message_id=message.message_id)
+ """
+
+ chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id)
+ if isinstance(state, State):
+ state = state.name
+ return self.bot.set_state(
+ chat_id=chat_id,
+ user_id=user_id,
+ state=state,
+ business_connection_id=business_connection_id,
+ bot_id=bot_id,
+ message_thread_id=message_thread_id
+ )
+
+ def get(self) -> str:
+ """
+ Get current state for current user.
+
+ :return: Current state name.
+ :rtype: str
+ """
+
+ chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id)
+ return self.bot.get_state(
+ chat_id=chat_id,
+ user_id=user_id,
+ business_connection_id=business_connection_id,
+ bot_id=bot_id,
+ message_thread_id=message_thread_id
+ )
+
+ def delete(self) -> bool:
+ """
+ Deletes state and data for current user.
+
+ .. warning::
+
+ This method deletes state and associated data for current user.
+ """
+ chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id)
+ return self.bot.delete_state(
+ chat_id=chat_id,
+ user_id=user_id,
+ business_connection_id=business_connection_id,
+ bot_id=bot_id,
+ message_thread_id=message_thread_id
+ )
+
+ def reset_data(self) -> bool:
+ """
+ Reset data for current user.
+ State will not be changed.
+ """
+
+ chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id)
+ return self.bot.reset_data(
+ chat_id=chat_id,
+ user_id=user_id,
+ business_connection_id=business_connection_id,
+ bot_id=bot_id,
+ message_thread_id=message_thread_id
+ )
+
+ def data(self) -> dict:
+ """
+ Get data for current user.
+
+ .. code-block:: python3
+
+ with state_context.data() as data:
+ print(data)
+ data['name'] = 'John'
+ """
+
+ chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id)
+ return self.bot.retrieve_data(
+ chat_id=chat_id,
+ user_id=user_id,
+ business_connection_id=business_connection_id,
+ bot_id=bot_id,
+ message_thread_id=message_thread_id
+ )
+
+ def add_data(self, **kwargs) -> None:
+ """
+ Add data for current user.
+
+ :param kwargs: Data to add.
+ :type kwargs: dict
+ """
+
+ chat_id, user_id, business_connection_id, bot_id, message_thread_id = resolve_context(self.message, self.bot.bot_id)
+ return self.bot.add_data(
+ chat_id=chat_id,
+ user_id=user_id,
+ business_connection_id=business_connection_id,
+ bot_id=bot_id,
+ message_thread_id=message_thread_id,
+ **kwargs
+ )
+
\ No newline at end of file
diff --git a/telebot/states/sync/middleware.py b/telebot/states/sync/middleware.py
new file mode 100644
index 000000000..4a252158b
--- /dev/null
+++ b/telebot/states/sync/middleware.py
@@ -0,0 +1,21 @@
+from telebot.handler_backends import BaseMiddleware
+from telebot import TeleBot
+from telebot.states.sync.context import StateContext
+from telebot.util import update_types
+from telebot import types
+
+
+class StateMiddleware(BaseMiddleware):
+
+ def __init__(self, bot: TeleBot) -> None:
+ self.update_sensitive = False
+ self.update_types = update_types
+ self.bot: TeleBot = bot
+
+ def pre_process(self, message, data):
+ state_context = StateContext(message, self.bot)
+ data['state_context'] = state_context
+ data['state'] = state_context # 2 ways to access state context
+
+ def post_process(self, message, data, exception):
+ pass
diff --git a/telebot/storage/__init__.py b/telebot/storage/__init__.py
index 59e2b058c..bb3dd7f07 100644
--- a/telebot/storage/__init__.py
+++ b/telebot/storage/__init__.py
@@ -1,13 +1,13 @@
from telebot.storage.memory_storage import StateMemoryStorage
from telebot.storage.redis_storage import StateRedisStorage
from telebot.storage.pickle_storage import StatePickleStorage
-from telebot.storage.base_storage import StateContext,StateStorageBase
-
-
-
+from telebot.storage.base_storage import StateDataContext, StateStorageBase
__all__ = [
- 'StateStorageBase', 'StateContext',
- 'StateMemoryStorage', 'StateRedisStorage', 'StatePickleStorage'
-]
\ No newline at end of file
+ "StateStorageBase",
+ "StateDataContext",
+ "StateMemoryStorage",
+ "StateRedisStorage",
+ "StatePickleStorage",
+]
diff --git a/telebot/storage/base_storage.py b/telebot/storage/base_storage.py
index 92b31ba85..9956c7ab9 100644
--- a/telebot/storage/base_storage.py
+++ b/telebot/storage/base_storage.py
@@ -1,5 +1,6 @@
import copy
+
class StateStorageBase:
def __init__(self) -> None:
pass
@@ -15,30 +16,30 @@ def get_data(self, chat_id, user_id):
Get data for a user in a particular chat.
"""
raise NotImplementedError
-
+
def set_state(self, chat_id, user_id, state):
"""
Set state for a particular user.
- ! Note that you should create a
- record if it does not exist, and
+ ! Note that you should create a
+ record if it does not exist, and
if a record with state already exists,
you need to update a record.
"""
raise NotImplementedError
-
+
def delete_state(self, chat_id, user_id):
"""
Delete state for a particular user.
"""
raise NotImplementedError
-
+
def reset_data(self, chat_id, user_id):
"""
Reset data for a particular user in a chat.
"""
raise NotImplementedError
-
+
def get_state(self, chat_id, user_id):
raise NotImplementedError
@@ -48,21 +49,70 @@ def get_interactive_data(self, chat_id, user_id):
def save(self, chat_id, user_id, data):
raise NotImplementedError
+ def _get_key(
+ self,
+ chat_id: int,
+ user_id: int,
+ prefix: str,
+ separator: str,
+ business_connection_id: str = None,
+ message_thread_id: int = None,
+ bot_id: int = None,
+ ) -> str:
+ """
+ Convert parameters to a key.
+ """
+ params = [prefix]
+ if bot_id:
+ params.append(str(bot_id))
+ if business_connection_id:
+ params.append(business_connection_id)
+ if message_thread_id:
+ params.append(str(message_thread_id))
+ params.append(str(chat_id))
+ params.append(str(user_id))
+
+ return separator.join(params)
-class StateContext:
+class StateDataContext:
"""
Class for data.
"""
- def __init__(self , obj, chat_id, user_id) -> None:
+
+ def __init__(
+ self,
+ obj,
+ chat_id,
+ user_id,
+ business_connection_id=None,
+ message_thread_id=None,
+ bot_id=None,
+ ):
self.obj = obj
- self.data = copy.deepcopy(obj.get_data(chat_id, user_id))
+ res = obj.get_data(
+ chat_id=chat_id,
+ user_id=user_id,
+ business_connection_id=business_connection_id,
+ message_thread_id=message_thread_id,
+ bot_id=bot_id,
+ )
+ self.data = copy.deepcopy(res)
self.chat_id = chat_id
self.user_id = user_id
-
+ self.bot_id = bot_id
+ self.business_connection_id = business_connection_id
+ self.message_thread_id = message_thread_id
def __enter__(self):
return self.data
def __exit__(self, exc_type, exc_val, exc_tb):
- return self.obj.save(self.chat_id, self.user_id, self.data)
\ No newline at end of file
+ return self.obj.save(
+ self.chat_id,
+ self.user_id,
+ self.data,
+ self.business_connection_id,
+ self.message_thread_id,
+ self.bot_id,
+ )
diff --git a/telebot/storage/memory_storage.py b/telebot/storage/memory_storage.py
index 7d71c7ccd..21266632c 100644
--- a/telebot/storage/memory_storage.py
+++ b/telebot/storage/memory_storage.py
@@ -1,69 +1,225 @@
-from telebot.storage.base_storage import StateStorageBase, StateContext
+from telebot.storage.base_storage import StateStorageBase, StateDataContext
+from typing import Optional, Union
class StateMemoryStorage(StateStorageBase):
- def __init__(self) -> None:
- super().__init__()
- self.data = {}
- #
- # {chat_id: {user_id: {'state': None, 'data': {}}, ...}, ...}
-
-
- def set_state(self, chat_id, user_id, state):
- if hasattr(state, 'name'):
+ """
+ Memory storage for states.
+
+ Stores states in memory as a dictionary.
+
+ .. code-block:: python3
+
+ storage = StateMemoryStorage()
+ bot = TeleBot(token, storage=storage)
+
+ :param separator: Separator for keys, default is ":".
+ :type separator: Optional[str]
+
+ :param prefix: Prefix for keys, default is "telebot".
+ :type prefix: Optional[str]
+ """
+
+ def __init__(
+ self, separator: Optional[str] = ":", prefix: Optional[str] = "telebot"
+ ) -> None:
+ self.separator = separator
+ self.prefix = prefix
+ if not self.prefix:
+ raise ValueError("Prefix cannot be empty")
+
+ self.data = (
+ {}
+ ) # key: telebot:bot_id:business_connection_id:message_thread_id:chat_id:user_id
+
+ def set_state(
+ self,
+ chat_id: int,
+ user_id: int,
+ state: str,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> bool:
+ if hasattr(state, "name"):
state = state.name
- if chat_id in self.data:
- if user_id in self.data[chat_id]:
- self.data[chat_id][user_id]['state'] = state
- return True
- else:
- self.data[chat_id][user_id] = {'state': state, 'data': {}}
- return True
- self.data[chat_id] = {user_id: {'state': state, 'data': {}}}
+
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+
+ if self.data.get(_key) is None:
+ self.data[_key] = {"state": state, "data": {}}
+ else:
+ self.data[_key]["state"] = state
+
return True
-
- def delete_state(self, chat_id, user_id):
- if self.data.get(chat_id):
- if self.data[chat_id].get(user_id):
- del self.data[chat_id][user_id]
- if chat_id == user_id:
- del self.data[chat_id]
-
- return True
-
- return False
-
-
- def get_state(self, chat_id, user_id):
-
- if self.data.get(chat_id):
- if self.data[chat_id].get(user_id):
- return self.data[chat_id][user_id]['state']
-
- return None
- def get_data(self, chat_id, user_id):
- if self.data.get(chat_id):
- if self.data[chat_id].get(user_id):
- return self.data[chat_id][user_id]['data']
-
- return None
-
- def reset_data(self, chat_id, user_id):
- if self.data.get(chat_id):
- if self.data[chat_id].get(user_id):
- self.data[chat_id][user_id]['data'] = {}
- return True
- return False
-
- def set_data(self, chat_id, user_id, key, value):
- if self.data.get(chat_id):
- if self.data[chat_id].get(user_id):
- self.data[chat_id][user_id]['data'][key] = value
- return True
- raise RuntimeError('chat_id {} and user_id {} does not exist'.format(chat_id, user_id))
-
- def get_interactive_data(self, chat_id, user_id):
- return StateContext(self, chat_id, user_id)
-
- def save(self, chat_id, user_id, data):
- self.data[chat_id][user_id]['data'] = data
\ No newline at end of file
+
+ def get_state(
+ self,
+ chat_id: int,
+ user_id: int,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> Union[str, None]:
+
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+
+ if self.data.get(_key) is None:
+ return None
+
+ return self.data[_key]["state"]
+
+ def delete_state(
+ self,
+ chat_id: int,
+ user_id: int,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> bool:
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+
+ if self.data.get(_key) is None:
+ return False
+
+ del self.data[_key]
+ return True
+
+ def set_data(
+ self,
+ chat_id: int,
+ user_id: int,
+ key: str,
+ value: Union[str, int, float, dict],
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> bool:
+
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+
+ if self.data.get(_key) is None:
+ raise RuntimeError(f"StateMemoryStorage: key {_key} does not exist.")
+ self.data[_key]["data"][key] = value
+ return True
+
+ def get_data(
+ self,
+ chat_id: int,
+ user_id: int,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> dict:
+
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+
+ return self.data.get(_key, {}).get("data", {})
+
+ def reset_data(
+ self,
+ chat_id: int,
+ user_id: int,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> bool:
+
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+
+ if self.data.get(_key) is None:
+ return False
+ self.data[_key]["data"] = {}
+ return True
+
+ def get_interactive_data(
+ self,
+ chat_id: int,
+ user_id: int,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> Optional[dict]:
+ return StateDataContext(
+ self,
+ chat_id=chat_id,
+ user_id=user_id,
+ business_connection_id=business_connection_id,
+ message_thread_id=message_thread_id,
+ bot_id=bot_id,
+ )
+
+ def save(
+ self,
+ chat_id: int,
+ user_id: int,
+ data: dict,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> bool:
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+
+ if self.data.get(_key) is None:
+ return False
+ self.data[_key]["data"] = data
+ return True
+
+ def __str__(self) -> str:
+ return f""
diff --git a/telebot/storage/pickle_storage.py b/telebot/storage/pickle_storage.py
index 68c9fbed5..e11a05043 100644
--- a/telebot/storage/pickle_storage.py
+++ b/telebot/storage/pickle_storage.py
@@ -1,34 +1,63 @@
-from telebot.storage.base_storage import StateStorageBase, StateContext
import os
-
import pickle
+import threading
+from typing import Optional, Union, Callable
+from telebot.storage.base_storage import StateStorageBase, StateDataContext
+
+
+def with_lock(func: Callable) -> Callable:
+ def wrapper(self, *args, **kwargs):
+ with self.lock:
+ return func(self, *args, **kwargs)
+
+ return wrapper
class StatePickleStorage(StateStorageBase):
- def __init__(self, file_path="./.state-save/states.pkl") -> None:
- super().__init__()
+ """
+ State storage based on pickle file.
+
+ .. warning::
+
+ This storage is not recommended for production use.
+ Data may be corrupted. If you face a case where states do not work as expected,
+ try to use another storage.
+
+ .. code-block:: python3
+
+ storage = StatePickleStorage()
+ bot = TeleBot(token, storage=storage)
+
+ :param file_path: Path to file where states will be stored.
+ :type file_path: str
+
+ :param prefix: Prefix for keys, default is "telebot".
+ :type prefix: Optional[str]
+
+ :param separator: Separator for keys, default is ":".
+ :type separator: Optional[str]
+ """
+
+ def __init__(
+ self,
+ file_path: str = "./.state-save/states.pkl",
+ prefix="telebot",
+ separator: Optional[str] = ":",
+ ) -> None:
self.file_path = file_path
+ self.prefix = prefix
+ self.separator = separator
+ self.lock = threading.Lock()
+
self.create_dir()
- self.data = self.read()
- def convert_old_to_new(self):
- """
- Use this function to convert old storage to new storage.
- This function is for people who was using pickle storage
- that was in version <=4.3.1.
- """
- # old looks like:
- # {1: {'state': 'start', 'data': {'name': 'John'}}
- # we should update old version pickle to new.
- # new looks like:
- # {1: {2: {'state': 'start', 'data': {'name': 'John'}}}}
- new_data = {}
- for key, value in self.data.items():
- # this returns us id and dict with data and state
- new_data[key] = {key: value} # convert this to new
- # pass it to global data
- self.data = new_data
- self.update_data() # update data in file
+ def _read_from_file(self) -> dict:
+ with open(self.file_path, "rb") as f:
+ return pickle.load(f)
+
+ def _write_to_file(self, data: dict) -> None:
+ with open(self.file_path, "wb") as f:
+ pickle.dump(data, f)
def create_dir(self):
"""
@@ -37,80 +66,198 @@ def create_dir(self):
dirs, filename = os.path.split(self.file_path)
os.makedirs(dirs, exist_ok=True)
if not os.path.isfile(self.file_path):
- with open(self.file_path,'wb') as file:
+ with open(self.file_path, "wb") as file:
pickle.dump({}, file)
- def read(self):
- file = open(self.file_path, 'rb')
- data = pickle.load(file)
- file.close()
- return data
-
- def update_data(self):
- file = open(self.file_path, 'wb+')
- pickle.dump(self.data, file, protocol=pickle.HIGHEST_PROTOCOL)
- file.close()
-
- def set_state(self, chat_id, user_id, state):
- if hasattr(state, 'name'):
- state = state.name
- if chat_id in self.data:
- if user_id in self.data[chat_id]:
- self.data[chat_id][user_id]['state'] = state
- self.update_data()
- return True
- else:
- self.data[chat_id][user_id] = {'state': state, 'data': {}}
- self.update_data()
- return True
- self.data[chat_id] = {user_id: {'state': state, 'data': {}}}
- self.update_data()
+ @with_lock
+ def set_state(
+ self,
+ chat_id: int,
+ user_id: int,
+ state: str,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> bool:
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+ data = self._read_from_file()
+ if _key not in data:
+ data[_key] = {"state": state, "data": {}}
+ else:
+ data[_key]["state"] = state
+ self._write_to_file(data)
return True
-
- def delete_state(self, chat_id, user_id):
- if self.data.get(chat_id):
- if self.data[chat_id].get(user_id):
- del self.data[chat_id][user_id]
- if chat_id == user_id:
- del self.data[chat_id]
- self.update_data()
- return True
+ @with_lock
+ def get_state(
+ self,
+ chat_id: int,
+ user_id: int,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> Union[str, None]:
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+ data = self._read_from_file()
+ return data.get(_key, {}).get("state")
+
+ @with_lock
+ def delete_state(
+ self,
+ chat_id: int,
+ user_id: int,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> bool:
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+ data = self._read_from_file()
+ if _key in data:
+ del data[_key]
+ self._write_to_file(data)
+ return True
return False
-
- def get_state(self, chat_id, user_id):
- if self.data.get(chat_id):
- if self.data[chat_id].get(user_id):
- return self.data[chat_id][user_id]['state']
-
- return None
- def get_data(self, chat_id, user_id):
- if self.data.get(chat_id):
- if self.data[chat_id].get(user_id):
- return self.data[chat_id][user_id]['data']
-
- return None
-
- def reset_data(self, chat_id, user_id):
- if self.data.get(chat_id):
- if self.data[chat_id].get(user_id):
- self.data[chat_id][user_id]['data'] = {}
- self.update_data()
- return True
+ @with_lock
+ def set_data(
+ self,
+ chat_id: int,
+ user_id: int,
+ key: str,
+ value: Union[str, int, float, dict],
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> bool:
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+ data = self._read_from_file()
+ state_data = data.get(_key, {})
+ state_data["data"][key] = value
+
+ if _key not in data:
+ raise RuntimeError(f"PickleStorage: key {_key} does not exist.")
+
+ self._write_to_file(data)
+ return True
+
+ @with_lock
+ def get_data(
+ self,
+ chat_id: int,
+ user_id: int,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> dict:
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+ data = self._read_from_file()
+ return data.get(_key, {}).get("data", {})
+
+ @with_lock
+ def reset_data(
+ self,
+ chat_id: int,
+ user_id: int,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> bool:
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+ data = self._read_from_file()
+ if _key in data:
+ data[_key]["data"] = {}
+ self._write_to_file(data)
+ return True
return False
- def set_data(self, chat_id, user_id, key, value):
- if self.data.get(chat_id):
- if self.data[chat_id].get(user_id):
- self.data[chat_id][user_id]['data'][key] = value
- self.update_data()
- return True
- raise RuntimeError('chat_id {} and user_id {} does not exist'.format(chat_id, user_id))
+ def get_interactive_data(
+ self,
+ chat_id: int,
+ user_id: int,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> Optional[dict]:
+ return StateDataContext(
+ self,
+ chat_id=chat_id,
+ user_id=user_id,
+ business_connection_id=business_connection_id,
+ message_thread_id=message_thread_id,
+ bot_id=bot_id,
+ )
- def get_interactive_data(self, chat_id, user_id):
- return StateContext(self, chat_id, user_id)
+ @with_lock
+ def save(
+ self,
+ chat_id: int,
+ user_id: int,
+ data: dict,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> bool:
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+ data = self._read_from_file()
+ data[_key]["data"] = data
+ self._write_to_file(data)
+ return True
- def save(self, chat_id, user_id, data):
- self.data[chat_id][user_id]['data'] = data
- self.update_data()
+ def __str__(self) -> str:
+ return f"StatePickleStorage({self.file_path}, {self.prefix})"
diff --git a/telebot/storage/redis_storage.py b/telebot/storage/redis_storage.py
index 3fac57c46..c9935ac5e 100644
--- a/telebot/storage/redis_storage.py
+++ b/telebot/storage/redis_storage.py
@@ -1,183 +1,324 @@
-from telebot.storage.base_storage import StateStorageBase, StateContext
import json
+from telebot.storage.base_storage import StateStorageBase, StateDataContext
+from typing import Optional, Union
redis_installed = True
try:
- from redis import Redis, ConnectionPool
-
-except:
+ import redis
+except ImportError:
redis_installed = False
+
class StateRedisStorage(StateStorageBase):
"""
- This class is for Redis storage.
- This will work only for states.
- To use it, just pass this class to:
- TeleBot(storage=StateRedisStorage())
+ State storage based on Redis.
+
+ .. code-block:: python3
+
+ storage = StateRedisStorage(...)
+ bot = TeleBot(token, storage=storage)
+
+ :param host: Redis host, default is "localhost".
+ :type host: str
+
+ :param port: Redis port, default is 6379.
+ :type port: int
+
+ :param db: Redis database, default is 0.
+ :type db: int
+
+ :param password: Redis password, default is None.
+ :type password: Optional[str]
+
+ :param prefix: Prefix for keys, default is "telebot".
+ :type prefix: Optional[str]
+
+ :param redis_url: Redis URL, default is None.
+ :type redis_url: Optional[str]
+
+ :param connection_pool: Redis connection pool, default is None.
+ :type connection_pool: Optional[ConnectionPool]
+
+ :param separator: Separator for keys, default is ":".
+ :type separator: Optional[str]
+
"""
- def __init__(self, host='localhost', port=6379, db=0, password=None, prefix='telebot_', redis_url=None):
- super().__init__()
+
+ def __init__(
+ self,
+ host="localhost",
+ port=6379,
+ db=0,
+ password=None,
+ prefix="telebot",
+ redis_url=None,
+ connection_pool: "redis.ConnectionPool" = None,
+ separator: Optional[str] = ":",
+ ) -> None:
+
+ if not redis_installed:
+ raise ImportError(
+ "Redis is not installed. Please install it via pip install redis"
+ )
+
+ self.separator = separator
+ self.prefix = prefix
+ if not self.prefix:
+ raise ValueError("Prefix cannot be empty")
+
if redis_url:
- self.redis = ConnectionPool.from_url(redis_url)
+ self.redis = redis.Redis.from_url(redis_url)
+ elif connection_pool:
+ self.redis = redis.Redis(connection_pool=connection_pool)
else:
- self.redis = ConnectionPool(host=host, port=port, db=db, password=password)
- #self.con = Redis(connection_pool=self.redis) -> use this when necessary
- #
- # {chat_id: {user_id: {'state': None, 'data': {}}, ...}, ...}
- self.prefix = prefix
- if not redis_installed:
- raise Exception("Redis is not installed. Install it via 'pip install redis'")
-
- def get_record(self, key):
- """
- Function to get record from database.
- It has nothing to do with states.
- Made for backward compatibility
- """
- connection = Redis(connection_pool=self.redis)
- result = connection.get(self.prefix+str(key))
- connection.close()
- if result: return json.loads(result)
- return
+ self.redis = redis.Redis(host=host, port=port, db=db, password=password)
- def set_record(self, key, value):
- """
- Function to set record to database.
- It has nothing to do with states.
- Made for backward compatibility
- """
- connection = Redis(connection_pool=self.redis)
- connection.set(self.prefix+str(key), json.dumps(value))
- connection.close()
+ def set_state(
+ self,
+ chat_id: int,
+ user_id: int,
+ state: str,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> bool:
+ if hasattr(state, "name"):
+ state = state.name
+
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+
+ def set_state_action(pipe):
+ pipe.multi()
+
+ data = pipe.hget(_key, "data")
+ result = pipe.execute()
+ data = result[0]
+ if data is None:
+ # If data is None, set it to an empty dictionary
+ data = {}
+ pipe.hset(_key, "data", json.dumps(data))
+
+ pipe.hset(_key, "state", state)
+
+ self.redis.transaction(set_state_action, _key)
return True
- def delete_record(self, key):
- """
- Function to delete record from database.
- It has nothing to do with states.
- Made for backward compatibility
- """
- connection = Redis(connection_pool=self.redis)
- connection.delete(self.prefix+str(key))
- connection.close()
+ def get_state(
+ self,
+ chat_id: int,
+ user_id: int,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> Union[str, None]:
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+ state_bytes = self.redis.hget(_key, "state")
+ return state_bytes.decode("utf-8") if state_bytes else None
+
+ def delete_state(
+ self,
+ chat_id: int,
+ user_id: int,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> bool:
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+ return self.redis.delete(_key) > 0
+
+ def set_data(
+ self,
+ chat_id: int,
+ user_id: int,
+ key: str,
+ value: Union[str, int, float, dict],
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> bool:
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+
+ def set_data_action(pipe):
+ pipe.multi()
+ data = pipe.hget(_key, "data")
+ data = data.execute()[0]
+ if data is None:
+ raise RuntimeError(f"RedisStorage: key {_key} does not exist.")
+ else:
+ data = json.loads(data)
+ data[key] = value
+ pipe.hset(_key, "data", json.dumps(data))
+
+ self.redis.transaction(set_data_action, _key)
return True
- def set_state(self, chat_id, user_id, state):
- """
- Set state for a particular user in a chat.
- """
- response = self.get_record(chat_id)
- user_id = str(user_id)
- if hasattr(state, 'name'):
- state = state.name
+ def get_data(
+ self,
+ chat_id: int,
+ user_id: int,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> dict:
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
+ data = self.redis.hget(_key, "data")
+ return json.loads(data) if data else {}
+
+ def reset_data(
+ self,
+ chat_id: int,
+ user_id: int,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> bool:
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
- if response:
- if user_id in response:
- response[user_id]['state'] = state
+ def reset_data_action(pipe):
+ pipe.multi()
+ if pipe.exists(_key):
+ pipe.hset(_key, "data", "{}")
else:
- response[user_id] = {'state': state, 'data': {}}
- else:
- response = {user_id: {'state': state, 'data': {}}}
- self.set_record(chat_id, response)
+ return False
+ self.redis.transaction(reset_data_action, _key)
return True
-
- def delete_state(self, chat_id, user_id):
- """
- Delete state for a particular user in a chat.
- """
- response = self.get_record(chat_id)
- user_id = str(user_id)
- if response:
- if user_id in response:
- del response[user_id]
- if user_id == str(chat_id):
- self.delete_record(chat_id)
- return True
- else: self.set_record(chat_id, response)
- return True
- return False
-
-
- def get_value(self, chat_id, user_id, key):
- """
- Get value for a data of a user in a chat.
- """
- response = self.get_record(chat_id)
- user_id = str(user_id)
- if response:
- if user_id in response:
- if key in response[user_id]['data']:
- return response[user_id]['data'][key]
- return None
-
-
- def get_state(self, chat_id, user_id):
- """
- Get state of a user in a chat.
- """
- response = self.get_record(chat_id)
- user_id = str(user_id)
- if response:
- if user_id in response:
- return response[user_id]['state']
- return None
+ def get_interactive_data(
+ self,
+ chat_id: int,
+ user_id: int,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> Optional[dict]:
+ return StateDataContext(
+ self,
+ chat_id=chat_id,
+ user_id=user_id,
+ business_connection_id=business_connection_id,
+ message_thread_id=message_thread_id,
+ bot_id=bot_id,
+ )
+ def save(
+ self,
+ chat_id: int,
+ user_id: int,
+ data: dict,
+ business_connection_id: Optional[str] = None,
+ message_thread_id: Optional[int] = None,
+ bot_id: Optional[int] = None,
+ ) -> bool:
+ _key = self._get_key(
+ chat_id,
+ user_id,
+ self.prefix,
+ self.separator,
+ business_connection_id,
+ message_thread_id,
+ bot_id,
+ )
- def get_data(self, chat_id, user_id):
- """
- Get data of particular user in a particular chat.
- """
- response = self.get_record(chat_id)
- user_id = str(user_id)
- if response:
- if user_id in response:
- return response[user_id]['data']
- return None
+ def save_action(pipe):
+ pipe.multi()
+ if pipe.exists(_key):
+ pipe.hset(_key, "data", json.dumps(data))
+ else:
+ return False
+ self.redis.transaction(save_action, _key)
+ return True
- def reset_data(self, chat_id, user_id):
- """
- Reset data of a user in a chat.
+ def migrate_format(self, bot_id: int, prefix: Optional[str] = "telebot_"):
"""
- response = self.get_record(chat_id)
- user_id = str(user_id)
- if response:
- if user_id in response:
- response[user_id]['data'] = {}
- self.set_record(chat_id, response)
- return True
+ Migrate from old to new format of keys.
+ Run this function once to migrate all redis existing keys to new format.
-
+ Starting from version 4.23.0, the format of keys has been changed:
+ :value
+ - Old format: {prefix}chat_id: {user_id: {'state': None, 'data': {}}, ...}
+ - New format:
+ {prefix}{separator}{bot_id}{separator}{business_connection_id}{separator}{message_thread_id}{separator}{chat_id}{separator}{user_id}: {'state': ..., 'data': {}}
+ This function will help you to migrate from the old format to the new one in order to avoid data loss.
- def set_data(self, chat_id, user_id, key, value):
- """
- Set data without interactive data.
- """
- response = self.get_record(chat_id)
- user_id = str(user_id)
- if response:
- if user_id in response:
- response[user_id]['data'][key] = value
- self.set_record(chat_id, response)
- return True
- return False
-
- def get_interactive_data(self, chat_id, user_id):
- """
- Get Data in interactive way.
- You can use with() with this function.
+ :param bot_id: Bot ID; To get it, call a getMe request and grab the id from the response.
+ :type bot_id: int
+
+ :param prefix: Prefix for keys, default is "telebot_"(old default value)
+ :type prefix: Optional[str]
"""
- return StateContext(self, chat_id, user_id)
-
- def save(self, chat_id, user_id, data):
- response = self.get_record(chat_id)
- user_id = str(user_id)
- if response:
- if user_id in response:
- response[user_id]['data'] = data
- self.set_record(chat_id, response)
- return True
-
+ keys = self.redis.keys(f"{prefix}*")
+
+ for key in keys:
+ old_key = key.decode("utf-8")
+ # old: {prefix}chat_id: {user_id: {'state': None, 'data': {}}, ...}
+ value = self.redis.get(old_key)
+ value = json.loads(value)
+
+ chat_id = old_key[len(prefix) :]
+ user_id = list(value.keys())[0]
+ state = value[user_id]["state"]
+ state_data = value[user_id]["data"]
+
+ # set new format
+ new_key = self._get_key(
+ int(chat_id), int(user_id), self.prefix, self.separator, bot_id=bot_id
+ )
+ self.redis.hset(new_key, "state", state)
+ self.redis.hset(new_key, "data", json.dumps(state_data))
+
+ # delete old key
+ self.redis.delete(old_key)
+
+ def __str__(self) -> str:
+ return f"StateRedisStorage({self.redis})"
diff --git a/telebot/util.py b/telebot/util.py
index 295c0d1aa..0448893e1 100644
--- a/telebot/util.py
+++ b/telebot/util.py
@@ -686,6 +686,26 @@ def validate_web_app_data(token: str, raw_init_data: str):
return hmac.new(secret_key.digest(), data_check_string.encode(), sha256).hexdigest() == init_data_hash
+def validate_token(token) -> bool:
+ if any(char.isspace() for char in token):
+ raise ValueError('Token must not contain spaces')
+
+ if ':' not in token:
+ raise ValueError('Token must contain a colon')
+
+ if len(token.split(':')) != 2:
+ raise ValueError('Token must contain exactly 2 parts separated by a colon')
+
+ return True
+
+def extract_bot_id(token) -> Union[int, None]:
+ try:
+ validate_token(token)
+ except ValueError:
+ return None
+ return int(token.split(':')[0])
+
+
__all__ = (
"content_type_media", "content_type_service", "update_types",
"WorkerThread", "AsyncTask", "CustomRequestResponse",
@@ -696,5 +716,5 @@ def validate_web_app_data(token: str, raw_init_data: str):
"split_string", "smart_split", "escape", "user_link", "quick_markup",
"antiflood", "parse_web_app_data", "validate_web_app_data",
"or_set", "or_clear", "orify", "OrEvent", "per_thread",
- "webhook_google_functions"
+ "webhook_google_functions", "validate_token", "extract_bot_id"
)
diff --git a/tests/test_handler_backends.py b/tests/test_handler_backends.py
index bb541bf8a..b74d6fad0 100644
--- a/tests/test_handler_backends.py
+++ b/tests/test_handler_backends.py
@@ -19,7 +19,7 @@
@pytest.fixture()
def telegram_bot():
- return telebot.TeleBot('', threaded=False)
+ return telebot.TeleBot('1234:test', threaded=False)
@pytest.fixture