-
Notifications
You must be signed in to change notification settings - Fork 322
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: skyline2006 <[email protected]> Co-authored-by: Zhicheng Zhang <[email protected]>
- Loading branch information
1 parent
2f50c6b
commit 6bab5fe
Showing
40 changed files
with
1,741 additions
and
457 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import os | ||
from typing import Any, Dict, List, Union | ||
|
||
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex | ||
from llama_index.core.llama_pack.base import BaseLlamaPack | ||
from llama_index.core.readers.base import BaseReader | ||
|
||
|
||
class Knowledge(BaseLlamaPack): | ||
""" rag pipeline. | ||
从不同的源加载知识,支持:文件夹路径(str),文件路径列表(list),将不同源配置到不同的召回方式(dict). | ||
Automatically select the best file reader given file extensions. | ||
Args: | ||
knowledge_source: Path to the directory,或文件路径列表,或指定召回方式的文件路径。 | ||
cache_dir: 缓存indexing后的信息。 | ||
""" | ||
|
||
def __init__(self, | ||
knowledge_source: Union[List, str, Dict], | ||
cache_dir: str = './run', | ||
**kwargs) -> None: | ||
|
||
# extra_readers = self.get_extra_readers() | ||
self.documents = [] | ||
if isinstance(knowledge_source, str): | ||
if os.path.exists(knowledge_source): | ||
self.documents.append( | ||
SimpleDirectoryReader( | ||
input_dir=knowledge_source, | ||
recursive=True).load_data()) | ||
|
||
self.documents = SimpleDirectoryReader( | ||
input_files=knowledge_source).load_data() | ||
|
||
def get_extra_readers(self) -> Dict[str, BaseReader]: | ||
return {} | ||
|
||
def get_modules(self) -> Dict[str, Any]: | ||
"""Get modules for rewrite.""" | ||
return { | ||
'node_parser': self.node_parser, | ||
'recursive_retriever': self.recursive_retriever, | ||
'query_engines': self.query_engines, | ||
'reader': self.path_reader, | ||
} | ||
|
||
def run(self, query: str, **kwargs) -> str: | ||
return self.query_engine.query(query, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import os | ||
from enum import Enum | ||
from http import HTTPStatus | ||
from typing import Any, List, Optional | ||
|
||
import dashscope | ||
from llama_index.legacy.bridge.pydantic import Field | ||
from llama_index.legacy.callbacks import CallbackManager | ||
from llama_index.legacy.core.embeddings.base import (DEFAULT_EMBED_BATCH_SIZE, | ||
BaseEmbedding) | ||
|
||
# Enums for validation and type safety | ||
DashscopeModelName = [ | ||
'text-embedding-v1', | ||
'text-embedding-v2', | ||
] | ||
|
||
|
||
# Assuming BaseEmbedding is a Pydantic model and handles its own initializations | ||
class DashscopeEmbedding(BaseEmbedding): | ||
"""DashscopeEmbedding uses the dashscope API to generate embeddings for text.""" | ||
|
||
def __init__( | ||
self, | ||
model_name: str = 'text-embedding-v2', | ||
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE, | ||
callback_manager: Optional[CallbackManager] = None, | ||
): | ||
""" | ||
A class representation for generating embeddings using the dashscope API. | ||
Args: | ||
model_name (str): The name of the model to be used for generating embeddings. The class ensures that | ||
this model is supported and that the input type provided is compatible with the model. | ||
""" | ||
|
||
assert os.environ.get( | ||
'DASHSCOPE_API_KEY', | ||
None), 'DASHSCOPE_API_KEY should be set in environ.' | ||
|
||
# Validate model_name and input_type | ||
if model_name not in DashscopeModelName: | ||
raise ValueError(f'model {model_name} is not supported.') | ||
|
||
super().__init__( | ||
model_name=model_name, | ||
embed_batch_size=embed_batch_size, | ||
callback_manager=callback_manager, | ||
) | ||
|
||
@classmethod | ||
def class_name(cls) -> str: | ||
return 'DashscopeEmbedding' | ||
|
||
def _embed(self, | ||
texts: List[str], | ||
text_type='document') -> List[List[float]]: | ||
"""Embed sentences using dashscope.""" | ||
resp = dashscope.TextEmbedding.call( | ||
input=texts, | ||
model=self.model_name, | ||
text_type=text_type, | ||
) | ||
if resp.status_code == HTTPStatus.OK: | ||
res = resp.output['embeddings'] | ||
else: | ||
raise ValueError(f'call dashscope api failed: {resp}') | ||
|
||
return [list(map(float, e['embedding'])) for e in res] | ||
|
||
def _get_query_embedding(self, query: str) -> List[float]: | ||
"""Get query embedding.""" | ||
return self._embed([query], text_type='query')[0] | ||
|
||
async def _aget_query_embedding(self, query: str) -> List[float]: | ||
"""Get query embedding async.""" | ||
return self._get_query_embedding(query) | ||
|
||
def _get_text_embedding(self, text: str) -> List[float]: | ||
"""Get text embedding.""" | ||
return self._embed([text], text_type='document')[0] | ||
|
||
async def _aget_text_embedding(self, text: str) -> List[float]: | ||
"""Get text embedding async.""" | ||
return self._get_text_embedding(text) | ||
|
||
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: | ||
"""Get text embeddings.""" | ||
return self._embed(texts, text_type='document') |
Oops, something went wrong.