From c6bf3817268ddb9a4b9c8cef2bd8acf82c647354 Mon Sep 17 00:00:00 2001 From: ZiTao-Li Date: Wed, 26 Jun 2024 13:48:07 -0700 Subject: [PATCH 1/9] fix gradio displaying user message twice --- src/agentscope/agents/user_agent.py | 2 +- src/agentscope/logging.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/agentscope/agents/user_agent.py b/src/agentscope/agents/user_agent.py index c18889229..3d0b1cf3e 100644 --- a/src/agentscope/agents/user_agent.py +++ b/src/agentscope/agents/user_agent.py @@ -149,4 +149,4 @@ def speak( f"object, got {type(content)} instead.", ) - logger.chat(msg) + logger.chat(msg, disable_gradio=True) diff --git a/src/agentscope/logging.py b/src/agentscope/logging.py index 61a89e43d..eff7f21b4 100644 --- a/src/agentscope/logging.py +++ b/src/agentscope/logging.py @@ -73,7 +73,7 @@ def _get_speaker_color(speaker: str) -> tuple[str, str]: def _chat( message: dict, *args: Any, - disable_studio: bool = False, + disable_gradio: bool = False, **kwargs: Any, ) -> None: """ @@ -142,15 +142,15 @@ def _chat( **kwargs, ) - if hasattr(thread_local_data, "uid") and not disable_studio: - log_studio(message, thread_local_data.uid, **kwargs) + if hasattr(thread_local_data, "uid") and not disable_gradio: + log_gradio(message, thread_local_data.uid, **kwargs) return logger.log(LEVEL_DISPLAY_MSG, message, *args, **kwargs) logger.log(LEVEL_SAVE_LOG, message, *args, **kwargs) -def log_studio(message: dict, uid: str, **kwargs: Any) -> None: +def log_gradio(message: dict, uid: str, **kwargs: Any) -> None: """Send chat message to studio. Args: From acef15e758489134eeb23d1a812eed3894db4f60 Mon Sep 17 00:00:00 2001 From: ZiTao-Li Date: Mon, 1 Jul 2024 00:10:58 -0700 Subject: [PATCH 2/9] add local embedding model example/instruction in tutorial. --- docs/sphinx_doc/en/source/tutorial/210-rag.md | 101 +++++++++++++++++ .../zh_CN/source/tutorial/210-rag.md | 104 ++++++++++++++++++ 2 files changed, 205 insertions(+) diff --git a/docs/sphinx_doc/en/source/tutorial/210-rag.md b/docs/sphinx_doc/en/source/tutorial/210-rag.md index 867fdb2ec..b1c3c7e2f 100644 --- a/docs/sphinx_doc/en/source/tutorial/210-rag.md +++ b/docs/sphinx_doc/en/source/tutorial/210-rag.md @@ -190,6 +190,107 @@ RAG agent is an agent that can generate answers based on the retrieved knowledge Your agent will be equipped with a list of knowledge according to the `knowledge_id_list`. You can decide how to use the retrieved content and even update and refresh the index in your agent's `reply` function. +## (Optional) Setting up a local embedding model service + +For those who are interested in setting up a local embedding service, we provide the following example based on the +`sentence_transformers` package, which is a popular specialized package for embedding models (based on the `transformer` package and compatible with both HuggingFace and ModelScope models). +In this example, we will use one of the SOTA embedding models, `gte-Qwen2-7B-instruct`. + +* Step 1: follow the instruction on [HuggingFace](https://huggingface.co/Alibaba-NLP/gte-Qwen2-7B-instruct) or [ModelScope](https://www.modelscope.cn/models/iic/gte_Qwen2-7B-instruct ) to download the embedding model. +* Step 2: Set up the server. The following code is for reference. + +```python +import datetime +import argparse + +from flask import Flask +from flask import request +from sentence_transformers import SentenceTransformer + +def create_timestamp(format_: str = "%Y-%m-%d %H:%M:%S") -> str: + """Get current timestamp.""" + return datetime.datetime.now().strftime(format_) + +app = Flask(__name__) + +@app.route("/embedding/", methods=["POST"]) +def get_embedding() -> dict: + """Receive post request and return response""" + json = request.get_json() + + inputs = json.pop("inputs") + + global model + + if isinstance(inputs, str): + inputs = [inputs] + + embeddings = model.encode(inputs) + + return { + "data": { + "completion_tokens": 0, + "messages": {}, + "prompt_tokens": 0, + "response": { + "data": [ + { + "embedding": emb.astype(float).tolist(), + } + for emb in embeddings + ], + "created": "", + "id": create_timestamp(), + "model": "flask_model", + "object": "text_completion", + "usage": { + "completion_tokens": 0, + "prompt_tokens": 0, + "total_tokens": 0, + }, + }, + "total_tokens": 0, + "username": "", + }, + } + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_name_or_path", type=str, required=True) + parser.add_argument("--device", type=str, default="auto") + parser.add_argument("--port", type=int, default=8000) + args = parser.parse_args() + + print("setting up for embedding model....") + model = SentenceTransformer( + args.model_name_or_path + ) + + app.run(port=args.port) +``` + +* Step 3: start server. +```bash +python setup_ms_service.py --model_name_or_path {$PATH_TO_gte_Qwen2_7B_instruct} +``` + + +Testing whether the model is running successfully. +```python +from agentscope.models.post_model import PostAPIEmbeddingWrapper + + +model = PostAPIEmbeddingWrapper( + config_name="test_config", + api_url="http://127.0.0.1:8000/embedding/", + json_args={ + "max_length": 4096, + "temperature": 0.5 + } +) + +print(model("testing")) +``` [[Back to the top]](#210-rag-en) diff --git a/docs/sphinx_doc/zh_CN/source/tutorial/210-rag.md b/docs/sphinx_doc/zh_CN/source/tutorial/210-rag.md index 7a0efd7d0..a61c40e72 100644 --- a/docs/sphinx_doc/zh_CN/source/tutorial/210-rag.md +++ b/docs/sphinx_doc/zh_CN/source/tutorial/210-rag.md @@ -174,6 +174,110 @@ RAG 智能体是可以基于检索到的知识生成答案的智能体。 **自己搭建 RAG 智能体.** 只要您的智能体配置具有`knowledge_id_list`,您就可以将一个agent和这个列表传递给`KnowledgeBank.equip`;这样该agent就是被装配`knowledge_id`。 您可以在`reply`函数中自己决定如何从`Knowledge`对象中提取和使用信息,甚至通过`Knowledge`修改知识库。 + +## (拓展) 架设自己的embedding model服务 + +我们在此也对架设本地embedding model感兴趣的用户提供以下的样例。 +以下样例基于在embedding model范围中很受欢迎的`sentence_transformers` 包(基于`transformer` 而且兼容HuggingFace和ModelScope的模型)。 +这个样例中,我们会使用当下最好的文本向量模型之一`gte-Qwen2-7B-instruct`。 + + +* 第一步: 遵循在 [HuggingFace](https://huggingface.co/Alibaba-NLP/gte-Qwen2-7B-instruct) 或者 [ModelScope](https://www.modelscope.cn/models/iic/gte_Qwen2-7B-instruct )的指示下载模型。 +* 第二步: 设置服务器。以下是一段参考代码。 + +```python +import datetime +import argparse + +from flask import Flask +from flask import request +from sentence_transformers import SentenceTransformer + +def create_timestamp(format_: str = "%Y-%m-%d %H:%M:%S") -> str: + """Get current timestamp.""" + return datetime.datetime.now().strftime(format_) + +app = Flask(__name__) + +@app.route("/embedding/", methods=["POST"]) +def get_embedding() -> dict: + """Receive post request and return response""" + json = request.get_json() + + inputs = json.pop("inputs") + + global model + + if isinstance(inputs, str): + inputs = [inputs] + + embeddings = model.encode(inputs) + + return { + "data": { + "completion_tokens": 0, + "messages": {}, + "prompt_tokens": 0, + "response": { + "data": [ + { + "embedding": emb.astype(float).tolist(), + } + for emb in embeddings + ], + "created": "", + "id": create_timestamp(), + "model": "flask_model", + "object": "text_completion", + "usage": { + "completion_tokens": 0, + "prompt_tokens": 0, + "total_tokens": 0, + }, + }, + "total_tokens": 0, + "username": "", + }, + } + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_name_or_path", type=str, required=True) + parser.add_argument("--device", type=str, default="auto") + parser.add_argument("--port", type=int, default=8000) + args = parser.parse_args() + + print("setting up for embedding model....") + model = SentenceTransformer( + args.model_name_or_path + ) + + app.run(port=args.port) +``` + +* 第三部:启动服务器。 +```bash +python setup_ms_service.py --model_name_or_path {$PATH_TO_gte_Qwen2_7B_instruct} +``` + + +测试服务是否成功启动。 +```python +from agentscope.models.post_model import PostAPIEmbeddingWrapper + + +model = PostAPIEmbeddingWrapper( + config_name="test_config", + api_url="http://127.0.0.1:8000/embedding/", + json_args={ + "max_length": 4096, + "temperature": 0.5 + } +) + +print(model("testing")) +``` + [[回到顶部]](#210-rag-zh) From f9afd6ebc950246e7310c317ede39412d8a72b97 Mon Sep 17 00:00:00 2001 From: ZiTao-Li Date: Mon, 1 Jul 2024 00:20:01 -0700 Subject: [PATCH 3/9] minor fix --- docs/sphinx_doc/en/source/tutorial/210-rag.md | 2 ++ docs/sphinx_doc/zh_CN/source/tutorial/210-rag.md | 2 ++ 2 files changed, 4 insertions(+) diff --git a/docs/sphinx_doc/en/source/tutorial/210-rag.md b/docs/sphinx_doc/en/source/tutorial/210-rag.md index b1c3c7e2f..277b06c58 100644 --- a/docs/sphinx_doc/en/source/tutorial/210-rag.md +++ b/docs/sphinx_doc/en/source/tutorial/210-rag.md @@ -260,6 +260,8 @@ if __name__ == "__main__": parser.add_argument("--device", type=str, default="auto") parser.add_argument("--port", type=int, default=8000) args = parser.parse_args() + + global model print("setting up for embedding model....") model = SentenceTransformer( diff --git a/docs/sphinx_doc/zh_CN/source/tutorial/210-rag.md b/docs/sphinx_doc/zh_CN/source/tutorial/210-rag.md index a61c40e72..30ebd89ea 100644 --- a/docs/sphinx_doc/zh_CN/source/tutorial/210-rag.md +++ b/docs/sphinx_doc/zh_CN/source/tutorial/210-rag.md @@ -246,6 +246,8 @@ if __name__ == "__main__": parser.add_argument("--device", type=str, default="auto") parser.add_argument("--port", type=int, default=8000) args = parser.parse_args() + + global model print("setting up for embedding model....") model = SentenceTransformer( From 9af802b5dcf31c56aa8f3bcd4afddd0633944fd1 Mon Sep 17 00:00:00 2001 From: ZiTao-Li Date: Mon, 1 Jul 2024 00:25:17 -0700 Subject: [PATCH 4/9] fix format --- docs/sphinx_doc/en/source/tutorial/210-rag.md | 6 +++--- docs/sphinx_doc/zh_CN/source/tutorial/210-rag.md | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/sphinx_doc/en/source/tutorial/210-rag.md b/docs/sphinx_doc/en/source/tutorial/210-rag.md index 277b06c58..e0739f047 100644 --- a/docs/sphinx_doc/en/source/tutorial/210-rag.md +++ b/docs/sphinx_doc/en/source/tutorial/210-rag.md @@ -192,7 +192,7 @@ You can decide how to use the retrieved content and even update and refresh the ## (Optional) Setting up a local embedding model service -For those who are interested in setting up a local embedding service, we provide the following example based on the +For those who are interested in setting up a local embedding service, we provide the following example based on the `sentence_transformers` package, which is a popular specialized package for embedding models (based on the `transformer` package and compatible with both HuggingFace and ModelScope models). In this example, we will use one of the SOTA embedding models, `gte-Qwen2-7B-instruct`. @@ -224,7 +224,7 @@ def get_embedding() -> dict: if isinstance(inputs, str): inputs = [inputs] - + embeddings = model.encode(inputs) return { @@ -260,7 +260,7 @@ if __name__ == "__main__": parser.add_argument("--device", type=str, default="auto") parser.add_argument("--port", type=int, default=8000) args = parser.parse_args() - + global model print("setting up for embedding model....") diff --git a/docs/sphinx_doc/zh_CN/source/tutorial/210-rag.md b/docs/sphinx_doc/zh_CN/source/tutorial/210-rag.md index 30ebd89ea..268c2e425 100644 --- a/docs/sphinx_doc/zh_CN/source/tutorial/210-rag.md +++ b/docs/sphinx_doc/zh_CN/source/tutorial/210-rag.md @@ -210,7 +210,7 @@ def get_embedding() -> dict: if isinstance(inputs, str): inputs = [inputs] - + embeddings = model.encode(inputs) return { @@ -246,7 +246,7 @@ if __name__ == "__main__": parser.add_argument("--device", type=str, default="auto") parser.add_argument("--port", type=int, default=8000) args = parser.parse_args() - + global model print("setting up for embedding model....") From 79bc4b36e9af273c6ac9f9494ce568b9a47d5ad0 Mon Sep 17 00:00:00 2001 From: ZiTao-Li Date: Mon, 1 Jul 2024 20:03:09 -0700 Subject: [PATCH 5/9] update following comments --- docs/sphinx_doc/en/source/tutorial/210-rag.md | 4 +++- docs/sphinx_doc/zh_CN/source/tutorial/210-rag.md | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/sphinx_doc/en/source/tutorial/210-rag.md b/docs/sphinx_doc/en/source/tutorial/210-rag.md index e0739f047..1af8fdaa4 100644 --- a/docs/sphinx_doc/en/source/tutorial/210-rag.md +++ b/docs/sphinx_doc/en/source/tutorial/210-rag.md @@ -196,7 +196,9 @@ For those who are interested in setting up a local embedding service, we provide `sentence_transformers` package, which is a popular specialized package for embedding models (based on the `transformer` package and compatible with both HuggingFace and ModelScope models). In this example, we will use one of the SOTA embedding models, `gte-Qwen2-7B-instruct`. -* Step 1: follow the instruction on [HuggingFace](https://huggingface.co/Alibaba-NLP/gte-Qwen2-7B-instruct) or [ModelScope](https://www.modelscope.cn/models/iic/gte_Qwen2-7B-instruct ) to download the embedding model. +* Step 1: Follow the instruction on [HuggingFace](https://huggingface.co/Alibaba-NLP/gte-Qwen2-7B-instruct) or [ModelScope](https://www.modelscope.cn/models/iic/gte_Qwen2-7B-instruct ) to download the embedding model. + (For those who cannot access HuggingFace directly, you may want to use a HuggingFace mirror by running a bash command + `export HF_ENDPOINT=https://hf-mirror.com` or add a line of code `os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"` in your Python code.) * Step 2: Set up the server. The following code is for reference. ```python diff --git a/docs/sphinx_doc/zh_CN/source/tutorial/210-rag.md b/docs/sphinx_doc/zh_CN/source/tutorial/210-rag.md index 268c2e425..7921dd31d 100644 --- a/docs/sphinx_doc/zh_CN/source/tutorial/210-rag.md +++ b/docs/sphinx_doc/zh_CN/source/tutorial/210-rag.md @@ -183,6 +183,7 @@ RAG 智能体是可以基于检索到的知识生成答案的智能体。 * 第一步: 遵循在 [HuggingFace](https://huggingface.co/Alibaba-NLP/gte-Qwen2-7B-instruct) 或者 [ModelScope](https://www.modelscope.cn/models/iic/gte_Qwen2-7B-instruct )的指示下载模型。 + (如果无法直接从HuggingFace下载模型,也可以考虑使用HuggingFace镜像:bash命令行`export HF_ENDPOINT=https://hf-mirror.com`,或者在Python代码中加入`os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"`) * 第二步: 设置服务器。以下是一段参考代码。 ```python From 551501e1dfe1c9ac8cfb6fed12b98ec44c312025 Mon Sep 17 00:00:00 2001 From: ZiTao-Li Date: Mon, 1 Jul 2024 20:03:55 -0700 Subject: [PATCH 6/9] fix format --- docs/sphinx_doc/en/source/tutorial/210-rag.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sphinx_doc/en/source/tutorial/210-rag.md b/docs/sphinx_doc/en/source/tutorial/210-rag.md index 1af8fdaa4..39c3ecce0 100644 --- a/docs/sphinx_doc/en/source/tutorial/210-rag.md +++ b/docs/sphinx_doc/en/source/tutorial/210-rag.md @@ -197,7 +197,7 @@ For those who are interested in setting up a local embedding service, we provide In this example, we will use one of the SOTA embedding models, `gte-Qwen2-7B-instruct`. * Step 1: Follow the instruction on [HuggingFace](https://huggingface.co/Alibaba-NLP/gte-Qwen2-7B-instruct) or [ModelScope](https://www.modelscope.cn/models/iic/gte_Qwen2-7B-instruct ) to download the embedding model. - (For those who cannot access HuggingFace directly, you may want to use a HuggingFace mirror by running a bash command + (For those who cannot access HuggingFace directly, you may want to use a HuggingFace mirror by running a bash command `export HF_ENDPOINT=https://hf-mirror.com` or add a line of code `os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"` in your Python code.) * Step 2: Set up the server. The following code is for reference. From ae9cede82c14a7cd4a65af7495630c9557c55bf4 Mon Sep 17 00:00:00 2001 From: ZiTao-Li Date: Thu, 22 Aug 2024 11:17:56 -0700 Subject: [PATCH 7/9] security check to see if a URL is pointing to internal IP --- src/agentscope/service/web/web_digest.py | 36 +++++++++++++++++++++++- tests/web_digest_test.py | 7 ++++- 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/src/agentscope/service/web/web_digest.py b/src/agentscope/service/web/web_digest.py index 5e39fdc4d..588163c5a 100644 --- a/src/agentscope/service/web/web_digest.py +++ b/src/agentscope/service/web/web_digest.py @@ -5,7 +5,8 @@ from typing import Optional, Callable, Sequence, Any import requests from loguru import logger - +import socket +import ipaddress from agentscope.service.service_response import ServiceResponse from agentscope.service.service_status import ServiceExecStatus @@ -36,6 +37,30 @@ def is_valid_url(url: str) -> bool: except ValueError: return False # A ValueError indicates that the URL is not valid. +def sanitize_url(url: str) -> bool: + """ + Check if a URL is to interal IP addresses + Args: + url (str): url to be checked + + Returns: + bool: True if url is not to interal IP addresses, + False otherwise + """ + parsed_url = urlparse(url) + hostname = parsed_url.hostname + # Resolve the hostname to an IP address + ip = socket.gethostbyname(hostname) + # Check if it's localhost or within the loopback range + if (ip.startswith("127.") + or ip == "::1" + or ipaddress.ip_address(ip).is_private): + logger.warning(f"Access to this URL {url} is " + f"restricted because it is private") + return False + + return True + def load_web( url: str, @@ -43,6 +68,7 @@ def load_web( html_selected_tags: Optional[Sequence[str]] = None, self_parse_func: Optional[Callable[[requests.Response], Any]] = None, timeout: int = 5, + exclude_internal_ips: bool = True, ) -> ServiceResponse: """Function for parsing and digesting the web page. @@ -62,6 +88,8 @@ def load_web( The result is stored with `self_define_func` key timeout (int): timeout parameter for requests. + exclude_internal_ips (bool): + whether prevent the function access internal_ips Returns: `ServiceResponse`: If successful, `ServiceResponse` object is returned @@ -87,6 +115,12 @@ def load_web( "selected_tags_text": xxxxx } """ + if exclude_internal_ips and not sanitize_url(url): + return ServiceResponse( + ServiceExecStatus.ERROR, + content=f"Access to this URL {url} is restricted because it is private" + ) + header = { "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8,en-GB;q=0.7,en-US;q=0.6", "Cache-Control": "max-age=0", diff --git a/tests/web_digest_test.py b/tests/web_digest_test.py index 3b62b0e36..2435d32d5 100644 --- a/tests/web_digest_test.py +++ b/tests/web_digest_test.py @@ -58,7 +58,7 @@ def test_web_load(self, mock_get: MagicMock) -> None: mock_get.return_value = mock_response # set parameters - fake_url = "fake-url" + fake_url = "http://fake-url.com" results = load_web( url=fake_url, @@ -100,6 +100,11 @@ def format( expected_result, ) + def test_block_internal_ips(self) -> None: + """test whether can prevent internal_url successfully""" + internal_url = "http://localhost:8080/some/path" + response = load_web(internal_url) + self.assertEqual(ServiceExecStatus.ERROR, response.status) # This allows the tests to be run from the command line if __name__ == "__main__": From 603311c15bdde88cea6106ca834e8eb51a533ee7 Mon Sep 17 00:00:00 2001 From: ZiTao-Li Date: Thu, 22 Aug 2024 11:26:06 -0700 Subject: [PATCH 8/9] minor fix --- src/agentscope/service/web/web_digest.py | 35 ++++++++++++++++-------- tests/web_digest_test.py | 1 + 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/src/agentscope/service/web/web_digest.py b/src/agentscope/service/web/web_digest.py index 588163c5a..95864be8f 100644 --- a/src/agentscope/service/web/web_digest.py +++ b/src/agentscope/service/web/web_digest.py @@ -3,10 +3,11 @@ import json from urllib.parse import urlparse from typing import Optional, Callable, Sequence, Any -import requests -from loguru import logger import socket import ipaddress +import requests +from loguru import logger + from agentscope.service.service_response import ServiceResponse from agentscope.service.service_status import ServiceExecStatus @@ -37,7 +38,8 @@ def is_valid_url(url: str) -> bool: except ValueError: return False # A ValueError indicates that the URL is not valid. -def sanitize_url(url: str) -> bool: + +def is_internal_ip_address(url: str) -> bool: """ Check if a URL is to interal IP addresses Args: @@ -49,17 +51,25 @@ def sanitize_url(url: str) -> bool: """ parsed_url = urlparse(url) hostname = parsed_url.hostname + if hostname is None: + # illegal hostname is ignore in this function + return False + # Resolve the hostname to an IP address ip = socket.gethostbyname(hostname) # Check if it's localhost or within the loopback range - if (ip.startswith("127.") - or ip == "::1" - or ipaddress.ip_address(ip).is_private): - logger.warning(f"Access to this URL {url} is " - f"restricted because it is private") - return False + if ( + ip.startswith("127.") + or ip == "::1" + or ipaddress.ip_address(ip).is_private + ): + logger.warning( + f"Access to this URL {url} is " + f"restricted because it is private", + ) + return True - return True + return False def load_web( @@ -115,10 +125,11 @@ def load_web( "selected_tags_text": xxxxx } """ - if exclude_internal_ips and not sanitize_url(url): + if exclude_internal_ips and is_internal_ip_address(url): return ServiceResponse( ServiceExecStatus.ERROR, - content=f"Access to this URL {url} is restricted because it is private" + content=f"Access to this URL {url} is restricted " + f"because it is private", ) header = { diff --git a/tests/web_digest_test.py b/tests/web_digest_test.py index 2435d32d5..d08d0e489 100644 --- a/tests/web_digest_test.py +++ b/tests/web_digest_test.py @@ -106,6 +106,7 @@ def test_block_internal_ips(self) -> None: response = load_web(internal_url) self.assertEqual(ServiceExecStatus.ERROR, response.status) + # This allows the tests to be run from the command line if __name__ == "__main__": unittest.main() From 84dc581edb990ff4292dbf20a02bbe2ccc2adde6 Mon Sep 17 00:00:00 2001 From: ZiTao-Li Date: Wed, 16 Oct 2024 21:01:23 -0700 Subject: [PATCH 9/9] minor fix on embedding related content --- src/agentscope/manager/_file.py | 6 ++---- src/agentscope/rag/llama_index_knowledge.py | 3 ++- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/agentscope/manager/_file.py b/src/agentscope/manager/_file.py index 8fe93b171..d49f4488b 100644 --- a/src/agentscope/manager/_file.py +++ b/src/agentscope/manager/_file.py @@ -34,13 +34,12 @@ def _get_text_embedding_record_hash( if isinstance(embedding_model, dict): # Format the dict to avoid duplicate keys embedding_model = json.dumps(embedding_model, sort_keys=True) - elif isinstance(embedding_model, str): - embedding_model_hash = _hash_string(embedding_model, hash_method) - else: + elif not isinstance(embedding_model, str): raise RuntimeError( f"The embedding model must be a string or a dict, got " f"{type(embedding_model)}.", ) + embedding_model_hash = _hash_string(embedding_model, hash_method) # Calculate the embedding id by hashing the hash codes of the # original data and the embedding model @@ -48,7 +47,6 @@ def _get_text_embedding_record_hash( original_data_hash + embedding_model_hash, hash_method, ) - return record_hash diff --git a/src/agentscope/rag/llama_index_knowledge.py b/src/agentscope/rag/llama_index_knowledge.py index 142f71068..b886825ff 100644 --- a/src/agentscope/rag/llama_index_knowledge.py +++ b/src/agentscope/rag/llama_index_knowledge.py @@ -203,8 +203,9 @@ def __init__( ) if persist_root is None: - persist_root = FileManager.get_instance().run_dir or "./" + persist_root = FileManager.get_instance().cache_dir or "./" self.persist_dir = os.path.join(persist_root, knowledge_id) + logger.info(f"** persist_dir: {self.persist_dir}") self.emb_model = emb_model self.overwrite_index = overwrite_index self.showprogress = showprogress