Skip to content

Commit

Permalink
bot, rules: fix rate limiting rules without a rate limit
Browse files Browse the repository at this point in the history
Co-authored-by: dgw <[email protected]>
  • Loading branch information
Exirel and dgw committed Oct 18, 2024
1 parent 0e5c976 commit 94e2c92
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 77 deletions.
18 changes: 7 additions & 11 deletions sopel/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,30 +602,26 @@ def rate_limit_info(
if trigger.admin or rule.is_unblockable():
return False, None

nick = trigger.nick
is_channel = trigger.sender and not trigger.sender.is_nick()
channel = trigger.sender if is_channel else None

at_time = trigger.time

user_metrics = rule.get_user_metrics(trigger.nick)
channel_metrics = rule.get_channel_metrics(channel)
global_metrics = rule.get_global_metrics()

if user_metrics.is_limited(at_time - rule.user_rate_limit):
if rule.is_user_rate_limited(nick, at_time):
template = rule.user_rate_template
rate_limit_type = "user"
rate_limit = rule.user_rate_limit
metrics = user_metrics
elif is_channel and channel_metrics.is_limited(at_time - rule.channel_rate_limit):
metrics = rule.get_user_metrics(nick)
elif channel and rule.is_channel_rate_limited(channel, at_time):
template = rule.channel_rate_template
rate_limit_type = "channel"
rate_limit = rule.channel_rate_limit
metrics = channel_metrics
elif global_metrics.is_limited(at_time - rule.global_rate_limit):
metrics = rule.get_channel_metrics(channel)
elif rule.is_global_rate_limited(at_time):
template = rule.global_rate_template
rate_limit_type = "global"
rate_limit = rule.global_rate_limit
metrics = global_metrics
metrics = rule.get_global_metrics()
else:
return False, None

Expand Down
57 changes: 33 additions & 24 deletions sopel/plugins/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,40 +765,49 @@ def global_rate_limit(self) -> datetime.timedelta:
def is_user_rate_limited(
self,
nick: Identifier,
at_time: Optional[datetime.datetime] = None,
at_time: datetime.datetime,
) -> bool:
"""Tell when the rule reached the ``nick``'s rate limit.
:param nick: the nick associated with this check
:param at_time: optional aware datetime for the rate limit check;
if not given, ``utcnow`` will be used
:param at_time: aware datetime for the rate limit check
:return: ``True`` when the rule reached the limit, ``False`` otherwise.
.. versionchanged:: 8.0.1
Parameter ``at_time`` is now required.
"""

@abc.abstractmethod
def is_channel_rate_limited(
self,
channel: Identifier,
at_time: Optional[datetime.datetime] = None,
at_time: datetime.datetime,
) -> bool:
"""Tell when the rule reached the ``channel``'s rate limit.
:param channel: the channel associated with this check
:param at_time: optional aware datetime for the rate limit check;
if not given, ``utcnow`` will be used
:param at_time: aware datetime for the rate limit check
:return: ``True`` when the rule reached the limit, ``False`` otherwise.
.. versionchanged:: 8.0.1
Parameter ``at_time`` is now required.
"""

@abc.abstractmethod
def is_global_rate_limited(
self,
at_time: Optional[datetime.datetime] = None,
) -> bool:
def is_global_rate_limited(self, at_time: datetime.datetime) -> bool:
"""Tell when the rule reached the global rate limit.
:param at_time: optional aware datetime for the rate limit check;
if not given, ``utcnow`` will be used
:param at_time: aware datetime for the rate limit check
:return: ``True`` when the rule reached the limit, ``False`` otherwise.
.. versionchanged:: 8.0.1
Parameter ``at_time`` is now required.
"""

@property
Expand Down Expand Up @@ -1209,29 +1218,29 @@ def global_rate_limit(self) -> datetime.timedelta:
def is_user_rate_limited(
self,
nick: Identifier,
at_time: Optional[datetime.datetime] = None,
at_time: datetime.datetime,
) -> bool:
if at_time is None:
at_time = datetime.datetime.now(datetime.timezone.utc)
if self._user_rate_limit <= 0:
return False

metrics = self.get_user_metrics(nick)
return metrics.is_limited(at_time - self.user_rate_limit)

def is_channel_rate_limited(
self,
channel: Identifier,
at_time: Optional[datetime.datetime] = None,
at_time: datetime.datetime,
) -> bool:
if at_time is None:
at_time = datetime.datetime.now(datetime.timezone.utc)
if self._channel_rate_limit <= 0:
return False

metrics = self.get_channel_metrics(channel)
return metrics.is_limited(at_time - self.channel_rate_limit)

def is_global_rate_limited(
self,
at_time: Optional[datetime.datetime] = None,
) -> bool:
if at_time is None:
at_time = datetime.datetime.now(datetime.timezone.utc)
def is_global_rate_limited(self, at_time: datetime.datetime) -> bool:
if self._global_rate_limit <= 0:
return False

metrics = self.get_global_metrics()
return metrics.is_limited(at_time - self.global_rate_limit)

Expand Down
42 changes: 24 additions & 18 deletions test/plugins/test_plugins_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1566,14 +1566,16 @@ def handler(bot, trigger):
global_rate_limit=20,
channel_rate_limit=20,
)
assert rule.is_user_rate_limited(mocktrigger.nick) is False
assert rule.is_channel_rate_limited(mocktrigger.sender) is False
assert rule.is_global_rate_limited() is False
at_time = datetime.datetime.now(datetime.timezone.utc)
assert rule.is_user_rate_limited(mocktrigger.nick, at_time) is False
assert rule.is_channel_rate_limited(mocktrigger.sender, at_time) is False
assert rule.is_global_rate_limited(at_time) is False

rule.execute(mockbot, mocktrigger)
assert rule.is_user_rate_limited(mocktrigger.nick) is True
assert rule.is_channel_rate_limited(mocktrigger.sender) is True
assert rule.is_global_rate_limited() is True
at_time = datetime.datetime.now(datetime.timezone.utc)
assert rule.is_user_rate_limited(mocktrigger.nick, at_time) is True
assert rule.is_channel_rate_limited(mocktrigger.sender, at_time) is True
assert rule.is_global_rate_limited(at_time) is True


def test_rule_rate_limit_no_limit(mockbot, triggerfactory):
Expand All @@ -1592,14 +1594,16 @@ def handler(bot, trigger):
global_rate_limit=0,
channel_rate_limit=0,
)
assert rule.is_user_rate_limited(mocktrigger.nick) is False
assert rule.is_channel_rate_limited(mocktrigger.sender) is False
assert rule.is_global_rate_limited() is False
at_time = datetime.datetime.now(datetime.timezone.utc)
assert rule.is_user_rate_limited(mocktrigger.nick, at_time) is False
assert rule.is_channel_rate_limited(mocktrigger.sender, at_time) is False
assert rule.is_global_rate_limited(at_time) is False

rule.execute(mockbot, mocktrigger)
assert rule.is_user_rate_limited(mocktrigger.nick) is False
assert rule.is_channel_rate_limited(mocktrigger.sender) is False
assert rule.is_global_rate_limited() is False
at_time = datetime.datetime.now(datetime.timezone.utc)
assert rule.is_user_rate_limited(mocktrigger.nick, at_time) is False
assert rule.is_channel_rate_limited(mocktrigger.sender, at_time) is False
assert rule.is_global_rate_limited(at_time) is False


def test_rule_rate_limit_ignore_rate_limit(mockbot, triggerfactory):
Expand All @@ -1619,14 +1623,16 @@ def handler(bot, trigger):
channel_rate_limit=20,
threaded=False, # make sure there is no race-condition here
)
assert rule.is_user_rate_limited(mocktrigger.nick) is False
assert rule.is_channel_rate_limited(mocktrigger.sender) is False
assert rule.is_global_rate_limited() is False
at_time = datetime.datetime.now(datetime.timezone.utc)
assert rule.is_user_rate_limited(mocktrigger.nick, at_time) is False
assert rule.is_channel_rate_limited(mocktrigger.sender, at_time) is False
assert rule.is_global_rate_limited(at_time) is False

rule.execute(mockbot, mocktrigger)
assert rule.is_user_rate_limited(mocktrigger.nick) is False
assert rule.is_channel_rate_limited(mocktrigger.sender) is False
assert rule.is_global_rate_limited() is False
at_time = datetime.datetime.now(datetime.timezone.utc)
assert rule.is_user_rate_limited(mocktrigger.nick, at_time) is False
assert rule.is_channel_rate_limited(mocktrigger.sender, at_time) is False
assert rule.is_global_rate_limited(at_time) is False


def test_rule_rate_limit_messages(mockbot, triggerfactory):
Expand Down
104 changes: 80 additions & 24 deletions test/test_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

if typing.TYPE_CHECKING:
from sopel.config import Config
from sopel.tests.factories import BotFactory, IRCFactory, UserFactory
from sopel.tests.factories import (
BotFactory, ConfigFactory, IRCFactory, TriggerFactory, UserFactory,
)
from sopel.tests.mocks import MockIRCServer


Expand Down Expand Up @@ -81,17 +83,17 @@ def ignored():


@pytest.fixture
def tmpconfig(configfactory):
def tmpconfig(configfactory: ConfigFactory) -> Config:
return configfactory('test.cfg', TMP_CONFIG)


@pytest.fixture
def mockbot(tmpconfig, botfactory):
def mockbot(tmpconfig: Config, botfactory: BotFactory) -> bot.Sopel:
return botfactory(tmpconfig)


@pytest.fixture
def mockplugin(tmpdir):
def mockplugin(tmpdir) -> plugins.handlers.PyFilePlugin:
root = tmpdir.mkdir('loader_mods')
mod_file = root.join('mockplugin.py')
mod_file.write(MOCK_MODULE_CONTENT)
Expand Down Expand Up @@ -676,7 +678,7 @@ def url_callback_http(bot, trigger, match):
# call_rule

@pytest.fixture
def match_hello_rule(mockbot, triggerfactory):
def match_hello_rule(mockbot: bot.Sopel, triggerfactory: TriggerFactory):
"""Helper for generating matches to each `Rule` in the following tests"""
def _factory(rule_hello):
# trigger
Expand All @@ -694,7 +696,25 @@ def _factory(rule_hello):
return _factory


def test_call_rule(mockbot, match_hello_rule):
@pytest.fixture
def multimatch_hello_rule(mockbot: bot.Sopel, triggerfactory: TriggerFactory):
def _factory(rule_hello):
# trigger
line = ':[email protected] PRIVMSG #channel :hello hello hello'

trigger = triggerfactory(mockbot, line)
pretrigger = trigger._pretrigger

for match in rule_hello.match(mockbot, pretrigger):
wrapper = bot.SopelWrapper(mockbot, trigger)
yield match, trigger, wrapper
return _factory


def test_call_rule(
mockbot: bot.Sopel,
match_hello_rule: typing.Callable,
) -> None:
# setup
items = []

Expand All @@ -721,9 +741,10 @@ def testrule(bot, trigger):
assert items == [1]

# assert the rule is not rate limited
assert not rule_hello.is_user_rate_limited(Identifier('Test'))
assert not rule_hello.is_channel_rate_limited('#channel')
assert not rule_hello.is_global_rate_limited()
at_time = datetime.now(timezone.utc)
assert not rule_hello.is_user_rate_limited(Identifier('Test'), at_time)
assert not rule_hello.is_channel_rate_limited('#channel', at_time)
assert not rule_hello.is_global_rate_limited(at_time)

match, rule_trigger, wrapper = match_hello_rule(rule_hello)

Expand All @@ -738,6 +759,36 @@ def testrule(bot, trigger):
assert items == [1, 1]


def test_call_rule_multiple_matches(
mockbot: bot.Sopel,
multimatch_hello_rule: typing.Callable,
) -> None:
# setup
items = []

def testrule(bot, trigger):
bot.say('hi')
items.append(1)
return "Return Value"

find_hello = rules.FindRule(
[re.compile(r'(hi|hello|hey|sup)')],
plugin='testplugin',
label='testrule',
handler=testrule)

for match, rule_trigger, wrapper in multimatch_hello_rule(find_hello):
mockbot.call_rule(find_hello, wrapper, rule_trigger)

# assert the rule has been executed three times now
assert mockbot.backend.message_sent == rawlist(
'PRIVMSG #channel :hi',
'PRIVMSG #channel :hi',
'PRIVMSG #channel :hi',
)
assert items == [1, 1, 1]


def test_call_rule_rate_limited_user(mockbot, match_hello_rule):
items = []

Expand Down Expand Up @@ -767,9 +818,10 @@ def testrule(bot, trigger):
assert items == [1]

# assert the rule is now rate limited
assert rule_hello.is_user_rate_limited(Identifier('Test'))
assert not rule_hello.is_channel_rate_limited('#channel')
assert not rule_hello.is_global_rate_limited()
at_time = datetime.now(timezone.utc)
assert rule_hello.is_user_rate_limited(Identifier('Test'), at_time)
assert not rule_hello.is_channel_rate_limited('#channel', at_time)
assert not rule_hello.is_global_rate_limited(at_time)

match, rule_trigger, wrapper = match_hello_rule(rule_hello)

Expand Down Expand Up @@ -852,9 +904,10 @@ def testrule(bot, trigger):
assert items == [1]

# assert the rule is now rate limited
assert not rule_hello.is_user_rate_limited(Identifier('Test'))
assert rule_hello.is_channel_rate_limited('#channel')
assert not rule_hello.is_global_rate_limited()
at_time = datetime.now(timezone.utc)
assert not rule_hello.is_user_rate_limited(Identifier('Test'), at_time)
assert rule_hello.is_channel_rate_limited('#channel', at_time)
assert not rule_hello.is_global_rate_limited(at_time)

match, rule_trigger, wrapper = match_hello_rule(rule_hello)

Expand Down Expand Up @@ -897,9 +950,10 @@ def testrule(bot, trigger):
assert items == [1]

# assert the rule is now rate limited
assert not rule_hello.is_user_rate_limited(Identifier('Test'))
assert rule_hello.is_channel_rate_limited('#channel')
assert not rule_hello.is_global_rate_limited()
at_time = datetime.now(timezone.utc)
assert not rule_hello.is_user_rate_limited(Identifier('Test'), at_time)
assert rule_hello.is_channel_rate_limited('#channel', at_time)
assert not rule_hello.is_global_rate_limited(at_time)

match, rule_trigger, wrapper = match_hello_rule(rule_hello)

Expand Down Expand Up @@ -942,9 +996,10 @@ def testrule(bot, trigger):
assert items == [1]

# assert the rule is now rate limited
assert not rule_hello.is_user_rate_limited(Identifier('Test'))
assert not rule_hello.is_channel_rate_limited('#channel')
assert rule_hello.is_global_rate_limited()
at_time = datetime.now(timezone.utc)
assert not rule_hello.is_user_rate_limited(Identifier('Test'), at_time)
assert not rule_hello.is_channel_rate_limited('#channel', at_time)
assert rule_hello.is_global_rate_limited(at_time)

match, rule_trigger, wrapper = match_hello_rule(rule_hello)

Expand Down Expand Up @@ -987,9 +1042,10 @@ def testrule(bot, trigger):
assert items == [1]

# assert the rule is now rate limited
assert not rule_hello.is_user_rate_limited(Identifier('Test'))
assert not rule_hello.is_channel_rate_limited('#channel')
assert rule_hello.is_global_rate_limited()
at_time = datetime.now(timezone.utc)
assert not rule_hello.is_user_rate_limited(Identifier('Test'), at_time)
assert not rule_hello.is_channel_rate_limited('#channel', at_time)
assert rule_hello.is_global_rate_limited(at_time)

match, rule_trigger, wrapper = match_hello_rule(rule_hello)

Expand Down

0 comments on commit 94e2c92

Please sign in to comment.