From 13393b66242519ab385075264d923ee98bb3df49 Mon Sep 17 00:00:00 2001 From: Junyan Qin <1010553892@qq.com> Date: Tue, 12 Mar 2024 16:31:54 +0000 Subject: [PATCH] =?UTF-8?q?feat:=20=E9=99=90=E9=80=9F=E7=AE=97=E6=B3=95?= =?UTF-8?q?=E7=9A=84=E6=89=A9=E5=B1=95=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/pipeline/ratelimit/algo.py | 15 +++++++++++++++ pkg/pipeline/ratelimit/algos/fixedwin.py | 1 + pkg/pipeline/ratelimit/ratelimit.py | 14 +++++++++++++- pkg/provider/sysprompt/sysprompt.py | 6 ++++-- 4 files changed, 33 insertions(+), 3 deletions(-) diff --git a/pkg/pipeline/ratelimit/algo.py b/pkg/pipeline/ratelimit/algo.py index b6d9ba7b..448ae384 100644 --- a/pkg/pipeline/ratelimit/algo.py +++ b/pkg/pipeline/ratelimit/algo.py @@ -1,11 +1,26 @@ from __future__ import annotations import abc +import typing from ...core import app +preregistered_algos: list[typing.Type[ReteLimitAlgo]] = [] + +def algo_class(name: str): + + def decorator(cls: typing.Type[ReteLimitAlgo]) -> typing.Type[ReteLimitAlgo]: + cls.name = name + preregistered_algos.append(cls) + return cls + + return decorator + + class ReteLimitAlgo(metaclass=abc.ABCMeta): + name: str = None + ap: app.Application def __init__(self, ap: app.Application): diff --git a/pkg/pipeline/ratelimit/algos/fixedwin.py b/pkg/pipeline/ratelimit/algos/fixedwin.py index bb69b0dd..aa380291 100644 --- a/pkg/pipeline/ratelimit/algos/fixedwin.py +++ b/pkg/pipeline/ratelimit/algos/fixedwin.py @@ -19,6 +19,7 @@ def __init__(self): self.records = {} +@algo.algo_class("fixwin") class FixedWindowAlgo(algo.ReteLimitAlgo): containers_lock: asyncio.Lock diff --git a/pkg/pipeline/ratelimit/ratelimit.py b/pkg/pipeline/ratelimit/ratelimit.py index a9e29799..f43c8b06 100644 --- a/pkg/pipeline/ratelimit/ratelimit.py +++ b/pkg/pipeline/ratelimit/ratelimit.py @@ -16,7 +16,19 @@ class RateLimit(stage.PipelineStage): algo: algo.ReteLimitAlgo async def initialize(self): - self.algo = fixedwin.FixedWindowAlgo(self.ap) + + algo_name = self.ap.pipeline_cfg.data['rate-limit']['algo'] + + algo_class = None + + for algo_cls in algo.preregistered_algos: + if algo_cls.name == algo_name: + algo_class = algo_cls + break + else: + raise ValueError(f'未知的限速算法: {algo_name}') + + self.algo = algo_class(self.ap) await self.algo.initialize() async def process( diff --git a/pkg/provider/sysprompt/sysprompt.py b/pkg/provider/sysprompt/sysprompt.py index 61c598ed..c7695f5a 100644 --- a/pkg/provider/sysprompt/sysprompt.py +++ b/pkg/provider/sysprompt/sysprompt.py @@ -22,14 +22,16 @@ async def initialize(self): mode_name = self.ap.provider_cfg.data['prompt-mode'] + loader_class = None + for loader_cls in loader.preregistered_loaders: if loader_cls.name == mode_name: - loader_cls = loader_cls + loader_class = loader_cls break else: raise ValueError(f'未知的 Prompt 加载器: {mode_name}') - self.loader_inst: loader.PromptLoader = loader_cls(self.ap) + self.loader_inst: loader.PromptLoader = loader_class(self.ap) await self.loader_inst.initialize() await self.loader_inst.load()