Skip to content

Commit

Permalink
chore: Add SafetyChecker.
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhengChong committed Jul 31, 2024
1 parent 87f1d98 commit a24746c
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 31 deletions.
6 changes: 4 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
__pycache__
model/__pycache__
model/DensePose/__pycache__
model/SCHP/__pycache__
model/SCHP/*/__pycache__
resource/demo/output
resource/demo/example/.DS_Store
model/SCHP/*/__pycache__
densepose_
.vscode
.vscode
playground.py
34 changes: 20 additions & 14 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def image_grid(imgs, rows, cols):
device='cuda',
)


def submit_function(
person_image,
cloth_image,
Expand Down Expand Up @@ -170,19 +171,19 @@ def submit_function(
mask = mask_processor.blur(mask, blur_factor=9)

# Inference
try:
result_image = pipeline(
image=person_image,
condition_image=cloth_image,
mask=mask,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator
)[0]
except Exception as e:
raise gr.Error(
"An error occurred. Please try again later: {}".format(e)
)
# try:
result_image = pipeline(
image=person_image,
condition_image=cloth_image,
mask=mask,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator
)[0]
# except Exception as e:
# raise gr.Error(
# "An error occurred. Please try again later: {}".format(e)
# )

# Post-process
masked_person = vis_mask(person_image, mask)
Expand Down Expand Up @@ -263,6 +264,11 @@ def app_gradio():
)


submit = gr.Button("Submit")
gr.Markdown(
'<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
)

gr.Markdown(
'<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
)
Expand All @@ -283,7 +289,7 @@ def app_gradio():
choices=["result only", "input & result", "input & mask & result"],
value="input & mask & result",
)
submit = gr.Button("Submit")

with gr.Column(scale=2, min_width=500):
result_image = gr.Image(interactive=False, label="Result")
with gr.Row():
Expand Down
48 changes: 33 additions & 15 deletions model/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,23 @@
import inspect
import os
from typing import Union
import PIL
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler
from diffusers.utils.torch_utils import randn_tensor

import PIL
import numpy as np
import torch
import tqdm
from accelerate import load_checkpoint_in_model
from diffusers import AutoencoderKL, DDIMScheduler, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion.safety_checker import \
StableDiffusionSafetyChecker
from diffusers.utils.torch_utils import randn_tensor
from huggingface_hub import snapshot_download
from transformers import CLIPImageProcessor

from model.attn_processor import SkipAttnProcessor
from model.utils import get_trainable_module, init_adapter

from accelerate import load_checkpoint_in_model
from huggingface_hub import hf_hub_download, snapshot_download
from utils import (
compute_vae_encodings,
numpy_to_pil,
prepare_image,
prepare_mask_image,
resize_and_crop,
resize_and_padding,
)
from utils import (compute_vae_encodings, numpy_to_pil, prepare_image,
prepare_mask_image, resize_and_crop, resize_and_padding)


class CatVTONPipeline:
Expand All @@ -39,6 +36,8 @@ def __init__(

self.noise_scheduler = DDIMScheduler.from_pretrained(base_ckpt, subfolder="scheduler")
self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device, dtype=weight_dtype)
self.feature_extractor = CLIPImageProcessor.from_pretrained(base_ckpt, subfolder="feature_extractor")
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(base_ckpt, subfolder="safety_checker").to(device, dtype=weight_dtype)
self.unet = UNet2DConditionModel.from_pretrained(base_ckpt, subfolder="unet").to(device, dtype=weight_dtype)
init_adapter(self.unet, cross_attn_cls=SkipAttnProcessor) # Skip Cross-Attention
self.attn_modules = get_trainable_module(self.unet, "attention")
Expand Down Expand Up @@ -66,7 +65,16 @@ def auto_attn_ckpt_load(self, attn_ckpt, version):
print(f"Downloaded {attn_ckpt} to {repo_path}")
load_checkpoint_in_model(self.attn_modules, os.path.join(repo_path, sub_folder, 'attention'))


def run_safety_checker(self, image):
if self.safety_checker is None:
has_nsfw_concept = None
else:
safety_checker_input = self.feature_extractor(image, return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(self.weight_dtype)
)
return image, has_nsfw_concept

def check_inputs(self, image, condition_image, mask, width, height):
if isinstance(image, torch.Tensor) and isinstance(condition_image, torch.Tensor) and isinstance(mask, torch.Tensor):
return image, condition_image, mask
Expand Down Expand Up @@ -190,4 +198,14 @@ def __call__(
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
image = numpy_to_pil(image)

# Safety Check
current_script_directory = os.path.dirname(os.path.realpath(__file__))
nsfw_image = os.path.join(os.path.dirname(current_script_directory), 'resource', 'img', 'NSFW.jpg')
nsfw_image = PIL.Image.open(nsfw_image).resize(image[0].size)
image_np = np.array(image)
_, has_nsfw_concept = self.run_safety_checker(image=image_np)
for i, not_safe in enumerate(has_nsfw_concept):
if not_safe:
image[i] = nsfw_image
return image
Binary file added resource/demo/example/person/women/049713_0.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed resource/demo/example/person/women/Eva_0.png
Binary file not shown.
Binary file removed resource/demo/example/person/women/Yaqi_0.png
Binary file not shown.
Binary file added resource/demo/example/person/women/model_8.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added resource/img/NSFW.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit a24746c

Please sign in to comment.