Skip to content

Commit

Permalink
feat: add diffusion inpainting
Browse files Browse the repository at this point in the history
  • Loading branch information
fboulnois committed Nov 11, 2022
1 parent 1a8a2d4 commit ec4308e
Showing 1 changed file with 31 additions and 6 deletions.
37 changes: 31 additions & 6 deletions docker-entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import torch
from PIL import Image
from torch import autocast
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
from diffusers import (
StableDiffusionPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
)


def cuda_device():
Expand All @@ -24,16 +28,21 @@ def skip_safety_checker(images, *args, **kwargs):
return images, False


def stable_diffusion_pipeline(model, image, half, skip, do_slice, token):
def stable_diffusion_pipeline(model, image, mask, half, skip, do_slice, token):
if token is None:
with open("token.txt") as f:
token = f.read().replace("\n", "")

diffuser = StableDiffusionPipeline

if image is not None:
diffuser = StableDiffusionImg2ImgPipeline
image = load_image(image)

if mask is not None:
diffuser = StableDiffusionInpaintPipeline
mask = load_image(mask)

dtype, rev = (torch.float16, "fp16") if half else (torch.float32, "main")

print("load pipeline start:", iso_date_time())
Expand All @@ -50,14 +59,15 @@ def stable_diffusion_pipeline(model, image, half, skip, do_slice, token):

print("loaded models after:", iso_date_time())

return pipeline, image
return pipeline, image, mask


def stable_diffusion_inference(
pipeline,
prompt,
neg_prompt,
image,
mask,
samples,
iters,
height,
Expand All @@ -79,6 +89,8 @@ def stable_diffusion_inference(
prompt,
negative_prompt=neg_prompt,
init_image=image,
image=image,
mask_image=mask,
height=height,
width=width,
num_images_per_prompt=samples,
Expand Down Expand Up @@ -162,7 +174,13 @@ def main():
"--image",
type=str,
nargs="?",
help="The input filename to use for image-to-image diffusion",
help="The input image to use for image-to-image diffusion",
)
parser.add_argument(
"--mask",
type=str,
nargs="?",
help="The input mask to use for diffusion inpainting",
)
parser.add_argument(
"--model",
Expand Down Expand Up @@ -200,15 +218,22 @@ def main():
if args.prompt0 is not None:
args.prompt = args.prompt0

pipeline, image = stable_diffusion_pipeline(
args.model, args.image, args.half, args.skip, args.attention_slicing, args.token
pipeline, image, mask = stable_diffusion_pipeline(
args.model,
args.image,
args.mask,
args.half,
args.skip,
args.attention_slicing,
args.token,
)

stable_diffusion_inference(
pipeline,
args.prompt,
args.negative_prompt,
image,
mask,
args.n_samples,
args.n_iter,
args.H,
Expand Down

0 comments on commit ec4308e

Please sign in to comment.