Skip to content

Commit

Permalink
update langchain
Browse files Browse the repository at this point in the history
  • Loading branch information
qingzhong1 committed Jan 23, 2024
1 parent 60d35e3 commit 7f4d1fd
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,19 @@ def build_index_llama(index_name, embeddings, path=None, url_path=None, abstract
)
index.storage_context.persist(persist_dir=index_name)
return index
elif origin_data:
nodes = [TextNode(text=item.page_content, metadata=item.metadata) for item in origin_data]
text_splitter = SentenceSplitter(chunk_size=1024, chunk_overlap=20)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
service_context = ServiceContext.from_defaults(embed_model=embeddings, text_splitter=text_splitter)
index = VectorStoreIndex(
nodes,
storage_context=storage_context,
show_progress=True,
service_context=service_context,
)
index.storage_context.persist(persist_dir=index_name)
return index


def get_retriver_by_type(frame_type):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Any, Dict, List, Optional

from pydantic import Field

from erniebot_agent.tools.schema import ToolParameterView

from .base import Tool


class LangChainRetrievalToolInputView(ToolParameterView):
query: str = Field(description="查询语句")
top_k: int = Field(description="返回结果数量")


class SearchResponseDocument(ToolParameterView):
title: str = Field(description="检索结果的标题")
document: str = Field(description="检索结果的内容")


class LangChainRetrievalToolOutputView(ToolParameterView):
documents: List[SearchResponseDocument] = Field(description="检索结果,内容和用户输入query相关的段落")


class LangChainRetrievalTool(Tool):
description: str = "在知识库中检索与用户输入query相关的段落"

def __init__(
self,
db,
threshold: float = 0.0,
input_type=None,
output_type=None,
return_meta_data: bool = True,
) -> None:
super().__init__()
self.db = db
self.return_meta_data = return_meta_data
if input_type is not None:
self.input_type = input_type
if output_type is not None:
self.ouptut_type = output_type
self.threshold = threshold

async def __call__(self, query: str, top_k: int = 3, filters: Optional[Dict[str, Any]] = None):
documents = self.db.similarity_search_with_relevance_scores(query, top_k)
docs = []
for doc, score in documents:
if score > self.threshold:
new_doc = {"content": doc.page_content, "score": score}
if self.return_meta_data:
new_doc["meta"] = doc.metadata
docs.append(new_doc)

return {"documents": docs}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import Any, Dict, List, Optional

from pydantic import Field

from erniebot_agent.tools.schema import ToolParameterView

from .base import Tool


class LlamaIndexRetrievalToolInputView(ToolParameterView):
query: str = Field(description="查询语句")
top_k: int = Field(description="返回结果数量")


class SearchResponseDocument(ToolParameterView):
title: str = Field(description="检索结果的标题")
document: str = Field(description="检索结果的内容")


class LlamaIndexRetrievalToolOutputView(ToolParameterView):
documents: List[SearchResponseDocument] = Field(description="检索结果,内容和用户输入query相关的段落")


class LlamaIndexRetrievalTool(Tool):
description: str = "在知识库中检索与用户输入query相关的段落"

def __init__(
self,
db,
embed_model: Optional[str] = None,
threshold: float = 0.0,
input_type=None,
output_type=None,
return_meta_data: bool = True,
) -> None:
super().__init__()
self.db = db
self.embed_model = embed_model
self.return_meta_data = return_meta_data
if input_type is not None:
self.input_type = input_type
if output_type is not None:
self.ouptut_type = output_type
self.threshold = threshold

async def __call__(self, query: str, top_k: int = 3, filters: Optional[Dict[str, Any]] = None):
retriever = self.db.as_retriever(similarity_top_k=top_k)
nodes = retriever.retrieve(query)
docs = []
for doc in nodes:
if doc.score > self.threshold:
new_doc = {"content": doc.node.text, "score": doc.score}
if self.return_meta_data:
new_doc["meta"] = doc.metadata
docs.append(new_doc)
return {"documents": docs}

0 comments on commit 7f4d1fd

Please sign in to comment.