diff --git a/setup.py b/setup.py index cd577d5b8..e29fdd8f9 100644 --- a/setup.py +++ b/setup.py @@ -90,6 +90,7 @@ extra_litellm_requires = ["litellm"] extra_zhipuai_requires = ["zhipuai"] extra_ollama_requires = ["ollama>=0.1.7"] +extra_sd_webuiapi_requires = ["webuiapi"] # Full requires extra_full_requires = ( @@ -102,6 +103,7 @@ + extra_litellm_requires + extra_zhipuai_requires + extra_ollama_requires + + extra_sd_webuiapi_requires ) # For online workstation @@ -140,6 +142,7 @@ "litellm": extra_litellm_requires, "zhipuai": extra_zhipuai_requires, "gemini": extra_gemini_requires, + "stablediffusion": extra_sd_webuiapi_requires, # For service functions "service": extra_service_requires, # For distribution mode diff --git a/src/agentscope/models/stablediffusion_model.py b/src/agentscope/models/stablediffusion_model.py index c8f2d9548..1d5d20f54 100644 --- a/src/agentscope/models/stablediffusion_model.py +++ b/src/agentscope/models/stablediffusion_model.py @@ -1,17 +1,14 @@ # -*- coding: utf-8 -*- """Model wrapper for stable diffusion models.""" from abc import ABC -import base64 -import json -import time -from typing import Any, Optional, Union, Sequence +from typing import Any, Union, Sequence -import requests -from loguru import logger +try: + import webuiapi +except ImportError: + webuiapi = None from . import ModelWrapperBase, ModelResponse -from ..constants import _DEFAULT_MAX_RETRIES -from ..constants import _DEFAULT_RETRY_INTERVAL from ..message import Msg from ..manager import FileManager from ..utils.common import _convert_to_str @@ -23,9 +20,10 @@ class StableDiffusionWrapperBase(ModelWrapperBase, ABC): To use SD-webui API, please 1. First download stable-diffusion-webui from https://github.com/AUTOMATIC1111/stable-diffusion-webui and - install it with 'webui-user.bat' + install it 2. Move your checkpoint to 'models/Stable-diffusion' folder - 3. Start launch.py with the '--api' parameter to start the server + 3. Start launch.py with the '--api --port=7862' parameter + 4. Install the 'webuiapi' package by 'pip install webuiapi' After that, you can use the SD-webui API and query the available parameters on the http://localhost:7862/docs page """ @@ -35,15 +33,10 @@ class StableDiffusionWrapperBase(ModelWrapperBase, ABC): def __init__( self, config_name: str, - host: str = "127.0.0.1:7862", - base_url: Optional[Union[str, None]] = None, - use_https: bool = False, generate_args: dict = None, - headers: dict = None, options: dict = None, - timeout: int = 30, - max_retries: int = _DEFAULT_MAX_RETRIES, - retry_interval: int = _DEFAULT_RETRY_INTERVAL, + host: str = "127.0.0.1", + port: int = 7862, **kwargs: Any, ) -> None: """ @@ -52,46 +45,29 @@ def __init__( Args: config_name (`str`): The name of the model config. - host (`str`, default `"127.0.0.1:7862"`): - The host port of the stable-diffusion webui server. - base_url (`str`, default `None`): - Base URL for the stable-diffusion webui services. - Generated from host and use_https if not provided. - use_https (`bool`, default `False`): - Whether to generate the base URL with HTTPS protocol or HTTP. generate_args (`dict`, default `None`): The extra keyword arguments used in SD api generation, e.g. `{"steps": 50}`. - headers (`dict`, default `None`): - HTTP request headers. options (`dict`, default `None`): - The keyword arguments to change the webui settings + The keyword arguments to change the sd-webui settings such as model or CLIP skip, this changes will persist. e.g. `{"sd_model_checkpoint": "Anything-V3.0-pruned"}`. + host (`str`, default `"127.0.0.1"`): + The host of the stable-diffusion webui server. + port (`int`, default `7862`): + The port of the stable-diffusion webui server. """ - # Construct base_url based on HTTPS usage if not provided - if base_url is None: - if use_https: - base_url = f"https://{host}" - else: - base_url = f"http://{host}" - - self.base_url = base_url - self.options_url = f"{base_url}/sdapi/v1/options" + # Initialize the SD-webui API + self.api = webuiapi.WebUIApi(host=host, port=port, **kwargs) self.generate_args = generate_args or {} - # Initialize the HTTP session and update the request headers - self.session = requests.Session() - if headers: - self.session.headers.update(headers) - # Set options if provided if options: - self._set_options(options) + self.api.set_options(options) # Get the default model name from the web-options model_name = ( - self._get_options()["sd_model_checkpoint"].split("[")[0].strip() + self.api.get_options()["sd_model_checkpoint"].split("[")[0].strip() ) # Update the model name if self.generate_args.get("override_settings"): @@ -102,116 +78,29 @@ def __init__( super().__init__(config_name=config_name, model_name=model_name) - self.timeout = timeout - self.max_retries = max_retries - self.retry_interval = retry_interval - - @property - def url(self) -> str: - """SD-webui API endpoint URL""" - raise NotImplementedError() - - def _get_options(self) -> dict: - response = self.session.get(url=self.options_url) - if response.status_code != 200: - logger.error(f"Failed to get options with {response.json()}") - raise RuntimeError(f"Failed to get options with {response.json()}") - return response.json() - - def _set_options(self, options: dict) -> None: - response = self.session.post(url=self.options_url, json=options) - if response.status_code != 200: - logger.error(json.dumps(options, indent=4)) - raise RuntimeError(f"Failed to set options with {response.json()}") - logger.info("Optionsset successfully") - - def _invoke_model(self, payload: dict) -> dict: - """Invoke SD webui API and record the invocation if needed""" - # step1: prepare post requests - for i in range(1, self.max_retries + 1): - response = self.session.post(url=self.url, json=payload) - - if response.status_code == requests.codes.ok: - break - - if i < self.max_retries: - logger.warning( - f"Failed to call the model with " - f"requests.codes == {response.status_code}, retry " - f"{i + 1}/{self.max_retries} times", - ) - time.sleep(i * self.retry_interval) - - # step2: record model invocation - # record the model api invocation, which will be skipped if - # `FileManager.save_api_invocation` is `False` - self._save_model_invocation( - arguments=payload, - response=response.json(), - ) - - # step3: return the response json - if response.status_code == requests.codes.ok: - return response.json() - else: - logger.error( - json.dumps({"url": self.url, "json": payload}, indent=4), - ) - raise RuntimeError( - f"Failed to call the model with {response.json()}", - ) - - def _parse_response(self, response: dict) -> ModelResponse: - """Parse the response json data into ModelResponse""" - return ModelResponse(raw=response) - class StableDiffusionImageSynthesisWrapper(StableDiffusionWrapperBase): """Stable Diffusion Text-to-Image (txt2img) API Wrapper""" model_type: str = "sd_txt2img" - @property - def url(self) -> str: - return f"{self.base_url}/sdapi/v1/txt2img" - - def _parse_response(self, response: dict) -> ModelResponse: - session_parameters = response["parameters"] - size = f"{session_parameters['width']}*{session_parameters['height']}" - image_count = ( - session_parameters["batch_size"] * session_parameters["n_iter"] - ) - - self.monitor.update_image_tokens( - model_name=self.model_name, - image_count=image_count, - resolution=size, - ) - - # Get image base64code as a list - images = response["images"] - b64_images = [base64.b64decode(image) for image in images] - - file_manager = FileManager.get_instance() - # Return local url - image_urls = [file_manager.save_image(_) for _ in b64_images] - text = "Image saved to " + "\n".join(image_urls) - return ModelResponse(text=text, image_urls=image_urls, raw=response) - def __call__( self, prompt: str, + save_local: bool = True, **kwargs: Any, ) -> ModelResponse: """ Args: prompt (`str`): The prompt string to generate images from. + save_local (`bool`, default `True`): + Whether to save the generated images locally. **kwargs (`Any`): The keyword arguments to SD-webui txt2img API, e.g. `n_iter`, `steps`, `seed`, `width`, etc. Please refer to https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API - or http://localhost:7860/docs + or http://localhost:7862/docs for more detailed arguments. Returns: `ModelResponse`: @@ -226,10 +115,61 @@ def __call__( } # step2: forward to generate response - response = self._invoke_model(payload) + response = self.api.txt2img(**payload) + + # step3: save model invocation and update monitor + self._save_model_invocation_and_update_monitor( + payload=payload, + response=response.json, + ) + + # step4: parse the response + PIL_images = response.images + + file_manager = FileManager.get_instance() + if save_local: + # Save images + image_urls = [file_manager.save_image(_) for _ in PIL_images] + text = "Image saved to " + "\n".join(image_urls) + else: + image_urls = PIL_images + text = None + + return ModelResponse( + text=text, + image_urls=image_urls, + raw=response.json, + ) + + def _save_model_invocation_and_update_monitor( + self, + payload: dict, + response: dict, + ) -> None: + """Save the model invocation and update the monitor accordingly. + + Args: + kwargs (`dict`): + The keyword arguments to the DashScope chat API. + response (`dict`): + The response object returned by the DashScope chat API. + """ + self._save_model_invocation( + arguments=payload, + response=response, + ) + + session_parameters = response["parameters"] + size = f"{session_parameters['width']}*{session_parameters['height']}" + image_count = ( + session_parameters["batch_size"] * session_parameters["n_iter"] + ) - # step3: parse the response - return self._parse_response(response) + self.monitor.update_image_tokens( + model_name=self.model_name, + image_count=image_count, + resolution=size, + ) def format(self, *args: Union[Msg, Sequence[Msg]]) -> str: # This is a temporary implementation to focus on the prompt diff --git a/src/agentscope/service/__init__.py b/src/agentscope/service/__init__.py index 7d33e6501..20c1af051 100644 --- a/src/agentscope/service/__init__.py +++ b/src/agentscope/service/__init__.py @@ -45,6 +45,7 @@ openai_edit_image, openai_create_image_variation, ) +from .multi_modality.stablediffusion_services import sd_text_to_image from .service_response import ServiceResponse from .service_toolkit import ServiceToolkit @@ -117,6 +118,7 @@ def get_help() -> None: "openai_image_to_text", "openai_edit_image", "openai_create_image_variation", + "sd_text_to_image", "tripadvisor_search", "tripadvisor_search_location_photos", "tripadvisor_search_location_details", diff --git a/src/agentscope/service/multi_modality/stablediffusion_services.py b/src/agentscope/service/multi_modality/stablediffusion_services.py new file mode 100644 index 000000000..4547aa115 --- /dev/null +++ b/src/agentscope/service/multi_modality/stablediffusion_services.py @@ -0,0 +1,122 @@ +# -*- coding: utf-8 -*- +"""Use StableDiffusion-webui API to generate images +""" +import os +from typing import Optional + +from ...models import StableDiffusionImageSynthesisWrapper + +from ...manager import FileManager +from ..service_response import ( + ServiceResponse, + ServiceExecStatus, +) +from ...utils.common import ( + _get_timestamp, + _generate_random_code, +) +from ...constants import _DEFAULT_IMAGE_NAME + + +def sd_text_to_image( + prompt: str, + n_iter: int = 1, + width: int = 1024, + height: int = 1024, + options: dict = None, + baseurl: str = None, + save_dir: Optional[str] = None, +) -> ServiceResponse: + """Generate image(s) based on the given prompt, and return image url(s). + + Args: + prompt (`str`): + The text prompt to generate image. + n (`int`, defaults to `1`): + The number of images to generate. + width (`int`, defaults to `1024`): + Width of the image. + height (`int`, defaults to `1024`): + Height of the image. + options (`dict`, defaults to `None`): + The options to override the sd-webui default settings. + If not specified, will use the default settings. + baseurl (`str`, defaults to `None`): + The base url of the sd-webui. + save_dir (`Optional[str]`, defaults to 'None'): + The directory to save the generated images. If not specified, + will return the web urls. + + Returns: + ServiceResponse: + A dictionary with two variables: `status` and`content`. + If `status` is ServiceExecStatus.SUCCESS, + the `content` is a dict with key 'fig_paths" and + value is a list of the paths to the generated images. + + Example: + + .. code-block:: python + + prompt = "A beautiful sunset in the mountains" + print(sd_text_to_image(prompt, 2)) + + > { + > 'status': 'SUCCESS', + > 'content': {'image_urls': ['IMAGE_URL1', 'IMAGE_URL2']} + > } + + """ + text2img = StableDiffusionImageSynthesisWrapper( + config_name="sd-text-to-image-service", # Just a placeholder + baseurl=baseurl, + ) + try: + kwargs = {"n_iter": n_iter, "width": width, "height": height} + if options: + kwargs["override_settings"] = options + + res = text2img(prompt=prompt, save_local=False, **kwargs) + images = res.image_urls + + # save images to save_dir + if images is not None: + if save_dir: + os.makedirs(save_dir, exist_ok=True) + urls_local = [] + # Obtain the image file names in the url + for image in images: + image_name = _DEFAULT_IMAGE_NAME.format( + _get_timestamp( + "%Y%m%d-%H%M%S", + ), + _generate_random_code(), + ) + image_path = os.path.abspath( + os.path.join(save_dir, image_name), + ) + # Download the image + image.save(image_path) + urls_local.append(image_path) + return ServiceResponse( + ServiceExecStatus.SUCCESS, + {"image_urls": urls_local}, + ) + else: + # Return the default urls + file_manager = FileManager.get_instance() + urls = [file_manager.save_image(_) for _ in images] + return ServiceResponse( + ServiceExecStatus.SUCCESS, + {"image_urls": urls}, + ) + else: + return ServiceResponse( + ServiceExecStatus.ERROR, + "Error: Failed to generate images", + ) + except Exception as e: + return ServiceResponse( + ServiceExecStatus.ERROR, + str(e), + )