diff --git a/Desktop/.DS_Store b/Desktop/.DS_Store deleted file mode 100644 index 5008ddf..0000000 Binary files a/Desktop/.DS_Store and /dev/null differ diff --git a/README.md b/README.md index 1f5aff3..6538213 100644 --- a/README.md +++ b/README.md @@ -26,9 +26,12 @@ agent = Agent( ) agent.run( - "Save a picture of a watercolor painting of a dog riding a skateboard to the desktop." + "Save a picture of a watercolor painting of a dog riding a skateboard to the assets directory." ) ``` +Output image: + +![Watercolor painting of a dog riding a skateboard](assets/dog_skateboard_watercolor.jpeg) ## Installation diff --git a/assets/dog_skateboard_watercolor.jpeg b/assets/dog_skateboard_watercolor.jpeg new file mode 100644 index 0000000..de79c75 Binary files /dev/null and b/assets/dog_skateboard_watercolor.jpeg differ diff --git a/examples/drivers/example_agent.py b/examples/drivers/example_agent.py index 76ede7f..9234342 100644 --- a/examples/drivers/example_agent.py +++ b/examples/drivers/example_agent.py @@ -20,5 +20,5 @@ ) agent.run( - "Save a picture of a watercolor painting of a dog riding a skateboard to the desktop." + "Save a picture of a watercolor painting of a dog riding a skateboard to the assets directory." ) diff --git a/griptape/black_forest/drivers/black_forest_image_generation_driver.py b/griptape/black_forest/drivers/black_forest_image_generation_driver.py index 0b18b60..9571a17 100644 --- a/griptape/black_forest/drivers/black_forest_image_generation_driver.py +++ b/griptape/black_forest/drivers/black_forest_image_generation_driver.py @@ -1,18 +1,34 @@ from __future__ import annotations + import os import time - +from typing import Any from urllib.parse import urljoin -from griptape.artifacts import ImageArtifact -import requests -from attrs import define, field, Factory +import requests +from attrs import Factory, define, field +from griptape.artifacts import ImageArtifact from griptape.drivers import BaseImageGenerationDriver @define class BlackForestImageGenerationDriver(BaseImageGenerationDriver): + """Driver for the Black Forest Labs image generation API. + + Attributes: + model: Black Forest model, for example 'flux-pro-1.1', 'flux-pro', 'flux-dev', 'flux-pro-1.1-ultra'. + width: Width of the generated image. Valid for 'flux-pro-1.1', 'flux-pro', 'flux-dev' models only. Integer range from 256 to 1440. Must be a multiple of 32. Default is 1024. + height: Height of the generated image. Valid for 'flux-pro-1.1', 'flux-pro', 'flux-dev' models only. Integer range from 256 to 1440. Must be a multiple of 32. Default is 1024. + aspect_ratio: Aspect ratio of the generated image between 21:9 and 9:21. Valid for 'flux-pro-1.1-ultra' model only. Default is 16:9. + prompt_upsampling: Optional flag to perform upsampling on the prompt. Valid for `flux-pro-1.1', 'flux-pro', 'flux-dev' models only. If active, automatically modifies the prompt for more creative generation. + safety_tolerance: Optional tolerance level for input and output moderation. Valid for 'flux-pro-1.1', 'flux-pro', 'flux-dev' models only. Between 0 and 6, 0 being most strict, 6 being least strict. + seed: Optional seed for reproducing results. Default is None. + steps: Optional number of steps for the image generation process. Valid for 'flux-dev' model only. Default is None. + + + """ + base_url: str = field( default="https://api.bfl.ml", kw_only=True, @@ -26,12 +42,45 @@ class BlackForestImageGenerationDriver(BaseImageGenerationDriver): width: int = field(default=1024, kw_only=True) height: int = field(default=768, kw_only=True) sleep_interval: float = field(default=0.5, kw_only=True) + safety_tolerance: int | None = field(default=None, kw_only=True) + aspect_ratio: str = field(default=None, kw_only=True) + seed: int | None = field(default=None, kw_only=True) + prompt_upsampling: bool = field(default=False, kw_only=True) + steps: int | None = field(default=None, kw_only=True) def try_text_to_image( self, prompts: list[str], negative_prompts: list[str] | None = None ) -> ImageArtifact: prompt = " ".join(prompts) + if self.width % 32 != 0 or self.height % 32 != 0: + msg = "width and height must be multiples of 32" + raise ValueError(msg) + if self.width < 256 or self.width > 1440: + raise ValueError("width must be between 256 and 1440") + if self.safety_tolerance and ( + self.safety_tolerance < 0 or self.safety_tolerance > 6 + ): + raise ValueError("safety_tolerance must be between 0 and 6") + + data: dict[str, Any] = { + "prompt": prompt, + } + + if self.seed: + data["seed"] = self.seed + + if self.model == "flux-pro-1.1-ultra" and self.aspect_ratio: + data["aspect_ratio"] = self.aspect_ratio + + if self.model in ["flux-pro-1.1", "flux-pro", "flux-dev"]: + data["width"] = self.width + data["height"] = self.height + if self.safety_tolerance: + data["safety_tolerance"] = self.safety_tolerance + if self.prompt_upsampling: + data["prompt_upsampling"] = self.prompt_upsampling + request = requests.post( urljoin(self.base_url, f"v1/{self.model}"), headers={ @@ -39,11 +88,7 @@ def try_text_to_image( "x-key": self.api_key, "Content-Type": "application/json", }, - json={ - "prompt": prompt, - "width": self.width, - "height": self.height, - }, + json=data, ).json() request_id = request["id"]