From a19d69da9f55e3acc9d1e46bfe4a57fbe2e0767a Mon Sep 17 00:00:00 2001 From: Zhichang Yu Date: Fri, 6 Sep 2024 16:16:45 +0800 Subject: [PATCH] Passed knowledge base retrieval testing Renamed deleteByQuery with delete Renamed bulk to upsertBulk getHighlight Replaced ELASTICSEARCH with dataStoreConn Fix KGSearch.search Moved Dealer.sql_retrieval to es_conn.py getAggregation --- README.md | 2 +- api/apps/chunk_app.py | 25 +- api/apps/document_app.py | 15 +- api/apps/file2document_app.py | 4 - api/apps/file_app.py | 3 - api/apps/kb_app.py | 6 +- api/apps/sdk/doc.py | 58 +-- api/apps/system_app.py | 5 +- api/db/services/document_service.py | 17 +- api/settings.py | 8 +- docker/.env | 2 +- docker/docker-compose.yml | 1 - docker/entrypoint.sh | 2 +- docker/service_conf.yaml | 6 +- graphrag/search.py | 111 +++-- poetry.lock | 51 ++- pyproject.toml | 1 + rag/benchmark.py | 16 +- rag/nlp/query.py | 109 +++-- rag/nlp/search.py | 253 +++-------- rag/settings.py | 7 +- rag/svr/task_executor.py | 16 +- rag/utils/data_store_conn.py | 244 +++++++++++ rag/utils/es_conn.py | 632 ++++++++++++---------------- rag/utils/infinity_conn.py | 255 +++++++++++ 25 files changed, 1080 insertions(+), 769 deletions(-) create mode 100644 rag/utils/data_store_conn.py create mode 100644 rag/utils/infinity_conn.py diff --git a/README.md b/README.md index 027420a4d91..0acbf21da5a 100644 --- a/README.md +++ b/README.md @@ -254,7 +254,7 @@ docker build -f Dockerfile -t infiniflow/ragflow:dev . git clone https://github.com/infiniflow/ragflow.git cd ragflow/ export POETRY_VIRTUALENVS_CREATE=true POETRY_VIRTUALENVS_IN_PROJECT=true - ~/.local/bin/poetry install --sync --no-root # install RAGFlow dependent python modules + ~/.local/bin/poetry install --sync --no-root --with=full # install RAGFlow dependent python modules ``` 3. Launch the dependent services (MinIO, Elasticsearch, Redis, and MySQL) using Docker Compose: diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index aca9a2ddf71..371d25142df 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -19,11 +19,9 @@ from flask import request from flask_login import login_required, current_user -from elasticsearch_dsl import Q from rag.app.qa import rmPrefix, beAdoc from rag.nlp import search, rag_tokenizer, keyword_extraction -from rag.utils.es_conn import ELASTICSEARCH from rag.utils import rmSpace from api.db import LLMType, ParserType from api.db.services.knowledgebase_service import KnowledgebaseService @@ -31,7 +29,7 @@ from api.db.services.user_service import UserTenantService from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.db.services.document_service import DocumentService -from api.settings import RetCode, retrievaler, kg_retrievaler +from api.settings import RetCode, retrievaler, kg_retrievaler, docStoreConn from api.utils.api_utils import get_json_result import hashlib import re @@ -83,7 +81,7 @@ def list_chunk(): return get_json_result(data=res) except Exception as e: if str(e).find("not_found") > 0: - return get_json_result(data=False, retmsg=f'No chunk found!', + return get_json_result(data=False, retmsg='No chunk found!', retcode=RetCode.DATA_ERROR) return server_error_response(e) @@ -96,7 +94,7 @@ def get(): tenants = UserTenantService.query(user_id=current_user.id) if not tenants: return get_data_error_result(retmsg="Tenant not found!") - res = ELASTICSEARCH.get( + res = docStoreConn.get( chunk_id, search.index_name( tenants[0].tenant_id)) if not res.get("found"): @@ -114,7 +112,7 @@ def get(): return get_json_result(data=res) except Exception as e: if str(e).find("NotFoundError") >= 0: - return get_json_result(data=False, retmsg=f'Chunk not found!', + return get_json_result(data=False, retmsg='Chunk not found!', retcode=RetCode.DATA_ERROR) return server_error_response(e) @@ -162,7 +160,7 @@ def set(): v, c = embd_mdl.encode([doc.name, req["content_with_weight"]]) v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] d["q_%d_vec" % len(v)] = v.tolist() - ELASTICSEARCH.upsert([d], search.index_name(tenant_id)) + docStoreConn.upsert([d], search.index_name(tenant_id)) return get_json_result(data=True) except Exception as e: return server_error_response(e) @@ -177,7 +175,7 @@ def switch(): tenant_id = DocumentService.get_tenant_id(req["doc_id"]) if not tenant_id: return get_data_error_result(retmsg="Tenant not found!") - if not ELASTICSEARCH.upsert([{"id": i, "available_int": int(req["available_int"])} for i in req["chunk_ids"]], + if not docStoreConn.upsert([{"id": i, "available_int": int(req["available_int"])} for i in req["chunk_ids"]], search.index_name(tenant_id)): return get_data_error_result(retmsg="Index updating failure") return get_json_result(data=True) @@ -191,8 +189,7 @@ def switch(): def rm(): req = request.json try: - if not ELASTICSEARCH.deleteByQuery( - Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)): + if not docStoreConn.delete({"_id": req["chunk_ids"]}, search.index_name(current_user.id)): return get_data_error_result(retmsg="Index updating failure") e, doc = DocumentService.get_by_id(req["doc_id"]) if not e: @@ -239,7 +236,7 @@ def create(): v, c = embd_mdl.encode([doc.name, req["content_with_weight"]]) v = 0.1 * v[0] + 0.9 * v[1] d["q_%d_vec" % len(v)] = v.tolist() - ELASTICSEARCH.upsert([d], search.index_name(tenant_id)) + docStoreConn.upsert([d], search.index_name(tenant_id)) DocumentService.increment_chunk_num( doc.id, doc.kb_id, c, 1, 0) @@ -272,7 +269,7 @@ def retrieval_test(): break else: return get_json_result( - data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', + data=False, retmsg='Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR) e, kb = KnowledgebaseService.get_by_id(kb_id[0]) @@ -300,7 +297,7 @@ def retrieval_test(): return get_json_result(data=ranks) except Exception as e: if str(e).find("not_found") > 0: - return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!', + return get_json_result(data=False, retmsg='No chunk found! Check the chunk status please!', retcode=RetCode.DATA_ERROR) return server_error_response(e) @@ -320,7 +317,7 @@ def knowledge_graph(): ty = sres.field[id]["knowledge_graph_kwd"] try: obj[ty] = json.loads(sres.field[id]["content_with_weight"]) - except Exception as e: + except Exception: print(traceback.format_exc(), flush=True) return get_json_result(data=obj) diff --git a/api/apps/document_app.py b/api/apps/document_app.py index 0e4af144786..8cdfcbf4905 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -27,14 +27,13 @@ from api.db.services.task_service import TaskService, queue_tasks from api.db.services.user_service import UserTenantService from rag.nlp import search -from rag.utils.es_conn import ELASTICSEARCH from api.db.services import duplicate_name from api.db.services.knowledgebase_service import KnowledgebaseService from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils import get_uuid from api.db import FileType, TaskStatus, ParserType, FileSource from api.db.services.document_service import DocumentService, doc_upload_and_parse -from api.settings import RetCode +from api.settings import RetCode, docStoreConn from api.utils.api_utils import get_json_result from rag.utils.storage_factory import STORAGE_IMPL from api.utils.file_utils import filename_type, thumbnail @@ -187,7 +186,7 @@ def list_docs(): break else: return get_json_result( - data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', + data=False, retmsg='Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR) keywords = request.args.get("keywords", "") @@ -276,13 +275,13 @@ def change_status(): retmsg="Database error (Document update)!") if str(req["status"]) == "0": - ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]), + docStoreConn.updateScriptByQuery(Q("term", doc_id=req["doc_id"]), scripts="ctx._source.available_int=0;", idxnm=search.index_name( kb.tenant_id) ) else: - ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=req["doc_id"]), + docStoreConn.updateScriptByQuery(Q("term", doc_id=req["doc_id"]), scripts="ctx._source.available_int=1;", idxnm=search.index_name( kb.tenant_id) @@ -365,8 +364,7 @@ def run(): tenant_id = DocumentService.get_tenant_id(id) if not tenant_id: return get_data_error_result(retmsg="Tenant not found!") - ELASTICSEARCH.deleteByQuery( - Q("match", doc_id=id), idxnm=search.index_name(tenant_id)) + docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id)) if str(req["run"]) == TaskStatus.RUNNING.value: TaskService.filter_delete([Task.doc_id == id]) @@ -490,8 +488,7 @@ def change_parser(): tenant_id = DocumentService.get_tenant_id(req["doc_id"]) if not tenant_id: return get_data_error_result(retmsg="Tenant not found!") - ELASTICSEARCH.deleteByQuery( - Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id)) + docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id)) return get_json_result(data=True) except Exception as e: diff --git a/api/apps/file2document_app.py b/api/apps/file2document_app.py index 1e4b2c9ad3f..173e7209fe3 100644 --- a/api/apps/file2document_app.py +++ b/api/apps/file2document_app.py @@ -13,9 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License # -from elasticsearch_dsl import Q -from api.db.db_models import File2Document from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService @@ -28,8 +26,6 @@ from api.db.services.document_service import DocumentService from api.settings import RetCode from api.utils.api_utils import get_json_result -from rag.nlp import search -from rag.utils.es_conn import ELASTICSEARCH @manager.route('/convert', methods=['POST']) diff --git a/api/apps/file_app.py b/api/apps/file_app.py index 4f4b44af98f..c34cff51cf7 100644 --- a/api/apps/file_app.py +++ b/api/apps/file_app.py @@ -18,7 +18,6 @@ import re import flask -from elasticsearch_dsl import Q from flask import request from flask_login import login_required, current_user @@ -32,8 +31,6 @@ from api.settings import RetCode from api.utils.api_utils import get_json_result from api.utils.file_utils import filename_type -from rag.nlp import search -from rag.utils.es_conn import ELASTICSEARCH from rag.utils.storage_factory import STORAGE_IMPL diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index 551e7867ed6..ffa6cbb1db4 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -72,7 +72,7 @@ def update(): if not KnowledgebaseService.query( created_by=current_user.id, id=req["kb_id"]): return get_json_result( - data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR) + data=False, retmsg='Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR) e, kb = KnowledgebaseService.get_by_id(req["kb_id"]) if not e: @@ -110,7 +110,7 @@ def detail(): break else: return get_json_result( - data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', + data=False, retmsg='Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR) kb = KnowledgebaseService.get_detail(kb_id) if not kb: @@ -153,7 +153,7 @@ def rm(): created_by=current_user.id, id=req["kb_id"]) if not kbs: return get_json_result( - data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR) + data=False, retmsg='Only owner of knowledgebase authorized for this operation.', retcode=RetCode.OPERATING_ERROR) for doc in DocumentService.query(kb_id=req["kb_id"]): if not DocumentService.remove_document(doc, kbs[0].tenant_id): diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index 840994b38aa..f45e152056e 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -1,60 +1,36 @@ import pathlib import re import datetime -import json -import traceback -from botocore.docs.method import document_model_driven_method from flask import request -from flask_login import login_required, current_user -from elasticsearch_dsl import Q -from pygments import highlight -from sphinx.addnodes import document from rag.app.qa import rmPrefix, beAdoc from rag.nlp import search, rag_tokenizer, keyword_extraction -from rag.utils.es_conn import ELASTICSEARCH from rag.utils import rmSpace from api.db import LLMType, ParserType from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import TenantLLMService -from api.db.services.user_service import UserTenantService -from api.utils.api_utils import server_error_response, get_error_data_result, validate_request +from api.utils.api_utils import server_error_response, get_error_data_result from api.db.services.document_service import DocumentService -from api.settings import RetCode, retrievaler, kg_retrievaler +from api.settings import RetCode, retrievaler, kg_retrievaler, docStoreConn from api.utils.api_utils import get_result import hashlib -import re -from api.utils.api_utils import get_result, token_required, get_error_data_result +from api.utils.api_utils import token_required from api.db.db_models import Task, File from api.db.services.task_service import TaskService, queue_tasks -from api.db.services.user_service import TenantService, UserTenantService -from api.utils.api_utils import server_error_response, get_error_data_result, validate_request -from api.utils.api_utils import get_result, get_result, get_error_data_result -from functools import partial from io import BytesIO -from elasticsearch_dsl import Q -from flask import request, send_file -from flask_login import login_required +from flask import send_file from api.db import FileSource, TaskStatus, FileType -from api.db.db_models import File -from api.db.services.document_service import DocumentService from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService -from api.db.services.knowledgebase_service import KnowledgebaseService -from api.settings import RetCode, retrievaler -from api.utils.api_utils import construct_json_result, construct_error_response -from rag.app import book, laws, manual, naive, one, paper, presentation, qa, resume, table, picture, audio, email -from rag.nlp import search -from rag.utils import rmSpace -from rag.utils.es_conn import ELASTICSEARCH +from api.utils.api_utils import construct_json_result from rag.utils.storage_factory import STORAGE_IMPL MAXIMUM_OF_UPLOADING_FILES = 256 @@ -142,8 +118,7 @@ def update_doc(tenant_id, dataset_id, document_id): tenant_id = DocumentService.get_tenant_id(req["id"]) if not tenant_id: return get_error_data_result(retmsg="Tenant not found!") - ELASTICSEARCH.deleteByQuery( - Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id)) + docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id)) return get_result() @@ -265,8 +240,7 @@ def parse(tenant_id,dataset_id): info["token_num"] = 0 DocumentService.update_by_id(id, info) # if str(req["run"]) == TaskStatus.CANCEL.value: - ELASTICSEARCH.deleteByQuery( - Q("match", doc_id=id), idxnm=search.index_name(tenant_id)) + docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id)) TaskService.filter_delete([Task.doc_id == id]) e, doc = DocumentService.get_by_id(id) doc = doc.to_dict() @@ -293,8 +267,7 @@ def stop_parsing(tenant_id,dataset_id): DocumentService.update_by_id(id, info) # if str(req["run"]) == TaskStatus.CANCEL.value: tenant_id = DocumentService.get_tenant_id(id) - ELASTICSEARCH.deleteByQuery( - Q("match", doc_id=id), idxnm=search.index_name(tenant_id)) + docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id)) return get_result() @@ -402,7 +375,7 @@ def create(tenant_id,dataset_id,document_id): v, c = embd_mdl.encode([doc.name, req["content"]]) v = 0.1 * v[0] + 0.9 * v[1] d["q_%d_vec" % len(v)] = v.tolist() - ELASTICSEARCH.upsert([d], search.index_name(tenant_id)) + docStoreConn.upsert([d], search.index_name(tenant_id)) DocumentService.increment_chunk_num( doc.id, doc.kb_id, c, 1, 0) @@ -445,8 +418,7 @@ def rm_chunk(tenant_id,dataset_id,document_id): for chunk_id in req.get("chunk_ids"): if chunk_id not in sres.ids: return get_error_data_result(f"Chunk {chunk_id} not found") - if not ELASTICSEARCH.deleteByQuery( - Q("ids", values=req["chunk_ids"]), search.index_name(tenant_id)): + if not docStoreConn.delete({"_id": req["chunk_ids"]}, search.index_name(tenant_id)): return get_error_data_result(retmsg="Index updating failure") deleted_chunk_ids = req["chunk_ids"] chunk_number = len(deleted_chunk_ids) @@ -459,10 +431,8 @@ def rm_chunk(tenant_id,dataset_id,document_id): @token_required def set(tenant_id,dataset_id,document_id,chunk_id): try: - res = ELASTICSEARCH.get( - chunk_id, search.index_name( - tenant_id)) - except Exception as e: + res = docStoreConn.get(chunk_id, search.index_name(tenant_id)) + except Exception: return get_error_data_result(f"Can't find this chunk {chunk_id}") if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id): return get_error_data_result(retmsg=f"You don't own the dataset {dataset_id}.") @@ -508,7 +478,7 @@ def set(tenant_id,dataset_id,document_id,chunk_id): v, c = embd_mdl.encode([doc.name, d["content_with_weight"]]) v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] d["q_%d_vec" % len(v)] = v.tolist() - ELASTICSEARCH.upsert([d], search.index_name(tenant_id)) + docStoreConn.upsert([d], search.index_name(tenant_id)) return get_result() @@ -580,6 +550,6 @@ def retrieval_test(tenant_id): return get_result(data=ranks) except Exception as e: if str(e).find("not_found") > 0: - return get_result(retmsg=f'No chunk found! Check the chunk statu s please!', + return get_result(retmsg='No chunk found! Check the chunk status please!', retcode=RetCode.DATA_ERROR) return server_error_response(e) \ No newline at end of file diff --git a/api/apps/system_app.py b/api/apps/system_app.py index 28df3d688d6..daacc1cc761 100644 --- a/api/apps/system_app.py +++ b/api/apps/system_app.py @@ -22,12 +22,11 @@ from api.db.services.api_service import APITokenService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.user_service import UserTenantService -from api.settings import DATABASE_TYPE +from api.settings import DATABASE_TYPE, docStoreConn from api.utils import current_timestamp, datetime_format from api.utils.api_utils import get_json_result, get_data_error_result, server_error_response, \ generate_confirmation_token, request, validate_request from api.versions import get_rag_version -from rag.utils.es_conn import ELASTICSEARCH from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE from timeit import default_timer as timer @@ -46,7 +45,7 @@ def status(): res = {} st = timer() try: - res["es"] = ELASTICSEARCH.health() + res["es"] = docStoreConn.health() res["es"]["elapsed"] = "{:.1f}".format((timer() - st)*1000.) except Exception as e: res["es"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)} diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index ed6220aed24..2c87b9aa421 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -24,16 +24,14 @@ from datetime import datetime from io import BytesIO -from elasticsearch_dsl import Q from peewee import fn from api.db.db_utils import bulk_insert_into_db -from api.settings import stat_logger +from api.settings import stat_logger, docStoreConn from api.utils import current_timestamp, get_format_time, get_uuid from api.utils.file_utils import get_project_base_directory from graphrag.mind_map_extractor import MindMapExtractor from rag.settings import SVR_QUEUE_NAME -from rag.utils.es_conn import ELASTICSEARCH from rag.utils.storage_factory import STORAGE_IMPL from rag.nlp import search, rag_tokenizer @@ -138,8 +136,7 @@ def insert(cls, doc): @classmethod @DB.connection_context() def remove_document(cls, doc, tenant_id): - ELASTICSEARCH.deleteByQuery( - Q("match", doc_id=doc.id), idxnm=search.index_name(tenant_id)) + docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id)) cls.clear_chunk_num(doc.id) return cls.delete_by_id(doc.id) @@ -317,7 +314,7 @@ def get_doc_id_by_doc_name(cls, doc_name): @classmethod @DB.connection_context() def get_thumbnails(cls, docids): - fields = [cls.model.id, cls.model.thumbnail] + fields = [cls.model.id, cls.model.thumbnail, cls.model.kb_id] return list(cls.model.select( *fields).where(cls.model.id.in_(docids)).dicts()) @@ -421,7 +418,7 @@ def do_cancel(cls, doc_id): try: _, doc = DocumentService.get_by_id(doc_id) return doc.run == TaskStatus.CANCEL.value or doc.progress < 0 - except Exception as e: + except Exception: pass return False @@ -463,8 +460,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): raise LookupError("Can't find this knowledgebase!") idxnm = search.index_name(kb.tenant_id) - if not ELASTICSEARCH.indexExist(idxnm): - ELASTICSEARCH.createIdx(idxnm, json.load( + if not docStoreConn.indexExist(idxnm): + docStoreConn.createIdx(idxnm, json.load( open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r"))) embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language) @@ -574,7 +571,7 @@ def embedding(doc_id, cnts, batch_size=16): v = vects[i] d["q_%d_vec" % len(v)] = v for b in range(0, len(cks), es_bulk_size): - ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], idxnm) + docStoreConn.upsertBulk(cks[b:b + es_bulk_size], idxnm) DocumentService.increment_chunk_num( doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0) diff --git a/api/settings.py b/api/settings.py index f48a5fe7a59..9a089cb68ac 100644 --- a/api/settings.py +++ b/api/settings.py @@ -33,7 +33,8 @@ database_logger = getLogger("database") chat_logger = getLogger("chat") -from rag.utils.es_conn import ELASTICSEARCH +from rag.utils.es_conn import ESConnection +#from rag.utils.infinity_conn import InfinityConnection from rag.nlp import search from graphrag import search as kg_search from api.utils import get_base_config, decrypt_database_config @@ -205,8 +206,9 @@ PRIVILEGE_COMMAND_WHITELIST = [] CHECK_NODES_IDENTITY = False -retrievaler = search.Dealer(ELASTICSEARCH) -kg_retrievaler = kg_search.KGSearch(ELASTICSEARCH) +docStoreConn = ESConnection() +retrievaler = search.Dealer(docStoreConn) +kg_retrievaler = kg_search.KGSearch(docStoreConn) class CustomEnum(Enum): diff --git a/docker/.env b/docker/.env index 8cb40201508..31bed31e3e7 100644 --- a/docker/.env +++ b/docker/.env @@ -29,7 +29,7 @@ MINIO_PASSWORD=infini_rag_flow REDIS_PORT=6379 REDIS_PASSWORD=infini_rag_flow -SVR_HTTP_PORT=9380 +SVR_HTTP_PORT=9456 # the Docker image for the slim version RAGFLOW_IMAGE=infiniflow/ragflow:dev-slim diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 2b8de187a8e..1e6aa1c3db1 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -15,7 +15,6 @@ services: - ${SVR_HTTP_PORT}:9380 - 80:80 - 443:443 - - 5678:5678 volumes: - ./service_conf.yaml:/ragflow/conf/service_conf.yaml - ./ragflow-logs:/ragflow/logs diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh index 1c2c3bc35a2..e1a1e31a6ff 100755 --- a/docker/entrypoint.sh +++ b/docker/entrypoint.sh @@ -3,7 +3,7 @@ # unset http proxy which maybe set by docker daemon export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY="" -/usr/sbin/nginx +# /usr/sbin/nginx export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/ diff --git a/docker/service_conf.yaml b/docker/service_conf.yaml index a92cfb212e6..d94ddc983de 100644 --- a/docker/service_conf.yaml +++ b/docker/service_conf.yaml @@ -1,12 +1,12 @@ ragflow: host: 0.0.0.0 - http_port: 9380 + http_port: 9456 mysql: name: 'rag_flow' user: 'root' password: 'infini_rag_flow' host: 'mysql' - port: 3306 + port: 5455 max_connections: 100 stale_timeout: 30 minio: @@ -14,7 +14,7 @@ minio: password: 'infini_rag_flow' host: 'minio:9000' es: - hosts: 'http://es01:9200' + hosts: 'http://es01:1200' username: 'elastic' password: 'infini_rag_flow' redis: diff --git a/graphrag/search.py b/graphrag/search.py index 85ba0698a3b..e8a761ad1a2 100644 --- a/graphrag/search.py +++ b/graphrag/search.py @@ -15,32 +15,39 @@ # import json from copy import deepcopy +from typing import Dict import pandas as pd -from elasticsearch_dsl import Q, Search +from rag.utils.data_store_conn import OrderByExpr, FusionExpr from rag.nlp.search import Dealer class KGSearch(Dealer): def search(self, req, idxnm, emb_mdl=None, highlight=False): - def merge_into_first(sres, title=""): - df,texts = [],[] - for d in sres["hits"]["hits"]: + def merge_into_first(sres, title="") -> Dict[str, str]: + if not sres: + return {} + content_with_weight = "" + df, texts = [],[] + for d in sres.values(): try: - df.append(json.loads(d["_source"]["content_with_weight"])) - except Exception as e: - texts.append(d["_source"]["content_with_weight"]) - pass - if not df and not texts: return False + df.append(json.loads(d["content_with_weight"])) + except Exception: + texts.append(d["content_with_weight"]) if df: - try: - sres["hits"]["hits"][0]["_source"]["content_with_weight"] = title + "\n" + pd.DataFrame(df).to_csv() - except Exception as e: - pass + content_with_weight = title + "\n" + pd.DataFrame(df).to_csv() else: - sres["hits"]["hits"][0]["_source"]["content_with_weight"] = title + "\n" + "\n".join(texts) - return True + content_with_weight = title + "\n" + "\n".join(texts) + first_id = "" + first_source = {} + for k, v in sres.items(): + first_id = id + first_source = deepcopy(v) + break + first_source["content_with_weight"] = content_with_weight + first_id = next(iter(sres)) + return {first_id: first_source} src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", "image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int", "name_kwd", @@ -49,61 +56,49 @@ def merge_into_first(sres, title=""): ]) qst = req.get("question", "") - binary_query, keywords = self.qryr.question(qst, min_match="5%") - binary_query = self._add_filters(binary_query, req) + matchText, keywords = self.qryr.question(qst, min_match=0.05) + condition = self.get_filters(req) ## Entity retrieval - bqry = deepcopy(binary_query) - bqry.filter.append(Q("terms", knowledge_graph_kwd=["entity"])) - s = Search() - s = s.query(bqry)[0: 32] - - s = s.to_dict() - q_vec = [] - if req.get("vector"): - assert emb_mdl, "No embedding model selected" - s["knn"] = self._vector( - qst, emb_mdl, req.get( - "similarity", 0.1), 1024) - s["knn"]["filter"] = bqry.to_dict() - q_vec = s["knn"]["query_vector"] + condition.update({"knowledge_graph_kwd": ["entity"]}) + assert emb_mdl, "No embedding model selected" + matchDense = self.get_vector(qst, emb_mdl, 1024, req.get("similarity", 0.1)) + q_vec = matchDense.embedding_data + fusionExpr = FusionExpr("weighted_sum", 32, {"weights": "0.5, 0.5"}) - ent_res = self.es.search(deepcopy(s), idxnm=idxnm, timeout="600s", src=src) - entities = [d["name_kwd"] for d in self.es.getSource(ent_res)] - ent_ids = self.es.getDocIds(ent_res) - if merge_into_first(ent_res, "-Entities-"): - ent_ids = ent_ids[0:1] + ent_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 32, idxnm) + ent_res_fields = self.dataStore.getFields(ent_res, src) + entities = [d["name_kwd"] for d in ent_res_fields.values()] + ent_ids = self.dataStore.getChunkIds(ent_res) + ent_content = merge_into_first(ent_res_fields, "-Entities-") + if ent_content: + ent_ids = list(ent_content.keys()) ## Community retrieval - bqry = deepcopy(binary_query) - bqry.filter.append(Q("terms", entities_kwd=entities)) - bqry.filter.append(Q("terms", knowledge_graph_kwd=["community_report"])) - s = Search() - s = s.query(bqry)[0: 32] - s = s.to_dict() - comm_res = self.es.search(deepcopy(s), idxnm=idxnm, timeout="600s", src=src) - comm_ids = self.es.getDocIds(comm_res) - if merge_into_first(comm_res, "-Community Report-"): - comm_ids = comm_ids[0:1] + condition = self.get_filters(req) + condition.update({"entities_kwd": entities, "knowledge_graph_kwd": ["community_report"]}) + comm_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 32, idxnm) + comm_res_fields = self.dataStore.getFields(comm_res, src) + comm_ids = self.dataStore.getChunkIds(comm_res) + comm_content = merge_into_first(comm_res_fields, "-Community Report-") + if comm_content: + comm_ids = list(comm_content.keys()) ## Text content retrieval - bqry = deepcopy(binary_query) - bqry.filter.append(Q("terms", knowledge_graph_kwd=["text"])) - s = Search() - s = s.query(bqry)[0: 6] - s = s.to_dict() - txt_res = self.es.search(deepcopy(s), idxnm=idxnm, timeout="600s", src=src) - txt_ids = self.es.getDocIds(txt_res) - if merge_into_first(txt_res, "-Original Content-"): - txt_ids = txt_ids[0:1] + condition = self.get_filters(req) + condition.update({"knowledge_graph_kwd": ["text"]}) + txt_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 6, idxnm) + txt_res_fields = self.dataStore.getFields(txt_res, src) + txt_ids = self.dataStore.getChunkIds(txt_res) + txt_content = merge_into_first(txt_res_fields, "-Original Content-") + if txt_content: + txt_ids = list(txt_content.keys()) return self.SearchResult( total=len(ent_ids) + len(comm_ids) + len(txt_ids), ids=[*ent_ids, *comm_ids, *txt_ids], query_vector=q_vec, - aggregation=None, highlight=None, - field={**self.getFields(ent_res, src), **self.getFields(comm_res, src), **self.getFields(txt_res, src)}, + field={**ent_content, **comm_content, **txt_content}, keywords=[] ) - diff --git a/poetry.lock b/poetry.lock index 74b2d4b40de..cd6f89e6489 100644 --- a/poetry.lock +++ b/poetry.lock @@ -550,7 +550,7 @@ name = "bce-python-sdk" version = "0.9.23" description = "BCE SDK for python" optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,<4,>=2.7" +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, <4" files = [ {file = "bce_python_sdk-0.9.23-py3-none-any.whl", hash = "sha256:8debe21a040e00060f6044877d594765ed7b18bc765c6bf16b878bca864140a3"}, {file = "bce_python_sdk-0.9.23.tar.gz", hash = "sha256:19739fed5cd0725356fc5ffa2acbdd8fb23f2a81edb91db21a03174551d0cf41"}, @@ -1854,7 +1854,7 @@ name = "fastembed" version = "0.3.6" description = "Fast, light, accurate library built for retrieval embedding generation" optional = false -python-versions = "<3.13,>=3.8.0" +python-versions = ">=3.8.0,<3.13" files = [ {file = "fastembed-0.3.6-py3-none-any.whl", hash = "sha256:2bf70edae28bb4ccd9e01617098c2075b0ba35b88025a3d22b0e1e85b2c488ce"}, {file = "fastembed-0.3.6.tar.gz", hash = "sha256:c93c8ec99b8c008c2d192d6297866b8d70ec7ac8f5696b34eb5ea91f85efd15f"}, @@ -2726,7 +2726,7 @@ name = "graspologic" version = "3.4.1" description = "A set of Python modules for graph statistics" optional = false -python-versions = "<3.13,>=3.9" +python-versions = ">=3.9,<3.13" files = [ {file = "graspologic-3.4.1-py3-none-any.whl", hash = "sha256:c6563e087eda599bad1de831d4b7321c0daa7a82f4e85a7d7737ff67e07cdda2"}, {file = "graspologic-3.4.1.tar.gz", hash = "sha256:7561f0b852a2bccd351bff77e8db07d9892f9dfa35a420fdec01690e4fdc8075"}, @@ -3326,7 +3326,7 @@ web-service = ["fastapi (>=0.109.0,<0.110.0)", "uvicorn (>=0.25.0,<0.26.0)"] [[package]] name = "intel-openmp" version = "2021.4.0" -description = "Intel® OpenMP* Runtime Library" +description = "Intel OpenMP* Runtime Library" optional = false python-versions = "*" files = [ @@ -5343,6 +5343,47 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "polars" +version = "1.9.0" +description = "Blazingly fast DataFrame library" +optional = false +python-versions = ">=3.9" +files = [ + {file = "polars-1.9.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:a471d2ce96f6fa5dd0ef16bcdb227f3dbe3af8acb776ca52f9e64ef40c7489a0"}, + {file = "polars-1.9.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:94b12d731cd200d2c50b13fc070d6353f708e632bca6529c5a72aa6a69e5285d"}, + {file = "polars-1.9.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f85f132732aa63c6f3b502b0fdfc3ba9f0b78cc6330059b5a2d6f9fd78508acb"}, + {file = "polars-1.9.0-cp38-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:f753c8941a3b3249d59262d68a856714a96a7d4e16977aefbb196be0c192e151"}, + {file = "polars-1.9.0-cp38-abi3-win_amd64.whl", hash = "sha256:95de07066cd797dd940fa2783708a7bef93c827a57be0f4dfad3575a6144212b"}, + {file = "polars-1.9.0.tar.gz", hash = "sha256:8e1206ef876f61c1d50a81e102611ea92ee34631cb135b46ad314bfefd3cb122"}, +] + +[package.extras] +adbc = ["adbc-driver-manager[dbapi]", "adbc-driver-sqlite[dbapi]"] +all = ["polars[async,cloudpickle,database,deltalake,excel,fsspec,graph,iceberg,numpy,pandas,plot,pyarrow,pydantic,style,timezone]"] +async = ["gevent"] +calamine = ["fastexcel (>=0.9)"] +cloudpickle = ["cloudpickle"] +connectorx = ["connectorx (>=0.3.2)"] +database = ["nest-asyncio", "polars[adbc,connectorx,sqlalchemy]"] +deltalake = ["deltalake (>=0.15.0)"] +excel = ["polars[calamine,openpyxl,xlsx2csv,xlsxwriter]"] +fsspec = ["fsspec"] +gpu = ["cudf-polars-cu12"] +graph = ["matplotlib"] +iceberg = ["pyiceberg (>=0.5.0)"] +numpy = ["numpy (>=1.16.0)"] +openpyxl = ["openpyxl (>=3.0.0)"] +pandas = ["pandas", "polars[pyarrow]"] +plot = ["altair (>=5.4.0)"] +pyarrow = ["pyarrow (>=7.0.0)"] +pydantic = ["pydantic"] +sqlalchemy = ["polars[pandas]", "sqlalchemy"] +style = ["great-tables (>=0.8.0)"] +timezone = ["backports-zoneinfo", "tzdata"] +xlsx2csv = ["xlsx2csv (>=0.8.0)"] +xlsxwriter = ["xlsxwriter"] + [[package]] name = "pooch" version = "1.8.2" @@ -9008,4 +9049,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = ">=3.12,<3.13" -content-hash = "9bb9779964942b51070e5499058a7e00f64a370f901db366ae94abc8f0ade600" +content-hash = "fbd7b37a15bca5beb077ee24e85eea1a9e93e6ead2cc730a33597930d8094842" diff --git a/pyproject.toml b/pyproject.toml index 22176016345..c29208d18c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,6 +113,7 @@ graspologic = "^3.4.1" pymysql = "^1.1.1" mini-racer = "^0.12.4" pyicu = "^2.13.1" +polars = "^1.9.0" [tool.poetry.group.full] diff --git a/rag/benchmark.py b/rag/benchmark.py index 490c031f97c..35107c9a3ee 100644 --- a/rag/benchmark.py +++ b/rag/benchmark.py @@ -13,16 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import json import os from collections import defaultdict from api.db import LLMType from api.db.services.llm_service import LLMBundle from api.db.services.knowledgebase_service import KnowledgebaseService -from api.settings import retrievaler +from api.settings import retrievaler, docStoreConn from api.utils import get_uuid from rag.nlp import tokenize, search -from rag.utils.es_conn import ELASTICSEARCH from ranx import evaluate import pandas as pd from tqdm import tqdm @@ -83,11 +81,11 @@ def ms_marco_index(self, file_path, index_name): qrels[query][d["id"]] = int(rel) if len(docs) >= 32: docs = self.embedding(docs) - ELASTICSEARCH.bulk(docs, search.index_name(index_name)) + docStoreConn.upsertBulk(docs, search.index_name(index_name)) docs = [] docs = self.embedding(docs) - ELASTICSEARCH.bulk(docs, search.index_name(index_name)) + docStoreConn.upsertBulk(docs, search.index_name(index_name)) return qrels, texts def trivia_qa_index(self, file_path, index_name): @@ -110,11 +108,11 @@ def trivia_qa_index(self, file_path, index_name): qrels[query][d["id"]] = int(rel) if len(docs) >= 32: docs = self.embedding(docs) - ELASTICSEARCH.bulk(docs, search.index_name(index_name)) + docStoreConn.upsertBulk(docs, search.index_name(index_name)) docs = [] docs = self.embedding(docs) - ELASTICSEARCH.bulk(docs, search.index_name(index_name)) + docStoreConn.upsertBulk(docs, search.index_name(index_name)) return qrels, texts def miracl_index(self, file_path, corpus_path, index_name): @@ -155,11 +153,11 @@ def miracl_index(self, file_path, corpus_path, index_name): qrels[query][d["id"]] = int(rel) if len(docs) >= 32: docs = self.embedding(docs) - ELASTICSEARCH.bulk(docs, search.index_name(index_name)) + docStoreConn.upsertBulk(docs, search.index_name(index_name)) docs = [] docs = self.embedding(docs) - ELASTICSEARCH.bulk(docs, search.index_name(index_name)) + docStoreConn.upsertBulk(docs, search.index_name(index_name)) return qrels, texts diff --git a/rag/nlp/query.py b/rag/nlp/query.py index c58c99c4cfc..77bae5a2077 100644 --- a/rag/nlp/query.py +++ b/rag/nlp/query.py @@ -15,20 +15,25 @@ # import json -import math import re import logging -import copy -from elasticsearch_dsl import Q +from rag.utils.data_store_conn import MatchTextExpr from rag.nlp import rag_tokenizer, term_weight, synonym -class EsQueryer: - def __init__(self, es): + +class FulltextQueryer: + def __init__(self): self.tw = term_weight.Dealer() - self.es = es self.syn = synonym.Dealer() - self.flds = ["ask_tks^10", "ask_small_tks"] + self.query_fields = [ + "title_tks^10", + "title_sm_tks^5", + "important_kwd^30", + "important_tks^20", + "content_ltks^2", + "content_sm_ltks", + ] @staticmethod def subSpecialChar(line): @@ -43,12 +48,15 @@ def isChinese(line): for t in arr: if not re.match(r"[a-zA-Z]+$", t): e += 1 - return e * 1. / len(arr) >= 0.7 + return e * 1.0 / len(arr) >= 0.7 @staticmethod def rmWWW(txt): patts = [ - (r"是*(什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*", ""), + ( + r"是*(什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*", + "", + ), (r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "), (r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of) ", " ") ] @@ -56,14 +64,13 @@ def rmWWW(txt): txt = re.sub(r, p, txt, flags=re.IGNORECASE) return txt - def question(self, txt, tbl="qa", min_match="60%"): + def question(self, txt, tbl="qa", min_match:float=0.6): txt = re.sub( r"[ :\r\n\t,,。??/`!!&\^%%]+", " ", - rag_tokenizer.tradi2simp( - rag_tokenizer.strQ2B( - txt.lower()))).strip() - txt = EsQueryer.rmWWW(txt) + rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(txt.lower())), + ).strip() + txt = FulltextQueryer.rmWWW(txt) if not self.isChinese(txt): tks = rag_tokenizer.tokenize(txt).split(" ") @@ -73,14 +80,20 @@ def question(self, txt, tbl="qa", min_match="60%"): tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk] q = ["{}^{:.4f}".format(tk, w) for tk, w in tks_w if tk] for i in range(1, len(tks_w)): - q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2)) + q.append( + '"%s %s"^%.4f' + % ( + tks_w[i - 1][0], + tks_w[i][0], + max(tks_w[i - 1][1], tks_w[i][1]) * 2, + ) + ) if not q: q.append(txt) - return Q("bool", - must=Q("query_string", fields=self.flds, - type="best_fields", query=" ".join(q), - boost=1)#, minimum_should_match=min_match) - ), list(set([t for t in txt.split(" ") if t])) + query = " ".join(q) + return MatchTextExpr( + self.query_fields, query, 100, {"minimum_should_match": min_match} + ), tks def need_fine_grained_tokenize(tk): if len(tk) < 3: @@ -100,65 +113,71 @@ def need_fine_grained_tokenize(tk): logging.info(json.dumps(twts, ensure_ascii=False)) tms = [] for tk, w in sorted(twts, key=lambda x: x[1] * -1): - sm = rag_tokenizer.fine_grained_tokenize(tk).split(" ") if need_fine_grained_tokenize(tk) else [] + sm = ( + rag_tokenizer.fine_grained_tokenize(tk).split(" ") + if need_fine_grained_tokenize(tk) + else [] + ) sm = [ re.sub( r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+", "", - m) for m in sm] - sm = [EsQueryer.subSpecialChar(m) for m in sm if len(m) > 1] + m, + ) + for m in sm + ] + sm = [FulltextQueryer.subSpecialChar(m) for m in sm if len(m) > 1] sm = [m for m in sm if len(m) > 1] keywords.append(re.sub(r"[ \\\"']+", "", tk)) keywords.extend(sm) - if len(keywords) >= 12: break + if len(keywords) >= 12: + break tk_syns = self.syn.lookup(tk) - tk = EsQueryer.subSpecialChar(tk) + tk = FulltextQueryer.subSpecialChar(tk) if tk.find(" ") > 0: - tk = "\"%s\"" % tk + tk = '"%s"' % tk if tk_syns: tk = f"({tk} %s)" % " ".join(tk_syns) if sm: - tk = f"{tk} OR \"%s\" OR (\"%s\"~2)^0.5" % ( - " ".join(sm), " ".join(sm)) + tk = f'{tk} OR "%s" OR ("%s"~2)^0.5' % (" ".join(sm), " ".join(sm)) if tk.strip(): tms.append((tk, w)) tms = " ".join([f"({t})^{w}" for t, w in tms]) if len(twts) > 1: - tms += f" (\"%s\"~4)^1.5" % (" ".join([t for t, _ in twts])) + tms += ' ("%s"~4)^1.5' % (" ".join([t for t, _ in twts])) if re.match(r"[0-9a-z ]+$", tt): - tms = f"(\"{tt}\" OR \"%s\")" % rag_tokenizer.tokenize(tt) + tms = f'("{tt}" OR "%s")' % rag_tokenizer.tokenize(tt) syns = " OR ".join( - ["\"%s\"^0.7" % EsQueryer.subSpecialChar(rag_tokenizer.tokenize(s)) for s in syns]) + [ + '"%s"^0.7' + % FulltextQueryer.subSpecialChar(rag_tokenizer.tokenize(s)) + for s in syns + ] + ) if syns: tms = f"({tms})^5 OR ({syns})^0.7" qs.append(tms) - flds = copy.deepcopy(self.flds) - mst = [] if qs: - mst.append( - Q("query_string", fields=flds, type="best_fields", - query=" OR ".join([f"({t})" for t in qs if t]), boost=1, minimum_should_match=min_match) - ) - - return Q("bool", - must=mst, - ), list(set(keywords)) + query = " OR ".join([f"({t})" for t in qs if t]) + return MatchTextExpr( + self.query_fields, query, 100, {"minimum_should_match": min_match} + ), keywords + return None, keywords - def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, - vtweight=0.7): + def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, vtweight=0.7): from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity import numpy as np + sims = CosineSimilarity([avec], bvecs) tksim = self.token_similarity(atks, btkss) - return np.array(sims[0]) * vtweight + \ - np.array(tksim) * tkweight, tksim, sims[0] + return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0] def token_similarity(self, atks, btkss): def toDict(tks): diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 89a2592ece0..6f4ab8f0a66 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -14,34 +14,25 @@ # limitations under the License. # -import json import re -from copy import deepcopy -from elasticsearch_dsl import Q, Search from typing import List, Optional, Dict, Union from dataclasses import dataclass -from rag.settings import es_logger +from rag.settings import doc_store_logger from rag.utils import rmSpace -from rag.nlp import rag_tokenizer, query, is_english +from rag.nlp import rag_tokenizer, query import numpy as np +from rag.utils.data_store_conn import DocStoreConnection, MatchDenseExpr, FusionExpr, OrderByExpr def index_name(uid): return f"ragflow_{uid}" class Dealer: - def __init__(self, es): - self.qryr = query.EsQueryer(es) - self.qryr.flds = [ - "title_tks^10", - "title_sm_tks^5", - "important_kwd^30", - "important_tks^20", - "content_ltks^2", - "content_sm_ltks"] - self.es = es + def __init__(self, dataStore: DocStoreConnection): + self.qryr = query.FulltextQueryer() + self.dataStore = dataStore @dataclass class SearchResult: @@ -54,98 +45,63 @@ class SearchResult: keywords: Optional[List[str]] = None group_docs: List[List] = None - def _vector(self, txt, emb_mdl, sim=0.8, topk=10): - qv, c = emb_mdl.encode_queries(txt) - return { - "field": "q_%d_vec" % len(qv), - "k": topk, - "similarity": sim, - "num_candidates": topk * 2, - "query_vector": [float(v) for v in qv] - } - - def _add_filters(self, bqry, req): - if req.get("kb_ids"): - bqry.filter.append(Q("terms", kb_id=req["kb_ids"])) - if req.get("doc_ids"): - bqry.filter.append(Q("terms", doc_id=req["doc_ids"])) - if req.get("knowledge_graph_kwd"): - bqry.filter.append(Q("terms", knowledge_graph_kwd=req["knowledge_graph_kwd"])) - if "available_int" in req: - if req["available_int"] == 0: - bqry.filter.append(Q("range", available_int={"lt": 1})) - else: - bqry.filter.append( - Q("bool", must_not=Q("range", available_int={"lt": 1}))) - return bqry - - def search(self, req, idxnm, emb_mdl=None, highlight=False): + def get_vector(self, txt, emb_mdl, topk=10, similarity=0.1): + qv, _ = emb_mdl.encode_queries(txt) + embedding_data = [float(v) for v in qv] + vector_column_name = f"q_{len(embedding_data)}_vec" + return MatchDenseExpr(vector_column_name, embedding_data, 'float', 'cosine', topk, {"similarity": similarity}) + + def get_filters(self, req): + condition = dict() + for key, field in {"kb_ids": "kb_id", "doc_ids": "doc_id"}.items(): + if key in req and req[key] is not None: + condition[field] = req[key] + # TODO(yzc): `available_int` is nullable however infinity doesn't support nullable columns. + for key in ["knowledge_graph_kwd"]: + if key in req and req[key] is not None: + condition[key] = req[key] + return condition + + def search(self, req, idxnm: str, emb_mdl: str, highlight = False): qst = req.get("question", "") - bqry, keywords = self.qryr.question(qst, min_match="30%") - bqry = self._add_filters(bqry, req) - bqry.boost = 0.05 + matchText, keywords = self.qryr.question(qst, min_match=0.3) + filters = self.get_filters(req) + orderBy = OrderByExpr() - s = Search() pg = int(req.get("page", 1)) - 1 topk = int(req.get("topk", 1024)) ps = int(req.get("size", topk)) + offset, limit = pg * ps, (pg + 1) * ps src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", "image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int", "knowledge_graph_kwd", "q_1024_vec", "q_1536_vec", "available_int", "content_with_weight"]) - s = s.query(bqry)[pg * ps:(pg + 1) * ps] - s = s.highlight("content_ltks") - s = s.highlight("title_ltks") + q_vec = [] + assert emb_mdl, "No embedding model selected" + matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1)) + q_vec = matchDense.embedding_data + + fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05, 0.95"}) if not qst: - if not req.get("sort"): - s = s.sort( - #{"create_time": {"order": "desc", "unmapped_type": "date"}}, - {"create_timestamp_flt": { - "order": "desc", "unmapped_type": "float"}} - ) + if req.get("sort"): + orderBy.asc("page_num_int").asc("top_int").desc("create_timestamp_flt") else: - s = s.sort( - {"page_num_int": {"order": "asc", "unmapped_type": "float", - "mode": "avg", "numeric_type": "double"}}, - {"top_int": {"order": "asc", "unmapped_type": "float", - "mode": "avg", "numeric_type": "double"}}, - #{"create_time": {"order": "desc", "unmapped_type": "date"}}, - {"create_timestamp_flt": { - "order": "desc", "unmapped_type": "float"}} - ) - - if qst: - s = s.highlight_options( - fragment_size=120, - number_of_fragments=5, - boundary_scanner_locale="zh-CN", - boundary_scanner="SENTENCE", - boundary_chars=",./;:\\!(),。?:!……()——、" - ) - s = s.to_dict() - q_vec = [] - if req.get("vector"): - assert emb_mdl, "No embedding model selected" - s["knn"] = self._vector( - qst, emb_mdl, req.get( - "similarity", 0.1), topk) - s["knn"]["filter"] = bqry.to_dict() - if not highlight and "highlight" in s: - del s["highlight"] - q_vec = s["knn"]["query_vector"] - es_logger.info("【Q】: {}".format(json.dumps(s))) - res = self.es.search(deepcopy(s), idxnm=idxnm, timeout="600s", src=src) - es_logger.info("TOTAL: {}".format(self.es.getTotal(res))) - if self.es.getTotal(res) == 0 and "knn" in s: - bqry, _ = self.qryr.question(qst, min_match="10%") - if req.get("doc_ids"): - bqry = Q("bool", must=[]) - bqry = self._add_filters(bqry, req) - s["query"] = bqry.to_dict() - s["knn"]["filter"] = bqry.to_dict() - s["knn"]["similarity"] = 0.17 - res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src) - es_logger.info("【Q】: {}".format(json.dumps(s))) + orderBy.desc("create_timestamp_flt") + + highlightFields = ["content_ltks", "title_tks"] if highlight else [] + + res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr], orderBy, offset, limit, idxnm) + total=self.dataStore.getTotal(res) + + doc_store_logger.info(f"TOTAL: {total}") + + # If result is empty, try again with lower min_match + if total == 0: + matchText, _ = self.qryr.question(qst, min_match=0.1) + del filters["doc_ids"] + matchDense.extra_options["similarity"] = 0.17 + res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr], orderBy, offset, limit, idxnm) + total=self.dataStore.getTotal(res) kwds = set([]) for k in keywords: @@ -157,67 +113,19 @@ def search(self, req, idxnm, emb_mdl=None, highlight=False): continue kwds.add(kk) - aggs = self.getAggregation(res, "docnm_kwd") - + ids=self.dataStore.getChunkIds(res) + highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight") + aggs = self.dataStore.getAggregation(res, "docnm_kwd") return self.SearchResult( - total=self.es.getTotal(res), - ids=self.es.getDocIds(res), + total=total, + ids=ids, query_vector=q_vec, aggregation=aggs, - highlight=self.getHighlight(res, keywords, "content_with_weight"), - field=self.getFields(res, src), + highlight=highlight, + field=self.dataStore.getFields(res, src), keywords=list(kwds) ) - def getAggregation(self, res, g): - if not "aggregations" in res or "aggs_" + g not in res["aggregations"]: - return - bkts = res["aggregations"]["aggs_" + g]["buckets"] - return [(b["key"], b["doc_count"]) for b in bkts] - - def getHighlight(self, res, keywords, fieldnm): - ans = {} - for d in res["hits"]["hits"]: - hlts = d.get("highlight") - if not hlts: - continue - txt = "...".join([a for a in list(hlts.items())[0][1]]) - if not is_english(txt.split(" ")): - ans[d["_id"]] = txt - continue - - txt = d["_source"][fieldnm] - txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE|re.MULTILINE) - txts = [] - for t in re.split(r"[.?!;\n]", txt): - for w in keywords: - t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"%re.escape(w), r"\1\2\3", t, flags=re.IGNORECASE|re.MULTILINE) - if not re.search(r"[^<>]+", t, flags=re.IGNORECASE|re.MULTILINE): continue - txts.append(t) - ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]]) - - return ans - - def getFields(self, sres, flds): - res = {} - if not flds: - return {} - for d in self.es.getSource(sres): - m = {n: d.get(n) for n in flds if d.get(n) is not None} - for n, v in m.items(): - if isinstance(v, type([])): - m[n] = "\t".join([str(vv) if not isinstance( - vv, list) else "\t".join([str(vvv) for vvv in vv]) for vv in v]) - continue - if not isinstance(v, type("")): - m[n] = str(m[n]) - if n.find("tks") > 0: - m[n] = rmSpace(m[n]) - - if m: - res[d["id"]] = m - return res - @staticmethod def trans2floats(txt): return [float(t) for t in txt.split("\t")] @@ -260,7 +168,7 @@ def insert_citations(self, answer, chunks, chunk_v, continue idx.append(i) pieces_.append(t) - es_logger.info("{} => {}".format(answer, pieces_)) + doc_store_logger.info("{} => {}".format(answer, pieces_)) if not pieces_: return answer, set([]) @@ -281,7 +189,7 @@ def insert_citations(self, answer, chunks, chunk_v, chunks_tks, tkweight, vtweight) mx = np.max(sim) * 0.99 - es_logger.info("{} SIM: {}".format(pieces_[i], mx)) + doc_store_logger.info("{} SIM: {}".format(pieces_[i], mx)) if mx < thr: continue cites[idx[i]] = list( @@ -436,39 +344,12 @@ def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, simi return ranks def sql_retrieval(self, sql, fetch_size=128, format="json"): - from api.settings import chat_logger - sql = re.sub(r"[ `]+", " ", sql) - sql = sql.replace("%", "") - es_logger.info(f"Get es sql: {sql}") - replaces = [] - for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql): - fld, v = r.group(1), r.group(3) - match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format( - fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v))) - replaces.append( - ("{}{}'{}'".format( - r.group(1), - r.group(2), - r.group(3)), - match)) - - for p, r in replaces: - sql = sql.replace(p, r, 1) - chat_logger.info(f"To es: {sql}") - - try: - tbl = self.es.sql(sql, fetch_size, format) - return tbl - except Exception as e: - chat_logger.error(f"SQL failure: {sql} =>" + str(e)) - return {"error": str(e)} + tbl = self.dataStore.sql(sql, fetch_size, format) + return tbl def chunk_list(self, doc_id, tenant_id, max_count=1024, fields=["docnm_kwd", "content_with_weight", "img_id"]): - s = Search() - s = s.query(Q("match", doc_id=doc_id))[0:max_count] - s = s.to_dict() - es_res = self.es.search(s, idxnm=index_name(tenant_id), timeout="600s", src=fields) - res = [] - for index, chunk in enumerate(es_res['hits']['hits']): - res.append({fld: chunk['_source'].get(fld) for fld in fields}) + condition = {"doc_id": doc_id} + res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), 0, 0, index_name(tenant_id)) + if len(res) > max_count: + res = res[0:max_count] return res diff --git a/rag/settings.py b/rag/settings.py index 8c88c4067b1..a61d66d3967 100644 --- a/rag/settings.py +++ b/rag/settings.py @@ -27,10 +27,11 @@ ES = get_base_config("es", {}) AZURE = get_base_config("azure", {}) S3 = get_base_config("s3", {}) +INFINITY = get_base_config("infinity", {"uri": "127.0.0.1:23817"}) MINIO = decrypt_database_config(name="minio") try: REDIS = decrypt_database_config(name="redis") -except Exception as e: +except Exception: REDIS = {} pass DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024)) @@ -44,7 +45,7 @@ # {CRITICAL: 50, FATAL:50, ERROR:40, WARNING:30, WARN:30, INFO:20, DEBUG:10, NOTSET:0} LoggerFactory.LEVEL = 30 -es_logger = getLogger("es") +doc_store_logger = getLogger("doc_store") minio_logger = getLogger("minio") s3_logger = getLogger("s3") azure_logger = getLogger("azure") @@ -53,7 +54,7 @@ database_logger = getLogger("database") formatter = logging.Formatter("%(asctime)-15s %(levelname)-8s (%(process)d) %(message)s") -for logger in [es_logger, minio_logger, s3_logger, azure_logger, cron_logger, chunk_logger, database_logger]: +for logger in [doc_store_logger, minio_logger, s3_logger, azure_logger, cron_logger, chunk_logger, database_logger]: logger.setLevel(logging.INFO) for handler in logger.handlers: handler.setFormatter(fmt=formatter) diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 48d00b13e1d..5614135bf7d 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -31,14 +31,13 @@ import numpy as np import pandas as pd -from elasticsearch_dsl import Q from api.db import LLMType, ParserType from api.db.services.document_service import DocumentService from api.db.services.llm_service import LLMBundle from api.db.services.task_service import TaskService from api.db.services.file2document_service import File2DocumentService -from api.settings import retrievaler +from api.settings import retrievaler, docStoreConn from api.utils.file_utils import get_project_base_directory from api.db.db_models import close_connection from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email @@ -47,7 +46,6 @@ from rag.settings import database_logger, SVR_QUEUE_NAME from rag.settings import cron_logger, DOC_MAXIMUM_SIZE from rag.utils import rmSpace, num_tokens_from_string -from rag.utils.es_conn import ELASTICSEARCH from rag.utils.redis_conn import REDIS_CONN, Payload from rag.utils.storage_factory import STORAGE_IMPL @@ -226,9 +224,9 @@ def build(row): def init_kb(row): idxnm = search.index_name(row["tenant_id"]) - if ELASTICSEARCH.indexExist(idxnm): + if docStoreConn.indexExist(idxnm): return - return ELASTICSEARCH.createIdx(idxnm, json.load( + return docStoreConn.createIdx(idxnm, json.load( open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r"))) @@ -366,20 +364,18 @@ def main(): es_r = "" es_bulk_size = 4 for b in range(0, len(cks), es_bulk_size): - es_r = ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"])) + es_r = docStoreConn.upsertBulk(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"])) if b % 128 == 0: callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="") cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st)) if es_r: callback(-1, "Insert chunk error, detail info please check ragflow-logs/api/cron_logger.log. Please also check ES status!") - ELASTICSEARCH.deleteByQuery( - Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"])) + docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"])) cron_logger.error(str(es_r)) else: if TaskService.do_cancel(r["id"]): - ELASTICSEARCH.deleteByQuery( - Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"])) + docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"])) continue callback(1., "Done!") DocumentService.increment_chunk_num( diff --git a/rag/utils/data_store_conn.py b/rag/utils/data_store_conn.py new file mode 100644 index 00000000000..e3d26734193 --- /dev/null +++ b/rag/utils/data_store_conn.py @@ -0,0 +1,244 @@ +from abc import ABC, abstractmethod +from typing import Optional, Union +from dataclasses import dataclass +import numpy as np +import polars as pl +from typing import List, Dict + +DEFAULT_MATCH_VECTOR_TOPN = 10 +DEFAULT_MATCH_SPARSE_TOPN = 10 +VEC = Union[list, np.ndarray] + + +@dataclass +class SparseVector: + indices: list[int] + values: Union[list[float], list[int], None] = None + + def __post_init__(self): + assert (self.values is None) or (len(self.indices) == len(self.values)) + + def to_dict_old(self): + d = {"indices": self.indices} + if self.values is not None: + d["values"] = self.values + return d + + def to_dict(self): + if self.values is None: + raise ValueError("SparseVector.values is None") + result = {} + for i, v in zip(self.indices, self.values): + result[str(i)] = v + return result + + @staticmethod + def from_dict(d): + return SparseVector(d["indices"], d.get("values")) + + def __str__(self): + return f"SparseVector(indices={self.indices}{'' if self.values is None else f', values={self.values}'})" + + def __repr__(self): + return str(self) + + +class MatchTextExpr(ABC): + def __init__( + self, + fields: str, + matching_text: str, + topn: int, + extra_options: dict = dict(), + ): + self.fields = fields + self.matching_text = matching_text + self.topn = topn + self.extra_options = extra_options + + +class MatchDenseExpr(ABC): + def __init__( + self, + vector_column_name: str, + embedding_data: VEC, + embedding_data_type: str, + distance_type: str, + topn: int = DEFAULT_MATCH_VECTOR_TOPN, + extra_options: dict = dict(), + ): + self.vector_column_name = vector_column_name + self.embedding_data = embedding_data + self.embedding_data_type = embedding_data_type + self.distance_type = distance_type + self.topn = topn + self.extra_options = extra_options + + +class MatchSparseExpr(ABC): + def __init__( + self, + vector_column_name: str, + sparse_data: SparseVector | dict, + distance_type: str, + topn: int, + opt_params: Optional[dict] = None, + ): + self.vector_column_name = vector_column_name + self.sparse_data = sparse_data + self.distance_type = distance_type + self.topn = topn + self.opt_params = opt_params + + +class MatchTensorExpr(ABC): + def __init__( + self, + column_name: str, + query_data: VEC, + query_data_type: str, + topn: int, + extra_option: Optional[dict] = None, + ): + self.column_name = column_name + self.query_data = query_data + self.query_data_type = query_data_type + self.topn = topn + self.extra_option = extra_option + + +class FusionExpr(ABC): + def __init__(self, method: str, topn: int, fusion_params: Optional[dict] = None): + self.method = method + self.topn = topn + self.fusion_params = fusion_params + + +MatchExpr = Union[ + MatchTextExpr, MatchDenseExpr, MatchSparseExpr, MatchTensorExpr, FusionExpr +] + + +class OrderByExpr(ABC): + def __init__(self): + self.fields = list() + def asc(self, field: str): + self.fields.append((field, 0)) + return self + def desc(self, field: str): + self.fields.append((field, 1)) + return self + def fields(self): + return self.fields + +class DocStoreConnection(ABC): + """ + Database operations + """ + + @abstractmethod + def health(self) -> dict: + """ + Return the health status of the database. + """ + raise NotImplementedError("Not implemented") + + """ + Table operations + """ + + @abstractmethod + def createIdx(self, vectorSize: int, indexName: str): + """ + Create an index with given name + """ + raise NotImplementedError("Not implemented") + + @abstractmethod + def deleteIdx(self, indexName: str): + """ + Delete an index with given name + """ + raise NotImplementedError("Not implemented") + + @abstractmethod + def indexExist(self, indexName: str) -> bool: + """ + Check if an index with given name exists + """ + raise NotImplementedError("Not implemented") + + """ + CRUD operations + """ + + @abstractmethod + def search( + self, selectFields: list[str], highlight: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexName: str + ) -> list[dict] | pl.DataFrame: + """ + Search with given conjunctive equivalent filtering condition and return all fields of matched documents + """ + raise NotImplementedError("Not implemented") + + @abstractmethod + def get(self, docId: str, indexName: str) -> dict | pl.DataFrame: + """ + Get single doc with given doc_id + """ + raise NotImplementedError("Not implemented") + + @abstractmethod + def upsertBulk(self, rows: list[dict], idexName: str): + """ + Update or insert a bulk of rows + """ + raise NotImplementedError("Not implemented") + + @abstractmethod + def update(self, condition: dict, newValue: dict, indexName: str): + """ + Update rows with given conjunctive equivalent filtering condition + """ + raise NotImplementedError("Not implemented") + + @abstractmethod + def delete(self, condition: dict, indexName: str): + """ + Delete rows with given conjunctive equivalent filtering condition + """ + raise NotImplementedError("Not implemented") + + """ + Helper functions for search result + """ + + @abstractmethod + def getTotal(self, res): + raise NotImplementedError("Not implemented") + + @abstractmethod + def getChunkIds(self, res): + raise NotImplementedError("Not implemented") + + @abstractmethod + def getFields(self, res, fields: List[str]) -> Dict[str, dict]: + raise NotImplementedError("Not implemented") + + @abstractmethod + def getHighlight(self, res, keywords: List[str], fieldnm: str): + raise NotImplementedError("Not implemented") + + @abstractmethod + def getAggregation(self, res, fieldnm: str): + raise NotImplementedError("Not implemented") + + """ + SQL + """ + @abstractmethod + def sql(sql: str, fetch_size: int, format: str): + """ + Run the sql generated by text-to-sql + """ + raise NotImplementedError("Not implemented") diff --git a/rag/utils/es_conn.py b/rag/utils/es_conn.py index 8b07be312c3..8c63c84c553 100644 --- a/rag/utils/es_conn.py +++ b/rag/utils/es_conn.py @@ -1,29 +1,30 @@ import re import json import time -import copy +import os +from typing import List, Dict import elasticsearch -from elastic_transport import ConnectionTimeout +import copy from elasticsearch import Elasticsearch -from elasticsearch_dsl import UpdateByQuery, Search, Index -from rag.settings import es_logger +from elasticsearch_dsl import Q, Search, Index +from elastic_transport import ConnectionTimeout +from rag.settings import doc_store_logger from rag import settings from rag.utils import singleton +from api.utils.file_utils import get_project_base_directory +import polars as pl +from rag.utils.data_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr +from rag.nlp import is_english, rag_tokenizer +from . import rmSpace -es_logger.info("Elasticsearch version: "+str(elasticsearch.__version__)) +doc_store_logger.info("Elasticsearch sdk version: "+str(elasticsearch.__version__)) @singleton -class ESConnection: +class ESConnection(DocStoreConnection): def __init__(self): self.info = {} - self.conn() - self.idxnm = settings.ES.get("index_name", "") - if not self.es.ping(): - raise Exception("Can't connect to ES cluster") - - def conn(self): for _ in range(10): try: self.es = Elasticsearch( @@ -34,209 +35,137 @@ def conn(self): ) if self.es: self.info = self.es.info() - es_logger.info("Connect to es.") + doc_store_logger.info("Connect to es.") break except Exception as e: - es_logger.error("Fail to connect to es: " + str(e)) + doc_store_logger.error("Fail to connect to es: " + str(e)) time.sleep(1) - - def version(self): + if not self.es.ping(): + raise Exception("Can't connect to ES cluster") v = self.info.get("version", {"number": "5.6"}) v = v["number"].split(".")[0] - return int(v) >= 7 - - def health(self): + if int(v) < 8: + raise Exception(f"ES version must be greater than or equal to 8, current version: {v}") + fp_mapping = os.path.join(get_project_base_directory(), "conf", "mapping.json") + if not os.path.exists(fp_mapping): + raise Exception(f"Mapping file not found at {fp_mapping}") + self.mapping = json.load(open(fp_mapping, "r")) + + """ + Database operations + """ + def health(self) -> dict: return dict(self.es.cluster.health()) - def upsert(self, df, idxnm=""): - res = [] - for d in df: - id = d["id"] - del d["id"] - d = {"doc": d, "doc_as_upsert": "true"} - T = False - for _ in range(10): - try: - if not self.version(): - r = self.es.update( - index=( - self.idxnm if not idxnm else idxnm), - body=d, - id=id, - doc_type="doc", - refresh=True, - retry_on_conflict=100) - else: - r = self.es.update( - index=( - self.idxnm if not idxnm else idxnm), - body=d, - id=id, - refresh=True, - retry_on_conflict=100) - es_logger.info("Successfully upsert: %s" % id) - T = True - break - except Exception as e: - es_logger.warning("Fail to index: " + - json.dumps(d, ensure_ascii=False) + str(e)) - if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE): - time.sleep(3) - continue - self.conn() - T = False - - if not T: - res.append(d) - es_logger.error( - "Fail to index: " + - re.sub( - "[\r\n]", - "", - json.dumps( - d, - ensure_ascii=False))) - d["id"] = id - d["_index"] = self.idxnm - - if not res: - return True - return False - - def bulk(self, df, idx_nm=None): - ids, acts = {}, [] - for d in df: - id = d["id"] if "id" in d else d["_id"] - ids[id] = copy.deepcopy(d) - ids[id]["_index"] = self.idxnm if not idx_nm else idx_nm - if "id" in d: - del d["id"] - if "_id" in d: - del d["_id"] - acts.append( - {"update": {"_id": id, "_index": ids[id]["_index"]}, "retry_on_conflict": 100}) - acts.append({"doc": d, "doc_as_upsert": "true"}) - - res = [] - for _ in range(100): - try: - if elasticsearch.__version__[0] < 8: - r = self.es.bulk( - index=( - self.idxnm if not idx_nm else idx_nm), - body=acts, - refresh=False, - timeout="600s") - else: - r = self.es.bulk(index=(self.idxnm if not idx_nm else - idx_nm), operations=acts, - refresh=False, timeout="600s") - if re.search(r"False", str(r["errors"]), re.IGNORECASE): - return res - - for it in r["items"]: - if "error" in it["update"]: - res.append(str(it["update"]["_id"]) + - ":" + str(it["update"]["error"])) - - return res - except Exception as e: - es_logger.warn("Fail to bulk: " + str(e)) - if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE): - time.sleep(3) - continue - self.conn() - - return res - - def bulk4script(self, df): - ids, acts = {}, [] - for d in df: - id = d["id"] - ids[id] = copy.deepcopy(d["raw"]) - acts.append({"update": {"_id": id, "_index": self.idxnm}}) - acts.append(d["script"]) - es_logger.info("bulk upsert: %s" % id) + """ + Table operations + """ + def createIdx(self, indexName: str): + try: + from elasticsearch.client import IndicesClient + return IndicesClient(self.es).create(index=indexName, + settings=self.mapping["settings"], + mappings=self.mapping["mappings"]) + except Exception as e: + doc_store_logger.error("ES create index error %s ----%s" % (indexName, str(e))) - res = [] - for _ in range(10): + def deleteIdx(self, indexName: str): + try: + return self.es.indices.delete(indexName, allow_no_indices=True) + except Exception as e: + doc_store_logger.error("ES delete index error %s ----%s" % (indexName, str(e))) + + def indexExist(self, indexName: str) -> bool: + s = Index(indexName, self.es) + for i in range(3): try: - if not self.version(): - r = self.es.bulk( - index=self.idxnm, - body=acts, - refresh=False, - timeout="600s", - doc_type="doc") - else: - r = self.es.bulk( - index=self.idxnm, - body=acts, - refresh=False, - timeout="600s") - if re.search(r"False", str(r["errors"]), re.IGNORECASE): - return res - - for it in r["items"]: - if "error" in it["update"]: - res.append(str(it["update"]["_id"])) - - return res + return s.exists() except Exception as e: - es_logger.warning("Fail to bulk: " + str(e)) - if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE): - time.sleep(3) + doc_store_logger.error("ES updateByQuery indexExist: " + str(e)) + if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: continue - self.conn() - - return res + return False + + """ + CRUD operations + """ + def search(self, selectFields: list[str], highlightFields: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexName: str) -> list[dict] | pl.DataFrame: + """ + Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html + """ + s = Search() + bqry = None + vector_similarity_weight = 0.5 + for m in matchExprs: + if isinstance(m, FusionExpr) and m.method=="weighted_sum" and "weights" in m.fusion_params: + assert len(matchExprs)==3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance(matchExprs[1], MatchDenseExpr) and isinstance(matchExprs[2], FusionExpr) + weights = m.fusion_params["weights"] + vector_similarity_weight = float(weights.split(",")[1]) + for m in matchExprs: + if isinstance(m, MatchTextExpr): + minimum_should_match = "0%" + if "minimum_should_match" in m.extra_options: + minimum_should_match = str(int(m.extra_options["minimum_should_match"] * 100)) + "%" + bqry = Q("bool", + must=Q("query_string", fields=m.fields, + type="best_fields", query=m.matching_text, + minimum_should_match = minimum_should_match, + boost=1), + boost = 1.0 - vector_similarity_weight, + ) + if condition: + for k, v in condition.items(): + if isinstance(v, list): + if v: + bqry.filter.append(Q("terms", **{k: v})) + elif isinstance(v, str) or isinstance(v, int): + bqry.filter.append(Q("term", **{k: v})) + else: + raise Exception(f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.") + elif isinstance(m, MatchDenseExpr): + assert(bqry is not None) + similarity = 0.0 + if "similarity" in m.extra_options: + similarity = m.extra_options["similarity"] + s = s.knn(m.vector_column_name, + m.topn, + m.topn * 2, + query_vector = list(m.embedding_data), + filter = bqry.to_dict(), + similarity = similarity, + ) - def rm(self, d): - for _ in range(10): - try: - if not self.version(): - r = self.es.delete( - index=self.idxnm, - id=d["id"], - doc_type="doc", - refresh=True) - else: - r = self.es.delete( - index=self.idxnm, - id=d["id"], - refresh=True, - doc_type="_doc") - es_logger.info("Remove %s" % d["id"]) - return True - except Exception as e: - es_logger.warn("Fail to delete: " + str(d) + str(e)) - if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE): - time.sleep(3) - continue - if re.search(r"(not_found)", str(e), re.IGNORECASE): - return True - self.conn() + s.query = bqry + for field in highlightFields: + s = s.highlight(field) - es_logger.error("Fail to delete: " + str(d)) + if orderBy: + orders = list() + for field, order in orderBy.fields: + order = "asc" if order == 0 else "desc" + orders.append({field: {"order": order, "unmapped_type": "float", + "mode": "avg", "numeric_type": "double"}}) + s = s.sort(*orders) - return False + if limit!=0: + s = s[offset:limit] + q = s.to_dict() + doc_store_logger.info("ESConnection.search [Q]: " + json.dumps(q)) - def search(self, q, idxnm=None, src=False, timeout="2s"): - if not isinstance(q, dict): - q = Search().query(q).to_dict() for i in range(3): try: - res = self.es.search(index=(self.idxnm if not idxnm else idxnm), + res = self.es.search(index=(indexName), body=q, - timeout=timeout, + timeout="600s", # search_type="dfs_query_then_fetch", track_total_hits=True, - _source=src) + _source=True) if str(res.get("timed_out", "")).lower() == "true": raise Exception("Es Timeout.") + doc_store_logger.info("ESConnection.search res: " + str(res)) return res except Exception as e: - es_logger.error( + doc_store_logger.error( "ES search exception: " + str(e) + "【Q】:" + @@ -244,178 +173,126 @@ def search(self, q, idxnm=None, src=False, timeout="2s"): if str(e).find("Timeout") > 0: continue raise e - es_logger.error("ES search timeout for 3 times!") + doc_store_logger.error("ES search timeout for 3 times!") raise Exception("ES search timeout.") - def sql(self, sql, fetch_size=128, format="json", timeout="2s"): - for i in range(3): - try: - res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout=timeout) - return res - except ConnectionTimeout as e: - es_logger.error("Timeout【Q】:" + sql) - continue - except Exception as e: - raise e - es_logger.error("ES search timeout for 3 times!") - raise ConnectionTimeout() - - - def get(self, doc_id, idxnm=None): + def get(self, docId: str, indexName: str) -> dict: for i in range(3): try: - res = self.es.get(index=(self.idxnm if not idxnm else idxnm), - id=doc_id) + res = self.es.get(index=(indexName), + id=docId) if str(res.get("timed_out", "")).lower() == "true": raise Exception("Es Timeout.") return res except Exception as e: - es_logger.error( + doc_store_logger.error( "ES get exception: " + str(e) + "【Q】:" + - doc_id) + docId) if str(e).find("Timeout") > 0: continue raise e - es_logger.error("ES search timeout for 3 times!") + doc_store_logger.error("ES search timeout for 3 times!") raise Exception("ES search timeout.") - def updateByQuery(self, q, d): - ubq = UpdateByQuery(index=self.idxnm).using(self.es).query(q) - scripts = "" - for k, v in d.items(): - scripts += "ctx._source.%s = params.%s;" % (str(k), str(k)) - ubq = ubq.script(source=scripts, params=d) - ubq = ubq.params(refresh=False) - ubq = ubq.params(slices=5) - ubq = ubq.params(conflicts="proceed") - for i in range(3): - try: - r = ubq.execute() - return True - except Exception as e: - es_logger.error("ES updateByQuery exception: " + - str(e) + "【Q】:" + str(q.to_dict())) - if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: - continue - self.conn() - - return False + def upsertBulk(self, documents: list[dict], indexName: str): + # Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html + acts = [] + for d in documents: + d_copy = copy.deepcopy(d) + meta_id = d_copy["_id"] + del d_copy["_id"] + acts.append( + {"update": {"_id": meta_id, "_index": indexName}, "retry_on_conflict": 100}) + acts.append({"doc": d_copy, "doc_as_upsert": "true"}) - def updateScriptByQuery(self, q, scripts, idxnm=None): - ubq = UpdateByQuery( - index=self.idxnm if not idxnm else idxnm).using( - self.es).query(q) - ubq = ubq.script(source=scripts) - ubq = ubq.params(refresh=True) - ubq = ubq.params(slices=5) - ubq = ubq.params(conflicts="proceed") - for i in range(3): + res = [] + for _ in range(100): try: - r = ubq.execute() - return True - except Exception as e: - es_logger.error("ES updateByQuery exception: " + - str(e) + "【Q】:" + str(q.to_dict())) - if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: - continue - self.conn() - - return False + r = self.es.bulk(index=(indexName), operations=acts, + refresh=False, timeout="600s") + if re.search(r"False", str(r["errors"]), re.IGNORECASE): + return res - def deleteByQuery(self, query, idxnm=""): - for i in range(3): - try: - r = self.es.delete_by_query( - index=idxnm if idxnm else self.idxnm, - refresh = True, - body=Search().query(query).to_dict()) - return True + for it in r["items"]: + if "error" in it["update"]: + res.append(str(it["update"]["_id"]) + + ":" + str(it["update"]["error"])) + return res except Exception as e: - es_logger.error("ES updateByQuery deleteByQuery: " + - str(e) + "【Q】:" + str(query.to_dict())) - if str(e).find("NotFoundError") > 0: return True - if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: + doc_store_logger.warn("Fail to bulk: " + str(e)) + if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE): + time.sleep(3) continue + self.conn() + return res - return False - - def update(self, id, script, routing=None): + def update(self, condition: dict, newValue: dict, indexName: str): + if 'id' not in condition: + raise Exception("Condition must contain id.") + doc = copy.deepcopy(condition) + id = doc['id'] + del doc['id'] for i in range(3): try: - if not self.version(): - r = self.es.update( - index=self.idxnm, - id=id, - body=json.dumps( - script, - ensure_ascii=False), - doc_type="doc", - routing=routing, - refresh=False) - else: - r = self.es.update(index=self.idxnm, id=id, body=json.dumps(script, ensure_ascii=False), - routing=routing, refresh=False) # , doc_type="_doc") + self.es.update(index=indexName, id=id, doc=doc) return True except Exception as e: - es_logger.error( - "ES update exception: " + str(e) + " id:" + str(id) + ", version:" + str(self.version()) + - json.dumps(script, ensure_ascii=False)) + doc_store_logger.error( + "ES update exception: " + str(e) + " id:" + str(id) + + json.dumps(condition, ensure_ascii=False)) if str(e).find("Timeout") > 0: continue - return False - def indexExist(self, idxnm): - s = Index(idxnm if idxnm else self.idxnm, self.es) - for i in range(3): + def delete(self, condition: dict, indexName: str): + qry = None + if "_id" in condition: + chunk_ids = condition["_id"] + if not isinstance(chunk_ids, list): + chunk_ids = [chunk_ids] + qry = Q("ids", values=chunk_ids) + else: + qry = Q("bool") + for k, v in condition.items(): + if isinstance(v, list): + qry.must.append(Q("terms", **{k: v})) + elif isinstance(v, str) or isinstance(v, int): + qry.must.append(Q("term", **{k: v})) + else: + raise Exception("Condition value must be int, str or list.") + doc_store_logger.info("ESConnection.delete [Q]: " + json.dumps(qry.to_dict())) + for _ in range(10): try: - return s.exists() + self.es.delete_by_query( + index=indexName, + body = Search().query(qry).to_dict(), + refresh=True) + return True except Exception as e: - es_logger.error("ES updateByQuery indexExist: " + str(e)) - if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: + doc_store_logger.warn("Fail to delete: " + str(filter) + str(e)) + if re.search(r"(Timeout|time out)", str(e), re.IGNORECASE): + time.sleep(3) continue - + if re.search(r"(not_found)", str(e), re.IGNORECASE): + return True + self.conn() return False - def docExist(self, docid, idxnm=None): - for i in range(3): - try: - return self.es.exists(index=(idxnm if idxnm else self.idxnm), - id=docid) - except Exception as e: - es_logger.error("ES Doc Exist: " + str(e)) - if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0: - continue - return False - - def createIdx(self, idxnm, mapping): - try: - if elasticsearch.__version__[0] < 8: - return self.es.indices.create(idxnm, body=mapping) - from elasticsearch.client import IndicesClient - return IndicesClient(self.es).create(index=idxnm, - settings=mapping["settings"], - mappings=mapping["mappings"]) - except Exception as e: - es_logger.error("ES create index error %s ----%s" % (idxnm, str(e))) - - def deleteIdx(self, idxnm): - try: - return self.es.indices.delete(idxnm, allow_no_indices=True) - except Exception as e: - es_logger.error("ES delete index error %s ----%s" % (idxnm, str(e))) + """ + Helper functions for search result + """ def getTotal(self, res): if isinstance(res["hits"]["total"], type({})): return res["hits"]["total"]["value"] return res["hits"]["total"] - def getDocIds(self, res): + def getChunkIds(self, res): return [d["_id"] for d in res["hits"]["hits"]] - def getSource(self, res): + def __getSource(self, res): rr = [] for d in res["hits"]["hits"]: d["_source"]["id"] = d["_id"] @@ -423,40 +300,89 @@ def getSource(self, res): rr.append(d["_source"]) return rr - def scrollIter(self, pagesize=100, scroll_time='2m', q={ - "query": {"match_all": {}}, "sort": [{"updated_at": {"order": "desc"}}]}): - for _ in range(100): - try: - page = self.es.search( - index=self.idxnm, - scroll=scroll_time, - size=pagesize, - body=q, - _source=None - ) - break - except Exception as e: - es_logger.error("ES scrolling fail. " + str(e)) - time.sleep(3) + def getFields(self, res, fields: List[str]) -> Dict[str, dict]: + res_fields = {} + if not fields: + return {} + for d in self.__getSource(res): + m = {n: d.get(n) for n in fields if d.get(n) is not None} + for n, v in m.items(): + if isinstance(v, list): + m[n] = "\t".join([str(vv) if not isinstance( + vv, list) else "\t".join([str(vvv) for vvv in vv]) for vv in v]) + continue + if not isinstance(v, str): + m[n] = str(m[n]) + if n.find("tks") > 0: + m[n] = rmSpace(m[n]) - sid = page['_scroll_id'] - scroll_size = page['hits']['total']["value"] - es_logger.info("[TOTAL]%d" % scroll_size) - # Start scrolling - while scroll_size > 0: - yield page["hits"]["hits"] - for _ in range(100): - try: - page = self.es.scroll(scroll_id=sid, scroll=scroll_time) - break - except Exception as e: - es_logger.error("ES scrolling fail. " + str(e)) - time.sleep(3) + if m: + res_fields[d["id"]] = m + return res_fields - # Update the scroll ID - sid = page['_scroll_id'] - # Get the number of results that we returned in the last scroll - scroll_size = len(page['hits']['hits']) + def getHighlight(self, res, keywords: List[str], fieldnm: str): + ans = {} + for d in res["hits"]["hits"]: + hlts = d.get("highlight") + if not hlts: + continue + txt = "...".join([a for a in list(hlts.items())[0][1]]) + if not is_english(txt.split(" ")): + ans[d["_id"]] = txt + continue + txt = d["_source"][fieldnm] + txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE|re.MULTILINE) + txts = [] + for t in re.split(r"[.?!;\n]", txt): + for w in keywords: + t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"%re.escape(w), r"\1\2\3", t, flags=re.IGNORECASE|re.MULTILINE) + if not re.search(r"[^<>]+", t, flags=re.IGNORECASE|re.MULTILINE): continue + txts.append(t) + ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]]) + + return ans + + def getAggregation(self, res, fieldnm: str): + agg_field = "aggs_" + fieldnm + if "aggregations" not in res or agg_field not in res["aggregations"]: + return list() + bkts = res["aggregations"][agg_field]["buckets"] + return [(b["key"], b["doc_count"]) for b in bkts] + + + """ + SQL + """ + def sql(self, sql: str, fetch_size: int, format: str): + doc_store_logger.info(f"ESConnection.sql get sql: {sql}") + sql = re.sub(r"[ `]+", " ", sql) + sql = sql.replace("%", "") + replaces = [] + for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql): + fld, v = r.group(1), r.group(3) + match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format( + fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v))) + replaces.append( + ("{}{}'{}'".format( + r.group(1), + r.group(2), + r.group(3)), + match)) + + for p, r in replaces: + sql = sql.replace(p, r, 1) + doc_store_logger.info(f"ESConnection.sql to es: {sql}") -ELASTICSEARCH = ESConnection() + for i in range(3): + try: + res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout="2s") + return res + except ConnectionTimeout: + doc_store_logger.error("ESConnection.sql timeout [Q]: " + sql) + continue + except Exception as e: + doc_store_logger.error(f"ESConnection.sql failure: {sql} => " + str(e)) + return None + doc_store_logger.error("ESConnection.sql timeout for 3 times!") + return None diff --git a/rag/utils/infinity_conn.py b/rag/utils/infinity_conn.py new file mode 100644 index 00000000000..8a9904757c1 --- /dev/null +++ b/rag/utils/infinity_conn.py @@ -0,0 +1,255 @@ +import infinity +from infinity.common import ConflictType +from infinity.index import IndexInfo, IndexType +from infinity.connection_pool import ConnectionPool +from rag import settings +from rag.settings import doc_store_logger +from rag.utils import singleton +import polars as pl + +from rag.utils.data_store_conn import ( + DocStoreConnection, + MatchExpr, + MatchTextExpr, + MatchDenseExpr, + MatchSparseExpr, + MatchTensorExpr, + FusionExpr, + OrderByExpr, +) + + +def equivalent_condition_to_str(condition: dict) -> str: + cond = list() + for k, v in condition.items(): + if not isinstance(k, str): + continue + if isinstance(v, list): + inCond = list() + for item in v: + if isinstance(item, str): + inCond.append(f'{k}="{v}"', k, v) + else: + inCond.append(f'{k}={str(v)}', k, v) + if inCond: + strInCond = f'({" OR ".join(inCond)})' + cond.append(strInCond) + elif isinstance(v, str): + cond.append(f'{k}="{v}"', k, v) + else: + cond.append(f"{k}={str(v)}", k, v) + return " AND ".join(cond) + + +@singleton +class InfinityConnection(DocStoreConnection): + def __init__(self): + self.dbName = settings.INFINITY.get("db_name", "default_db") + infinity_uri = settings.INFINITY["uri"] + if ":" in infinity_uri: + host, port = infinity_uri.split(":") + infinity_uri = infinity.common.NetworkAddress(host, int(port)) + self.inf = ConnectionPool(infinity_uri) + doc_store_logger.info(f"Connected to infinity {infinity_uri}.") + + """ + Database operations + """ + + def health(self) -> dict: + """ + Return the health status of the database. + TODO: Infinity-sdk provides health() to wrap `show global variables` and `show tables` + """ + return dict() + + """ + Table operations + """ + + def createIdx(self, vectorSize: int, indexName: str): + inf_conn = self.inf.get_conn() + inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore) + inf_table = inf_db.create_table( + indexName, + { + "_id": { + "type": "varchar", + "default": "", + }, # ES has this meta field as the primary key for every index + "doc_id": {"type": "varchar", "default": ""}, + "kb_id": {"type": "varchar", "default": ""}, + "create_time": {"type": "varchar", "default": ""}, + "create_timestamp_flt": {"type": "float", "default": 0.0}, + "img_id": {"type": "varchar", "default": ""}, + "docnm_kwd": {"type": "varchar", "default": ""}, + "title_tks": {"type": "varchar", "default": ""}, + "title_sm_tks": {"type": "varchar", "default": ""}, + "name_kwd": {"type": "varchar", "default": ""}, + "important_kwd": {"type": "varchar", "default": ""}, + "important_tks": {"type": "varchar", "default": ""}, + "content_with_weight": { + "type": "varchar", + "default": "", + }, # The raw chunk text + "content_ltks": {"type": "varchar", "default": ""}, + "content_sm_ltks": {"type": "varchar", "default": ""}, + "q_vec": {"type": f"vector,{vectorSize},float"}, + "page_num_int": {"type": "varchar", "default": 0}, + "top_int": {"type": "varchar", "default": 0}, + "position_int": {"type": "varchar", "default": 0}, + "weight_int": {"type": "integer", "default": 0}, + "weight_flt": {"type": "float", "default": 0.0}, + "rank_int": {"type": "integer", "default": 0}, + "available_int": {"type": "integer", "default": 0}, + }, + ConflictType.Ignore, + ) + inf_table.create_index( + "q_vec_idx", + IndexInfo( + "q_vec", + IndexType.Hnsw, + { + "M": "16", + "ef_construction": "50", + "metric": "cosine", + "encode": "lvq", + }, + ), + ConflictType.Ignore, + ) + inf_table.create_index( + "text_idx0", + IndexInfo("title_tks", IndexType.FullText, {"ANALYZER": "standard"}), + ConflictType.Ignore, + ) + inf_table.create_index( + "text_idx1", + IndexInfo("title_sm_tks", IndexType.FullText, {"ANALYZER": "standard"}), + ConflictType.Ignore, + ) + inf_table.create_index( + "text_idx2", + IndexInfo("important_kwd", IndexType.FullText, {"ANALYZER": "standard"}), + ConflictType.Ignore, + ) + inf_table.create_index( + "text_idx3", + IndexInfo("important_tks", IndexType.FullText, {"ANALYZER": "standard"}), + ConflictType.Ignore, + ) + inf_table.create_index( + "text_idx4", + IndexInfo("content_ltks", IndexType.FullText, {"ANALYZER": "standard"}), + ConflictType.Ignore, + ) + inf_table.create_index( + "text_idx5", + IndexInfo("content_sm_ltks", IndexType.FullText, {"ANALYZER": "standard"}), + ConflictType.Ignore, + ) + self.inf.release_conn(inf_conn) + + def deleteIdx(self, indexName: str): + inf_conn = self.inf.get_conn() + db = inf_conn.get_database(self.dbName) + db.drop_table(indexName, ConflictType.Ignore) + self.inf.release_conn(inf_conn) + + def indexExist(self, indexName: str) -> bool: + try: + inf_conn = self.inf.get_conn() + _ = inf_conn.get_table(self.dbName, indexName) + self.inf.release_conn(inf_conn) + return True + except Exception as e: + doc_store_logger.error("INFINITY indexExist: " + str(e)) + return False + + """ + CRUD operations + """ + + def search( + self, selectFields: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexName: str + ) -> list[dict] | pl.DataFrame: + """ + TODO: convert result to dict? + """ + inf_conn = self.inf.get_conn() + table_instance = inf_conn.get_table(self.dbName, indexName) + builder = table_instance.output(selectFields) + if condition: + builder = builder.filter(equivalent_condition_to_str(condition)) + for matchExpr in matchExprs: + if isinstance(matchExpr, MatchTextExpr): + builder = builder.match_text( + matchExpr.fields, + matchExpr.matching_text, + matchExpr.topn, + matchExpr.extra_options, + ) + elif isinstance(matchExpr, MatchDenseExpr): + builder = builder.match_dense( + matchExpr.vector_column_name, + matchExpr.embedding_data, + matchExpr.embedding_data_type, + matchExpr.distance_type, + matchExpr.topn, + matchExpr.knn_params, + ) + elif isinstance(matchExpr, MatchSparseExpr): + builder = builder.match_sparse( + matchExpr.vector_column_name, + matchExpr.sparse_data, + matchExpr.distance_type, + matchExpr.topn, + matchExpr.opt_params, + ) + elif isinstance(matchExpr, MatchTensorExpr): + builder = builder.match_tensor( + matchExpr.column_name, + matchExpr.query_data, + matchExpr.query_data_type, + matchExpr.topn, + matchExpr.extra_option, + ) + elif isinstance(matchExpr, FusionExpr): + builder = builder.fusion( + matchExpr.method, matchExpr.topn, matchExpr.fusion_params + ) + builder.offset(offset).limit(limit) + res = builder.to_pl() + self.inf.release_conn(inf_conn) + return res + + def get(self, docId: str, indexName: str) -> dict | pl.DataFrame: + inf_conn = self.inf.get_conn() + table_instance = inf_conn.get_table(self.dbName, indexName) + res = table_instance.output(["*"]).filter(f"doc_id = '{docId}'").to_pl() + self.inf.release(inf_conn) + return res + + def upsertBulk(self, documents: list[dict], indexName: str): + ids = [f"_id={d['_id']}" for d in documents] + del_filter = " OR ".join(ids) + inf_conn = self.inf.get_conn() + table_instance = inf_conn.get_table(self.dbName, indexName) + table_instance.delete(del_filter) + table_instance.insert(documents) + self.inf.release_conn(inf_conn) + + def update(self, condition: dict, newValue: dict, indexName: str): + inf_conn = self.inf.get_conn() + table_instance = inf_conn.get_table(self.dbName, indexName) + filter = equivalent_condition_to_str(condition) + table_instance.update(filter, newValue) + self.inf.release_conn(inf_conn) + + def delete(self, condition: dict, indexName: str): + inf_conn = self.inf.get_conn() + table_instance = inf_conn.get_table(self.dbName, indexName) + filter = equivalent_condition_to_str(condition) + table_instance.delete(filter) + self.inf.release_conn(inf_conn)