Skip to content

Commit

Permalink
Merge branch 'refs/heads/master' into feature/streamlit
Browse files Browse the repository at this point in the history
# Conflicts:
#	apps/datascience_assistant/README.md
#	apps/datascience_assistant/app.py
#	resources/data_science_assistant_streamlit_1.png
  • Loading branch information
dahaipeng committed Aug 21, 2024
2 parents dd351d3 + 124a4ea commit e87b6ac
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 31 deletions.
8 changes: 8 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ Modelscope-Agent DOCUMENTATION
deployment/local_deploy.md


.. toctree::
:maxdepth: 2
:caption: Agents

agents/data_science_assistant.md




Indices and tables
==================
Expand Down
5 changes: 5 additions & 0 deletions docs/source_en/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ Modelscope-Agent DOCUMENTATION
use_cases/openAPI_for_agent.md
deployment/local_deploy.md

.. toctree::
:maxdepth: 2
:caption: Agents

agents/data_science_assistant.md


Indices and tables
Expand Down
1 change: 1 addition & 0 deletions modelscope_agent/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def get_chat_model(model: str, model_server: str, **kwargs) -> BaseChatModel:
"""
model_type = re.split(r'[-/_]', model)[0] # parser qwen / gpt / ...
registered_model_id = f'{model_server}_{model_type}'

if registered_model_id in LLM_REGISTRY: # specific model from specific source
return LLM_REGISTRY[registered_model_id](model, model_server, **kwargs)
elif model_server in LLM_REGISTRY: # specific source
Expand Down
95 changes: 64 additions & 31 deletions modelscope_agent/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,55 @@
from modelscope_agent.llm.base import BaseChatModel, register_llm
from modelscope_agent.utils.logger import agent_logger as logger
from modelscope_agent.utils.retry import retry
from openai import OpenAI
from openai import AzureOpenAI, OpenAI


@register_llm('openai')
@register_llm('azure_openai')
class OpenAi(BaseChatModel):

def __init__(self,
model: str,
model_server: str,
is_chat: bool = True,
is_function_call: Optional[bool] = None,
support_stream: Optional[bool] = None,
**kwargs):
def __init__(
self,
model: str,
model_server: str,
is_chat: bool = True,
is_function_call: Optional[bool] = None,
support_stream: Optional[bool] = None,
**kwargs,
):
super().__init__(model, model_server, is_function_call)
default_api_base = os.getenv('OPENAI_API_BASE',
'https://api.openai.com/v1')
api_base = kwargs.get('api_base', default_api_base).strip()
api_key = kwargs.get('api_key',
os.getenv('OPENAI_API_KEY',
default='EMPTY')).strip()
logger.info(f'client url {api_base}, client key: {api_key}')
self.client = OpenAI(api_key=api_key, base_url=api_base)

self.is_azure = model_server.lower().startswith('azure')
if self.is_azure:
default_azure_endpoint = os.getenv(
'AZURE_OPENAI_ENDPOINT',
'https://docs-test-001.openai.azure.com/')
azure_endpoint = kwargs.get('azure_endpoint',
default_azure_endpoint).strip()
api_key = kwargs.get(
'api_key', os.getenv('AZURE_OPENAI_API_KEY',
default='EMPTY')).strip()
api_version = kwargs.get('api_version', '2024-06-01').strip()
logger.info(
f'client url {azure_endpoint}, client key: {api_key}, client version: {api_version}'
)

self.client = AzureOpenAI(
azure_endpoint=azure_endpoint,
api_key=api_key,
api_version=api_version,
)
else:
default_api_base = os.getenv('OPENAI_API_BASE',
'https://api.openai.com/v1')
api_base = kwargs.get('api_base', default_api_base).strip()
api_key = kwargs.get('api_key',
os.getenv('OPENAI_API_KEY',
default='EMPTY')).strip()
logger.info(f'client url {api_base}, client key: {api_key}')

self.client = OpenAI(api_key=api_key, base_url=api_base)

self.is_function_call = is_function_call
self.is_chat = is_chat
self.support_stream = support_stream
Expand All @@ -38,21 +65,24 @@ def _chat_stream(self,
logger.info(
f'call openai api, model: {self.model}, messages: {str(messages)}, '
f'stop: {str(stop)}, stream: True, args: {str(kwargs)}')
stream_options = {'include_usage': True}

if not self.is_azure:
kwargs['stream_options'] = {'include_usage': True}

response = self.client.chat.completions.create(
model=self.model,
messages=messages,
stop=stop,
stream=True,
stream_options=stream_options,
**kwargs)

response = self.stat_last_call_token_info_stream(response)
# TODO: error handling
for chunk in response:
# sometimes delta.content is None by vllm, we should not yield None
if len(chunk.choices) > 0 and hasattr(
chunk.choices[0].delta,
'content') and chunk.choices[0].delta.content:
if (len(chunk.choices) > 0
and hasattr(chunk.choices[0].delta, 'content')
and chunk.choices[0].delta.content):
logger.info(
f'call openai api success, output: {chunk.choices[0].delta.content}'
)
Expand Down Expand Up @@ -93,12 +123,14 @@ def support_raw_prompt(self) -> bool:
return not self.is_chat

@retry(max_retries=3, delay_seconds=0.5)
def chat(self,
prompt: Optional[str] = None,
messages: Optional[List[Dict]] = None,
stop: Optional[List[str]] = None,
stream: bool = False,
**kwargs) -> Union[str, Iterator[str]]:
def chat(
self,
prompt: Optional[str] = None,
messages: Optional[List[Dict]] = None,
stop: Optional[List[str]] = None,
stream: bool = False,
**kwargs,
) -> Union[str, Iterator[str]]:

if 'uuid_str' in kwargs:
kwargs.pop('uuid_str')
Expand Down Expand Up @@ -150,7 +182,8 @@ def chat_with_functions(self,
messages=messages,
tools=functions,
tool_choice='auto',
**kwargs)
**kwargs,
)
else:
response = self.client.chat.completions.create(
model=self.model, messages=messages, **kwargs)
Expand All @@ -175,9 +208,9 @@ def _chat_stream(self,
# TODO: error handling
for chunk in response:
# sometimes delta.content is None by vllm, we should not yield None
if len(chunk.choices) > 0 and hasattr(
chunk.choices[0].delta,
'content') and chunk.choices[0].delta.content:
if (len(chunk.choices) > 0
and hasattr(chunk.choices[0].delta, 'content')
and chunk.choices[0].delta.content):
logger.info(
f'call openai api success, output: {chunk.choices[0].delta.content}'
)
Expand Down

0 comments on commit e87b6ac

Please sign in to comment.