diff --git a/app.py b/app.py index 1998082c..44b0740c 100644 --- a/app.py +++ b/app.py @@ -4,10 +4,9 @@ The entrance of the gradio """ -import os -import os.path as osp -import gradio as gr import tyro +import gradio as gr +import os.path as osp from src.utils.helper import load_description from src.gradio_pipeline import GradioPipeline from src.config.crop_config import CropConfig @@ -43,18 +42,24 @@ def partial_fields(target_class, kwargs): [osp.join(example_portrait_dir, "s7.jpg"), osp.join(example_video_dir, "d7.mp4"), True, True, True], ] #################### interface logic #################### + # Define components first eye_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target eye-close ratio") lip_retargeting_slider = gr.Slider(minimum=0, maximum=0.8, step=0.01, label="target lip-close ratio") -output_image = gr.Image(label="The animated image with the given eye-close and lip-close ratio.", type="numpy") +retargeting_input_image = gr.Image(type="numpy") +output_image = gr.Image( type="numpy") +output_image_paste_back = gr.Image(type="numpy") +output_video = gr.Video() +output_video_concat = gr.Video() + with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.HTML(load_description(title_md)) gr.Markdown(load_description("assets/gradio_description_upload.md")) with gr.Row(): with gr.Accordion(open=True, label="Reference Portrait"): - image_input = gr.Image(label="Please upload the reference portrait here.", type="filepath") + image_input = gr.Image(type="filepath") with gr.Accordion(open=True, label="Driving Video"): - video_input = gr.Video(label="Please upload the driving video here.") + video_input = gr.Video() gr.Markdown(load_description("assets/gradio_description_animation.md")) with gr.Row(): with gr.Accordion(open=True, label="Animation Options"): @@ -63,16 +68,17 @@ def partial_fields(target_class, kwargs): flag_remap_input = gr.Checkbox(value=True, label="paste-back") flag_do_crop_input = gr.Checkbox(value=True, label="do crop") with gr.Row(): - process_button_animation = gr.Button("๐Ÿš€ Animate", variant="primary") + with gr.Column(): + process_button_animation = gr.Button("๐Ÿš€ Animate", variant="primary") + with gr.Column(): + process_button_reset = gr.ClearButton([image_input, video_input, output_video, output_video_concat], value="๐Ÿงน Clear") with gr.Row(): with gr.Column(): with gr.Accordion(open=True, label="The animated video in the original image space"): - output_video = gr.Video(label="The animated video after pasted back.") + output_video.render() with gr.Column(): with gr.Accordion(open=True, label="The animated video"): - output_video_concat = gr.Video(label="The animated video and driving video.") - with gr.Row(): - process_button_reset = gr.ClearButton([image_input, video_input, output_video, output_video_concat], value="๐Ÿงน Clear") + output_video_concat.render() with gr.Row(): # Examples gr.Markdown("## You could choose the examples below โฌ‡๏ธ") @@ -89,28 +95,36 @@ def partial_fields(target_class, kwargs): examples_per_page=5 ) gr.Markdown(load_description("assets/gradio_description_retargeting.md")) + with gr.Row(): + eye_retargeting_slider.render() + lip_retargeting_slider.render() + with gr.Row(): + process_button_retargeting = gr.Button("๐Ÿš— Retargeting", variant="primary") + process_button_reset_retargeting = gr.ClearButton( + [ + eye_retargeting_slider, + lip_retargeting_slider, + retargeting_input_image, + output_image, + output_image_paste_back + ], + value="๐Ÿงน Clear" + ) with gr.Row(): with gr.Column(): - process_button_close_ratio = gr.Button("๐Ÿค– Calculate the eye-close and lip-close ratio") - process_button_retargeting = gr.Button("๐Ÿš— Retargeting", variant="primary") - process_button_reset_retargeting = gr.ClearButton([output_image, eye_retargeting_slider, lip_retargeting_slider], value="๐Ÿงน Clear") - # with gr.Column(): - eye_retargeting_slider.render() - lip_retargeting_slider.render() + with gr.Accordion(open=True, label="Retargeting Input"): + retargeting_input_image.render() with gr.Column(): - with gr.Accordion(open=True, label="Eye and lip Retargeting Result"): + with gr.Accordion(open=True, label="Retargeting Result"): output_image.render() + with gr.Column(): + with gr.Accordion(open=True, label="Paste-back Result"): + output_image_paste_back.render() # binding functions for buttons - process_button_close_ratio.click( - fn=gradio_pipeline.prepare_retargeting, - inputs=image_input, - outputs=[eye_retargeting_slider, lip_retargeting_slider], - show_progress=True - ) process_button_retargeting.click( fn=gradio_pipeline.execute_image, inputs=[eye_retargeting_slider, lip_retargeting_slider], - outputs=output_image, + outputs=[output_image, output_image_paste_back], show_progress=True ) process_button_animation.click( @@ -125,8 +139,12 @@ def partial_fields(target_class, kwargs): outputs=[output_video, output_video_concat], show_progress=True ) - process_button_reset.click() - process_button_reset_retargeting + image_input.change( + fn=gradio_pipeline.prepare_retargeting, + inputs=image_input, + outputs=[eye_retargeting_slider, lip_retargeting_slider, retargeting_input_image] + ) + ########################################################## demo.launch( diff --git a/assets/gradio_description_retargeting.md b/assets/gradio_description_retargeting.md index 0a5dcbaf..5fe6ebf8 100644 --- a/assets/gradio_description_retargeting.md +++ b/assets/gradio_description_retargeting.md @@ -1,7 +1 @@ -๐Ÿ”ฅ To change the target eye-close and lip-close ratio of the reference portrait, please: -
- 1. Please first press the ๐Ÿค– Calculate the eye-close and lip-close ratio button, and wait for the result shown in the sliders. -
-
- 2. Please drag the sliders and then click the ๐Ÿš— Retargeting button. Then the result would be shown in the middle block. You can try running it multiple times! -
+๐Ÿ”ฅ To change the target eye-close and lip-close ratio of the reference portrait, please drag the sliders and then click the ๐Ÿš— Retargeting button. The result would be shown in the middle block. You can try running it multiple times. ๐Ÿ˜Š Set both ratios to 0.8 to see what's going on! diff --git a/assets/gradio_title.md b/assets/gradio_title.md index bf4bf2bb..e2b765e1 100644 --- a/assets/gradio_title.md +++ b/assets/gradio_title.md @@ -2,7 +2,7 @@

LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control

+ Project Page
diff --git a/src/gradio_pipeline.py b/src/gradio_pipeline.py index 45d661bf..d176cfc5 100644 --- a/src/gradio_pipeline.py +++ b/src/gradio_pipeline.py @@ -3,13 +3,14 @@ """ Pipeline for gradio """ - +import gradio as gr from .config.argument_config import ArgumentConfig from .live_portrait_pipeline import LivePortraitPipeline from .utils.io import load_img_online +from .utils.rprint import rlog as log +from .utils.crop import prepare_paste_back, paste_back from .utils.camera import get_rotation_matrix from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio -from .utils.rprint import rlog as log def update_args(args, user_args): """update the args according to user inputs @@ -26,10 +27,15 @@ def __init__(self, inference_cfg, crop_cfg, args: ArgumentConfig): # self.live_portrait_wrapper = self.live_portrait_wrapper self.args = args # for single image retargeting + self.start_prepare = False self.f_s_user = None self.x_c_s_info_user = None self.x_s_user = None self.source_lmk_user = None + self.mask_ori = None + self.img_rgb = None + self.crop_M_c2o = None + def execute_video( self, @@ -41,64 +47,94 @@ def execute_video( ): """ for video driven potrait animation """ - args_user = { - 'source_image': input_image_path, - 'driving_info': input_video_path, - 'flag_relative': flag_relative_input, - 'flag_do_crop': flag_do_crop_input, - 'flag_pasteback': flag_remap_input - } - # update config from user input - self.args = update_args(self.args, args_user) - self.live_portrait_wrapper.update_config(self.args.__dict__) - self.cropper.update_config(self.args.__dict__) - # video driven animation - video_path, video_path_concat = self.execute(self.args) - return video_path, video_path_concat + if input_image_path is not None and input_video_path is not None: + args_user = { + 'source_image': input_image_path, + 'driving_info': input_video_path, + 'flag_relative': flag_relative_input, + 'flag_do_crop': flag_do_crop_input, + 'flag_pasteback': flag_remap_input + } + # update config from user input + self.args = update_args(self.args, args_user) + self.live_portrait_wrapper.update_config(self.args.__dict__) + self.cropper.update_config(self.args.__dict__) + # video driven animation + video_path, video_path_concat = self.execute(self.args) + gr.Info("Run successfully!", duration=2) + return video_path, video_path_concat, + else: + raise gr.Error("The input reference portrait or driving video hasn't been prepared yet ๐Ÿ’ฅ!", duration=5) def execute_image(self, input_eye_ratio: float, input_lip_ratio: float): """ for single image retargeting """ - # โˆ†_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i) - combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], self.source_lmk_user) - eyes_delta = self.live_portrait_wrapper.retarget_eye(self.x_s_user, combined_eye_ratio_tensor) - # โˆ†_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i) - combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], self.source_lmk_user) - lip_delta = self.live_portrait_wrapper.retarget_lip(self.x_s_user, combined_lip_ratio_tensor) - num_kp = self.x_s_user.shape[1] - # default: use x_s - x_d_new = self.x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3) - # D(W(f_s; x_s, xโ€ฒ_d)) - out = self.live_portrait_wrapper.warp_decode(self.f_s_user, self.x_s_user, x_d_new) - out = self.live_portrait_wrapper.parse_output(out['out'])[0] - return out + if input_eye_ratio is None or input_eye_ratio is None: + raise gr.Error("Invalid ratio input ๐Ÿ’ฅ!", duration=5) + elif self.f_s_user is None: + if self.start_prepare: + raise gr.Error( + "The reference portrait is under processing ๐Ÿ’ฅ! Please wait for a second.", + duration=5 + ) + else: + raise gr.Error( + "The reference portrait hasn't been prepared yet ๐Ÿ’ฅ! Please scroll to the top of the page to upload.", + duration=5 + ) + else: + # โˆ†_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i) + combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], self.source_lmk_user) + eyes_delta = self.live_portrait_wrapper.retarget_eye(self.x_s_user, combined_eye_ratio_tensor) + # โˆ†_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i) + combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], self.source_lmk_user) + lip_delta = self.live_portrait_wrapper.retarget_lip(self.x_s_user, combined_lip_ratio_tensor) + num_kp = self.x_s_user.shape[1] + # default: use x_s + x_d_new = self.x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3) + # D(W(f_s; x_s, xโ€ฒ_d)) + out = self.live_portrait_wrapper.warp_decode(self.f_s_user, self.x_s_user, x_d_new) + out = self.live_portrait_wrapper.parse_output(out['out'])[0] + out_to_ori_blend = paste_back(out, self.crop_M_c2o, self.img_rgb, self.mask_ori) + gr.Info("Run successfully!", duration=2) + return out, out_to_ori_blend + def prepare_retargeting(self, input_image_path, flag_do_crop = True): """ for single image retargeting """ - inference_cfg = self.live_portrait_wrapper.cfg - ######## process reference portrait ######## - img_rgb = load_img_online(input_image_path, mode='rgb', max_dim=1280, n=16) - log(f"Load source image from {input_image_path}.") - crop_info = self.cropper.crop_single_image(img_rgb) - if flag_do_crop: - I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256']) - else: - I_s = self.live_portrait_wrapper.prepare_source(img_rgb) - x_s_info = self.live_portrait_wrapper.get_kp_info(I_s) - R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll']) - ############################################ - - # record global info for next time use - self.f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s) - self.x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info) - self.x_s_info_user = x_s_info - self.source_lmk_user = crop_info['lmk_crop'] + if input_image_path is not None: + gr.Info("Upload successfully!", duration=2) + self.start_prepare = True + inference_cfg = self.live_portrait_wrapper.cfg + ######## process reference portrait ######## + img_rgb = load_img_online(input_image_path, mode='rgb', max_dim=1280, n=16) + log(f"Load source image from {input_image_path}.") + crop_info = self.cropper.crop_single_image(img_rgb) + if flag_do_crop: + I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256']) + else: + I_s = self.live_portrait_wrapper.prepare_source(img_rgb) + x_s_info = self.live_portrait_wrapper.get_kp_info(I_s) + R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll']) + ############################################ - # update slider - eye_close_ratio = calc_eye_close_ratio(self.source_lmk_user[None]) - eye_close_ratio = float(eye_close_ratio.squeeze(0).mean()) - lip_close_ratio = calc_lip_close_ratio(self.source_lmk_user[None]) - lip_close_ratio = float(lip_close_ratio.squeeze(0).mean()) - - return eye_close_ratio, lip_close_ratio + # record global info for next time use + self.f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s) + self.x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info) + self.x_s_info_user = x_s_info + self.source_lmk_user = crop_info['lmk_crop'] + self.img_rgb = img_rgb + self.crop_M_c2o = crop_info['M_c2o'] + self.mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0])) + # update slider + eye_close_ratio = calc_eye_close_ratio(self.source_lmk_user[None]) + eye_close_ratio = float(eye_close_ratio.squeeze(0).mean()) + lip_close_ratio = calc_lip_close_ratio(self.source_lmk_user[None]) + lip_close_ratio = float(lip_close_ratio.squeeze(0).mean()) + # for vis + self.I_s_vis = self.live_portrait_wrapper.parse_output(I_s)[0] + return eye_close_ratio, lip_close_ratio, self.I_s_vis + else: + # when press the clear button, go here + return 0.8, 0.8, self.I_s_vis diff --git a/src/live_portrait_pipeline.py b/src/live_portrait_pipeline.py index 668e3a38..933c9118 100644 --- a/src/live_portrait_pipeline.py +++ b/src/live_portrait_pipeline.py @@ -20,10 +20,10 @@ from .utils.cropper import Cropper from .utils.camera import get_rotation_matrix from .utils.video import images2video, concat_frames -from .utils.crop import _transform_img +from .utils.crop import _transform_img, prepare_paste_back, paste_back from .utils.retargeting_utils import calc_lip_close_ratio -from .utils.io import load_image_rgb, load_driving_info -from .utils.helper import mkdir, basename, dct2cuda, is_video, is_template, resize_to_limit +from .utils.io import load_image_rgb, load_driving_info, resize_to_limit +from .utils.helper import mkdir, basename, dct2cuda, is_video, is_template from .utils.rprint import rlog as log from .live_portrait_wrapper import LivePortraitWrapper @@ -90,10 +90,7 @@ def execute(self, args: ArgumentConfig): ######## prepare for pasteback ######## if inference_cfg.flag_pasteback: - if inference_cfg.mask_crop is None: - inference_cfg.mask_crop = cv2.imread(make_abs_path('./utils/resources/mask_template.png'), cv2.IMREAD_COLOR) - mask_ori = _transform_img(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0])) - mask_ori = mask_ori.astype(np.float32) / 255. + mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0])) I_p_paste_lst = [] ######################################### @@ -172,9 +169,7 @@ def execute(self, args: ArgumentConfig): I_p_lst.append(I_p_i) if inference_cfg.flag_pasteback: - I_p_i_to_ori = _transform_img(I_p_i, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0])) - I_p_i_to_ori_blend = np.clip(mask_ori * I_p_i_to_ori + (1 - mask_ori) * img_rgb, 0, 255).astype(np.uint8) - out = np.hstack([I_p_i_to_ori, I_p_i_to_ori_blend]) + I_p_i_to_ori_blend = paste_back(I_p_i, crop_info['M_c2o'], img_rgb, mask_ori) I_p_paste_lst.append(I_p_i_to_ori_blend) mkdir(args.output_dir) diff --git a/src/live_portrait_wrapper.py b/src/live_portrait_wrapper.py index 2cb2eab6..ac3c63a1 100644 --- a/src/live_portrait_wrapper.py +++ b/src/live_portrait_wrapper.py @@ -12,7 +12,6 @@ from src.utils.timer import Timer from src.utils.helper import load_model, concat_feat -from src.utils.retargeting_utils import compute_eye_delta, compute_lip_delta from src.utils.camera import headpose_pred_to_degree, get_rotation_matrix from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio from src.config.inference_config import InferenceConfig @@ -211,33 +210,6 @@ def retarget_lip(self, kp_source: torch.Tensor, lip_close_ratio: torch.Tensor) - return delta - def retarget_keypoints(self, frame_idx, num_keypoints, input_eye_ratios, input_lip_ratios, source_landmarks, portrait_wrapper, kp_source, driving_transformed_kp): - # TODO: GPT style, refactor it... - if self.cfg.flag_eye_retargeting: - # โˆ†_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i) - eye_delta = compute_eye_delta(frame_idx, input_eye_ratios, source_landmarks, portrait_wrapper, kp_source) - else: - # ฮฑ_eyes = 0 - eye_delta = None - - if self.cfg.flag_lip_retargeting: - # โˆ†_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i) - lip_delta = compute_lip_delta(frame_idx, input_lip_ratios, source_landmarks, portrait_wrapper, kp_source) - else: - # ฮฑ_lip = 0 - lip_delta = None - - if self.cfg.flag_relative: # use x_s - new_driving_kp = kp_source + \ - (eye_delta.reshape(-1, num_keypoints, 3) if eye_delta is not None else 0) + \ - (lip_delta.reshape(-1, num_keypoints, 3) if lip_delta is not None else 0) - else: # use x_d,i - new_driving_kp = driving_transformed_kp + \ - (eye_delta.reshape(-1, num_keypoints, 3) if eye_delta is not None else 0) + \ - (lip_delta.reshape(-1, num_keypoints, 3) if lip_delta is not None else 0) - - return new_driving_kp - def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor: """ kp_source: BxNx3 diff --git a/src/utils/crop.py b/src/utils/crop.py index c061ef46..8f233639 100644 --- a/src/utils/crop.py +++ b/src/utils/crop.py @@ -4,14 +4,17 @@ cropping function and the related preprocess functions for cropping """ -import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) # NOTE: enforce single thread import numpy as np -from .rprint import rprint as print +import os.path as osp from math import sin, cos, acos, degrees +import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) # NOTE: enforce single thread +from .rprint import rprint as print DTYPE = np.float32 CV2_INTERP = cv2.INTER_LINEAR +def make_abs_path(fn): + return osp.join(osp.dirname(osp.realpath(__file__)), fn) def _transform_img(img, M, dsize, flags=CV2_INTERP, borderMode=None): """ conduct similarity or affine transformation to the image, do not do border operation! @@ -391,3 +394,19 @@ def average_bbox_lst(bbox_lst): bbox_arr = np.array(bbox_lst) return np.mean(bbox_arr, axis=0).tolist() +def prepare_paste_back(mask_crop, crop_M_c2o, dsize): + """prepare mask for later image paste back + """ + if mask_crop is None: + mask_crop = cv2.imread(make_abs_path('./resources/mask_template.png'), cv2.IMREAD_COLOR) + mask_ori = _transform_img(mask_crop, crop_M_c2o, dsize) + mask_ori = mask_ori.astype(np.float32) / 255. + return mask_ori + +def paste_back(image_to_processed, crop_M_c2o, rgb_ori, mask_ori): + """paste back the image + """ + dsize = (rgb_ori.shape[1], rgb_ori.shape[0]) + result = _transform_img(image_to_processed, crop_M_c2o, dsize=dsize) + result = np.clip(mask_ori * result + (1 - mask_ori) * rgb_ori, 0, 255).astype(np.uint8) + return result \ No newline at end of file diff --git a/src/utils/cropper.py b/src/utils/cropper.py index e8ee1943..d5d511c9 100644 --- a/src/utils/cropper.py +++ b/src/utils/cropper.py @@ -1,5 +1,6 @@ # coding: utf-8 +import gradio as gr import numpy as np import os.path as osp from typing import List, Union, Tuple @@ -72,6 +73,7 @@ def crop_single_image(self, obj, **kwargs): if len(src_face) == 0: log('No face detected in the source image.') + raise gr.Error("No face detected in the source image ๐Ÿ’ฅ!", duration=5) raise Exception("No face detected in the source image!") elif len(src_face) > 1: log(f'More than one face detected in the image, only pick one face by rule {direction}.') diff --git a/src/utils/helper.py b/src/utils/helper.py index 267f97f3..05c991ec 100644 --- a/src/utils/helper.py +++ b/src/utils/helper.py @@ -154,22 +154,3 @@ def load_description(fp): content = f.read() return content - -def resize_to_limit(img, max_dim=1280, n=2): - h, w = img.shape[:2] - if max_dim > 0 and max(h, w) > max_dim: - if h > w: - new_h = max_dim - new_w = int(w * (max_dim / h)) - else: - new_w = max_dim - new_h = int(h * (max_dim / w)) - img = cv2.resize(img, (new_w, new_h)) - n = max(n, 1) - new_h = img.shape[0] - (img.shape[0] % n) - new_w = img.shape[1] - (img.shape[1] % n) - if new_h == 0 or new_w == 0: - return img - if new_h != img.shape[0] or new_w != img.shape[1]: - img = img[:new_h, :new_w] - return img diff --git a/src/utils/io.py b/src/utils/io.py index f930c480..29a7e008 100644 --- a/src/utils/io.py +++ b/src/utils/io.py @@ -40,7 +40,7 @@ def contiguous(obj): return obj -def _resize_to_limit(img: np.ndarray, max_dim=1920, n=2): +def resize_to_limit(img: np.ndarray, max_dim=1920, n=2): """ ajust the size of the image so that the maximum dimension does not exceed max_dim, and the width and the height of the image are multiples of n. :param img: the image to be processed. @@ -87,7 +87,7 @@ def load_img_online(obj, mode="bgr", **kwargs): img = obj # Resize image to satisfy constraints - img = _resize_to_limit(img, max_dim=max_dim, n=n) + img = resize_to_limit(img, max_dim=max_dim, n=n) if mode.lower() == "bgr": return contiguous(img) diff --git a/src/utils/retargeting_utils.py b/src/utils/retargeting_utils.py index 2028590c..20a1bdd3 100644 --- a/src/utils/retargeting_utils.py +++ b/src/utils/retargeting_utils.py @@ -4,7 +4,6 @@ """ import numpy as np -import torch def calculate_distance_ratio(lmk: np.ndarray, idx1: int, idx2: int, idx3: int, idx4: int, eps: float = 1e-6) -> np.ndarray: @@ -53,24 +52,3 @@ def calc_lip_close_ratio(lmk: np.ndarray) -> np.ndarray: np.ndarray: Calculated lip-close ratio. """ return calculate_distance_ratio(lmk, 90, 102, 48, 66) - - -def compute_eye_delta(frame_idx, input_eye_ratios, source_landmarks, portrait_wrapper, kp_source): - input_eye_ratio = input_eye_ratios[frame_idx][0][0] - eye_close_ratio = calc_eye_close_ratio(source_landmarks[None]) - eye_close_ratio_tensor = torch.from_numpy(eye_close_ratio).float().cuda(portrait_wrapper.device_id) - input_eye_ratio_tensor = torch.Tensor([input_eye_ratio]).reshape(1, 1).cuda(portrait_wrapper.device_id) - combined_eye_ratio_tensor = torch.cat([eye_close_ratio_tensor, input_eye_ratio_tensor], dim=1) - # print(combined_eye_ratio_tensor.mean()) - eye_delta = portrait_wrapper.retarget_eye(kp_source, combined_eye_ratio_tensor) - return eye_delta - - -def compute_lip_delta(frame_idx, input_lip_ratios, source_landmarks, portrait_wrapper, kp_source): - input_lip_ratio = input_lip_ratios[frame_idx][0] - lip_close_ratio = calc_lip_close_ratio(source_landmarks[None]) - lip_close_ratio_tensor = torch.from_numpy(lip_close_ratio).float().cuda(portrait_wrapper.device_id) - input_lip_ratio_tensor = torch.Tensor([input_lip_ratio]).cuda(portrait_wrapper.device_id) - combined_lip_ratio_tensor = torch.cat([lip_close_ratio_tensor, input_lip_ratio_tensor], dim=1) - lip_delta = portrait_wrapper.retarget_lip(kp_source, combined_lip_ratio_tensor) - return lip_delta