diff --git a/docs/source/index.rst b/docs/source/index.rst index baff8638..d458e6c7 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -52,6 +52,14 @@ Modelscope-Agent DOCUMENTATION deployment/local_deploy.md +.. toctree:: + :maxdepth: 2 + :caption: Agents + + agents/data_science_assistant.md + + + Indices and tables ================== diff --git a/docs/source_en/index.rst b/docs/source_en/index.rst index 849089e6..36d2677f 100644 --- a/docs/source_en/index.rst +++ b/docs/source_en/index.rst @@ -50,6 +50,11 @@ Modelscope-Agent DOCUMENTATION use_cases/openAPI_for_agent.md deployment/local_deploy.md +.. toctree:: + :maxdepth: 2 + :caption: Agents + + agents/data_science_assistant.md Indices and tables diff --git a/modelscope_agent/llm/__init__.py b/modelscope_agent/llm/__init__.py index c06ed0af..2e036570 100644 --- a/modelscope_agent/llm/__init__.py +++ b/modelscope_agent/llm/__init__.py @@ -17,6 +17,7 @@ def get_chat_model(model: str, model_server: str, **kwargs) -> BaseChatModel: """ model_type = re.split(r'[-/_]', model)[0] # parser qwen / gpt / ... registered_model_id = f'{model_server}_{model_type}' + if registered_model_id in LLM_REGISTRY: # specific model from specific source return LLM_REGISTRY[registered_model_id](model, model_server, **kwargs) elif model_server in LLM_REGISTRY: # specific source diff --git a/modelscope_agent/llm/openai.py b/modelscope_agent/llm/openai.py index 89f757e5..cfbacdff 100644 --- a/modelscope_agent/llm/openai.py +++ b/modelscope_agent/llm/openai.py @@ -4,28 +4,55 @@ from modelscope_agent.llm.base import BaseChatModel, register_llm from modelscope_agent.utils.logger import agent_logger as logger from modelscope_agent.utils.retry import retry -from openai import OpenAI +from openai import AzureOpenAI, OpenAI @register_llm('openai') +@register_llm('azure_openai') class OpenAi(BaseChatModel): - def __init__(self, - model: str, - model_server: str, - is_chat: bool = True, - is_function_call: Optional[bool] = None, - support_stream: Optional[bool] = None, - **kwargs): + def __init__( + self, + model: str, + model_server: str, + is_chat: bool = True, + is_function_call: Optional[bool] = None, + support_stream: Optional[bool] = None, + **kwargs, + ): super().__init__(model, model_server, is_function_call) - default_api_base = os.getenv('OPENAI_API_BASE', - 'https://api.openai.com/v1') - api_base = kwargs.get('api_base', default_api_base).strip() - api_key = kwargs.get('api_key', - os.getenv('OPENAI_API_KEY', - default='EMPTY')).strip() - logger.info(f'client url {api_base}, client key: {api_key}') - self.client = OpenAI(api_key=api_key, base_url=api_base) + + self.is_azure = model_server.lower().startswith('azure') + if self.is_azure: + default_azure_endpoint = os.getenv( + 'AZURE_OPENAI_ENDPOINT', + 'https://docs-test-001.openai.azure.com/') + azure_endpoint = kwargs.get('azure_endpoint', + default_azure_endpoint).strip() + api_key = kwargs.get( + 'api_key', os.getenv('AZURE_OPENAI_API_KEY', + default='EMPTY')).strip() + api_version = kwargs.get('api_version', '2024-06-01').strip() + logger.info( + f'client url {azure_endpoint}, client key: {api_key}, client version: {api_version}' + ) + + self.client = AzureOpenAI( + azure_endpoint=azure_endpoint, + api_key=api_key, + api_version=api_version, + ) + else: + default_api_base = os.getenv('OPENAI_API_BASE', + 'https://api.openai.com/v1') + api_base = kwargs.get('api_base', default_api_base).strip() + api_key = kwargs.get('api_key', + os.getenv('OPENAI_API_KEY', + default='EMPTY')).strip() + logger.info(f'client url {api_base}, client key: {api_key}') + + self.client = OpenAI(api_key=api_key, base_url=api_base) + self.is_function_call = is_function_call self.is_chat = is_chat self.support_stream = support_stream @@ -38,21 +65,24 @@ def _chat_stream(self, logger.info( f'call openai api, model: {self.model}, messages: {str(messages)}, ' f'stop: {str(stop)}, stream: True, args: {str(kwargs)}') - stream_options = {'include_usage': True} + + if not self.is_azure: + kwargs['stream_options'] = {'include_usage': True} + response = self.client.chat.completions.create( model=self.model, messages=messages, stop=stop, stream=True, - stream_options=stream_options, **kwargs) + response = self.stat_last_call_token_info_stream(response) # TODO: error handling for chunk in response: # sometimes delta.content is None by vllm, we should not yield None - if len(chunk.choices) > 0 and hasattr( - chunk.choices[0].delta, - 'content') and chunk.choices[0].delta.content: + if (len(chunk.choices) > 0 + and hasattr(chunk.choices[0].delta, 'content') + and chunk.choices[0].delta.content): logger.info( f'call openai api success, output: {chunk.choices[0].delta.content}' ) @@ -93,12 +123,14 @@ def support_raw_prompt(self) -> bool: return not self.is_chat @retry(max_retries=3, delay_seconds=0.5) - def chat(self, - prompt: Optional[str] = None, - messages: Optional[List[Dict]] = None, - stop: Optional[List[str]] = None, - stream: bool = False, - **kwargs) -> Union[str, Iterator[str]]: + def chat( + self, + prompt: Optional[str] = None, + messages: Optional[List[Dict]] = None, + stop: Optional[List[str]] = None, + stream: bool = False, + **kwargs, + ) -> Union[str, Iterator[str]]: if 'uuid_str' in kwargs: kwargs.pop('uuid_str') @@ -150,7 +182,8 @@ def chat_with_functions(self, messages=messages, tools=functions, tool_choice='auto', - **kwargs) + **kwargs, + ) else: response = self.client.chat.completions.create( model=self.model, messages=messages, **kwargs) @@ -175,9 +208,9 @@ def _chat_stream(self, # TODO: error handling for chunk in response: # sometimes delta.content is None by vllm, we should not yield None - if len(chunk.choices) > 0 and hasattr( - chunk.choices[0].delta, - 'content') and chunk.choices[0].delta.content: + if (len(chunk.choices) > 0 + and hasattr(chunk.choices[0].delta, 'content') + and chunk.choices[0].delta.content): logger.info( f'call openai api success, output: {chunk.choices[0].delta.content}' )