diff --git a/.gitignore b/.gitignore index 90b9508..2722475 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,10 @@ +__pycache__ model/__pycache__ model/DensePose/__pycache__ model/SCHP/__pycache__ +model/SCHP/*/__pycache__ resource/demo/output resource/demo/example/.DS_Store -model/SCHP/*/__pycache__ densepose_ -.vscode \ No newline at end of file +.vscode +playground.py diff --git a/app.py b/app.py index 3700a75..760a72a 100644 --- a/app.py +++ b/app.py @@ -126,6 +126,7 @@ def image_grid(imgs, rows, cols): device='cuda', ) + def submit_function( person_image, cloth_image, @@ -170,19 +171,19 @@ def submit_function( mask = mask_processor.blur(mask, blur_factor=9) # Inference - try: - result_image = pipeline( - image=person_image, - condition_image=cloth_image, - mask=mask, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - generator=generator - )[0] - except Exception as e: - raise gr.Error( - "An error occurred. Please try again later: {}".format(e) - ) + # try: + result_image = pipeline( + image=person_image, + condition_image=cloth_image, + mask=mask, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + generator=generator + )[0] + # except Exception as e: + # raise gr.Error( + # "An error occurred. Please try again later: {}".format(e) + # ) # Post-process masked_person = vis_mask(person_image, mask) @@ -263,6 +264,11 @@ def app_gradio(): ) + submit = gr.Button("Submit") + gr.Markdown( + '
!!! Click only Once, Wait for Delay !!!
' + ) + gr.Markdown( 'Advanced options can adjust details:
1. `Inference Step` may enhance details;
2. `CFG` is highly correlated with saturation;
3. `Random seed` may improve pseudo-shadow.
' ) @@ -283,7 +289,7 @@ def app_gradio(): choices=["result only", "input & result", "input & mask & result"], value="input & mask & result", ) - submit = gr.Button("Submit") + with gr.Column(scale=2, min_width=500): result_image = gr.Image(interactive=False, label="Result") with gr.Row(): diff --git a/model/pipeline.py b/model/pipeline.py index 9bbeaa1..b717dcd 100644 --- a/model/pipeline.py +++ b/model/pipeline.py @@ -1,26 +1,23 @@ import inspect import os from typing import Union -import PIL -from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler -from diffusers.utils.torch_utils import randn_tensor +import PIL +import numpy as np import torch import tqdm +from accelerate import load_checkpoint_in_model +from diffusers import AutoencoderKL, DDIMScheduler, UNet2DConditionModel +from diffusers.pipelines.stable_diffusion.safety_checker import \ + StableDiffusionSafetyChecker +from diffusers.utils.torch_utils import randn_tensor +from huggingface_hub import snapshot_download +from transformers import CLIPImageProcessor from model.attn_processor import SkipAttnProcessor from model.utils import get_trainable_module, init_adapter - -from accelerate import load_checkpoint_in_model -from huggingface_hub import hf_hub_download, snapshot_download -from utils import ( - compute_vae_encodings, - numpy_to_pil, - prepare_image, - prepare_mask_image, - resize_and_crop, - resize_and_padding, -) +from utils import (compute_vae_encodings, numpy_to_pil, prepare_image, + prepare_mask_image, resize_and_crop, resize_and_padding) class CatVTONPipeline: @@ -39,6 +36,8 @@ def __init__( self.noise_scheduler = DDIMScheduler.from_pretrained(base_ckpt, subfolder="scheduler") self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device, dtype=weight_dtype) + self.feature_extractor = CLIPImageProcessor.from_pretrained(base_ckpt, subfolder="feature_extractor") + self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(base_ckpt, subfolder="safety_checker").to(device, dtype=weight_dtype) self.unet = UNet2DConditionModel.from_pretrained(base_ckpt, subfolder="unet").to(device, dtype=weight_dtype) init_adapter(self.unet, cross_attn_cls=SkipAttnProcessor) # Skip Cross-Attention self.attn_modules = get_trainable_module(self.unet, "attention") @@ -66,7 +65,16 @@ def auto_attn_ckpt_load(self, attn_ckpt, version): print(f"Downloaded {attn_ckpt} to {repo_path}") load_checkpoint_in_model(self.attn_modules, os.path.join(repo_path, sub_folder, 'attention')) - + def run_safety_checker(self, image): + if self.safety_checker is None: + has_nsfw_concept = None + else: + safety_checker_input = self.feature_extractor(image, return_tensors="pt").to(self.device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(self.weight_dtype) + ) + return image, has_nsfw_concept + def check_inputs(self, image, condition_image, mask, width, height): if isinstance(image, torch.Tensor) and isinstance(condition_image, torch.Tensor) and isinstance(mask, torch.Tensor): return image, condition_image, mask @@ -190,4 +198,14 @@ def __call__( # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = numpy_to_pil(image) + + # Safety Check + current_script_directory = os.path.dirname(os.path.realpath(__file__)) + nsfw_image = os.path.join(os.path.dirname(current_script_directory), 'resource', 'img', 'NSFW.jpg') + nsfw_image = PIL.Image.open(nsfw_image).resize(image[0].size) + image_np = np.array(image) + _, has_nsfw_concept = self.run_safety_checker(image=image_np) + for i, not_safe in enumerate(has_nsfw_concept): + if not_safe: + image[i] = nsfw_image return image diff --git a/resource/demo/example/person/women/049713_0.jpg b/resource/demo/example/person/women/049713_0.jpg new file mode 100644 index 0000000..71732c4 Binary files /dev/null and b/resource/demo/example/person/women/049713_0.jpg differ diff --git a/resource/demo/example/person/women/Eva_0.png b/resource/demo/example/person/women/Eva_0.png deleted file mode 100644 index 8cc93d7..0000000 Binary files a/resource/demo/example/person/women/Eva_0.png and /dev/null differ diff --git a/resource/demo/example/person/women/Yaqi_0.png b/resource/demo/example/person/women/Yaqi_0.png deleted file mode 100644 index 66b7e23..0000000 Binary files a/resource/demo/example/person/women/Yaqi_0.png and /dev/null differ diff --git a/resource/demo/example/person/women/model_8.png b/resource/demo/example/person/women/model_8.png new file mode 100644 index 0000000..a11a417 Binary files /dev/null and b/resource/demo/example/person/women/model_8.png differ diff --git a/resource/img/NSFW.jpg b/resource/img/NSFW.jpg new file mode 100644 index 0000000..48d52c2 Binary files /dev/null and b/resource/img/NSFW.jpg differ