diff --git a/griptape/black_forest/drivers/black_forest_image_generation_driver.py b/griptape/black_forest/drivers/black_forest_image_generation_driver.py index 567f04b..444de92 100644 --- a/griptape/black_forest/drivers/black_forest_image_generation_driver.py +++ b/griptape/black_forest/drivers/black_forest_image_generation_driver.py @@ -12,6 +12,40 @@ from griptape.drivers import BaseImageGenerationDriver +def steps_validator(instance, attribute, value): + if value and (value < 1 or value > 50): + raise ValueError("steps must be between 1 and 50") + + +def size_validator(instance, attribute, value): + if value and value % 32 != 0: + raise ValueError(f"{attribute} must be a multiple of 32") + if value and value < 256 or value > 1440: + raise ValueError(f"{attribute} must be between 256 and 1440") + + +def safety_validator(instance, attribute, value): + if value and (value < 0 or value > 6): + raise ValueError("safety_tolerance must be between 0 and 6") + + +def aspect_ratio_validator(instance, attribute, value): + if value: + width, height = value.split(":") + if width < 9 or width > 21 or height < 9 or height > 21: + raise ValueError("aspect_ratio must be between 9:21 and 21:9") + + +def guidance_validator(instance, attribute, value): + if value and (value < 1.5 or value > 5): + raise ValueError("guidance must be between 1.5 and 5") + + +def interval_validator(instance, attribute, value): + if value and (value < 1 or value > 4): + raise ValueError("interval must be between 1 and 4") + + @define class BlackForestImageGenerationDriver(BaseImageGenerationDriver): """Driver for the Black Forest Labs image generation API. @@ -24,9 +58,10 @@ class BlackForestImageGenerationDriver(BaseImageGenerationDriver): prompt_upsampling: Optional flag to perform upsampling on the prompt. Valid for `flux-pro-1.1', 'flux-pro', 'flux-dev' models only. If active, automatically modifies the prompt for more creative generation. safety_tolerance: Optional tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. seed: Optional seed for reproducing results. Default is None. - steps: Optional number of steps for the image generation process. Valid for 'flux-dev' model only. Default is None. - - + steps: Optional number of steps for the image generation process. Valid for 'flux-dev' and `flux-pro` models only. Can be a value between 1 and 50. Default is None. + guidance: Optional guidance scale for image generation. High guidance scales improve prompt adherence at the cost of reduced realism. Min: 1.5, max: 5. Valid for 'flux-dev' and 'flux-pro' models only. + interval: Optional interval parameter for guidance control. Valid for 'flux-pro' model only. Value is an integer between 1 and 4. Default is None. + raw: Optional flag to generate less processed, more natural-looking images. Valid for 'flux-pro-1.1-ultra' model only. Default is False. """ base_url: str = field( @@ -39,30 +74,31 @@ class BlackForestImageGenerationDriver(BaseImageGenerationDriver): kw_only=True, metadata={"serializable": False}, ) - width: int = field(default=1024, kw_only=True) - height: int = field(default=768, kw_only=True) + width: int = field(default=1024, kw_only=True, validator=size_validator) + height: int = field(default=768, kw_only=True, validator=size_validator) sleep_interval: float = field(default=0.5, kw_only=True) - safety_tolerance: int | None = field(default=None, kw_only=True) - aspect_ratio: str = field(default=None, kw_only=True) + safety_tolerance: int | None = field( + default=None, kw_only=True, validator=safety_validator + ) + aspect_ratio: str = field( + default=None, kw_only=True, validator=aspect_ratio_validator + ) seed: int | None = field(default=None, kw_only=True) prompt_upsampling: bool = field(default=False, kw_only=True) - steps: int | None = field(default=None, kw_only=True) + steps: int | None = field(default=None, kw_only=True, validator=steps_validator) + guidance: float | None = field( + default=None, kw_only=True, validator=guidance_validator + ) + interval: int | None = field( + default=None, kw_only=True, validator=interval_validator + ) + raw: bool = field(default=False, kw_only=True) def try_text_to_image( self, prompts: list[str], negative_prompts: list[str] | None = None ) -> ImageArtifact: prompt = " ".join(prompts) - if self.width % 32 != 0 or self.height % 32 != 0: - msg = "width and height must be multiples of 32" - raise ValueError(msg) - if self.width < 256 or self.width > 1440: - raise ValueError("width must be between 256 and 1440") - if self.safety_tolerance and ( - self.safety_tolerance < 0 or self.safety_tolerance > 6 - ): - raise ValueError("safety_tolerance must be between 0 and 6") - data: dict[str, Any] = { "prompt": prompt, } @@ -72,9 +108,21 @@ def try_text_to_image( if self.safety_tolerance: data["safety_tolerance"] = self.safety_tolerance - if self.model == "flux-pro-1.1-ultra" and self.aspect_ratio: + if self.aspect_ratio and self.model == "flux-pro-1.1-ultra": data["aspect_ratio"] = self.aspect_ratio + if self.raw and self.model == "flux-pro-1.1-ultra": + data["raw"] = self.raw + + if self.guidance and self.model in ["flux-dev", "flux-pro"]: + data["guidance"] = float(self.guidance) + + if self.steps and self.model in ["flux-dev", "flux-pro"]: + data["steps"] = int(self.steps) + + if self.interval and self.model == "flux-pro": + data["interval"] = int(self.interval) + if self.model in ["flux-pro-1.1", "flux-pro", "flux-dev"]: data["width"] = self.width data["height"] = self.height