Skip to content

Commit

Permalink
Merge pull request #921 from RockChinQ/feat/authenticating
Browse files Browse the repository at this point in the history
Feat: 用户鉴权
  • Loading branch information
RockChinQ authored Nov 17, 2024
2 parents 036c218 + 20e3edb commit 1a457be
Show file tree
Hide file tree
Showing 23 changed files with 549 additions and 108 deletions.
37 changes: 29 additions & 8 deletions pkg/api/http/controller/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import abc
import typing
import enum
import quart
from quart.typing import RouteCallable

Expand All @@ -23,6 +24,12 @@ def decorator(cls: typing.Type[RouterGroup]) -> typing.Type[RouterGroup]:
return decorator


class AuthType(enum.Enum):
"""认证类型"""
NONE = 'none'
USER_TOKEN = 'user-token'


class RouterGroup(abc.ABC):

name: str
Expand All @@ -41,13 +48,30 @@ def __init__(self, ap: app.Application, quart_app: quart.Quart) -> None:
async def initialize(self) -> None:
pass

def route(self, rule: str, **options: typing.Any) -> typing.Callable[[RouteCallable], RouteCallable]: # decorator
def route(self, rule: str, auth_type: AuthType = AuthType.USER_TOKEN, **options: typing.Any) -> typing.Callable[[RouteCallable], RouteCallable]: # decorator
"""注册一个路由"""
def decorator(f: RouteCallable) -> RouteCallable:
nonlocal rule
rule = self.path + rule

async def handler_error(*args, **kwargs):

if auth_type == AuthType.USER_TOKEN:
# 从Authorization头中获取token
token = quart.request.headers.get('Authorization', '').replace('Bearer ', '')

if not token:
return self.http_status(401, -1, '未提供有效的用户令牌')

try:
user_email = await self.ap.user_service.verify_jwt_token(token)

# 检查f是否接受user_email参数
if 'user_email' in f.__code__.co_varnames:
kwargs['user_email'] = user_email
except Exception as e:
return self.http_status(401, -1, str(e))

try:
return await f(*args, **kwargs)
except Exception as e: # 自动 500
Expand All @@ -61,25 +85,22 @@ async def handler_error(*args, **kwargs):
return f

return decorator

def _cors(self, response: quart.Response) -> quart.Response:
return response

def success(self, data: typing.Any = None) -> quart.Response:
"""返回一个 200 响应"""
return self._cors(quart.jsonify({
return quart.jsonify({
'code': 0,
'msg': 'ok',
'data': data,
}))
})

def fail(self, code: int, msg: str) -> quart.Response:
"""返回一个异常响应"""

return self._cors(quart.jsonify({
return quart.jsonify({
'code': code,
'msg': msg,
}))
})

def http_status(self, status: int, code: int, msg: str) -> quart.Response:
"""返回一个指定状态码的响应"""
Expand Down
2 changes: 1 addition & 1 deletion pkg/api/http/controller/groups/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class SystemRouterGroup(group.RouterGroup):

async def initialize(self) -> None:
@self.route('/info', methods=['GET'])
@self.route('/info', methods=['GET'], auth_type=group.AuthType.NONE)
async def _() -> str:
return self.success(
data={
Expand Down
43 changes: 43 additions & 0 deletions pkg/api/http/controller/groups/user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import quart
import sqlalchemy

from .. import group
from .....persistence.entities import user


@group.group_class('user', '/api/v1/user')
class UserRouterGroup(group.RouterGroup):

async def initialize(self) -> None:
@self.route('/init', methods=['GET', 'POST'], auth_type=group.AuthType.NONE)
async def _() -> str:
if quart.request.method == 'GET':
return self.success(data={
'initialized': await self.ap.user_service.is_initialized()
})

if await self.ap.user_service.is_initialized():
return self.fail(1, '系统已初始化')

json_data = await quart.request.json

user_email = json_data['user']
password = json_data['password']

await self.ap.user_service.create_user(user_email, password)

return self.success()

@self.route('/auth', methods=['POST'], auth_type=group.AuthType.NONE)
async def _() -> str:
json_data = await quart.request.json

token = await self.ap.user_service.authenticate(json_data['user'], json_data['password'])

return self.success(data={
'token': token
})

@self.route('/check-token', methods=['GET'])
async def _() -> str:
return self.success()
2 changes: 1 addition & 1 deletion pkg/api/http/controller/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import quart_cors

from ....core import app, entities as core_entities
from .groups import logs, system, settings, plugins, stats
from .groups import logs, system, settings, plugins, stats, user
from . import group


Expand Down
74 changes: 74 additions & 0 deletions pkg/api/http/service/user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from __future__ import annotations

import sqlalchemy
import argon2
import jwt
import datetime

from ....core import app
from ....persistence.entities import user
from ....utils import constants


class UserService:

ap: app.Application

def __init__(self, ap: app.Application) -> None:
self.ap = ap

async def is_initialized(self) -> bool:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(user.User).limit(1)
)

result_list = result.all()
return result_list is not None and len(result_list) > 0

async def create_user(self, user_email: str, password: str) -> None:
ph = argon2.PasswordHasher()

hashed_password = ph.hash(password)

await self.ap.persistence_mgr.execute_async(
sqlalchemy.insert(user.User).values(
user=user_email,
password=hashed_password
)
)

async def authenticate(self, user_email: str, password: str) -> str | None:
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(user.User).where(user.User.user == user_email)
)

result_list = result.all()

if result_list is None or len(result_list) == 0:
raise ValueError('用户不存在')

user_obj = result_list[0]

ph = argon2.PasswordHasher()

if not ph.verify(user_obj.password, password):
raise ValueError('密码错误')

return await self.generate_jwt_token(user_email)

async def generate_jwt_token(self, user_email: str) -> str:
jwt_secret = self.ap.instance_secret_meta.data['jwt_secret']
jwt_expire = self.ap.system_cfg.data['http-api']['jwt-expire']

payload = {
'user': user_email,
'iss': 'LangBot-'+constants.edition,
'exp': datetime.datetime.now() + datetime.timedelta(seconds=jwt_expire)
}

return jwt.encode(payload, jwt_secret, algorithm='HS256')

async def verify_jwt_token(self, token: str) -> str:
jwt_secret = self.ap.instance_secret_meta.data['jwt_secret']

return jwt.decode(token, jwt_secret, algorithms=['HS256'])['user']
7 changes: 7 additions & 0 deletions pkg/core/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr
from ..persistence import mgr as persistencemgr
from ..api.http.controller import main as http_controller
from ..api.http.service import user as user_service
from ..utils import logcache, ip
from . import taskmgr
from . import entities as core_entities
Expand Down Expand Up @@ -74,6 +75,8 @@ class Application:

llm_models_meta: config_mgr.ConfigManager = None

instance_secret_meta: config_mgr.ConfigManager = None

# =========================

ctr_mgr: center_mgr.V2CenterAPI = None
Expand All @@ -100,6 +103,10 @@ class Application:

log_cache: logcache.LogCache = None

# ========= HTTP Services =========

user_service: user_service.UserService = None

def __init__(self):
pass

Expand Down
2 changes: 2 additions & 0 deletions pkg/core/bootutils/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
"aiosqlite": "aiosqlite",
"aiofiles": "aiofiles",
"aioshutil": "aioshutil",
"argon2": "argon2-cffi",
"jwt": "pyjwt",
}


Expand Down
3 changes: 2 additions & 1 deletion pkg/core/migrations/m013_http_api_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ async def run(self):
self.ap.system_cfg.data['http-api'] = {
"enable": True,
"host": "0.0.0.0",
"port": 5300
"port": 5300,
"jwt-expire": 604800
}

self.ap.system_cfg.data['persistence'] = {
Expand Down
4 changes: 4 additions & 0 deletions pkg/core/stages/build_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ...platform import manager as im_mgr
from ...persistence import mgr as persistencemgr
from ...api.http.controller import main as http_controller
from ...api.http.service import user as user_service
from ...utils import logcache
from .. import taskmgr

Expand Down Expand Up @@ -112,5 +113,8 @@ async def run(self, ap: app.Application):
await http_ctrl.initialize()
ap.http_ctrl = http_ctrl

user_service_inst = user_service.UserService(ap)
ap.user_service = user_service_inst

ctrl = controller.Controller(ap)
ap.ctrl = ctrl
7 changes: 7 additions & 0 deletions pkg/core/stages/load_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import secrets

from .. import stage, app
from ..bootutils import config
from ...config import settings as settings_mgr
Expand Down Expand Up @@ -75,3 +77,8 @@ async def run(self, ap: app.Application):

ap.llm_models_meta = await config.load_json_config("data/metadata/llm-models.json", "templates/metadata/llm-models.json")
await ap.llm_models_meta.dump_config()

ap.instance_secret_meta = await config.load_json_config("data/metadata/instance-secret.json", template_data={
'jwt_secret': secrets.token_hex(16)
})
await ap.instance_secret_meta.dump_config()
Empty file.
5 changes: 5 additions & 0 deletions pkg/persistence/entities/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import sqlalchemy.orm


class Base(sqlalchemy.orm.DeclarativeBase):
pass
11 changes: 11 additions & 0 deletions pkg/persistence/entities/user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import sqlalchemy

from .base import Base


class User(Base):
__tablename__ = 'users'

id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True)
user = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
password = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
8 changes: 5 additions & 3 deletions pkg/persistence/mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sqlalchemy

from . import database
from .entities import user, base
from ..core import app
from .databases import sqlite

Expand All @@ -23,7 +24,7 @@ class PersistenceManager:

def __init__(self, ap: app.Application):
self.ap = ap
self.meta = sqlalchemy.MetaData()
self.meta = base.Base.metadata

async def initialize(self):

Expand All @@ -46,10 +47,11 @@ async def execute_async(
self,
*args,
**kwargs
):
) -> sqlalchemy.engine.cursor.CursorResult:
async with self.get_db_engine().connect() as conn:
await conn.execute(*args, **kwargs)
result = await conn.execute(*args, **kwargs)
await conn.commit()
return result

def get_db_engine(self) -> sqlalchemy_asyncio.AsyncEngine:
return self.db.get_engine()
4 changes: 3 additions & 1 deletion pkg/utils/constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
semantic_version = "v3.3.1.1"

debug_mode = False
debug_mode = False

edition = 'community'
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ sqlalchemy[asyncio]
aiosqlite
quart-cors
aiofiles
aioshutil
aioshutil
argon2-cffi
pyjwt
5 changes: 5 additions & 0 deletions templates/schema/system.json
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@
},
"port": {
"type": "integer"
},
"jwt-expire": {
"type": "integer",
"title": "JWT 过期时间",
"description": "单位:秒"
}
}
},
Expand Down
3 changes: 2 additions & 1 deletion templates/system.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
"http-api": {
"enable": true,
"host": "0.0.0.0",
"port": 5300
"port": 5300,
"jwt-expire": 604800
},
"persistence": {
"sqlite": {
Expand Down
Loading

0 comments on commit 1a457be

Please sign in to comment.