-
Notifications
You must be signed in to change notification settings - Fork 291
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add: SDXL colab * update: readme * move to diffusers, clean and colab link --------- Co-authored-by: Thomas Capelle <[email protected]>
- Loading branch information
1 parent
a29b730
commit bdc7f39
Showing
2 changed files
with
357 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,354 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"<a href=\"https://colab.research.google.com/github/wandb/examples/blob/master/colabs/diffusers/sdxl-text-to-image.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Stable-Diffusion XL 1.0 using 🤗 Diffusers\n", | ||
"\n", | ||
"This notebook demonstrates the following:\n", | ||
"- Performing text-conditional image-generations using [🤗 Diffusers](https://huggingface.co/docs/diffusers).\n", | ||
"- Using the Stable Diffusion XL Refiner pipeline to further refine the outputs of the base model.\n", | ||
"- Manage image generation experiments using [Weights & Biases](http://wandb.ai/geekyrakshit).\n", | ||
"- Log the prompts and generated images to [Weigts & Biases](http://wandb.ai/geekyrakshit) for visalization." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Installing the Dependencies" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!pip install -qq diffusers[\"torch\"] transformers wandb" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"import wandb\n", | ||
"from diffusers import DiffusionPipeline, EulerDiscreteScheduler" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Experiment Management using Weights & Biases\n", | ||
"\n", | ||
"Managing our image generation experiments is crucial for the sake of reproducibility. Hence we sync all the configs of our experiments with our Weights & Biases run. This stores all the configs of the experiments, right from the prompts to the refinement technque and the configuration of the scheduler." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# initialize a wandb run\n", | ||
"wandb.init(project=\"stable-diffusion-xl\", entity=\"geekyrakshit\", job_type=\"text-to-image\", save_code=True)\n", | ||
"\n", | ||
"# define experiment configs\n", | ||
"config = wandb.config\n", | ||
"config.stable_diffusion_checkpoint = \"stabilityai/stable-diffusion-xl-base-1.0\"\n", | ||
"config.refiner_checkpoint = \"stabilityai/stable-diffusion-xl-refiner-1.0\"\n", | ||
"config.offload_to_cpu = False\n", | ||
"config.compile_model = False\n", | ||
"config.prompt_1 = \"a photograph of an evil and vile looking demon in Bengali attire eating fish. The demon has large and bloody teeth. The demon is sitting on the branches of a giant Banyan tree, dimly lit, bluish and dark color palette, realistic, 8k\"\n", | ||
"config.prompt_2 = \"\" # Leave blank if you want both text encoders to use the same prompt\n", | ||
"config.negative_prompt_1 = \"static, painting, illustration, sd character, low quality, low resolution, greyscale, monochrome, nose, cropped, lowres, jpeg artifacts, deformed iris, deformed pupils, bad eyes, semi-realistic worst quality, bad lips, deformed mouth, deformed face, deformed fingers, deformed toes standing still, posing\"\n", | ||
"config.negative_prompt_2 = \"static, painting, illustration, sd character, low quality, low resolution, greyscale, monochrome, nose, cropped, lowres, jpeg artifacts, deformed iris, deformed pupils, bad eyes, semi-realistic worst quality, bad lips, deformed mouth, deformed face, deformed fingers, deformed toes standing still, posing\"\n", | ||
"config.seed = None\n", | ||
"config.use_ensemble_of_experts = True\n", | ||
"config.num_inference_steps = 100\n", | ||
"config.num_refinement_steps = 150\n", | ||
"config.high_noise_fraction = 0.8 # Set explicitly only if config.use_ensemble_of_experts is True\n", | ||
"config.scheduler_kwargs = {\n", | ||
" \"beta_end\": 0.012,\n", | ||
" \"beta_schedule\": \"scaled_linear\", # one of [\"linear\", \"scaled_linear\"]\n", | ||
" \"beta_start\": 0.00085,\n", | ||
" \"interpolation_type\": \"linear\", # one of [\"linear\", \"log_linear\"]\n", | ||
" \"num_train_timesteps\": 1000,\n", | ||
" \"prediction_type\": \"epsilon\", # one of [\"epsilon\", \"sample\", \"v_prediction\"]\n", | ||
" \"steps_offset\": 1,\n", | ||
" \"timestep_spacing\": \"leading\", # one of [\"linspace\", \"leading\"]\n", | ||
" \"trained_betas\": None,\n", | ||
" \"use_karras_sigmas\": False,\n", | ||
"}" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We can make the experiment deterministic based on the seed specified in the experiment configs." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"if config.seed is not None:\n", | ||
" generator = [torch.Generator(device=\"cuda\").manual_seed(config.seed)]\n", | ||
"else:\n", | ||
" generator = [torch.Generator(device=\"cuda\")]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Creating the Diffusion Pipelines\n", | ||
"\n", | ||
"For performing text-conditional image generation, we use the `diffusers` library to define the diffusion pipelines corresponding to the base SDXL model and the SDXL refinement model." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Define base model\n", | ||
"pipe = DiffusionPipeline.from_pretrained(\n", | ||
" config.stable_diffusion_checkpoint,\n", | ||
" torch_dtype=torch.float16,\n", | ||
" variant=\"fp16\",\n", | ||
" use_safetensors=True,\n", | ||
" scheduler=EulerDiscreteScheduler(**config.scheduler_kwargs),\n", | ||
")\n", | ||
"\n", | ||
"# Offload to CPU in case of OOM\n", | ||
"if config.offload_to_cpu:\n", | ||
" pipe.enable_model_cpu_offload()\n", | ||
"else:\n", | ||
" pipe.to(\"cuda\")\n", | ||
"\n", | ||
"# Compile model using `torch.compile`, this might give a significant speedup\n", | ||
"if config.compile_model:\n", | ||
" pipe.unet = torch.compile(pipe.unet, mode=\"reduce-overhead\", fullgraph=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Define base model\n", | ||
"refiner = DiffusionPipeline.from_pretrained(\n", | ||
" config.refiner_checkpoint,\n", | ||
" text_encoder_2=pipe.text_encoder_2,\n", | ||
" vae=pipe.vae,\n", | ||
" torch_dtype=torch.float16,\n", | ||
" use_safetensors=True,\n", | ||
" variant=\"fp16\",\n", | ||
" scheduler=EulerDiscreteScheduler(**config.scheduler_kwargs),\n", | ||
")\n", | ||
"\n", | ||
"# Offload to CPU in case of OOM\n", | ||
"if config.offload_to_cpu:\n", | ||
" refiner.enable_model_cpu_offload()\n", | ||
"else:\n", | ||
" refiner.to(\"cuda\")\n", | ||
"\n", | ||
"# Compile model using `torch.compile`, this might give a significant speedup\n", | ||
"if config.compile_model:\n", | ||
" refiner.unet = torch.compile(refiner.unet, mode=\"reduce-overhead\", fullgraph=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We now define a utility function to postprocess the latents obtained from the base model." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def postprocess_latent(latent):\n", | ||
" vae_output = pipe.vae.decode(\n", | ||
" latent.images / pipe.vae.config.scaling_factor, return_dict=False\n", | ||
" )[0].detach()\n", | ||
" return pipe.image_processor.postprocess(vae_output, output_type=\"pil\")[0]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Text-to-Image Generation\n", | ||
"\n", | ||
"Now, we pass the prompts and the negative prompts to the base model and then pass the output to the refiner for firther refinement. In order to know more about the different refinement techniques that can be used with SDXL, you can check [`diffusers` docs](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"if config.use_ensemble_of_experts:\n", | ||
" latent = pipe(\n", | ||
" prompt=config.prompt_1 if config.prompt_1 != \"\" else None,\n", | ||
" prompt_2=config.prompt_2 if config.prompt_2 != \"\" else None,\n", | ||
" negative_prompt=config.negative_prompt_1 if config.negative_prompt_1 != \"\" else None,\n", | ||
" negative_prompt_2=config.negative_prompt_2 if config.negative_prompt_2 != \"\" else None,\n", | ||
" output_type=\"latent\",\n", | ||
" num_inference_steps=config.num_inference_steps,\n", | ||
" denoising_end=config.high_noise_fraction,\n", | ||
" generator=generator,\n", | ||
" )\n", | ||
"else:\n", | ||
" latent = pipe(\n", | ||
" prompt=config.prompt_1 if config.prompt_1 != \"\" else None,\n", | ||
" prompt_2=config.prompt_2 if config.prompt_2 != \"\" else None,\n", | ||
" negative_prompt=config.negative_prompt_1 if config.negative_prompt_1 != \"\" else None,\n", | ||
" negative_prompt_2=config.negative_prompt_2 if config.negative_prompt_2 != \"\" else None,\n", | ||
" output_type=\"latent\",\n", | ||
" num_inference_steps=config.num_inference_steps,\n", | ||
" generator=generator,\n", | ||
" )\n", | ||
"unrefined_image = postprocess_latent(latent)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"if config.use_ensemble_of_experts:\n", | ||
" refined_image = refiner(\n", | ||
" prompt=config.prompt_1 if config.prompt_1 != \"\" else None,\n", | ||
" prompt_2=config.prompt_2 if config.prompt_2 != \"\" else None,\n", | ||
" negative_prompt=config.negative_prompt_1 if config.negative_prompt_1 != \"\" else None,\n", | ||
" negative_prompt_2=config.negative_prompt_2 if config.negative_prompt_2 != \"\" else None,\n", | ||
" image=latent.images,\n", | ||
" num_inference_steps=config.num_refinement_steps,\n", | ||
" denoising_start=config.high_noise_fraction,\n", | ||
" generator=generator,\n", | ||
" ).images[0]\n", | ||
"else:\n", | ||
" refined_image = refiner(\n", | ||
" prompt=config.prompt_1 if config.prompt_1 != \"\" else None,\n", | ||
" prompt_2=config.prompt_2 if config.prompt_2 != \"\" else None,\n", | ||
" negative_prompt=config.negative_prompt_1 if config.negative_prompt_1 != \"\" else None,\n", | ||
" negative_prompt_2=config.negative_prompt_2 if config.negative_prompt_2 != \"\" else None,\n", | ||
" image=latent.images[0][None, :],\n", | ||
" generator=generator,\n", | ||
" ).images[0]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Logging the Images to Weights & Biases\n", | ||
"\n", | ||
"Now, we log the images to Weights & Biases. This enables us to:\n", | ||
"\n", | ||
"- Visualize our generations\n", | ||
"- Examine the generated images across different images\n", | ||
"- Ensure reproducibility of the experiments" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Create a [wandb table](https://docs.wandb.ai/guides/tables)\n", | ||
"table = wandb.Table(columns=[\n", | ||
" \"Prompt-1\",\n", | ||
" \"Prompt-2\",\n", | ||
" \"Negative-Prompt-1\",\n", | ||
" \"Negative-Prompt-2\",\n", | ||
" \"Unrefined-Image\",\n", | ||
" \"Refined-Image\",\n", | ||
" \"Use-Ensemble-of-Experts\",\n", | ||
"])\n", | ||
"\n", | ||
"unrefined_image = wandb.Image(unrefined_image)\n", | ||
"refined_image = wandb.Image(refined_image)\n", | ||
"\n", | ||
"# Add the images to the table\n", | ||
"table.add_data(\n", | ||
" config.prompt_1,\n", | ||
" config.prompt_2,\n", | ||
" config.negative_prompt_1,\n", | ||
" config.negative_prompt_2,\n", | ||
" unrefined_image,\n", | ||
" refined_image,\n", | ||
" config.use_ensemble_of_experts,\n", | ||
")\n", | ||
"\n", | ||
"# Log the images and table to wandb\n", | ||
"wandb.log({\n", | ||
" \"Unrefined-Image\": unrefined_image,\n", | ||
" \"Refined-Image\": refined_image,\n", | ||
" \"Text-to-Image\": table\n", | ||
"})\n", | ||
"\n", | ||
"# finish the experiment\n", | ||
"wandb.finish()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Here's how you can examine your generations across multiple experiments 👇\n", | ||
"\n", | ||
"![](https://i.imgur.com/zNynGye.png)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Here's how you can manage your prompts and your generations across experiments 👇\n", | ||
"\n", | ||
"![](https://i.imgur.com/JVEXkx0.png)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"accelerator": "GPU", | ||
"colab": { | ||
"include_colab_link": true, | ||
"provenance": [], | ||
"toc_visible": true | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"name": "python3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |