diff --git a/.gitignore b/.gitignore
index 90b9508..2722475 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,8 +1,10 @@
+__pycache__
model/__pycache__
model/DensePose/__pycache__
model/SCHP/__pycache__
+model/SCHP/*/__pycache__
resource/demo/output
resource/demo/example/.DS_Store
-model/SCHP/*/__pycache__
densepose_
-.vscode
\ No newline at end of file
+.vscode
+playground.py
diff --git a/app.py b/app.py
index 3700a75..760a72a 100644
--- a/app.py
+++ b/app.py
@@ -126,6 +126,7 @@ def image_grid(imgs, rows, cols):
device='cuda',
)
+
def submit_function(
person_image,
cloth_image,
@@ -170,19 +171,19 @@ def submit_function(
mask = mask_processor.blur(mask, blur_factor=9)
# Inference
- try:
- result_image = pipeline(
- image=person_image,
- condition_image=cloth_image,
- mask=mask,
- num_inference_steps=num_inference_steps,
- guidance_scale=guidance_scale,
- generator=generator
- )[0]
- except Exception as e:
- raise gr.Error(
- "An error occurred. Please try again later: {}".format(e)
- )
+ # try:
+ result_image = pipeline(
+ image=person_image,
+ condition_image=cloth_image,
+ mask=mask,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ generator=generator
+ )[0]
+ # except Exception as e:
+ # raise gr.Error(
+ # "An error occurred. Please try again later: {}".format(e)
+ # )
# Post-process
masked_person = vis_mask(person_image, mask)
@@ -263,6 +264,11 @@ def app_gradio():
)
+ submit = gr.Button("Submit")
+ gr.Markdown(
+ '
!!! Click only Once, Wait for Delay !!!'
+ )
+
gr.Markdown(
'Advanced options can adjust details:
1. `Inference Step` may enhance details;
2. `CFG` is highly correlated with saturation;
3. `Random seed` may improve pseudo-shadow.'
)
@@ -283,7 +289,7 @@ def app_gradio():
choices=["result only", "input & result", "input & mask & result"],
value="input & mask & result",
)
- submit = gr.Button("Submit")
+
with gr.Column(scale=2, min_width=500):
result_image = gr.Image(interactive=False, label="Result")
with gr.Row():
diff --git a/model/pipeline.py b/model/pipeline.py
index 9bbeaa1..b717dcd 100644
--- a/model/pipeline.py
+++ b/model/pipeline.py
@@ -1,26 +1,23 @@
import inspect
import os
from typing import Union
-import PIL
-from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler
-from diffusers.utils.torch_utils import randn_tensor
+import PIL
+import numpy as np
import torch
import tqdm
+from accelerate import load_checkpoint_in_model
+from diffusers import AutoencoderKL, DDIMScheduler, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion.safety_checker import \
+ StableDiffusionSafetyChecker
+from diffusers.utils.torch_utils import randn_tensor
+from huggingface_hub import snapshot_download
+from transformers import CLIPImageProcessor
from model.attn_processor import SkipAttnProcessor
from model.utils import get_trainable_module, init_adapter
-
-from accelerate import load_checkpoint_in_model
-from huggingface_hub import hf_hub_download, snapshot_download
-from utils import (
- compute_vae_encodings,
- numpy_to_pil,
- prepare_image,
- prepare_mask_image,
- resize_and_crop,
- resize_and_padding,
-)
+from utils import (compute_vae_encodings, numpy_to_pil, prepare_image,
+ prepare_mask_image, resize_and_crop, resize_and_padding)
class CatVTONPipeline:
@@ -39,6 +36,8 @@ def __init__(
self.noise_scheduler = DDIMScheduler.from_pretrained(base_ckpt, subfolder="scheduler")
self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device, dtype=weight_dtype)
+ self.feature_extractor = CLIPImageProcessor.from_pretrained(base_ckpt, subfolder="feature_extractor")
+ self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(base_ckpt, subfolder="safety_checker").to(device, dtype=weight_dtype)
self.unet = UNet2DConditionModel.from_pretrained(base_ckpt, subfolder="unet").to(device, dtype=weight_dtype)
init_adapter(self.unet, cross_attn_cls=SkipAttnProcessor) # Skip Cross-Attention
self.attn_modules = get_trainable_module(self.unet, "attention")
@@ -66,7 +65,16 @@ def auto_attn_ckpt_load(self, attn_ckpt, version):
print(f"Downloaded {attn_ckpt} to {repo_path}")
load_checkpoint_in_model(self.attn_modules, os.path.join(repo_path, sub_folder, 'attention'))
-
+ def run_safety_checker(self, image):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ safety_checker_input = self.feature_extractor(image, return_tensors="pt").to(self.device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(self.weight_dtype)
+ )
+ return image, has_nsfw_concept
+
def check_inputs(self, image, condition_image, mask, width, height):
if isinstance(image, torch.Tensor) and isinstance(condition_image, torch.Tensor) and isinstance(mask, torch.Tensor):
return image, condition_image, mask
@@ -190,4 +198,14 @@ def __call__(
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
image = numpy_to_pil(image)
+
+ # Safety Check
+ current_script_directory = os.path.dirname(os.path.realpath(__file__))
+ nsfw_image = os.path.join(os.path.dirname(current_script_directory), 'resource', 'img', 'NSFW.jpg')
+ nsfw_image = PIL.Image.open(nsfw_image).resize(image[0].size)
+ image_np = np.array(image)
+ _, has_nsfw_concept = self.run_safety_checker(image=image_np)
+ for i, not_safe in enumerate(has_nsfw_concept):
+ if not_safe:
+ image[i] = nsfw_image
return image
diff --git a/resource/demo/example/person/women/049713_0.jpg b/resource/demo/example/person/women/049713_0.jpg
new file mode 100644
index 0000000..71732c4
Binary files /dev/null and b/resource/demo/example/person/women/049713_0.jpg differ
diff --git a/resource/demo/example/person/women/Eva_0.png b/resource/demo/example/person/women/Eva_0.png
deleted file mode 100644
index 8cc93d7..0000000
Binary files a/resource/demo/example/person/women/Eva_0.png and /dev/null differ
diff --git a/resource/demo/example/person/women/Yaqi_0.png b/resource/demo/example/person/women/Yaqi_0.png
deleted file mode 100644
index 66b7e23..0000000
Binary files a/resource/demo/example/person/women/Yaqi_0.png and /dev/null differ
diff --git a/resource/demo/example/person/women/model_8.png b/resource/demo/example/person/women/model_8.png
new file mode 100644
index 0000000..a11a417
Binary files /dev/null and b/resource/demo/example/person/women/model_8.png differ
diff --git a/resource/img/NSFW.jpg b/resource/img/NSFW.jpg
new file mode 100644
index 0000000..48d52c2
Binary files /dev/null and b/resource/img/NSFW.jpg differ