diff --git a/setup.py b/setup.py index 75c1077d3..e72e5ea6b 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,7 @@ "scipy", # Leaving openai and dashscope here as default supports "openai>=1.3.0", - "dashscope==1.14.1", + "dashscope>=1.19.0", ] extra_service_requires = [ diff --git a/src/agentscope/models/dashscope_model.py b/src/agentscope/models/dashscope_model.py index a3ac23613..15f69aa1b 100644 --- a/src/agentscope/models/dashscope_model.py +++ b/src/agentscope/models/dashscope_model.py @@ -1,968 +1,976 @@ -# -*- coding: utf-8 -*- -"""Model wrapper for DashScope models""" -import os -from abc import ABC -from http import HTTPStatus -from typing import Any, Union, List, Sequence, Optional, Generator - -from loguru import logger - -from ..manager import FileManager -from ..message import Msg -from ..utils.common import _convert_to_str, _guess_type_by_extension - -try: - import dashscope - from dashscope.api_entities.dashscope_response import GenerationResponse -except ImportError: - dashscope = None - GenerationResponse = None - -from .model import ModelWrapperBase, ModelResponse - - -class DashScopeWrapperBase(ModelWrapperBase, ABC): - """The model wrapper for DashScope API.""" - - def __init__( - self, - config_name: str, - model_name: str = None, - api_key: str = None, - generate_args: dict = None, - **kwargs: Any, - ) -> None: - """Initialize the DashScope wrapper. - - Args: - config_name (`str`): - The name of the model config. - model_name (`str`, default `None`): - The name of the model to use in DashScope API. - api_key (`str`, default `None`): - The API key for DashScope API. - generate_args (`dict`, default `None`): - The extra keyword arguments used in DashScope api generation, - e.g. `temperature`, `seed`. - """ - if model_name is None: - model_name = config_name - logger.warning("model_name is not set, use config_name instead.") - - super().__init__(config_name=config_name, model_name=model_name) - - if dashscope is None: - raise ImportError( - "The package 'dashscope' is not installed. Please install it " - "by running `pip install dashscope==1.14.1`", - ) - - self.generate_args = generate_args or {} - - self.api_key = api_key - if self.api_key: - dashscope.api_key = self.api_key - self.max_length = None - - def format( - self, - *args: Union[Msg, Sequence[Msg]], - ) -> Union[List[dict], str]: - raise RuntimeError( - f"Model Wrapper [{type(self).__name__}] doesn't " - f"need to format the input. Please try to use the " - f"model wrapper directly.", - ) - - -class DashScopeChatWrapper(DashScopeWrapperBase): - """The model wrapper for DashScope's chat API, refer to - https://help.aliyun.com/zh/dashscope/developer-reference/api-details - - Response: - - Refer to - https://help.aliyun.com/zh/dashscope/developer-reference/quick-start?spm=a2c4g.11186623.0.0.7e346eb5RvirBw - - ```json - { - "status_code": 200, - "request_id": "a75a1b22-e512-957d-891b-37db858ae738", - "code": "", - "message": "", - "output": { - "text": null, - "finish_reason": null, - "choices": [ - { - "finish_reason": "stop", - "message": { - "role": "assistant", - "content": "xxx" - } - } - ] - }, - "usage": { - "input_tokens": 25, - "output_tokens": 77, - "total_tokens": 102 - } - } - ``` - """ - - model_type: str = "dashscope_chat" - - deprecated_model_type: str = "tongyi_chat" - - def __init__( - self, - config_name: str, - model_name: str = None, - api_key: str = None, - stream: bool = False, - generate_args: dict = None, - **kwargs: Any, - ) -> None: - """Initialize the DashScope wrapper. - - Args: - config_name (`str`): - The name of the model config. - model_name (`str`, default `None`): - The name of the model to use in DashScope API. - api_key (`str`, default `None`): - The API key for DashScope API. - stream (`bool`, default `False`): - If True, the response will be a generator in the `stream` - field of the returned `ModelResponse` object. - generate_args (`dict`, default `None`): - The extra keyword arguments used in DashScope api generation, - e.g. `temperature`, `seed`. - """ - - super().__init__( - config_name=config_name, - model_name=model_name, - api_key=api_key, - generate_args=generate_args, - **kwargs, - ) - - self.stream = stream - - def __call__( - self, - messages: list, - stream: Optional[bool] = None, - **kwargs: Any, - ) -> ModelResponse: - """Processes a list of messages to construct a payload for the - DashScope API call. It then makes a request to the DashScope API - and returns the response. This method also updates monitoring - metrics based on the API response. - - Each message in the 'messages' list can contain text content and - optionally an 'image_urls' key. If 'image_urls' is provided, - it is expected to be a list of strings representing URLs to images. - These URLs will be transformed to a suitable format for the DashScope - API, which might involve converting local file paths to data URIs. - - Args: - messages (`list`): - A list of messages to process. - stream (`Optional[bool]`, default `None`): - The stream flag to control the response format, which will - overwrite the stream flag in the constructor. - **kwargs (`Any`): - The keyword arguments to DashScope chat completions API, - e.g. `temperature`, `max_tokens`, `top_p`, etc. Please - refer to - https://help.aliyun.com/zh/dashscope/developer-reference/api-details - for more detailed arguments. - - Returns: - `ModelResponse`: - A response object with the response text in text field, and - the raw response in raw field. If stream is True, the response - will be a generator in the `stream` field. - - Note: - `parse_func`, `fault_handler` and `max_retries` are reserved for - `_response_parse_decorator` to parse and check the response - generated by model wrapper. Their usages are listed as follows: - - `parse_func` is a callable function used to parse and check - the response generated by the model, which takes the response - as input. - - `max_retries` is the maximum number of retries when the - `parse_func` raise an exception. - - `fault_handler` is a callable function which is called - when the response generated by the model is invalid after - `max_retries` retries. - The rule of roles in messages for DashScope is very rigid, - for more details, please refer to - https://help.aliyun.com/zh/dashscope/developer-reference/api-details - """ - - # step1: prepare keyword arguments - kwargs = {**self.generate_args, **kwargs} - - # step2: checking messages - if not isinstance(messages, list): - raise ValueError( - "Dashscope `messages` field expected type `list`, " - f"got `{type(messages)}` instead.", - ) - if not all("role" in msg and "content" in msg for msg in messages): - raise ValueError( - "Each message in the 'messages' list must contain a 'role' " - "and 'content' key for DashScope API.", - ) - - # step3: forward to generate response - if stream is None: - stream = self.stream - - kwargs.update( - { - "model": self.model_name, - "messages": messages, - # Set the result to be "message" format. - "result_format": "message", - "stream": stream, - }, - ) - - # Switch to the incremental_output mode - if stream: - kwargs["incremental_output"] = True - - response = dashscope.Generation.call(**kwargs) - - # step3: invoke llm api, record the invocation and update the monitor - if stream: - - def generator() -> Generator[str, None, None]: - last_chunk = None - text = "" - for chunk in response: - if chunk.status_code != HTTPStatus.OK: - error_msg = ( - f"Request id: {chunk.request_id}\n" - f"Status code: {chunk.status_code}\n" - f"Error code: {chunk.code}\n" - f"Error message: {chunk.message}" - ) - raise RuntimeError(error_msg) - - text += chunk.output["choices"][0]["message"]["content"] - yield text - last_chunk = chunk - - # Replace the last chunk with the full text - last_chunk.output["choices"][0]["message"]["content"] = text - - # Save the model invocation and update the monitor - self._save_model_invocation_and_update_monitor( - kwargs, - last_chunk, - ) - - return ModelResponse( - stream=generator(), - raw=response, - ) - - else: - if response.status_code != HTTPStatus.OK: - error_msg = ( - f"Request id: {response.request_id},\n" - f"Status code: {response.status_code},\n" - f"Error code: {response.code},\n" - f"Error message: {response.message}." - ) - - raise RuntimeError(error_msg) - - # Record the model invocation and update the monitor - self._save_model_invocation_and_update_monitor( - kwargs, - response, - ) - - return ModelResponse( - text=response.output["choices"][0]["message"]["content"], - raw=response, - ) - - def _save_model_invocation_and_update_monitor( - self, - kwargs: dict, - response: GenerationResponse, - ) -> None: - """Save the model invocation and update the monitor accordingly. - - Args: - kwargs (`dict`): - The keyword arguments to the DashScope chat API. - response (`GenerationResponse`): - The response object returned by the DashScope chat API. - """ - input_tokens = response.usage.get("input_tokens", 0) - output_tokens = response.usage.get("output_tokens", 0) - - # Update the token record accordingly - self.monitor.update_text_and_embedding_tokens( - model_name=self.model_name, - prompt_tokens=input_tokens, - completion_tokens=output_tokens, - ) - - # Save the model invocation after the stream is exhausted - self._save_model_invocation( - arguments=kwargs, - response=response, - ) - - def format( - self, - *args: Union[Msg, Sequence[Msg]], - ) -> List[dict]: - """A common format strategy for chat models, which will format the - input messages into a user message. - - Note this strategy maybe not suitable for all scenarios, - and developers are encouraged to implement their own prompt - engineering strategies. - - The following is an example: - - .. code-block:: python - - prompt1 = model.format( - Msg("system", "You're a helpful assistant", role="system"), - Msg("Bob", "Hi, how can I help you?", role="assistant"), - Msg("user", "What's the date today?", role="user") - ) - - prompt2 = model.format( - Msg("Bob", "Hi, how can I help you?", role="assistant"), - Msg("user", "What's the date today?", role="user") - ) - - The prompt will be as follows: - - .. code-block:: python - - # prompt1 - [ - { - "role": "system", - "content": "You're a helpful assistant" - }, - { - "role": "user", - "content": ( - "## Conversation History\\n" - "Bob: Hi, how can I help you?\\n" - "user: What's the date today?" - ) - } - ] - - # prompt2 - [ - { - "role": "user", - "content": ( - "## Conversation History\\n" - "Bob: Hi, how can I help you?\\n" - "user: What's the date today?" - ) - } - ] - - - Args: - args (`Union[Msg, Sequence[Msg]]`): - The input arguments to be formatted, where each argument - should be a `Msg` object, or a list of `Msg` objects. - In distribution, placeholder is also allowed. - - Returns: - `List[dict]`: - The formatted messages. - """ - - return ModelWrapperBase.format_for_common_chat_models(*args) - - -class DashScopeImageSynthesisWrapper(DashScopeWrapperBase): - """The model wrapper for DashScope Image Synthesis API, refer to - https://help.aliyun.com/zh/dashscope/developer-reference/quick-start-1 - - Response: - - Refer to - https://help.aliyun.com/zh/dashscope/developer-reference/api-details-9?spm=a2c4g.11186623.0.0.7108fa70Op6eqF - - ```json - { - "status_code": 200, - "request_id": "b54ffeb8-6212-9dac-808c-b3771cba3788", - "code": null, - "message": "", - "output": { - "task_id": "996523eb-034d-459b-ac88-b340b95007a4", - "task_status": "SUCCEEDED", - "results": [ - { - "url": "RESULT_URL1" - }, - { - "url": "RESULT_URL2" - }, - ], - "task_metrics": { - "TOTAL": 2, - "SUCCEEDED": 2, - "FAILED": 0 - } - }, - "usage": { - "image_count": 2 - } - } - ``` - """ - - model_type: str = "dashscope_image_synthesis" - - def __call__( - self, - prompt: str, - save_local: bool = False, - **kwargs: Any, - ) -> ModelResponse: - """ - Args: - prompt (`str`): - The prompt string to generate images from. - save_local: (`bool`, default `False`): - Whether to save the generated images locally, and replace - the returned image url with the local path. - **kwargs (`Any`): - The keyword arguments to DashScope Image Synthesis API, - e.g. `n`, `size`, etc. Please refer to - https://help.aliyun.com/zh/dashscope/developer-reference/api-details-9 - for more detailed arguments. - - Returns: - `ModelResponse`: - A list of image urls in image_urls field and the - raw response in raw field. - - Note: - `parse_func`, `fault_handler` and `max_retries` are reserved - for `_response_parse_decorator` to parse and check the - response generated by model wrapper. Their usages are listed - as follows: - - `parse_func` is a callable function used to parse and - check the response generated by the model, which takes - the response as input. - - `max_retries` is the maximum number of retries when the - `parse_func` raise an exception. - - `fault_handler` is a callable function which is called - when the response generated by the model is invalid after - `max_retries` retries. - """ - # step1: prepare keyword arguments - kwargs = {**self.generate_args, **kwargs} - - # step2: forward to generate response - response = dashscope.ImageSynthesis.call( - model=self.model_name, - prompt=prompt, - **kwargs, - ) - if response.status_code != HTTPStatus.OK: - error_msg = ( - f" Request id: {response.request_id}," - f" Status code: {response.status_code}," - f" error code: {response.code}," - f" error message: {response.message}." - ) - raise RuntimeError(error_msg) - - # step3: record the model api invocation if needed - self._save_model_invocation( - arguments={ - "model": self.model_name, - "prompt": prompt, - **kwargs, - }, - response=response, - ) - - # step4: update monitor accordingly - self.monitor.update_image_tokens( - model_name=self.model_name, - image_count=response.usage.image_count, - resolution=kwargs.get("size", "1024*1024"), - ) - - # step5: return response - images = response.output["results"] - # Get image urls as a list - urls = [_["url"] for _ in images] - - if save_local: - file_manager = FileManager.get_instance() - # Return local url if save_local is True - urls = [file_manager.save_image(_) for _ in urls] - return ModelResponse(image_urls=urls, raw=response) - - -class DashScopeTextEmbeddingWrapper(DashScopeWrapperBase): - """The model wrapper for DashScope Text Embedding API. - - Response: - - Refer to - https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-api-details?spm=a2c4g.11186623.0.i3 - - ```json - { - "status_code": 200, // 200 indicate success otherwise failed. - "request_id": "fd564688-43f7-9595-b986", // The request id. - "code": "", // If failed, the error code. - "message": "", // If failed, the error message. - "output": { - "embeddings": [ // embeddings - { - "embedding": [ // one embedding output - -3.8450357913970947, ..., - ], - "text_index": 0 // the input index. - } - ] - }, - "usage": { - "total_tokens": 3 // the request tokens. - } - } - ``` - """ - - model_type: str = "dashscope_text_embedding" - - def __call__( - self, - texts: Union[list[str], str], - **kwargs: Any, - ) -> ModelResponse: - """Embed the messages with DashScope Text Embedding API. - - Args: - texts (`list[str]` or `str`): - The messages used to embed. - **kwargs (`Any`): - The keyword arguments to DashScope Text Embedding API, - e.g. `text_type`. Please refer to - https://help.aliyun.com/zh/dashscope/developer-reference/api-details-15 - for more detailed arguments. - - Returns: - `ModelResponse`: - A list of embeddings in embedding field and the raw - response in raw field. - - Note: - `parse_func`, `fault_handler` and `max_retries` are reserved - for `_response_parse_decorator` to parse and check the response - generated by model wrapper. Their usages are listed as follows: - - `parse_func` is a callable function used to parse and - check the response generated by the model, which takes the - response as input. - - `max_retries` is the maximum number of retries when the - `parse_func` raise an exception. - - `fault_handler` is a callable function which is called - when the response generated by the model is invalid after - `max_retries` retries. - """ - # step1: prepare keyword arguments - kwargs = {**self.generate_args, **kwargs} - - # step2: forward to generate response - response = dashscope.TextEmbedding.call( - input=texts, - model=self.model_name, - **kwargs, - ) - - if response.status_code != HTTPStatus.OK: - error_msg = ( - f" Request id: {response.request_id}," - f" Status code: {response.status_code}," - f" error code: {response.code}," - f" error message: {response.message}." - ) - raise RuntimeError(error_msg) - - # step3: record the model api invocation if needed - self._save_model_invocation( - arguments={ - "model": self.model_name, - "input": texts, - **kwargs, - }, - response=response, - ) - - # step4: update monitor accordingly - self.monitor.update_text_and_embedding_tokens( - model_name=self.model_name, - prompt_tokens=response.usage.get("total_tokens"), - total_tokens=response.usage.get("total_tokens"), - ) - - # step5: return response - return ModelResponse( - embedding=[_["embedding"] for _ in response.output["embeddings"]], - raw=response, - ) - - -class DashScopeMultiModalWrapper(DashScopeWrapperBase): - """The model wrapper for DashScope Multimodal API, refer to - https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-qianwen-vl-api - - Response: - - Refer to - https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-qianwen-vl-plus-api?spm=a2c4g.11186623.0.0.7fde1f5atQSalN - - ```json - { - "status_code": 200, - "request_id": "a0dc436c-2ee7-93e0-9667-c462009dec4d", - "code": "", - "message": "", - "output": { - "text": null, - "finish_reason": null, - "choices": [ - { - "finish_reason": "stop", - "message": { - "role": "assistant", - "content": [ - { - "text": "这张图片显..." - } - ] - } - } - ] - }, - "usage": { - "input_tokens": 1277, - "output_tokens": 81, - "image_tokens": 1247 - } - } - ``` - """ - - model_type: str = "dashscope_multimodal" - - def __call__( - self, - messages: list, - **kwargs: Any, - ) -> ModelResponse: - """Model call for DashScope MultiModal API. - - Args: - messages (`list`): - A list of messages to process. - **kwargs (`Any`): - The keyword arguments to DashScope MultiModal API, - e.g. `stream`. Please refer to - https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-qianwen-vl-plus-api - for more detailed arguments. - - Returns: - `ModelResponse`: - The response text in text field, and the raw response in - raw field. - - Note: - If involving image links, then the messages should be of the - following form: - - .. code-block:: python - - messages = [ - { - "role": "system", - "content": [ - {"text": "You are a helpful assistant."}, - ], - }, - { - "role": "user", - "content": [ - {"text": "What does this picture depict?"}, - {"image": "http://example.com/image.jpg"}, - ], - }, - ] - - Therefore, you should input a list matching the content value - above. - If only involving words, just input them. - """ - # step1: prepare keyword arguments - kwargs = {**self.generate_args, **kwargs} - - # step2: forward to generate response - response = dashscope.MultiModalConversation.call( - model=self.model_name, - messages=messages, - **kwargs, - ) - # Unhandled code path here - # response could be a generator , if stream is yes - # suggest add a check here - if response.status_code != HTTPStatus.OK: - error_msg = ( - f" Request id: {response.request_id}," - f" Status code: {response.status_code}," - f" error code: {response.code}," - f" error message: {response.message}." - ) - raise RuntimeError(error_msg) - - # step3: record the model api invocation if needed - self._save_model_invocation( - arguments={ - "model": self.model_name, - "messages": messages, - **kwargs, - }, - response=response, - ) - - # step4: update monitor accordingly - input_tokens = response.usage.get("input_tokens", 0) - image_tokens = response.usage.get("image_tokens", 0) - output_tokens = response.usage.get("output_tokens", 0) - # TODO: update the tokens - self.monitor.update_text_and_embedding_tokens( - model_name=self.model_name, - prompt_tokens=input_tokens, - completion_tokens=output_tokens + image_tokens, - ) - - # step5: return response - content = response.output["choices"][0]["message"]["content"] - if isinstance(content, list): - content = content[0]["text"] - - return ModelResponse( - text=content, - raw=response, - ) - - def format( - self, - *args: Union[Msg, Sequence[Msg]], - ) -> List: - """Format the messages for DashScope Multimodal API. - - The multimodal API has the following requirements: - - - The roles of messages must alternate between "user" and - "assistant". - - The message with the role "system" should be the first message - in the list. - - If the system message exists, then the second message must - have the role "user". - - The last message in the list should have the role "user". - - In each message, more than one figure is allowed. - - With the above requirements, we format the messages as follows: - - - If the first message is a system message, then we will keep it as - system prompt. - - We merge all messages into a conversation history prompt in a - single message with the role "user". - - When there are multiple figures in the given messages, we will - attach it to the user message by order. Note if there are - multiple figures, this strategy may cause misunderstanding for - the model. For advanced solutions, developers are encouraged to - implement their own prompt engineering strategies. - - The following is an example: - - .. code-block:: python - - prompt = model.format( - Msg( - "system", - "You're a helpful assistant", - role="system", url="figure1" - ), - Msg( - "Bob", - "How about this picture?", - role="assistant", url="figure2" - ), - Msg( - "user", - "It's wonderful! How about mine?", - role="user", image="figure3" - ) - ) - - The prompt will be as follows: - - .. code-block:: python - - [ - { - "role": "system", - "content": [ - {"text": "You are a helpful assistant"}, - {"image": "figure1"} - ] - }, - { - "role": "user", - "content": [ - {"image": "figure2"}, - {"image": "figure3"}, - { - "text": ( - "## Conversation History\\n" - "Bob: How about this picture?\\n" - "user: It's wonderful! How about mine?" - ) - }, - ] - } - ] - - Note: - In multimodal API, the url of local files should be prefixed with - "file://", which will be attached in this format function. - - Args: - args (`Union[Msg, Sequence[Msg]]`): - The input arguments to be formatted, where each argument - should be a `Msg` object, or a list of `Msg` objects. - In distribution, placeholder is also allowed. - - Returns: - `List[dict]`: - The formatted messages. - """ - - # Parse all information into a list of messages - input_msgs = [] - for _ in args: - if _ is None: - continue - if isinstance(_, Msg): - input_msgs.append(_) - elif isinstance(_, list) and all(isinstance(__, Msg) for __ in _): - input_msgs.extend(_) - else: - raise TypeError( - f"The input should be a Msg object or a list " - f"of Msg objects, got {type(_)}.", - ) - - messages = [] - - # record dialog history as a list of strings - dialogue = [] - image_or_audio_dicts = [] - for i, unit in enumerate(input_msgs): - if i == 0 and unit.role == "system": - # system prompt - content = self.convert_url(unit.url) - content.append({"text": _convert_to_str(unit.content)}) - - messages.append( - { - "role": unit.role, - "content": content, - }, - ) - else: - # text message - dialogue.append( - f"{unit.name}: {_convert_to_str(unit.content)}", - ) - # image and audio - image_or_audio_dicts.extend(self.convert_url(unit.url)) - - dialogue_history = "\n".join(dialogue) - - user_content_template = "## Conversation History\n{dialogue_history}" - - messages.append( - { - "role": "user", - "content": [ - # Place the image or audio before the conversation history - *image_or_audio_dicts, - { - "text": user_content_template.format( - dialogue_history=dialogue_history, - ), - }, - ], - }, - ) - - return messages - - def convert_url(self, url: Union[str, Sequence[str], None]) -> List[dict]: - """Convert the url to the format of DashScope API. Note for local - files, a prefix "file://" will be added. - - Args: - url (`Union[str, Sequence[str], None]`): - A string of url of a list of urls to be converted. - - Returns: - `List[dict]`: - A list of dictionaries with key as the type of the url - and value as the url. Only "image" and "audio" are supported. - """ - if url is None: - return [] - - if isinstance(url, str): - url_type = _guess_type_by_extension(url) - if url_type in ["audio", "image"]: - # Add prefix for local files - if os.path.exists(url): - url = "file://" + url - return [{url_type: url}] - else: - # skip unsupported url - logger.warning( - f"Skip unsupported url ({url_type}), " - f"expect image or audio.", - ) - return [] - elif isinstance(url, list): - dicts = [] - for _ in url: - dicts.extend(self.convert_url(_)) - return dicts - else: - raise TypeError( - f"Unsupported url type {type(url)}, " f"str or list expected.", - ) +# -*- coding: utf-8 -*- +"""Model wrapper for DashScope models""" +import os +from abc import ABC +from http import HTTPStatus +from typing import Any, Union, List, Sequence, Optional, Generator + +from loguru import logger + +from ..manager import FileManager +from ..message import Msg +from ..utils.common import _convert_to_str, _guess_type_by_extension + +try: + import dashscope + + dashscope_version = dashscope.version.__version__ + if dashscope_version < "1.19.0": + logger.warning( + f"You are using 'dashscope' version {dashscope_version}, " + "which is below the recommended version 1.19.0. " + "Please consider upgrading to maintain compatibility.", + ) + from dashscope.api_entities.dashscope_response import GenerationResponse +except ImportError: + dashscope = None + GenerationResponse = None + +from .model import ModelWrapperBase, ModelResponse + + +class DashScopeWrapperBase(ModelWrapperBase, ABC): + """The model wrapper for DashScope API.""" + + def __init__( + self, + config_name: str, + model_name: str = None, + api_key: str = None, + generate_args: dict = None, + **kwargs: Any, + ) -> None: + """Initialize the DashScope wrapper. + + Args: + config_name (`str`): + The name of the model config. + model_name (`str`, default `None`): + The name of the model to use in DashScope API. + api_key (`str`, default `None`): + The API key for DashScope API. + generate_args (`dict`, default `None`): + The extra keyword arguments used in DashScope api generation, + e.g. `temperature`, `seed`. + """ + if model_name is None: + model_name = config_name + logger.warning("model_name is not set, use config_name instead.") + + super().__init__(config_name=config_name, model_name=model_name) + + if dashscope is None: + raise ImportError( + "The package 'dashscope' is not installed. Please install it " + "by running `pip install dashscope>=1.19.0`", + ) + + self.generate_args = generate_args or {} + + self.api_key = api_key + if self.api_key: + dashscope.api_key = self.api_key + self.max_length = None + + def format( + self, + *args: Union[Msg, Sequence[Msg]], + ) -> Union[List[dict], str]: + raise RuntimeError( + f"Model Wrapper [{type(self).__name__}] doesn't " + f"need to format the input. Please try to use the " + f"model wrapper directly.", + ) + + +class DashScopeChatWrapper(DashScopeWrapperBase): + """The model wrapper for DashScope's chat API, refer to + https://help.aliyun.com/zh/dashscope/developer-reference/api-details + + Response: + - Refer to + https://help.aliyun.com/zh/dashscope/developer-reference/quick-start?spm=a2c4g.11186623.0.0.7e346eb5RvirBw + + ```json + { + "status_code": 200, + "request_id": "a75a1b22-e512-957d-891b-37db858ae738", + "code": "", + "message": "", + "output": { + "text": null, + "finish_reason": null, + "choices": [ + { + "finish_reason": "stop", + "message": { + "role": "assistant", + "content": "xxx" + } + } + ] + }, + "usage": { + "input_tokens": 25, + "output_tokens": 77, + "total_tokens": 102 + } + } + ``` + """ + + model_type: str = "dashscope_chat" + + deprecated_model_type: str = "tongyi_chat" + + def __init__( + self, + config_name: str, + model_name: str = None, + api_key: str = None, + stream: bool = False, + generate_args: dict = None, + **kwargs: Any, + ) -> None: + """Initialize the DashScope wrapper. + + Args: + config_name (`str`): + The name of the model config. + model_name (`str`, default `None`): + The name of the model to use in DashScope API. + api_key (`str`, default `None`): + The API key for DashScope API. + stream (`bool`, default `False`): + If True, the response will be a generator in the `stream` + field of the returned `ModelResponse` object. + generate_args (`dict`, default `None`): + The extra keyword arguments used in DashScope api generation, + e.g. `temperature`, `seed`. + """ + + super().__init__( + config_name=config_name, + model_name=model_name, + api_key=api_key, + generate_args=generate_args, + **kwargs, + ) + + self.stream = stream + + def __call__( + self, + messages: list, + stream: Optional[bool] = None, + **kwargs: Any, + ) -> ModelResponse: + """Processes a list of messages to construct a payload for the + DashScope API call. It then makes a request to the DashScope API + and returns the response. This method also updates monitoring + metrics based on the API response. + + Each message in the 'messages' list can contain text content and + optionally an 'image_urls' key. If 'image_urls' is provided, + it is expected to be a list of strings representing URLs to images. + These URLs will be transformed to a suitable format for the DashScope + API, which might involve converting local file paths to data URIs. + + Args: + messages (`list`): + A list of messages to process. + stream (`Optional[bool]`, default `None`): + The stream flag to control the response format, which will + overwrite the stream flag in the constructor. + **kwargs (`Any`): + The keyword arguments to DashScope chat completions API, + e.g. `temperature`, `max_tokens`, `top_p`, etc. Please + refer to + https://help.aliyun.com/zh/dashscope/developer-reference/api-details + for more detailed arguments. + + Returns: + `ModelResponse`: + A response object with the response text in text field, and + the raw response in raw field. If stream is True, the response + will be a generator in the `stream` field. + + Note: + `parse_func`, `fault_handler` and `max_retries` are reserved for + `_response_parse_decorator` to parse and check the response + generated by model wrapper. Their usages are listed as follows: + - `parse_func` is a callable function used to parse and check + the response generated by the model, which takes the response + as input. + - `max_retries` is the maximum number of retries when the + `parse_func` raise an exception. + - `fault_handler` is a callable function which is called + when the response generated by the model is invalid after + `max_retries` retries. + The rule of roles in messages for DashScope is very rigid, + for more details, please refer to + https://help.aliyun.com/zh/dashscope/developer-reference/api-details + """ + + # step1: prepare keyword arguments + kwargs = {**self.generate_args, **kwargs} + + # step2: checking messages + if not isinstance(messages, list): + raise ValueError( + "Dashscope `messages` field expected type `list`, " + f"got `{type(messages)}` instead.", + ) + if not all("role" in msg and "content" in msg for msg in messages): + raise ValueError( + "Each message in the 'messages' list must contain a 'role' " + "and 'content' key for DashScope API.", + ) + + # step3: forward to generate response + if stream is None: + stream = self.stream + + kwargs.update( + { + "model": self.model_name, + "messages": messages, + # Set the result to be "message" format. + "result_format": "message", + "stream": stream, + }, + ) + + # Switch to the incremental_output mode + if stream: + kwargs["incremental_output"] = True + + response = dashscope.Generation.call(**kwargs) + + # step3: invoke llm api, record the invocation and update the monitor + if stream: + + def generator() -> Generator[str, None, None]: + last_chunk = None + text = "" + for chunk in response: + if chunk.status_code != HTTPStatus.OK: + error_msg = ( + f"Request id: {chunk.request_id}\n" + f"Status code: {chunk.status_code}\n" + f"Error code: {chunk.code}\n" + f"Error message: {chunk.message}" + ) + raise RuntimeError(error_msg) + + text += chunk.output["choices"][0]["message"]["content"] + yield text + last_chunk = chunk + + # Replace the last chunk with the full text + last_chunk.output["choices"][0]["message"]["content"] = text + + # Save the model invocation and update the monitor + self._save_model_invocation_and_update_monitor( + kwargs, + last_chunk, + ) + + return ModelResponse( + stream=generator(), + raw=response, + ) + + else: + if response.status_code != HTTPStatus.OK: + error_msg = ( + f"Request id: {response.request_id},\n" + f"Status code: {response.status_code},\n" + f"Error code: {response.code},\n" + f"Error message: {response.message}." + ) + + raise RuntimeError(error_msg) + + # Record the model invocation and update the monitor + self._save_model_invocation_and_update_monitor( + kwargs, + response, + ) + + return ModelResponse( + text=response.output["choices"][0]["message"]["content"], + raw=response, + ) + + def _save_model_invocation_and_update_monitor( + self, + kwargs: dict, + response: GenerationResponse, + ) -> None: + """Save the model invocation and update the monitor accordingly. + + Args: + kwargs (`dict`): + The keyword arguments to the DashScope chat API. + response (`GenerationResponse`): + The response object returned by the DashScope chat API. + """ + input_tokens = response.usage.get("input_tokens", 0) + output_tokens = response.usage.get("output_tokens", 0) + + # Update the token record accordingly + self.monitor.update_text_and_embedding_tokens( + model_name=self.model_name, + prompt_tokens=input_tokens, + completion_tokens=output_tokens, + ) + + # Save the model invocation after the stream is exhausted + self._save_model_invocation( + arguments=kwargs, + response=response, + ) + + def format( + self, + *args: Union[Msg, Sequence[Msg]], + ) -> List[dict]: + """A common format strategy for chat models, which will format the + input messages into a user message. + + Note this strategy maybe not suitable for all scenarios, + and developers are encouraged to implement their own prompt + engineering strategies. + + The following is an example: + + .. code-block:: python + + prompt1 = model.format( + Msg("system", "You're a helpful assistant", role="system"), + Msg("Bob", "Hi, how can I help you?", role="assistant"), + Msg("user", "What's the date today?", role="user") + ) + + prompt2 = model.format( + Msg("Bob", "Hi, how can I help you?", role="assistant"), + Msg("user", "What's the date today?", role="user") + ) + + The prompt will be as follows: + + .. code-block:: python + + # prompt1 + [ + { + "role": "system", + "content": "You're a helpful assistant" + }, + { + "role": "user", + "content": ( + "## Conversation History\\n" + "Bob: Hi, how can I help you?\\n" + "user: What's the date today?" + ) + } + ] + + # prompt2 + [ + { + "role": "user", + "content": ( + "## Conversation History\\n" + "Bob: Hi, how can I help you?\\n" + "user: What's the date today?" + ) + } + ] + + + Args: + args (`Union[Msg, Sequence[Msg]]`): + The input arguments to be formatted, where each argument + should be a `Msg` object, or a list of `Msg` objects. + In distribution, placeholder is also allowed. + + Returns: + `List[dict]`: + The formatted messages. + """ + + return ModelWrapperBase.format_for_common_chat_models(*args) + + +class DashScopeImageSynthesisWrapper(DashScopeWrapperBase): + """The model wrapper for DashScope Image Synthesis API, refer to + https://help.aliyun.com/zh/dashscope/developer-reference/quick-start-1 + + Response: + - Refer to + https://help.aliyun.com/zh/dashscope/developer-reference/api-details-9?spm=a2c4g.11186623.0.0.7108fa70Op6eqF + + ```json + { + "status_code": 200, + "request_id": "b54ffeb8-6212-9dac-808c-b3771cba3788", + "code": null, + "message": "", + "output": { + "task_id": "996523eb-034d-459b-ac88-b340b95007a4", + "task_status": "SUCCEEDED", + "results": [ + { + "url": "RESULT_URL1" + }, + { + "url": "RESULT_URL2" + }, + ], + "task_metrics": { + "TOTAL": 2, + "SUCCEEDED": 2, + "FAILED": 0 + } + }, + "usage": { + "image_count": 2 + } + } + ``` + """ + + model_type: str = "dashscope_image_synthesis" + + def __call__( + self, + prompt: str, + save_local: bool = False, + **kwargs: Any, + ) -> ModelResponse: + """ + Args: + prompt (`str`): + The prompt string to generate images from. + save_local: (`bool`, default `False`): + Whether to save the generated images locally, and replace + the returned image url with the local path. + **kwargs (`Any`): + The keyword arguments to DashScope Image Synthesis API, + e.g. `n`, `size`, etc. Please refer to + https://help.aliyun.com/zh/dashscope/developer-reference/api-details-9 + for more detailed arguments. + + Returns: + `ModelResponse`: + A list of image urls in image_urls field and the + raw response in raw field. + + Note: + `parse_func`, `fault_handler` and `max_retries` are reserved + for `_response_parse_decorator` to parse and check the + response generated by model wrapper. Their usages are listed + as follows: + - `parse_func` is a callable function used to parse and + check the response generated by the model, which takes + the response as input. + - `max_retries` is the maximum number of retries when the + `parse_func` raise an exception. + - `fault_handler` is a callable function which is called + when the response generated by the model is invalid after + `max_retries` retries. + """ + # step1: prepare keyword arguments + kwargs = {**self.generate_args, **kwargs} + + # step2: forward to generate response + response = dashscope.ImageSynthesis.call( + model=self.model_name, + prompt=prompt, + **kwargs, + ) + if response.status_code != HTTPStatus.OK: + error_msg = ( + f" Request id: {response.request_id}," + f" Status code: {response.status_code}," + f" error code: {response.code}," + f" error message: {response.message}." + ) + raise RuntimeError(error_msg) + + # step3: record the model api invocation if needed + self._save_model_invocation( + arguments={ + "model": self.model_name, + "prompt": prompt, + **kwargs, + }, + response=response, + ) + + # step4: update monitor accordingly + self.monitor.update_image_tokens( + model_name=self.model_name, + image_count=response.usage.image_count, + resolution=kwargs.get("size", "1024*1024"), + ) + + # step5: return response + images = response.output["results"] + # Get image urls as a list + urls = [_["url"] for _ in images] + + if save_local: + file_manager = FileManager.get_instance() + # Return local url if save_local is True + urls = [file_manager.save_image(_) for _ in urls] + return ModelResponse(image_urls=urls, raw=response) + + +class DashScopeTextEmbeddingWrapper(DashScopeWrapperBase): + """The model wrapper for DashScope Text Embedding API. + + Response: + - Refer to + https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-api-details?spm=a2c4g.11186623.0.i3 + + ```json + { + "status_code": 200, // 200 indicate success otherwise failed. + "request_id": "fd564688-43f7-9595-b986", // The request id. + "code": "", // If failed, the error code. + "message": "", // If failed, the error message. + "output": { + "embeddings": [ // embeddings + { + "embedding": [ // one embedding output + -3.8450357913970947, ..., + ], + "text_index": 0 // the input index. + } + ] + }, + "usage": { + "total_tokens": 3 // the request tokens. + } + } + ``` + """ + + model_type: str = "dashscope_text_embedding" + + def __call__( + self, + texts: Union[list[str], str], + **kwargs: Any, + ) -> ModelResponse: + """Embed the messages with DashScope Text Embedding API. + + Args: + texts (`list[str]` or `str`): + The messages used to embed. + **kwargs (`Any`): + The keyword arguments to DashScope Text Embedding API, + e.g. `text_type`. Please refer to + https://help.aliyun.com/zh/dashscope/developer-reference/api-details-15 + for more detailed arguments. + + Returns: + `ModelResponse`: + A list of embeddings in embedding field and the raw + response in raw field. + + Note: + `parse_func`, `fault_handler` and `max_retries` are reserved + for `_response_parse_decorator` to parse and check the response + generated by model wrapper. Their usages are listed as follows: + - `parse_func` is a callable function used to parse and + check the response generated by the model, which takes the + response as input. + - `max_retries` is the maximum number of retries when the + `parse_func` raise an exception. + - `fault_handler` is a callable function which is called + when the response generated by the model is invalid after + `max_retries` retries. + """ + # step1: prepare keyword arguments + kwargs = {**self.generate_args, **kwargs} + + # step2: forward to generate response + response = dashscope.TextEmbedding.call( + input=texts, + model=self.model_name, + **kwargs, + ) + + if response.status_code != HTTPStatus.OK: + error_msg = ( + f" Request id: {response.request_id}," + f" Status code: {response.status_code}," + f" error code: {response.code}," + f" error message: {response.message}." + ) + raise RuntimeError(error_msg) + + # step3: record the model api invocation if needed + self._save_model_invocation( + arguments={ + "model": self.model_name, + "input": texts, + **kwargs, + }, + response=response, + ) + + # step4: update monitor accordingly + self.monitor.update_text_and_embedding_tokens( + model_name=self.model_name, + prompt_tokens=response.usage.get("total_tokens"), + total_tokens=response.usage.get("total_tokens"), + ) + + # step5: return response + return ModelResponse( + embedding=[_["embedding"] for _ in response.output["embeddings"]], + raw=response, + ) + + +class DashScopeMultiModalWrapper(DashScopeWrapperBase): + """The model wrapper for DashScope Multimodal API, refer to + https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-qianwen-vl-api + + Response: + - Refer to + https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-qianwen-vl-plus-api?spm=a2c4g.11186623.0.0.7fde1f5atQSalN + + ```json + { + "status_code": 200, + "request_id": "a0dc436c-2ee7-93e0-9667-c462009dec4d", + "code": "", + "message": "", + "output": { + "text": null, + "finish_reason": null, + "choices": [ + { + "finish_reason": "stop", + "message": { + "role": "assistant", + "content": [ + { + "text": "这张图片显..." + } + ] + } + } + ] + }, + "usage": { + "input_tokens": 1277, + "output_tokens": 81, + "image_tokens": 1247 + } + } + ``` + """ + + model_type: str = "dashscope_multimodal" + + def __call__( + self, + messages: list, + **kwargs: Any, + ) -> ModelResponse: + """Model call for DashScope MultiModal API. + + Args: + messages (`list`): + A list of messages to process. + **kwargs (`Any`): + The keyword arguments to DashScope MultiModal API, + e.g. `stream`. Please refer to + https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-qianwen-vl-plus-api + for more detailed arguments. + + Returns: + `ModelResponse`: + The response text in text field, and the raw response in + raw field. + + Note: + If involving image links, then the messages should be of the + following form: + + .. code-block:: python + + messages = [ + { + "role": "system", + "content": [ + {"text": "You are a helpful assistant."}, + ], + }, + { + "role": "user", + "content": [ + {"text": "What does this picture depict?"}, + {"image": "http://example.com/image.jpg"}, + ], + }, + ] + + Therefore, you should input a list matching the content value + above. + If only involving words, just input them. + """ + # step1: prepare keyword arguments + kwargs = {**self.generate_args, **kwargs} + + # step2: forward to generate response + response = dashscope.MultiModalConversation.call( + model=self.model_name, + messages=messages, + **kwargs, + ) + # Unhandled code path here + # response could be a generator , if stream is yes + # suggest add a check here + if response.status_code != HTTPStatus.OK: + error_msg = ( + f" Request id: {response.request_id}," + f" Status code: {response.status_code}," + f" error code: {response.code}," + f" error message: {response.message}." + ) + raise RuntimeError(error_msg) + + # step3: record the model api invocation if needed + self._save_model_invocation( + arguments={ + "model": self.model_name, + "messages": messages, + **kwargs, + }, + response=response, + ) + + # step4: update monitor accordingly + input_tokens = response.usage.get("input_tokens", 0) + image_tokens = response.usage.get("image_tokens", 0) + output_tokens = response.usage.get("output_tokens", 0) + # TODO: update the tokens + self.monitor.update_text_and_embedding_tokens( + model_name=self.model_name, + prompt_tokens=input_tokens, + completion_tokens=output_tokens + image_tokens, + ) + + # step5: return response + content = response.output["choices"][0]["message"]["content"] + if isinstance(content, list): + content = content[0]["text"] + + return ModelResponse( + text=content, + raw=response, + ) + + def format( + self, + *args: Union[Msg, Sequence[Msg]], + ) -> List: + """Format the messages for DashScope Multimodal API. + + The multimodal API has the following requirements: + + - The roles of messages must alternate between "user" and + "assistant". + - The message with the role "system" should be the first message + in the list. + - If the system message exists, then the second message must + have the role "user". + - The last message in the list should have the role "user". + - In each message, more than one figure is allowed. + + With the above requirements, we format the messages as follows: + + - If the first message is a system message, then we will keep it as + system prompt. + - We merge all messages into a conversation history prompt in a + single message with the role "user". + - When there are multiple figures in the given messages, we will + attach it to the user message by order. Note if there are + multiple figures, this strategy may cause misunderstanding for + the model. For advanced solutions, developers are encouraged to + implement their own prompt engineering strategies. + + The following is an example: + + .. code-block:: python + + prompt = model.format( + Msg( + "system", + "You're a helpful assistant", + role="system", url="figure1" + ), + Msg( + "Bob", + "How about this picture?", + role="assistant", url="figure2" + ), + Msg( + "user", + "It's wonderful! How about mine?", + role="user", image="figure3" + ) + ) + + The prompt will be as follows: + + .. code-block:: python + + [ + { + "role": "system", + "content": [ + {"text": "You are a helpful assistant"}, + {"image": "figure1"} + ] + }, + { + "role": "user", + "content": [ + {"image": "figure2"}, + {"image": "figure3"}, + { + "text": ( + "## Conversation History\\n" + "Bob: How about this picture?\\n" + "user: It's wonderful! How about mine?" + ) + }, + ] + } + ] + + Note: + In multimodal API, the url of local files should be prefixed with + "file://", which will be attached in this format function. + + Args: + args (`Union[Msg, Sequence[Msg]]`): + The input arguments to be formatted, where each argument + should be a `Msg` object, or a list of `Msg` objects. + In distribution, placeholder is also allowed. + + Returns: + `List[dict]`: + The formatted messages. + """ + + # Parse all information into a list of messages + input_msgs = [] + for _ in args: + if _ is None: + continue + if isinstance(_, Msg): + input_msgs.append(_) + elif isinstance(_, list) and all(isinstance(__, Msg) for __ in _): + input_msgs.extend(_) + else: + raise TypeError( + f"The input should be a Msg object or a list " + f"of Msg objects, got {type(_)}.", + ) + + messages = [] + + # record dialog history as a list of strings + dialogue = [] + image_or_audio_dicts = [] + for i, unit in enumerate(input_msgs): + if i == 0 and unit.role == "system": + # system prompt + content = self.convert_url(unit.url) + content.append({"text": _convert_to_str(unit.content)}) + + messages.append( + { + "role": unit.role, + "content": content, + }, + ) + else: + # text message + dialogue.append( + f"{unit.name}: {_convert_to_str(unit.content)}", + ) + # image and audio + image_or_audio_dicts.extend(self.convert_url(unit.url)) + + dialogue_history = "\n".join(dialogue) + + user_content_template = "## Conversation History\n{dialogue_history}" + + messages.append( + { + "role": "user", + "content": [ + # Place the image or audio before the conversation history + *image_or_audio_dicts, + { + "text": user_content_template.format( + dialogue_history=dialogue_history, + ), + }, + ], + }, + ) + + return messages + + def convert_url(self, url: Union[str, Sequence[str], None]) -> List[dict]: + """Convert the url to the format of DashScope API. Note for local + files, a prefix "file://" will be added. + + Args: + url (`Union[str, Sequence[str], None]`): + A string of url of a list of urls to be converted. + + Returns: + `List[dict]`: + A list of dictionaries with key as the type of the url + and value as the url. Only "image" and "audio" are supported. + """ + if url is None: + return [] + + if isinstance(url, str): + url_type = _guess_type_by_extension(url) + if url_type in ["audio", "image"]: + # Add prefix for local files + if os.path.exists(url): + url = "file://" + url + return [{url_type: url}] + else: + # skip unsupported url + logger.warning( + f"Skip unsupported url ({url_type}), " + f"expect image or audio.", + ) + return [] + elif isinstance(url, list): + dicts = [] + for _ in url: + dicts.extend(self.convert_url(_)) + return dicts + else: + raise TypeError( + f"Unsupported url type {type(url)}, " f"str or list expected.", + ) diff --git a/src/agentscope/service/multi_modality/dashscope_services.py b/src/agentscope/service/multi_modality/dashscope_services.py index d372b597e..fb09ec240 100644 --- a/src/agentscope/service/multi_modality/dashscope_services.py +++ b/src/agentscope/service/multi_modality/dashscope_services.py @@ -1,293 +1,293 @@ -# -*- coding: utf-8 -*- -"""Use DashScope API to generate images, -convert text to audio, and convert images to text. -Please refer to the official documentation for more details: -https://dashscope.aliyun.com/ -""" - -from typing import Union, Optional, Literal, Sequence - -import os - -from ...models import ( - DashScopeImageSynthesisWrapper, - DashScopeMultiModalWrapper, -) - -from ..service_response import ( - ServiceResponse, - ServiceExecStatus, -) -from ...utils.common import _download_file - - -def dashscope_text_to_image( - prompt: str, - api_key: str, - n: int = 1, - size: Literal["1024*1024", "720*1280", "1280*720"] = "1024*1024", - model: str = "wanx-v1", - save_dir: Optional[str] = None, -) -> ServiceResponse: - """Generate image(s) based on the given prompt, and return image url(s). - - Args: - prompt (`str`): - The text prompt to generate image. - api_key (`str`): - The api key for the dashscope api. - n (`int`, defaults to `1`): - The number of images to generate. - size (`Literal["1024*1024", "720*1280", "1280*720"]`, defaults to - `"1024*1024"`): - Size of the image. - model (`str`, defaults to '"wanx-v1"'): - The model to use. - save_dir (`Optional[str]`, defaults to 'None'): - The directory to save the generated images. If not specified, - will return the web urls. - - Returns: - ServiceResponse: - A dictionary with two variables: `status` and`content`. - If `status` is ServiceExecStatus.SUCCESS, - the `content` is a dict with key 'fig_paths" and - value is a list of the paths to the generated images. - - Example: - - .. code-block:: python - - prompt = "A beautiful sunset in the mountains" - print(dashscope_text_to_image(prompt, "{api_key}")) - - > { - > 'status': 'SUCCESS', - > 'content': {'image_urls': ['IMAGE_URL1', 'IMAGE_URL2']} - > } - - """ - text2img = DashScopeImageSynthesisWrapper( - config_name="dashscope-text-to-image-service", # Just a placeholder - model_name=model, - api_key=api_key, - ) - try: - res = text2img( - prompt=prompt, - n=n, - size=size, - ) - urls = res.image_urls - - # save images to save_dir - if urls is not None: - if save_dir: - os.makedirs(save_dir, exist_ok=True) - urls_local = [] - # Obtain the image file names in the url - for url in urls: - image_name = url.split("/")[-1] - image_path = os.path.abspath( - os.path.join(save_dir, image_name), - ) - # Download the image - _download_file(url, image_path) - urls_local.append(image_path) - - return ServiceResponse( - ServiceExecStatus.SUCCESS, - {"image_urls": urls_local}, - ) - else: - # Return the web urls - return ServiceResponse( - ServiceExecStatus.SUCCESS, - {"image_urls": urls}, - ) - else: - return ServiceResponse( - ServiceExecStatus.ERROR, - "Error: Failed to generate images", - ) - except Exception as e: - return ServiceResponse( - ServiceExecStatus.ERROR, - str(e), - ) - - -def dashscope_image_to_text( - image_urls: Union[str, Sequence[str]], - api_key: str, - prompt: str = "Describe the image", - model: str = "qwen-vl-plus", -) -> ServiceResponse: - """Generate text based on the given images. - - Args: - image_urls (`Union[str, Sequence[str]]`): - The url of single or multiple images. - api_key (`str`): - The api key for the dashscope api. - prompt (`str`, defaults to 'Describe the image' ): - The text prompt. - model (`str`, defaults to 'qwen-vl-plus'): - The model to use in DashScope MultiModal API. - - Returns: - `ServiceResponse`: - A dictionary with two variables: `status` and`content`. - If `status` is ServiceExecStatus.SUCCESS, the `content` is the - generated text. - - Example: - - .. code-block:: python - - image_url = "image.jpg" - prompt = "Describe the image" - print(image_to_text(image_url, prompt)) - - > {'status': 'SUCCESS', 'content': 'A beautiful sunset in the mountains'} - - """ - - img2text = DashScopeMultiModalWrapper( - config_name="dashscope-image-to-text-service", # Just a placeholder - model_name=model, - api_key=api_key, - ) - - if isinstance(image_urls, str): - image_urls = [image_urls] - - # Check if the local url is valid - img_abs_urls = [] - for url in image_urls: - if os.path.exists(url): - if os.path.isfile(url): - img_abs_urls.append(os.path.abspath(url)) - else: - return ServiceResponse( - ServiceExecStatus.ERROR, - f'Error: The input image url "{url}" is not a file.', - ) - else: - # Maybe a web url or an invalid url, we leave it to the API - # to handle - img_abs_urls.append(url) - - # Convert image paths according to the model requirements - contents = img2text.convert_url(img_abs_urls) - contents.append({"text": prompt}) - # currently only support one round of conversation - # if multiple rounds of conversation are needed, - # it would be better to implement an Agent class - sys_message = { - "role": "system", - "content": [{"text": "You are a helpful assistant."}], - } - user_message = { - "role": "user", - "content": contents, - } - messages = [sys_message, user_message] - try: - res = img2text(messages, stream=False) - description = res.text - if description is not None: - return ServiceResponse( - ServiceExecStatus.SUCCESS, - description, - ) - else: - return ServiceResponse( - ServiceExecStatus.ERROR, - "Error: Failed to generate text", - ) - except Exception as e: - return ServiceResponse( - ServiceExecStatus.ERROR, - str(e), - ) - - -def dashscope_text_to_audio( - text: str, - api_key: str, - save_dir: str, - model: str = "sambert-zhichu-v1", - sample_rate: int = 48000, -) -> ServiceResponse: - """Convert the given text to audio. - - Args: - text (`str`): - The text to be converted into audio. - api_key (`str`): - The api key for the dashscope API. - save_dir (`str`): - The directory to save the generated audio. - model (`str`, defaults to 'sambert-zhichu-v1'): - The model to use. Full model list can be found in - https://help.aliyun.com/zh/dashscope/model-list - sample_rate (`int`, defaults to 48000): - Samplerate of the audio. - - Returns: - `ServiceResponse`: - A dictionary with two variables: `status` and`content`. If - `status` is ServiceExecStatus.SUCCESS, the `content` contains - a dictionary with key "audio_path" and value is the path to - the generated audio. - - Example: - - .. code-block:: python - - text = "How is the weather today?" - print(text_to_audio(text)) gives: - - - > {'status': 'SUCCESS', 'content': {"audio_path": "AUDIO_PATH"}} - - """ - try: - import dashscope - except ImportError as e: - raise ImportError( - "The package 'dashscope' is not installed. Please install it by " - "running `pip install dashscope==1.14.1`", - ) from e - - dashscope.api_key = api_key - - res = dashscope.audio.tts.SpeechSynthesizer.call( - model=model, - text=text, - sample_rate=sample_rate, - format="wav", - ) - - audio_data = res.get_audio_data() - - if audio_data is not None: - if save_dir is not None: - os.makedirs(save_dir, exist_ok=True) - - # Save locally - text = text[0:15] if len(text) > 15 else text - audio_path = os.path.join(save_dir, f"{text.strip()}.wav") - - with open(audio_path, "wb") as f: - f.write(audio_data) - return ServiceResponse( - ServiceExecStatus.SUCCESS, - {"audio_path": audio_path}, - ) - else: - return ServiceResponse( - ServiceExecStatus.ERROR, - "Error: Failed to generate audio", - ) +# -*- coding: utf-8 -*- +"""Use DashScope API to generate images, +convert text to audio, and convert images to text. +Please refer to the official documentation for more details: +https://dashscope.aliyun.com/ +""" + +from typing import Union, Optional, Literal, Sequence + +import os + +from ...models import ( + DashScopeImageSynthesisWrapper, + DashScopeMultiModalWrapper, +) + +from ..service_response import ( + ServiceResponse, + ServiceExecStatus, +) +from ...utils.common import _download_file + + +def dashscope_text_to_image( + prompt: str, + api_key: str, + n: int = 1, + size: Literal["1024*1024", "720*1280", "1280*720"] = "1024*1024", + model: str = "wanx-v1", + save_dir: Optional[str] = None, +) -> ServiceResponse: + """Generate image(s) based on the given prompt, and return image url(s). + + Args: + prompt (`str`): + The text prompt to generate image. + api_key (`str`): + The api key for the dashscope api. + n (`int`, defaults to `1`): + The number of images to generate. + size (`Literal["1024*1024", "720*1280", "1280*720"]`, defaults to + `"1024*1024"`): + Size of the image. + model (`str`, defaults to '"wanx-v1"'): + The model to use. + save_dir (`Optional[str]`, defaults to 'None'): + The directory to save the generated images. If not specified, + will return the web urls. + + Returns: + ServiceResponse: + A dictionary with two variables: `status` and`content`. + If `status` is ServiceExecStatus.SUCCESS, + the `content` is a dict with key 'fig_paths" and + value is a list of the paths to the generated images. + + Example: + + .. code-block:: python + + prompt = "A beautiful sunset in the mountains" + print(dashscope_text_to_image(prompt, "{api_key}")) + + > { + > 'status': 'SUCCESS', + > 'content': {'image_urls': ['IMAGE_URL1', 'IMAGE_URL2']} + > } + + """ + text2img = DashScopeImageSynthesisWrapper( + config_name="dashscope-text-to-image-service", # Just a placeholder + model_name=model, + api_key=api_key, + ) + try: + res = text2img( + prompt=prompt, + n=n, + size=size, + ) + urls = res.image_urls + + # save images to save_dir + if urls is not None: + if save_dir: + os.makedirs(save_dir, exist_ok=True) + urls_local = [] + # Obtain the image file names in the url + for url in urls: + image_name = url.split("/")[-1] + image_path = os.path.abspath( + os.path.join(save_dir, image_name), + ) + # Download the image + _download_file(url, image_path) + urls_local.append(image_path) + + return ServiceResponse( + ServiceExecStatus.SUCCESS, + {"image_urls": urls_local}, + ) + else: + # Return the web urls + return ServiceResponse( + ServiceExecStatus.SUCCESS, + {"image_urls": urls}, + ) + else: + return ServiceResponse( + ServiceExecStatus.ERROR, + "Error: Failed to generate images", + ) + except Exception as e: + return ServiceResponse( + ServiceExecStatus.ERROR, + str(e), + ) + + +def dashscope_image_to_text( + image_urls: Union[str, Sequence[str]], + api_key: str, + prompt: str = "Describe the image", + model: str = "qwen-vl-plus", +) -> ServiceResponse: + """Generate text based on the given images. + + Args: + image_urls (`Union[str, Sequence[str]]`): + The url of single or multiple images. + api_key (`str`): + The api key for the dashscope api. + prompt (`str`, defaults to 'Describe the image' ): + The text prompt. + model (`str`, defaults to 'qwen-vl-plus'): + The model to use in DashScope MultiModal API. + + Returns: + `ServiceResponse`: + A dictionary with two variables: `status` and`content`. + If `status` is ServiceExecStatus.SUCCESS, the `content` is the + generated text. + + Example: + + .. code-block:: python + + image_url = "image.jpg" + prompt = "Describe the image" + print(image_to_text(image_url, prompt)) + + > {'status': 'SUCCESS', 'content': 'A beautiful sunset in the mountains'} + + """ + + img2text = DashScopeMultiModalWrapper( + config_name="dashscope-image-to-text-service", # Just a placeholder + model_name=model, + api_key=api_key, + ) + + if isinstance(image_urls, str): + image_urls = [image_urls] + + # Check if the local url is valid + img_abs_urls = [] + for url in image_urls: + if os.path.exists(url): + if os.path.isfile(url): + img_abs_urls.append(os.path.abspath(url)) + else: + return ServiceResponse( + ServiceExecStatus.ERROR, + f'Error: The input image url "{url}" is not a file.', + ) + else: + # Maybe a web url or an invalid url, we leave it to the API + # to handle + img_abs_urls.append(url) + + # Convert image paths according to the model requirements + contents = img2text.convert_url(img_abs_urls) + contents.append({"text": prompt}) + # currently only support one round of conversation + # if multiple rounds of conversation are needed, + # it would be better to implement an Agent class + sys_message = { + "role": "system", + "content": [{"text": "You are a helpful assistant."}], + } + user_message = { + "role": "user", + "content": contents, + } + messages = [sys_message, user_message] + try: + res = img2text(messages, stream=False) + description = res.text + if description is not None: + return ServiceResponse( + ServiceExecStatus.SUCCESS, + description, + ) + else: + return ServiceResponse( + ServiceExecStatus.ERROR, + "Error: Failed to generate text", + ) + except Exception as e: + return ServiceResponse( + ServiceExecStatus.ERROR, + str(e), + ) + + +def dashscope_text_to_audio( + text: str, + api_key: str, + save_dir: str, + model: str = "sambert-zhichu-v1", + sample_rate: int = 48000, +) -> ServiceResponse: + """Convert the given text to audio. + + Args: + text (`str`): + The text to be converted into audio. + api_key (`str`): + The api key for the dashscope API. + save_dir (`str`): + The directory to save the generated audio. + model (`str`, defaults to 'sambert-zhichu-v1'): + The model to use. Full model list can be found in + https://help.aliyun.com/zh/dashscope/model-list + sample_rate (`int`, defaults to 48000): + Samplerate of the audio. + + Returns: + `ServiceResponse`: + A dictionary with two variables: `status` and`content`. If + `status` is ServiceExecStatus.SUCCESS, the `content` contains + a dictionary with key "audio_path" and value is the path to + the generated audio. + + Example: + + .. code-block:: python + + text = "How is the weather today?" + print(text_to_audio(text)) gives: + + + > {'status': 'SUCCESS', 'content': {"audio_path": "AUDIO_PATH"}} + + """ + try: + import dashscope + except ImportError as e: + raise ImportError( + "The package 'dashscope' is not installed. Please install it by " + "running `pip install dashscope>=1.19.0`", + ) from e + + dashscope.api_key = api_key + + res = dashscope.audio.tts.SpeechSynthesizer.call( + model=model, + text=text, + sample_rate=sample_rate, + format="wav", + ) + + audio_data = res.get_audio_data() + + if audio_data is not None: + if save_dir is not None: + os.makedirs(save_dir, exist_ok=True) + + # Save locally + text = text[0:15] if len(text) > 15 else text + audio_path = os.path.join(save_dir, f"{text.strip()}.wav") + + with open(audio_path, "wb") as f: + f.write(audio_data) + return ServiceResponse( + ServiceExecStatus.SUCCESS, + {"audio_path": audio_path}, + ) + else: + return ServiceResponse( + ServiceExecStatus.ERROR, + "Error: Failed to generate audio", + ) diff --git a/src/agentscope/web/gradio/utils.py b/src/agentscope/web/gradio/utils.py index 4b7285e65..321766ba2 100644 --- a/src/agentscope/web/gradio/utils.py +++ b/src/agentscope/web/gradio/utils.py @@ -1,222 +1,222 @@ -# -*- coding: utf-8 -*- -"""web ui utils""" -import os -import threading -from typing import Optional -import hashlib -from multiprocessing import Queue -from queue import Empty -from collections import defaultdict - -from PIL import Image - -SYS_MSG_PREFIX = "【SYSTEM】" - -thread_local_data = threading.local() - - -def init_uid_queues() -> dict: - """Initializes and returns a dictionary of user-specific queues.""" - return { - "glb_queue_chat_msg": Queue(), - "glb_queue_user_input": Queue(), - "glb_queue_reset_msg": Queue(), - } - - -glb_uid_dict = defaultdict(init_uid_queues) - - -def send_msg( - msg: str, - is_player: bool = False, - role: Optional[str] = None, - uid: Optional[str] = None, - flushing: bool = False, - avatar: Optional[str] = None, - msg_id: Optional[str] = None, -) -> None: - """Sends a message to the web UI.""" - global glb_uid_dict - glb_queue_chat_msg = glb_uid_dict[uid]["glb_queue_chat_msg"] - if is_player: - glb_queue_chat_msg.put( - [ - { - "text": msg, - "name": role, - "flushing": flushing, - "avatar": avatar, - }, - None, - ], - ) - else: - glb_queue_chat_msg.put( - [ - None, - { - "text": msg, - "name": role, - "flushing": flushing, - "avatar": avatar, - "id": msg_id, - }, - ], - ) - - -def get_chat_msg(uid: Optional[str] = None) -> list: - """Retrieves the next chat message from the queue, if available.""" - global glb_uid_dict - glb_queue_chat_msg = glb_uid_dict[uid]["glb_queue_chat_msg"] - if not glb_queue_chat_msg.empty(): - line = glb_queue_chat_msg.get(block=False) - if line is not None: - return line - return [] - - -def send_player_input(msg: str, uid: Optional[str] = None) -> None: - """Sends player input to the web UI.""" - global glb_uid_dict - glb_queue_user_input = glb_uid_dict[uid]["glb_queue_user_input"] - glb_queue_user_input.put([None, msg]) - - -def get_player_input( - timeout: Optional[int] = None, - uid: Optional[str] = None, -) -> str: - """Gets player input from the web UI or command line.""" - global glb_uid_dict - glb_queue_user_input = glb_uid_dict[uid]["glb_queue_user_input"] - - if timeout: - try: - content = glb_queue_user_input.get(block=True, timeout=timeout)[1] - except Empty as exc: - raise TimeoutError("timed out") from exc - else: - content = glb_queue_user_input.get(block=True)[1] - if content == "**Reset**": - glb_uid_dict[uid] = init_uid_queues() - raise ResetException - return content - - -def send_reset_msg(uid: Optional[str] = None) -> None: - """Sends a reset message to the web UI.""" - uid = check_uuid(uid) - global glb_uid_dict - glb_queue_reset_msg = glb_uid_dict[uid]["glb_queue_reset_msg"] - glb_queue_reset_msg.put([None, "**Reset**"]) - send_player_input("**Reset**", uid) - - -def get_reset_msg(uid: Optional[str] = None) -> None: - """Retrieves a reset message from the queue, if available.""" - global glb_uid_dict - glb_queue_reset_msg = glb_uid_dict[uid]["glb_queue_reset_msg"] - if not glb_queue_reset_msg.empty(): - content = glb_queue_reset_msg.get(block=True)[1] - if content == "**Reset**": - glb_uid_dict[uid] = init_uid_queues() - raise ResetException - - -class ResetException(Exception): - """Custom exception to signal a reset action in the application.""" - - -def check_uuid(uid: Optional[str]) -> str: - """Checks whether a UUID is provided or generates a default one.""" - if not uid or uid == "": - if os.getenv("MODELSCOPE_ENVIRONMENT") == "studio": - import gradio as gr - - raise gr.Error("Please login first") - uid = "local_user" - return uid - - -def generate_image_from_name(name: str) -> str: - """Generates an image based on the hash of the given name.""" - from agentscope.manager import FileManager - - file_manager = FileManager.get_instance() - - # Using hashlib to generate a hash of the name - hash_func = hashlib.md5() - hash_func.update(name.encode("utf-8")) - hash_value = hash_func.hexdigest() - - # Extract the first 6 characters of the hash value as the hexadecimal - # representation of the color - # generate a color value between #000000 and #ffffff - color_hex = "#" + hash_value[:6] - color_rgb = Image.new("RGB", (1, 1), color_hex).getpixel((0, 0)) - - # If the image does not exist, generate and save it - width, height = 200, 200 - image = Image.new("RGB", (width, height), color_rgb) - - image_filepath = file_manager.save_image(image, f"{name}_image.png") - - return image_filepath - - -def audio2text(audio_path: str) -> str: - """Converts audio file at the given path to text using ASR.""" - - try: - from dashscope.audio.asr import RecognitionCallback, Recognition - except ImportError as e: - raise ImportError( - "The package dashscope is not found. Please install it by " - "running `pip install dashscope==1.14.1`", - ) from e - - callback = RecognitionCallback() - rec = Recognition( - model="paraformer-realtime-v1", - format="wav", - sample_rate=16000, - callback=callback, - ) - - result = rec.call(audio_path) - return " ".join([s["text"] for s in result["output"]["sentence"]]) - - -def cycle_dots(text: str, num_dots: int = 3) -> str: - """display thinking dots before agent reply""" - current_dots = len(text) - len(text.rstrip(".")) - next_dots = (current_dots + 1) % (num_dots + 1) - if next_dots == 0: - next_dots = 1 - return text.rstrip(".") + "." * next_dots - - -def user_input( - prefix: str = "User input: ", - timeout: Optional[int] = None, -) -> str: - """get user input""" - if hasattr(thread_local_data, "uid"): - get_reset_msg(uid=thread_local_data.uid) - content = get_player_input( - timeout=timeout, - uid=thread_local_data.uid, - ) - else: - if timeout: - from inputimeout import inputimeout, TimeoutOccurred - - try: - content = inputimeout(prefix, timeout=timeout) - except TimeoutOccurred as exc: - raise TimeoutError("timed out") from exc - else: - content = input(prefix) - return content +# -*- coding: utf-8 -*- +"""web ui utils""" +import os +import threading +from typing import Optional +import hashlib +from multiprocessing import Queue +from queue import Empty +from collections import defaultdict + +from PIL import Image + +SYS_MSG_PREFIX = "【SYSTEM】" + +thread_local_data = threading.local() + + +def init_uid_queues() -> dict: + """Initializes and returns a dictionary of user-specific queues.""" + return { + "glb_queue_chat_msg": Queue(), + "glb_queue_user_input": Queue(), + "glb_queue_reset_msg": Queue(), + } + + +glb_uid_dict = defaultdict(init_uid_queues) + + +def send_msg( + msg: str, + is_player: bool = False, + role: Optional[str] = None, + uid: Optional[str] = None, + flushing: bool = False, + avatar: Optional[str] = None, + msg_id: Optional[str] = None, +) -> None: + """Sends a message to the web UI.""" + global glb_uid_dict + glb_queue_chat_msg = glb_uid_dict[uid]["glb_queue_chat_msg"] + if is_player: + glb_queue_chat_msg.put( + [ + { + "text": msg, + "name": role, + "flushing": flushing, + "avatar": avatar, + }, + None, + ], + ) + else: + glb_queue_chat_msg.put( + [ + None, + { + "text": msg, + "name": role, + "flushing": flushing, + "avatar": avatar, + "id": msg_id, + }, + ], + ) + + +def get_chat_msg(uid: Optional[str] = None) -> list: + """Retrieves the next chat message from the queue, if available.""" + global glb_uid_dict + glb_queue_chat_msg = glb_uid_dict[uid]["glb_queue_chat_msg"] + if not glb_queue_chat_msg.empty(): + line = glb_queue_chat_msg.get(block=False) + if line is not None: + return line + return [] + + +def send_player_input(msg: str, uid: Optional[str] = None) -> None: + """Sends player input to the web UI.""" + global glb_uid_dict + glb_queue_user_input = glb_uid_dict[uid]["glb_queue_user_input"] + glb_queue_user_input.put([None, msg]) + + +def get_player_input( + timeout: Optional[int] = None, + uid: Optional[str] = None, +) -> str: + """Gets player input from the web UI or command line.""" + global glb_uid_dict + glb_queue_user_input = glb_uid_dict[uid]["glb_queue_user_input"] + + if timeout: + try: + content = glb_queue_user_input.get(block=True, timeout=timeout)[1] + except Empty as exc: + raise TimeoutError("timed out") from exc + else: + content = glb_queue_user_input.get(block=True)[1] + if content == "**Reset**": + glb_uid_dict[uid] = init_uid_queues() + raise ResetException + return content + + +def send_reset_msg(uid: Optional[str] = None) -> None: + """Sends a reset message to the web UI.""" + uid = check_uuid(uid) + global glb_uid_dict + glb_queue_reset_msg = glb_uid_dict[uid]["glb_queue_reset_msg"] + glb_queue_reset_msg.put([None, "**Reset**"]) + send_player_input("**Reset**", uid) + + +def get_reset_msg(uid: Optional[str] = None) -> None: + """Retrieves a reset message from the queue, if available.""" + global glb_uid_dict + glb_queue_reset_msg = glb_uid_dict[uid]["glb_queue_reset_msg"] + if not glb_queue_reset_msg.empty(): + content = glb_queue_reset_msg.get(block=True)[1] + if content == "**Reset**": + glb_uid_dict[uid] = init_uid_queues() + raise ResetException + + +class ResetException(Exception): + """Custom exception to signal a reset action in the application.""" + + +def check_uuid(uid: Optional[str]) -> str: + """Checks whether a UUID is provided or generates a default one.""" + if not uid or uid == "": + if os.getenv("MODELSCOPE_ENVIRONMENT") == "studio": + import gradio as gr + + raise gr.Error("Please login first") + uid = "local_user" + return uid + + +def generate_image_from_name(name: str) -> str: + """Generates an image based on the hash of the given name.""" + from agentscope.manager import FileManager + + file_manager = FileManager.get_instance() + + # Using hashlib to generate a hash of the name + hash_func = hashlib.md5() + hash_func.update(name.encode("utf-8")) + hash_value = hash_func.hexdigest() + + # Extract the first 6 characters of the hash value as the hexadecimal + # representation of the color + # generate a color value between #000000 and #ffffff + color_hex = "#" + hash_value[:6] + color_rgb = Image.new("RGB", (1, 1), color_hex).getpixel((0, 0)) + + # If the image does not exist, generate and save it + width, height = 200, 200 + image = Image.new("RGB", (width, height), color_rgb) + + image_filepath = file_manager.save_image(image, f"{name}_image.png") + + return image_filepath + + +def audio2text(audio_path: str) -> str: + """Converts audio file at the given path to text using ASR.""" + + try: + from dashscope.audio.asr import RecognitionCallback, Recognition + except ImportError as e: + raise ImportError( + "The package dashscope is not found. Please install it by " + "running `pip install dashscope>=1.19.0`", + ) from e + + callback = RecognitionCallback() + rec = Recognition( + model="paraformer-realtime-v1", + format="wav", + sample_rate=16000, + callback=callback, + ) + + result = rec.call(audio_path) + return " ".join([s["text"] for s in result["output"]["sentence"]]) + + +def cycle_dots(text: str, num_dots: int = 3) -> str: + """display thinking dots before agent reply""" + current_dots = len(text) - len(text.rstrip(".")) + next_dots = (current_dots + 1) % (num_dots + 1) + if next_dots == 0: + next_dots = 1 + return text.rstrip(".") + "." * next_dots + + +def user_input( + prefix: str = "User input: ", + timeout: Optional[int] = None, +) -> str: + """get user input""" + if hasattr(thread_local_data, "uid"): + get_reset_msg(uid=thread_local_data.uid) + content = get_player_input( + timeout=timeout, + uid=thread_local_data.uid, + ) + else: + if timeout: + from inputimeout import inputimeout, TimeoutOccurred + + try: + content = inputimeout(prefix, timeout=timeout) + except TimeoutOccurred as exc: + raise TimeoutError("timed out") from exc + else: + content = input(prefix) + return content