Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add owner check for team work #2892

Merged
merged 3 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions agent/component/exesql.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def check(self):
self.check_positive_integer(self.port, "IP Port")
self.check_empty(self.password, "Database password")
self.check_positive_integer(self.top_n, "Number of records")
if self.database == "rag_flow":
if self.host == "ragflow-mysql": raise ValueError("The host is not accessible.")
if self.password == "infini_rag_flow": raise ValueError("The host is not accessible.")


class ExeSQL(ComponentBase, ABC):
Expand Down
45 changes: 44 additions & 1 deletion api/apps/document_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,17 @@ def list_docs():


@manager.route('/infos', methods=['POST'])
@login_required
def docinfos():
req = request.json
doc_ids = req["doc_ids"]
for doc_id in doc_ids:
if not DocumentService.accessible(doc_id, current_user.id):
return get_json_result(
data=False,
retmsg='No authorization.',
retcode=RetCode.AUTHENTICATION_ERROR
)
docs = DocumentService.get_by_ids(doc_ids)
return get_json_result(data=list(docs.dicts()))

Expand Down Expand Up @@ -242,11 +250,17 @@ def thumbnails():
def change_status():
req = request.json
if str(req["status"]) not in ["0", "1"]:
get_json_result(
return get_json_result(
data=False,
retmsg='"Status" must be either 0 or 1!',
retcode=RetCode.ARGUMENT_ERROR)

if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(
data=False,
retmsg='No authorization.',
retcode=RetCode.AUTHENTICATION_ERROR)

try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
Expand Down Expand Up @@ -285,6 +299,15 @@ def rm():
req = request.json
doc_ids = req["doc_id"]
if isinstance(doc_ids, str): doc_ids = [doc_ids]

for doc_id in doc_ids:
if not DocumentService.accessible4deletion(doc_id, current_user.id):
return get_json_result(
data=False,
retmsg='No authorization.',
retcode=RetCode.AUTHENTICATION_ERROR
)

root_folder = FileService.get_root_folder(current_user.id)
pf_id = root_folder["id"]
FileService.init_knowledgebase_docs(pf_id, current_user.id)
Expand Down Expand Up @@ -323,6 +346,13 @@ def rm():
@validate_request("doc_ids", "run")
def run():
req = request.json
for doc_id in req["doc_ids"]:
if not DocumentService.accessible(doc_id, current_user.id):
return get_json_result(
data=False,
retmsg='No authorization.',
retcode=RetCode.AUTHENTICATION_ERROR
)
try:
for id in req["doc_ids"]:
info = {"run": str(req["run"]), "progress": 0}
Expand Down Expand Up @@ -356,6 +386,12 @@ def run():
@validate_request("doc_id", "name")
def rename():
req = request.json
if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(
data=False,
retmsg='No authorization.',
retcode=RetCode.AUTHENTICATION_ERROR
)
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
Expand Down Expand Up @@ -416,6 +452,13 @@ def get(doc_id):
@validate_request("doc_id", "parser_id")
def change_parser():
req = request.json

if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(
data=False,
retmsg='No authorization.',
retcode=RetCode.AUTHENTICATION_ERROR
)
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
Expand Down
23 changes: 16 additions & 7 deletions api/apps/kb_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from elasticsearch_dsl import Q
from flask import request
from flask_login import login_required, current_user

Expand All @@ -23,14 +22,12 @@
from api.db.services.file_service import FileService
from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid, get_format_time
from api.db import StatusEnum, UserTenantRole, FileSource
from api.utils import get_uuid
from api.db import StatusEnum, FileSource
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.db_models import Knowledgebase, File
from api.settings import stat_logger, RetCode
from api.db.db_models import File
from api.settings import RetCode
from api.utils.api_utils import get_json_result
from rag.nlp import search
from rag.utils.es_conn import ELASTICSEARCH


@manager.route('/create', methods=['post'])
Expand Down Expand Up @@ -65,6 +62,12 @@ def create():
def update():
req = request.json
req["name"] = req["name"].strip()
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
return get_json_result(
data=False,
retmsg='No authorization.',
retcode=RetCode.AUTHENTICATION_ERROR
)
try:
if not KnowledgebaseService.query(
created_by=current_user.id, id=req["kb_id"]):
Expand Down Expand Up @@ -139,6 +142,12 @@ def list_kbs():
@validate_request("kb_id")
def rm():
req = request.json
if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id):
return get_json_result(
data=False,
retmsg='No authorization.',
retcode=RetCode.AUTHENTICATION_ERROR
)
try:
kbs = KnowledgebaseService.query(
created_by=current_user.id, id=req["kb_id"])
Expand Down
29 changes: 28 additions & 1 deletion api/db/services/document_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from rag.nlp import search, rag_tokenizer

from api.db import FileType, TaskStatus, ParserType, LLMType
from api.db.db_models import DB, Knowledgebase, Tenant, Task
from api.db.db_models import DB, Knowledgebase, Tenant, Task, UserTenant
from api.db.db_models import Document
from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService
Expand Down Expand Up @@ -263,6 +263,33 @@ def get_tenant_id_by_name(cls, name):
return
return docs[0]["tenant_id"]

@classmethod
@DB.connection_context()
def accessible(cls, doc_id, user_id):
docs = cls.model.select(
cls.model.id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)
).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1)
docs = docs.dicts()
if not docs:
return False
return True

@classmethod
@DB.connection_context()
def accessible4deletion(cls, doc_id, user_id):
docs = cls.model.select(
cls.model.id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)
).where(cls.model.id == doc_id, Knowledgebase.created_by == user_id).paginate(0, 1)
docs = docs.dicts()
if not docs:
return False
return True

@classmethod
@DB.connection_context()
def get_embd_id(cls, doc_id):
Expand Down
24 changes: 23 additions & 1 deletion api/db/services/knowledgebase_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
#
from api.db import StatusEnum, TenantPermission
from api.db.db_models import Knowledgebase, DB, Tenant, User
from api.db.db_models import Knowledgebase, DB, Tenant, User, UserTenant
from api.db.services.common_service import CommonService


Expand Down Expand Up @@ -182,3 +182,25 @@ def get_list(cls, joined_tenant_ids, user_id,
kbs = kbs.paginate(page_number, items_per_page)

return list(kbs.dicts())

@classmethod
@DB.connection_context()
def accessible(cls, kb_id, user_id):
docs = cls.model.select(
cls.model.id).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1)
docs = docs.dicts()
if not docs:
return False
return True

@classmethod
@DB.connection_context()
def accessible4deletion(cls, kb_id, user_id):
docs = cls.model.select(
cls.model.id).where(cls.model.id == kb_id, cls.model.created_by == user_id).paginate(0, 1)
docs = docs.dicts()
if not docs:
return False
return True

6 changes: 3 additions & 3 deletions api/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
#
import os
from datetime import date
from enum import IntEnum, Enum
from api.utils.file_utils import get_project_base_directory
from api.utils.log_utils import LoggerFactory, getLogger
Expand Down Expand Up @@ -143,9 +144,8 @@

SECRET_KEY = get_base_config(
RAG_FLOW_SERVICE_NAME,
{}).get(
"secret_key",
"infiniflow")
{}).get("secret_key", str(date.today()))

TOKEN_EXPIRE_IN = get_base_config(
RAG_FLOW_SERVICE_NAME, {}).get(
"token_expires_in", 3600)
Expand Down