Skip to content

Commit

Permalink
feat: 限速算法的扩展性
Browse files Browse the repository at this point in the history
  • Loading branch information
RockChinQ committed Mar 12, 2024
1 parent b9fa11c commit 13393b6
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 3 deletions.
15 changes: 15 additions & 0 deletions pkg/pipeline/ratelimit/algo.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
from __future__ import annotations
import abc
import typing

from ...core import app


preregistered_algos: list[typing.Type[ReteLimitAlgo]] = []

def algo_class(name: str):

def decorator(cls: typing.Type[ReteLimitAlgo]) -> typing.Type[ReteLimitAlgo]:
cls.name = name
preregistered_algos.append(cls)
return cls

return decorator


class ReteLimitAlgo(metaclass=abc.ABCMeta):

name: str = None

ap: app.Application

def __init__(self, ap: app.Application):
Expand Down
1 change: 1 addition & 0 deletions pkg/pipeline/ratelimit/algos/fixedwin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(self):
self.records = {}


@algo.algo_class("fixwin")
class FixedWindowAlgo(algo.ReteLimitAlgo):

containers_lock: asyncio.Lock
Expand Down
14 changes: 13 additions & 1 deletion pkg/pipeline/ratelimit/ratelimit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,19 @@ class RateLimit(stage.PipelineStage):
algo: algo.ReteLimitAlgo

async def initialize(self):
self.algo = fixedwin.FixedWindowAlgo(self.ap)

algo_name = self.ap.pipeline_cfg.data['rate-limit']['algo']

algo_class = None

for algo_cls in algo.preregistered_algos:
if algo_cls.name == algo_name:
algo_class = algo_cls
break
else:
raise ValueError(f'未知的限速算法: {algo_name}')

self.algo = algo_class(self.ap)
await self.algo.initialize()

async def process(
Expand Down
6 changes: 4 additions & 2 deletions pkg/provider/sysprompt/sysprompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@ async def initialize(self):

mode_name = self.ap.provider_cfg.data['prompt-mode']

loader_class = None

for loader_cls in loader.preregistered_loaders:
if loader_cls.name == mode_name:
loader_cls = loader_cls
loader_class = loader_cls
break
else:
raise ValueError(f'未知的 Prompt 加载器: {mode_name}')

self.loader_inst: loader.PromptLoader = loader_cls(self.ap)
self.loader_inst: loader.PromptLoader = loader_class(self.ap)

await self.loader_inst.initialize()
await self.loader_inst.load()
Expand Down

0 comments on commit 13393b6

Please sign in to comment.