diff --git a/modelscope_agent/rag/emb/__init__.py b/modelscope_agent/rag/emb/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/modelscope_agent/rag/emb/dashscope.py b/modelscope_agent/rag/emb/dashscope.py deleted file mode 100644 index 0bf36dd5..00000000 --- a/modelscope_agent/rag/emb/dashscope.py +++ /dev/null @@ -1,88 +0,0 @@ -import os -from enum import Enum -from http import HTTPStatus -from typing import Any, List, Optional - -import dashscope -from llama_index.core.base.embeddings.base import (DEFAULT_EMBED_BATCH_SIZE, - BaseEmbedding) -from llama_index.core.callbacks import CallbackManager - -# Enums for validation and type safety -DashscopeModelName = [ - 'text-embedding-v1', - 'text-embedding-v2', -] - - -# Assuming BaseEmbedding is a Pydantic model and handles its own initializations -class DashscopeEmbedding(BaseEmbedding): - """DashscopeEmbedding uses the dashscope API to generate embeddings for text.""" - - def __init__( - self, - model_name: str = 'text-embedding-v2', - embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE, - callback_manager: Optional[CallbackManager] = None, - ): - """ - A class representation for generating embeddings using the dashscope API. - - Args: - model_name (str): The name of the model to be used for generating embeddings. The class ensures that - this model is supported and that the input type provided is compatible with the model. - """ - - assert os.environ.get( - 'DASHSCOPE_API_KEY', - None), 'DASHSCOPE_API_KEY should be set in environ.' - - # Validate model_name and input_type - if model_name not in DashscopeModelName: - raise ValueError(f'model {model_name} is not supported.') - - super().__init__( - model_name=model_name, - embed_batch_size=embed_batch_size, - callback_manager=callback_manager, - ) - - @classmethod - def class_name(cls) -> str: - return 'DashscopeEmbedding' - - def _embed(self, - texts: List[str], - text_type='document') -> List[List[float]]: - """Embed sentences using dashscope.""" - resp = dashscope.TextEmbedding.call( - input=texts, - model=self.model_name, - text_type=text_type, - ) - if resp.status_code == HTTPStatus.OK: - res = resp.output['embeddings'] - else: - raise ValueError(f'call dashscope api failed: {resp}') - - return [list(map(float, e['embedding'])) for e in res] - - def _get_query_embedding(self, query: str) -> List[float]: - """Get query embedding.""" - return self._embed([query], text_type='query')[0] - - async def _aget_query_embedding(self, query: str) -> List[float]: - """Get query embedding async.""" - return self._get_query_embedding(query) - - def _get_text_embedding(self, text: str) -> List[float]: - """Get text embedding.""" - return self._embed([text], text_type='document')[0] - - async def _aget_text_embedding(self, text: str) -> List[float]: - """Get text embedding async.""" - return self._get_text_embedding(text) - - def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: - """Get text embeddings.""" - return self._embed(texts, text_type='document')