Skip to content

Commit

Permalink
Merge pull request #5 from mkshing/v0.2.0
Browse files Browse the repository at this point in the history
v0.2.0
  • Loading branch information
mkshing authored Apr 12, 2023
2 parents 7f47d31 + c645374 commit 9199552
Show file tree
Hide file tree
Showing 20 changed files with 3,727 additions and 386 deletions.
102 changes: 90 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,17 @@ My summary tweet is found [here](https://twitter.com/mk1stats/status/16428655051
left: LoRA, right: SVDiff


Compared with LoRA, the number of trainable parameters is 0.6 M less parameters and the file size is only <1MB (LoRA: 3.1MB)!!
Compared with LoRA, the number of trainable parameters is 0.5 M less parameters and the file size is only 1.2MB (LoRA: 3.1MB)!!

![kumamon](assets/kumamon.png)

## Updates
### 2023.4.11
- Released v0.2.0 (please see [here](https://github.com/mkshing/svdiff-pytorch/releases/tag/v0.2.0) for the details)
- Add [Single Image Editing](#single-image-editing)
![chair-result](assets/chair-result.png)
<br>"photo of a ~~pink~~ blue chair with black legs"

## Installation
```
$ pip install svdiff-pytorch
Expand All @@ -26,9 +33,10 @@ $ git clone https://github.com/mkshing/svdiff-pytorch
$ pip install -r requirements.txt
```

## Training
The following example script is for "Single-Subject Generation", which is a domain-tuning on a single object or concept (using 3-5 images). (See Section 4.1)
## Single-Subject Generation
"Single-Subject Generation" is a domain-tuning on a single object or concept (using 3-5 images). (See Section 4.1)

### Training
According to the paper, the learning rate for SVDiff needs to be 1000 times larger than the lr used for fine-tuning.

```bash
Expand All @@ -48,29 +56,32 @@ accelerate launch train_svdiff.py \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=1 \
--learning_rate=5e-3 \
--learning_rate=1e-3 \
--learning_rate_1d=1e-6 \
--train_text_encoder \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--num_class_images=200 \
--max_train_steps=800
--max_train_steps=500
```


## Inference
### Inference

```python
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
import torch

from svdiff_pytorch import load_unet_for_svdiff
from svdiff_pytorch import load_unet_for_svdiff, load_text_encoder_for_svdiff

pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"
spectral_shifts_ckpt = "spectral_shifts.safetensors-path"
unet = load_unet_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=spectral_shifts_ckpt, subfolder="unet")
spectral_shifts_ckpt_dir = "ckpt-dir-path"
unet = load_unet_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=spectral_shifts_ckpt_dir, subfolder="unet")
text_encoder = load_text_encoder_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=spectral_shifts_ckpt_dir, subfolder="text_encoder")
# load pipe
pipe = StableDiffusionPipeline.from_pretrained(
pretrained_model_name_or_path,
unet=unet,
text_encoder=text_encoder,
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.to("cuda")
Expand All @@ -82,14 +93,14 @@ You can use the following CLI too! Once it's done, you will see `grid.png` for t
```bash
python inference.py \
--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
--spectral_shifts_ckpt="spectral_shifts.safetensors-path" \
--spectral_shifts_ckpt="ckpt-dir-path" \
--prompt="A picture of a sks dog in a bucket" \
--scheduler_type="dpm_solver++" \
--num_inference_steps=25 \
--num_images_per_prompt=2
```

## Gradio
### Gradio
You can also try SVDiff-pytorch in a UI with [gradio](https://gradio.app/). This demo supports both training and inference!

[![Open in Spaces](https://huggingface.co/datasets/huggingface/badges/raw/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/svdiff-library/SVDiff-Training-UI)
Expand All @@ -103,7 +114,73 @@ $ export HF_TOKEN="YOUR_HF_TOKEN_HERE"
$ python app.py
```

## Single Image Editing
### Training
In Single Image Editing, your instance prompt should be just the description of your input image **without the identifier**.

```bash
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export INSTANCE_DIR="dir-path-to-input-image"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"

accelerate launch train_svdiff.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--class_data_dir=$CLASS_DIR \
--output_dir=$OUTPUT_DIR \
--with_prior_preservation --prior_loss_weight=1.0 \
--instance_prompt="photo of a pink chair with black legs" \
--class_prompt="photo of a chair" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=1 \
--learning_rate=1e-3 \
--learning_rate_1d=1e-6 \
--train_text_encoder \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--num_class_images=200 \
--max_train_steps=500
```

### Inference

```python
import torch
from PIL import Image
from diffusers import DDIMScheduler
from svdiff_pytorch import load_unet_for_svdiff, load_text_encoder_for_svdiff, StableDiffusionPipelineWithDDIMInversion

pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"
spectral_shifts_ckpt_dir = "ckpt-dir-path"
image = "path-to-image"
source_prompt = "prompt-for-image"
target_prompt = "prompt-you-want-to-generate"

unet = load_unet_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=spectral_shifts_ckpt_dir, subfolder="unet")
text_encoder = load_text_encoder_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=spectral_shifts_ckpt_dir, subfolder="text_encoder")
# load pipe
pipe = StableDiffusionPipelineWithDDIMInversion.from_pretrained(
pretrained_model_name_or_path,
unet=unet,
text_encoder=text_encoder,
)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.to("cuda")

# (optional) ddim inversion
# if you don't do it, inv_latents = None
image = Image.open(image).convert("RGB").resize((512, 512))
# in SVDiff, they use guidance scale=1 in ddim inversion
inv_latents = pipe.invert(source_prompt, image=image, guidance_scale=1.0).latents

image = pipe(target_prompt, latents=inv_latents).images[0]
```


## Additional Features

### Spectral Shift Scaling

![scale](assets/scale.png)
Expand Down Expand Up @@ -165,6 +242,7 @@ And, add `--enable_tome_merging` to your training arguments!
- [x] Training
- [x] Inference
- [x] Scaling spectral shifts
- [x] Support Single Image Editing
- [ ] Support multiple spectral shifts (Section 3.2)
- [ ] Cut-Mix-Unmix (Section 3.3)
- [ ] SVDiff + LoRA
Binary file added assets/chair-result.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
47 changes: 45 additions & 2 deletions inference.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import argparse
import os
from tqdm import tqdm
import random
import torch
import huggingface_hub
from transformers import CLIPTextModel
from diffusers import StableDiffusionPipeline
from diffusers.utils import is_xformers_available
from svdiff_pytorch import load_unet_for_svdiff, SCHEDULER_MAPPING, image_grid
from svdiff_pytorch import load_unet_for_svdiff, load_text_encoder_for_svdiff, SCHEDULER_MAPPING, image_grid


def parse_args():
Expand All @@ -14,7 +17,7 @@ def parse_args():
# diffusers config
parser.add_argument("--prompt", type=str, nargs="?", default="a photo of *s", help="the prompt to render")
parser.add_argument("--num_inference_steps", type=int, default=50, help="number of sampling steps")
parser.add_argument("--guidance_scale", type=float, default=1.0, help="unconditional guidance scale")
parser.add_argument("--guidance_scale", type=float, default=7.5, help="unconditional guidance scale")
parser.add_argument("--num_images_per_prompt", type=int, default=1, help="number of images per prompt")
parser.add_argument("--height", type=int, default=512, help="image height, in pixel space",)
parser.add_argument("--width", type=int, default=512, help="image width, in pixel space",)
Expand All @@ -27,6 +30,33 @@ def parse_args():
return args


def load_text_encoder(pretrained_model_name_or_path, spectral_shifts_ckpt, device, fp16=False):
if os.path.isdir(spectral_shifts_ckpt):
spectral_shifts_ckpt = os.path.join(spectral_shifts_ckpt, "spectral_shifts_te.safetensors")
elif not os.path.exists(spectral_shifts_ckpt):
# download from hub
hf_hub_kwargs = {} if hf_hub_kwargs is None else hf_hub_kwargs
try:
spectral_shifts_ckpt = huggingface_hub.hf_hub_download(spectral_shifts_ckpt, filename="spectral_shifts_te.safetensors", **hf_hub_kwargs)
except huggingface_hub.utils.EntryNotFoundError:
return CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch.float16 if fp16 else None).to(device)
if not os.path.exists(spectral_shifts_ckpt):
return CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch.float16 if fp16 else None).to(device)
text_encoder = load_text_encoder_for_svdiff(
pretrained_model_name_or_path=pretrained_model_name_or_path,
spectral_shifts_ckpt=spectral_shifts_ckpt,
subfolder="text_encoder",
)
# first perform svd and cache
for module in text_encoder.modules():
if hasattr(module, "perform_svd"):
module.perform_svd()
if fp16:
text_encoder = text_encoder.to(device, dtype=torch.float16)
return text_encoder



def main():
args = parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
Expand All @@ -40,10 +70,18 @@ def main():
module.perform_svd()
if args.fp16:
unet = unet.to(device, dtype=torch.float16)
text_encoder = load_text_encoder(
pretrained_model_name_or_path=args.pretrained_model_name_or_path,
spectral_shifts_ckpt=args.spectral_shifts_ckpt,
fp16=args.fp16,
device=device
)

# load pipe
pipe = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=unet,
text_encoder=text_encoder,
requires_safety_checker=False,
safety_checker=None,
feature_extractor=None,
Expand All @@ -67,6 +105,11 @@ def main():
for module in pipe.unet.modules():
if hasattr(module, "set_scale"):
module.set_scale(scale=args.spectral_shifts_scale)
if not isinstance(pipe.text_encoder, CLIPTextModel):
for module in pipe.text_encoder.modules():
if hasattr(module, "set_scale"):
module.set_scale(scale=args.spectral_shifts_scale)

print(f"Set spectral_shifts_scale to {args.spectral_shifts_scale}!")

if args.seed == "random_seed":
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ diffusers==0.14.0
accelerate
torchvision
safetensors
transformers>=4.25.1
transformers>=4.25.1, <=4.27.3
ftfy
tensorboard
Jinja2
Expand Down
Loading

0 comments on commit 9199552

Please sign in to comment.