diff --git a/pkg/platform/ratelim/__init__.py b/pkg/pipeline/ratelimit/__init__.py similarity index 100% rename from pkg/platform/ratelim/__init__.py rename to pkg/pipeline/ratelimit/__init__.py diff --git a/pkg/platform/ratelim/algo.py b/pkg/pipeline/ratelimit/algo.py similarity index 100% rename from pkg/platform/ratelim/algo.py rename to pkg/pipeline/ratelimit/algo.py diff --git a/pkg/platform/ratelim/algos/__init__.py b/pkg/pipeline/ratelimit/algos/__init__.py similarity index 100% rename from pkg/platform/ratelim/algos/__init__.py rename to pkg/pipeline/ratelimit/algos/__init__.py diff --git a/pkg/platform/ratelim/algos/fixedwin.py b/pkg/pipeline/ratelimit/algos/fixedwin.py similarity index 100% rename from pkg/platform/ratelim/algos/fixedwin.py rename to pkg/pipeline/ratelimit/algos/fixedwin.py diff --git a/pkg/pipeline/ratelimit/ratelimit.py b/pkg/pipeline/ratelimit/ratelimit.py new file mode 100644 index 00000000..2e3c6706 --- /dev/null +++ b/pkg/pipeline/ratelimit/ratelimit.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +import typing + +from .. import entities, stagemgr, stage +from . import algo +from .algos import fixedwin +from ...core import entities as core_entities + + +@stage.stage_class("RequireRateLimitOccupancy") +@stage.stage_class("ReleaseRateLimitOccupancy") +class RateLimit(stage.PipelineStage): + + algo: algo.ReteLimitAlgo + + async def initialize(self): + self.algo = fixedwin.FixedWindowAlgo(self.ap) + await self.algo.initialize() + + async def process( + self, + query: core_entities.Query, + stage_inst_name: str, + ) -> typing.Union[ + entities.StageProcessResult, + typing.AsyncGenerator[entities.StageProcessResult, None], + ]: + """处理 + """ + if stage_inst_name == "RequireRateLimitOccupancy": + if await self.algo.require_access( + query.launcher_type.value, + query.launcher_id, + ): + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query, + ) + else: + return entities.StageProcessResult( + result_type=entities.ResultType.INTERRUPT, + new_query=query, + console_notice=f"根据限速规则忽略 {query.launcher_type.value}:{query.launcher_id} 消息", + user_notice=self.ap.tips_mgr.data['rate_limit_drop_tip'] + ) + elif stage_inst_name == "ReleaseRateLimitOccupancy": + await self.algo.release_access( + query.launcher_type, + query.launcher_id, + ) + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query, + ) diff --git a/pkg/pipeline/stagemgr.py b/pkg/pipeline/stagemgr.py index b3faaf9f..c855d816 100644 --- a/pkg/pipeline/stagemgr.py +++ b/pkg/pipeline/stagemgr.py @@ -12,6 +12,7 @@ from .respback import respback from .wrapper import wrapper from .preproc import preproc +from .ratelimit import ratelimit stage_order = [ @@ -19,7 +20,9 @@ "BanSessionCheckStage", "PreContentFilterStage", "PreProcessor", + "RequireRateLimitOccupancy", "MessageProcessor", + "ReleaseRateLimitOccupancy", "PostContentFilterStage", "ResponseWrapper", "LongTextProcessStage", diff --git a/pkg/platform/manager.py b/pkg/platform/manager.py index 384432f9..277d8ff5 100644 --- a/pkg/platform/manager.py +++ b/pkg/platform/manager.py @@ -9,13 +9,7 @@ from mirai import At, GroupMessage, MessageEvent, StrangerMessage, \ FriendMessage, Image, MessageChain, Plain import mirai -import func_timeout - -from ..provider import session as openai_session - -import tips as tips_custom from ..platform import adapter as msadapter -from .ratelim import ratelim from ..core import app, entities as core_entities from ..plugin import events @@ -31,15 +25,11 @@ class PlatformManager: # modern ap: app.Application = None - ratelimiter: ratelim.RateLimiter = None - def __init__(self, ap: app.Application = None): self.ap = ap - self.ratelimiter = ratelim.RateLimiter(ap) async def initialize(self): - await self.ratelimiter.initialize() config = self.ap.cfg_mgr.data diff --git a/pkg/platform/ratelim/ratelim.py b/pkg/platform/ratelim/ratelim.py deleted file mode 100644 index 68fe0316..00000000 --- a/pkg/platform/ratelim/ratelim.py +++ /dev/null @@ -1,31 +0,0 @@ -from __future__ import annotations - -from . import algo -from .algos import fixedwin -from ...core import app - - -class RateLimiter: - """限速器 - """ - - ap: app.Application - - algo: algo.ReteLimitAlgo - - def __init__(self, ap: app.Application): - self.ap = ap - - async def initialize(self): - self.algo = fixedwin.FixedWindowAlgo(self.ap) - await self.algo.initialize() - - async def require(self, launcher_type: str, launcher_id: int) -> bool: - """请求访问 - """ - return await self.algo.require_access(launcher_type, launcher_id) - - async def release(self, launcher_type: str, launcher_id: int): - """释放访问 - """ - return await self.algo.release_access(launcher_type, launcher_id) \ No newline at end of file