Skip to content

Commit

Permalink
feat: refine and upgrade gradio (KwaiVGI#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzzweakman authored Jul 5, 2024
1 parent 6473a3b commit 293cb9e
Show file tree
Hide file tree
Showing 11 changed files with 167 additions and 172 deletions.
72 changes: 45 additions & 27 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
The entrance of the gradio
"""

import os
import os.path as osp
import gradio as gr
import tyro
import gradio as gr
import os.path as osp
from src.utils.helper import load_description
from src.gradio_pipeline import GradioPipeline
from src.config.crop_config import CropConfig
Expand Down Expand Up @@ -43,18 +42,24 @@ def partial_fields(target_class, kwargs):
[osp.join(example_portrait_dir, "s7.jpg"), osp.join(example_video_dir, "d7.mp4"), True, True, True],
]
#################### interface logic ####################

# Define components first
eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eye-close ratio")
lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-close ratio")
output_image = gr.Image(label="The animated image with the given eye-close and lip-close ratio.", type="numpy")
retargeting_input_image = gr.Image(type="numpy")
output_image = gr.Image( type="numpy")
output_image_paste_back = gr.Image(type="numpy")
output_video = gr.Video()
output_video_concat = gr.Video()

with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.HTML(load_description(title_md))
gr.Markdown(load_description("assets/gradio_description_upload.md"))
with gr.Row():
with gr.Accordion(open=True, label="Reference Portrait"):
image_input = gr.Image(label="Please upload the reference portrait here.", type="filepath")
image_input = gr.Image(type="filepath")
with gr.Accordion(open=True, label="Driving Video"):
video_input = gr.Video(label="Please upload the driving video here.")
video_input = gr.Video()
gr.Markdown(load_description("assets/gradio_description_animation.md"))
with gr.Row():
with gr.Accordion(open=True, label="Animation Options"):
Expand All @@ -63,16 +68,17 @@ def partial_fields(target_class, kwargs):
flag_remap_input = gr.Checkbox(value=True, label="paste-back")
flag_do_crop_input = gr.Checkbox(value=True, label="do crop")
with gr.Row():
process_button_animation = gr.Button("🚀 Animate", variant="primary")
with gr.Column():
process_button_animation = gr.Button("🚀 Animate", variant="primary")
with gr.Column():
process_button_reset = gr.ClearButton([image_input, video_input, output_video, output_video_concat], value="🧹 Clear")
with gr.Row():
with gr.Column():
with gr.Accordion(open=True, label="The animated video in the original image space"):
output_video = gr.Video(label="The animated video after pasted back.")
output_video.render()
with gr.Column():
with gr.Accordion(open=True, label="The animated video"):
output_video_concat = gr.Video(label="The animated video and driving video.")
with gr.Row():
process_button_reset = gr.ClearButton([image_input, video_input, output_video, output_video_concat], value="🧹 Clear")
output_video_concat.render()
with gr.Row():
# Examples
gr.Markdown("## You could choose the examples below ⬇️")
Expand All @@ -89,28 +95,36 @@ def partial_fields(target_class, kwargs):
examples_per_page=5
)
gr.Markdown(load_description("assets/gradio_description_retargeting.md"))
with gr.Row():
eye_retargeting_slider.render()
lip_retargeting_slider.render()
with gr.Row():
process_button_retargeting = gr.Button("🚗 Retargeting", variant="primary")
process_button_reset_retargeting = gr.ClearButton(
[
eye_retargeting_slider,
lip_retargeting_slider,
retargeting_input_image,
output_image,
output_image_paste_back
],
value="🧹 Clear"
)
with gr.Row():
with gr.Column():
process_button_close_ratio = gr.Button("🤖 Calculate the eye-close and lip-close ratio")
process_button_retargeting = gr.Button("🚗 Retargeting", variant="primary")
process_button_reset_retargeting = gr.ClearButton([output_image, eye_retargeting_slider, lip_retargeting_slider], value="🧹 Clear")
# with gr.Column():
eye_retargeting_slider.render()
lip_retargeting_slider.render()
with gr.Accordion(open=True, label="Retargeting Input"):
retargeting_input_image.render()
with gr.Column():
with gr.Accordion(open=True, label="Eye and lip Retargeting Result"):
with gr.Accordion(open=True, label="Retargeting Result"):
output_image.render()
with gr.Column():
with gr.Accordion(open=True, label="Paste-back Result"):
output_image_paste_back.render()
# binding functions for buttons
process_button_close_ratio.click(
fn=gradio_pipeline.prepare_retargeting,
inputs=image_input,
outputs=[eye_retargeting_slider, lip_retargeting_slider],
show_progress=True
)
process_button_retargeting.click(
fn=gradio_pipeline.execute_image,
inputs=[eye_retargeting_slider, lip_retargeting_slider],
outputs=output_image,
outputs=[output_image, output_image_paste_back],
show_progress=True
)
process_button_animation.click(
Expand All @@ -125,8 +139,12 @@ def partial_fields(target_class, kwargs):
outputs=[output_video, output_video_concat],
show_progress=True
)
process_button_reset.click()
process_button_reset_retargeting
image_input.change(
fn=gradio_pipeline.prepare_retargeting,
inputs=image_input,
outputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image]
)

##########################################################

demo.launch(
Expand Down
8 changes: 1 addition & 7 deletions assets/gradio_description_retargeting.md
Original file line number Diff line number Diff line change
@@ -1,7 +1 @@
<span style="font-size: 1.2em;">🔥 To change the target eye-close and lip-close ratio of the reference portrait, please:</span>
<div style="margin-left: 20px;">
<span style="font-size: 1.2em;">1. Please <strong>first</strong> press the <strong>🤖 Calculate the eye-close and lip-close ratio</strong> button, and wait for the result shown in the sliders.</span>
</div>
<div style="margin-left: 20px;">
<span style="font-size: 1.2em;">2. Please drag the sliders and then click the <strong>🚗 Retargeting</strong> button. Then the result would be shown in the middle block. You can try running it multiple times!</span>
</div>
<span style="font-size: 1.2em;">🔥 To change the target eye-close and lip-close ratio of the reference portrait, please drag the sliders and then click the <strong>🚗 Retargeting</strong> button. The result would be shown in the middle block. You can try running it multiple times. <strong>😊 Set both ratios to 0.8 to see what's going on!</strong> </span>
2 changes: 1 addition & 1 deletion assets/gradio_title.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<div>
<h1>LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control</h1>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;>
<a href=""><img src="https://img.shields.io/badge/arXiv-XXXX.XXXX-red"></a>
<a href="https://arxiv.org/pdf/2407.03168"><img src="https://img.shields.io/badge/arXiv-2407.03168-red"></a>
<a href="https://liveportrait.github.io"><img src="https://img.shields.io/badge/Project_Page-LivePortrait-green" alt="Project Page"></a>
<a href="https://github.com/KwaiVGI/LivePortrait"><img src="https://img.shields.io/badge/Github-Code-blue"></a>
</div>
Expand Down
144 changes: 90 additions & 54 deletions src/gradio_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
"""
Pipeline for gradio
"""

import gradio as gr
from .config.argument_config import ArgumentConfig
from .live_portrait_pipeline import LivePortraitPipeline
from .utils.io import load_img_online
from .utils.rprint import rlog as log
from .utils.crop import prepare_paste_back, paste_back
from .utils.camera import get_rotation_matrix
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
from .utils.rprint import rlog as log

def update_args(args, user_args):
"""update the args according to user inputs
Expand All @@ -26,10 +27,15 @@ def __init__(self, inference_cfg, crop_cfg, args: ArgumentConfig):
# self.live_portrait_wrapper = self.live_portrait_wrapper
self.args = args
# for single image retargeting
self.start_prepare = False
self.f_s_user = None
self.x_c_s_info_user = None
self.x_s_user = None
self.source_lmk_user = None
self.mask_ori = None
self.img_rgb = None
self.crop_M_c2o = None


def execute_video(
self,
Expand All @@ -41,64 +47,94 @@ def execute_video(
):
""" for video driven potrait animation
"""
args_user = {
'source_image': input_image_path,
'driving_info': input_video_path,
'flag_relative': flag_relative_input,
'flag_do_crop': flag_do_crop_input,
'flag_pasteback': flag_remap_input
}
# update config from user input
self.args = update_args(self.args, args_user)
self.live_portrait_wrapper.update_config(self.args.__dict__)
self.cropper.update_config(self.args.__dict__)
# video driven animation
video_path, video_path_concat = self.execute(self.args)
return video_path, video_path_concat
if input_image_path is not None and input_video_path is not None:
args_user = {
'source_image': input_image_path,
'driving_info': input_video_path,
'flag_relative': flag_relative_input,
'flag_do_crop': flag_do_crop_input,
'flag_pasteback': flag_remap_input
}
# update config from user input
self.args = update_args(self.args, args_user)
self.live_portrait_wrapper.update_config(self.args.__dict__)
self.cropper.update_config(self.args.__dict__)
# video driven animation
video_path, video_path_concat = self.execute(self.args)
gr.Info("Run successfully!", duration=2)
return video_path, video_path_concat,
else:
raise gr.Error("The input reference portrait or driving video hasn't been prepared yet 💥!", duration=5)

def execute_image(self, input_eye_ratio: float, input_lip_ratio: float):
""" for single image retargeting
"""
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], self.source_lmk_user)
eyes_delta = self.live_portrait_wrapper.retarget_eye(self.x_s_user, combined_eye_ratio_tensor)
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], self.source_lmk_user)
lip_delta = self.live_portrait_wrapper.retarget_lip(self.x_s_user, combined_lip_ratio_tensor)
num_kp = self.x_s_user.shape[1]
# default: use x_s
x_d_new = self.x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3)
# D(W(f_s; x_s, x′_d))
out = self.live_portrait_wrapper.warp_decode(self.f_s_user, self.x_s_user, x_d_new)
out = self.live_portrait_wrapper.parse_output(out['out'])[0]
return out
if input_eye_ratio is None or input_eye_ratio is None:
raise gr.Error("Invalid ratio input 💥!", duration=5)
elif self.f_s_user is None:
if self.start_prepare:
raise gr.Error(
"The reference portrait is under processing 💥! Please wait for a second.",
duration=5
)
else:
raise gr.Error(
"The reference portrait hasn't been prepared yet 💥! Please scroll to the top of the page to upload.",
duration=5
)
else:
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], self.source_lmk_user)
eyes_delta = self.live_portrait_wrapper.retarget_eye(self.x_s_user, combined_eye_ratio_tensor)
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], self.source_lmk_user)
lip_delta = self.live_portrait_wrapper.retarget_lip(self.x_s_user, combined_lip_ratio_tensor)
num_kp = self.x_s_user.shape[1]
# default: use x_s
x_d_new = self.x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3)
# D(W(f_s; x_s, x′_d))
out = self.live_portrait_wrapper.warp_decode(self.f_s_user, self.x_s_user, x_d_new)
out = self.live_portrait_wrapper.parse_output(out['out'])[0]
out_to_ori_blend = paste_back(out, self.crop_M_c2o, self.img_rgb, self.mask_ori)
gr.Info("Run successfully!", duration=2)
return out, out_to_ori_blend


def prepare_retargeting(self, input_image_path, flag_do_crop = True):
""" for single image retargeting
"""
inference_cfg = self.live_portrait_wrapper.cfg
######## process reference portrait ########
img_rgb = load_img_online(input_image_path, mode='rgb', max_dim=1280, n=16)
log(f"Load source image from {input_image_path}.")
crop_info = self.cropper.crop_single_image(img_rgb)
if flag_do_crop:
I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256'])
else:
I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
############################################

# record global info for next time use
self.f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
self.x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
self.x_s_info_user = x_s_info
self.source_lmk_user = crop_info['lmk_crop']
if input_image_path is not None:
gr.Info("Upload successfully!", duration=2)
self.start_prepare = True
inference_cfg = self.live_portrait_wrapper.cfg
######## process reference portrait ########
img_rgb = load_img_online(input_image_path, mode='rgb', max_dim=1280, n=16)
log(f"Load source image from {input_image_path}.")
crop_info = self.cropper.crop_single_image(img_rgb)
if flag_do_crop:
I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256'])
else:
I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
############################################

# update slider
eye_close_ratio = calc_eye_close_ratio(self.source_lmk_user[None])
eye_close_ratio = float(eye_close_ratio.squeeze(0).mean())
lip_close_ratio = calc_lip_close_ratio(self.source_lmk_user[None])
lip_close_ratio = float(lip_close_ratio.squeeze(0).mean())

return eye_close_ratio, lip_close_ratio
# record global info for next time use
self.f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
self.x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
self.x_s_info_user = x_s_info
self.source_lmk_user = crop_info['lmk_crop']
self.img_rgb = img_rgb
self.crop_M_c2o = crop_info['M_c2o']
self.mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
# update slider
eye_close_ratio = calc_eye_close_ratio(self.source_lmk_user[None])
eye_close_ratio = float(eye_close_ratio.squeeze(0).mean())
lip_close_ratio = calc_lip_close_ratio(self.source_lmk_user[None])
lip_close_ratio = float(lip_close_ratio.squeeze(0).mean())
# for vis
self.I_s_vis = self.live_portrait_wrapper.parse_output(I_s)[0]
return eye_close_ratio, lip_close_ratio, self.I_s_vis
else:
# when press the clear button, go here
return 0.8, 0.8, self.I_s_vis
15 changes: 5 additions & 10 deletions src/live_portrait_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
from .utils.cropper import Cropper
from .utils.camera import get_rotation_matrix
from .utils.video import images2video, concat_frames
from .utils.crop import _transform_img
from .utils.crop import _transform_img, prepare_paste_back, paste_back
from .utils.retargeting_utils import calc_lip_close_ratio
from .utils.io import load_image_rgb, load_driving_info
from .utils.helper import mkdir, basename, dct2cuda, is_video, is_template, resize_to_limit
from .utils.io import load_image_rgb, load_driving_info, resize_to_limit
from .utils.helper import mkdir, basename, dct2cuda, is_video, is_template
from .utils.rprint import rlog as log
from .live_portrait_wrapper import LivePortraitWrapper

Expand Down Expand Up @@ -90,10 +90,7 @@ def execute(self, args: ArgumentConfig):

######## prepare for pasteback ########
if inference_cfg.flag_pasteback:
if inference_cfg.mask_crop is None:
inference_cfg.mask_crop = cv2.imread(make_abs_path('./utils/resources/mask_template.png'), cv2.IMREAD_COLOR)
mask_ori = _transform_img(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
mask_ori = mask_ori.astype(np.float32) / 255.
mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
I_p_paste_lst = []
#########################################

Expand Down Expand Up @@ -172,9 +169,7 @@ def execute(self, args: ArgumentConfig):
I_p_lst.append(I_p_i)

if inference_cfg.flag_pasteback:
I_p_i_to_ori = _transform_img(I_p_i, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
I_p_i_to_ori_blend = np.clip(mask_ori * I_p_i_to_ori + (1 - mask_ori) * img_rgb, 0, 255).astype(np.uint8)
out = np.hstack([I_p_i_to_ori, I_p_i_to_ori_blend])
I_p_i_to_ori_blend = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori)
I_p_paste_lst.append(I_p_i_to_ori_blend)

mkdir(args.output_dir)
Expand Down
Loading

0 comments on commit 293cb9e

Please sign in to comment.