Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added parameters #4

Merged
merged 5 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file removed Desktop/.DS_Store
Binary file not shown.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,12 @@ agent = Agent(
)

agent.run(
"Save a picture of a watercolor painting of a dog riding a skateboard to the desktop."
"Save a picture of a watercolor painting of a dog riding a skateboard to the assets directory."
)
```
Output image:

![Watercolor painting of a dog riding a skateboard](assets/dog_skateboard_watercolor.jpeg)

## Installation

Expand Down
Binary file added assets/dog_skateboard_watercolor.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion examples/drivers/example_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@
)

agent.run(
"Save a picture of a watercolor painting of a dog riding a skateboard to the desktop."
"Save a picture of a watercolor painting of a dog riding a skateboard to the assets directory."
)
Original file line number Diff line number Diff line change
@@ -1,18 +1,34 @@
from __future__ import annotations

import os
import time

from typing import Any
from urllib.parse import urljoin
from griptape.artifacts import ImageArtifact
import requests

from attrs import define, field, Factory
import requests
from attrs import Factory, define, field

from griptape.artifacts import ImageArtifact
from griptape.drivers import BaseImageGenerationDriver


@define
class BlackForestImageGenerationDriver(BaseImageGenerationDriver):
"""Driver for the Black Forest Labs image generation API.

Attributes:
model: Black Forest model, for example 'flux-pro-1.1', 'flux-pro', 'flux-dev', 'flux-pro-1.1-ultra'.
width: Width of the generated image. Valid for 'flux-pro-1.1', 'flux-pro', 'flux-dev' models only. Integer range from 256 to 1440. Must be a multiple of 32. Default is 1024.
height: Height of the generated image. Valid for 'flux-pro-1.1', 'flux-pro', 'flux-dev' models only. Integer range from 256 to 1440. Must be a multiple of 32. Default is 1024.
aspect_ratio: Aspect ratio of the generated image between 21:9 and 9:21. Valid for 'flux-pro-1.1-ultra' model only. Default is 16:9.
prompt_upsampling: Optional flag to perform upsampling on the prompt. Valid for `flux-pro-1.1', 'flux-pro', 'flux-dev' models only. If active, automatically modifies the prompt for more creative generation.
safety_tolerance: Optional tolerance level for input and output moderation. Valid for 'flux-pro-1.1', 'flux-pro', 'flux-dev' models only. Between 0 and 6, 0 being most strict, 6 being least strict.
seed: Optional seed for reproducing results. Default is None.
steps: Optional number of steps for the image generation process. Valid for 'flux-dev' model only. Default is None.


"""

base_url: str = field(
default="https://api.bfl.ml",
kw_only=True,
Expand All @@ -26,24 +42,53 @@ class BlackForestImageGenerationDriver(BaseImageGenerationDriver):
width: int = field(default=1024, kw_only=True)
height: int = field(default=768, kw_only=True)
sleep_interval: float = field(default=0.5, kw_only=True)
safety_tolerance: int | None = field(default=None, kw_only=True)
aspect_ratio: str = field(default=None, kw_only=True)
seed: int | None = field(default=None, kw_only=True)
prompt_upsampling: bool = field(default=False, kw_only=True)
steps: int | None = field(default=None, kw_only=True)

def try_text_to_image(
self, prompts: list[str], negative_prompts: list[str] | None = None
) -> ImageArtifact:
prompt = " ".join(prompts)

if self.width % 32 != 0 or self.height % 32 != 0:
msg = "width and height must be multiples of 32"
raise ValueError(msg)
if self.width < 256 or self.width > 1440:
raise ValueError("width must be between 256 and 1440")
if self.safety_tolerance and (
self.safety_tolerance < 0 or self.safety_tolerance > 6
):
raise ValueError("safety_tolerance must be between 0 and 6")

data: dict[str, Any] = {
"prompt": prompt,
}

if self.seed:
data["seed"] = self.seed

if self.model == "flux-pro-1.1-ultra" and self.aspect_ratio:
data["aspect_ratio"] = self.aspect_ratio

if self.model in ["flux-pro-1.1", "flux-pro", "flux-dev"]:
data["width"] = self.width
data["height"] = self.height
if self.safety_tolerance:
data["safety_tolerance"] = self.safety_tolerance
if self.prompt_upsampling:
data["prompt_upsampling"] = self.prompt_upsampling

request = requests.post(
urljoin(self.base_url, f"v1/{self.model}"),
headers={
"accept": "application/json",
"x-key": self.api_key,
"Content-Type": "application/json",
},
json={
"prompt": prompt,
"width": self.width,
"height": self.height,
},
json=data,
).json()

request_id = request["id"]
Expand Down