diff --git a/README.md b/README.md index 54b949e..b57a1b7 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,10 @@ docker compose -f docker/docker-compose.yaml up Update the [`docker-compose.yaml`](https://github.com/jhj0517/AdvancedLivePortrait-WebUI/blob/master/docker/docker-compose.yaml) to match your environment if you're not using an Nvidia GPU. -## ❀️ Citation and Thanks +### 🌐 Translation +Any PRs for language translation for [`translation.yaml`](https://github.com/jhj0517/AdvancedLivePortrait-WebUI/blob/master/i18n/translation.yaml) would be greatly appreciated! + +## ❀️ Acknowledgement 1. LivePortrait paper comes from ```bibtex @article{guo2024liveportrait, @@ -65,8 +68,6 @@ Update the [`docker-compose.yaml`](https://github.com/jhj0517/AdvancedLivePortra 2. The models are safetensors that have been converted by kijai. : https://github.com/kijai/ComfyUI-LivePortraitKJ 3. [ultralytics](https://github.com/ultralytics/ultralytics) is used to detect the face. 4. This WebUI is started from [ComfyUI-AdvancedLivePortrait](https://github.com/PowerHouseMan/ComfyUI-AdvancedLivePortrait), various facial expressions like AAA, EEE, Eyebrow, Wink are found by PowerHouseMan. - -### 🌐 Translation -Any PRs for language translation for [`translation.yaml`](https://github.com/jhj0517/AdvancedLivePortrait-WebUI/blob/master/i18n/translation.yaml) would be greatly appreciated! +5. [RealESRGAN](https://github.com/xinntao/Real-ESRGAN) is used for image restoration. diff --git a/app.py b/app.py index a7875d6..0f60d31 100644 --- a/app.py +++ b/app.py @@ -41,7 +41,9 @@ def create_expression_parameters(): gr.Slider(label=_("Sample Ratio"), minimum=-0.2, maximum=1.2, step=0.01, value=1, visible=False), gr.Dropdown(label=_("Sample Parts"), visible=False, choices=[part.value for part in SamplePart], value=SamplePart.ALL.value), - gr.Slider(label=_("Face Crop Factor"), minimum=1.5, maximum=2.5, step=0.1, value=2) + gr.Slider(label=_("Face Crop Factor"), minimum=1.5, maximum=2.5, step=0.1, value=2), + gr.Checkbox(label=_("Enable Image Restoration"), + info=_("This enables image restoration with RealESRGAN but slows down the speed"), value=False) ] @staticmethod @@ -53,6 +55,8 @@ def create_video_parameters(): gr.Slider(label=_("First frame eyes alignment factor"), minimum=0, maximum=1, step=0.01, value=1), gr.Slider(label=_("First frame mouth alignment factor"), minimum=0, maximum=1, step=0.01, value=1), gr.Slider(label=_("Face Crop Factor"), minimum=1.5, maximum=2.5, step=0.1, value=2), + gr.Checkbox(label=_("Enable Image Restoration"), + info=_("This enables image restoration with RealESRGAN but slows down the speed"), value=False) ] def launch(self): diff --git a/i18n/translation.yaml b/i18n/translation.yaml index 3681c03..90483ac 100644 --- a/i18n/translation.yaml +++ b/i18n/translation.yaml @@ -32,6 +32,8 @@ en: # English First frame mouth alignment factor: First frame mouth alignment factor First frame eyes alignment factor: First frame eyes alignment factor Face Crop Factor: Face Crop Factor + Enable Image Restoration: Enable Image Restoration + This enables image restoration with RealESRGAN but slows down the speed: This enables image restoration with RealESRGAN but slows down the speed ko: # Korean Language: μ–Έμ–΄ @@ -67,6 +69,8 @@ ko: # Korean First frame mouth alignment factor: 첫 ν”„λ ˆμž„ μž… 반영 λΉ„μœ¨ First frame eyes alignment factor: 첫 ν”„λ ˆμž„ 눈 반영 λΉ„μœ¨ Face Crop Factor: μ–Όκ΅΄ 크둭 λΉ„μœ¨ + Enable Image Restoration: ν™”μ§ˆ ν–₯상 + This enables image restoration with RealESRGAN but slows down the speed: RealESRGAN 으둜 ν™”μ§ˆμ„ ν–₯상 μ‹œν‚΅λ‹ˆλ‹€. μ†λ„λŠ” λŠλ €μ§‘λ‹ˆλ‹€. ja: # Japanese Language: 言θͺž @@ -102,6 +106,8 @@ ja: # Japanese First frame mouth alignment factor: First frame mouth alignment factor First frame eyes alignment factor: First frame eyes alignment factor Face Crop Factor: Face Crop Factor + Enable Image Restoration: Enable Image Restoration + This enables image restoration with RealESRGAN but slows down the speed: This enables image restoration with RealESRGAN but slows down the speed es: # Spanish Language: Idioma @@ -137,6 +143,8 @@ es: # Spanish First frame mouth alignment factor: First frame mouth alignment factor First frame eyes alignment factor: First frame eyes alignment factor Face Crop Factor: Face Crop Factor + Enable Image Restoration: Enable Image Restoration + This enables image restoration with RealESRGAN but slows down the speed: This enables image restoration with RealESRGAN but slows down the speed fr: # French Language: Langue @@ -172,6 +180,8 @@ fr: # French First frame mouth alignment factor: First frame mouth alignment factor First frame eyes alignment factor: First frame eyes alignment factor Face Crop Factor: Face Crop Factor + Enable Image Restoration: Enable Image Restoration + This enables image restoration with RealESRGAN but slows down the speed: This enables image restoration with RealESRGAN but slows down the speed de: # German Language: Sprache @@ -207,6 +217,8 @@ de: # German First frame mouth alignment factor: First frame mouth alignment factor First frame eyes alignment factor: First frame eyes alignment factor Face Crop Factor: Face Crop Factor + Enable Image Restoration: Enable Image Restoration + This enables image restoration with RealESRGAN but slows down the speed: This enables image restoration with RealESRGAN but slows down the speed zh: # Chinese Language: 语言 @@ -242,6 +254,8 @@ zh: # Chinese First frame mouth alignment factor: First frame mouth alignment factor First frame eyes alignment factor: First frame eyes alignment factor Face Crop Factor: Face Crop Factor + Enable Image Restoration: Enable Image Restoration + This enables image restoration with RealESRGAN but slows down the speed: This enables image restoration with RealESRGAN but slows down the speed uk: # Ukrainian Language: Мова @@ -277,6 +291,8 @@ uk: # Ukrainian First frame mouth alignment factor: First frame mouth alignment factor First frame eyes alignment factor: First frame eyes alignment factor Face Crop Factor: Face Crop Factor + Enable Image Restoration: Enable Image Restoration + This enables image restoration with RealESRGAN but slows down the speed: This enables image restoration with RealESRGAN but slows down the speed ru: # Russian Language: Π―Π·Ρ‹ΠΊ @@ -312,6 +328,8 @@ ru: # Russian First frame mouth alignment factor: First frame mouth alignment factor First frame eyes alignment factor: First frame eyes alignment factor Face Crop Factor: Face Crop Factor + Enable Image Restoration: Enable Image Restoration + This enables image restoration with RealESRGAN but slows down the speed: This enables image restoration with RealESRGAN but slows down the speed tr: # Turkish Language: Dil @@ -347,3 +365,5 @@ tr: # Turkish First frame mouth alignment factor: First frame mouth alignment factor First frame eyes alignment factor: First frame eyes alignment factor Face Crop Factor: Face Crop Factor + Enable Image Restoration: Enable Image Restoration + This enables image restoration with RealESRGAN but slows down the speed: This enables image restoration with RealESRGAN but slows down the speed diff --git a/modules/image_restoration/__init__.py b/modules/image_restoration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/image_restoration/real_esrgan/__init__.py b/modules/image_restoration/real_esrgan/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/image_restoration/real_esrgan/model_downloader.py b/modules/image_restoration/real_esrgan/model_downloader.py new file mode 100644 index 0000000..076da25 --- /dev/null +++ b/modules/image_restoration/real_esrgan/model_downloader.py @@ -0,0 +1,15 @@ +from modules.live_portrait.model_downloader import download_model + +MODELS_REALESRGAN_URL = { + "realesr-general-x4v3": "https://huggingface.co/jhj0517/realesr-general-x4v3/resolve/main/realesr-general-x4v3.pth", + "RealESRGAN_x2": "https://huggingface.co/ai-forever/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth", +} + +MODELS_REALESRGAN_SCALABILITY = { + "realesr-general-x4v3": [1, 2, 4], + "RealESRGAN_x2": [2] +} + + +def download_resrgan_model(file_path, url): + return download_model(file_path, url) diff --git a/modules/image_restoration/real_esrgan/real_esrgan_inferencer.py b/modules/image_restoration/real_esrgan/real_esrgan_inferencer.py new file mode 100644 index 0000000..e4562fc --- /dev/null +++ b/modules/image_restoration/real_esrgan/real_esrgan_inferencer.py @@ -0,0 +1,120 @@ +import os.path +import gradio as gr +import torch +import cv2 +from typing import Optional, Literal + +from modules.utils.paths import * +from modules.utils.image_helper import save_image +from .model_downloader import download_resrgan_model, MODELS_REALESRGAN_URL, MODELS_REALESRGAN_SCALABILITY +from .wrapper.rrdb_net import RRDBNet +from .wrapper.real_esrganer import RealESRGANer +from .wrapper.srvgg_net_compact import SRVGGNetCompact + + +class RealESRGANInferencer: + def __init__(self, + model_dir: str = MODELS_REAL_ESRGAN_DIR, + output_dir: str = OUTPUTS_DIR): + self.model_dir = model_dir + self.output_dir = output_dir + self.device = self.get_device() + self.arc = None + self.model = None + self.face_enhancer = None + + self.available_models = list(MODELS_REALESRGAN_URL.keys()) + self.default_model = self.available_models[0] + self.model_config = { + "model_name": self.default_model, + "scale": 1, + "half_precision": True + } + + def load_model(self, + model_name: Optional[str] = None, + scale: Literal[1, 2, 4] = 1, + half_precision: bool = True, + progress: gr.Progress = gr.Progress()): + model_config = { + "model_name": model_name, + "scale": scale, + "half_precision": half_precision + } + if model_config == self.model_config and self.model is not None: + return + else: + self.model_config = model_config + + if model_name is None: + model_name = self.default_model + + model_path = os.path.join(self.model_dir, model_name) + if not model_name.endswith(".pth"): + model_path += ".pth" + + if not os.path.exists(model_path): + progress(0, f"Downloading RealESRGAN model to : {model_path}") + download_resrgan_model(model_path, MODELS_REALESRGAN_URL[model_name]) + + name, ext = os.path.splitext(model_name) + assert scale in MODELS_REALESRGAN_SCALABILITY[name] + if name == 'RealESRGAN_x2': # x4 RRDBNet model + arc = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) + netscale = 4 + else: # x4 VGG-style model (S size) : "realesr-general-x4v3" + arc = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') + netscale = 4 + + self.model = RealESRGANer( + scale=netscale, + model_path=model_path, + model=arc, + half=half_precision, + device=torch.device(self.get_device()) + ) + + def restore_image(self, + img_path: str, + model_name: Optional[str] = None, + scale: int = 1, + half_precision: Optional[bool] = None, + overwrite: bool = True): + model_config = { + "model_name": self.model_config["model_name"], + "scale": scale, + "half_precision": half_precision + } + half_precision = True if self.device == "cuda" else False + + if self.model is None or self.model_config != model_config: + self.load_model( + model_name=self.default_model if model_name is None else model_name, + scale=scale, + half_precision=half_precision + ) + + try: + with torch.autocast(device_type=self.device, enabled=(self.device == "cuda")): + output, img_mode = self.model.enhance(img_path, outscale=scale) + if img_mode == "RGB": + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + + if overwrite: + output_path = img_path + else: + output_path = get_auto_incremental_file_path(self.output_dir, extension="png") + + output_path = save_image(output, output_path=output_path) + return output_path + except Exception as e: + raise + + @staticmethod + def get_device(): + if torch.cuda.is_available(): + return "cuda" + elif torch.backends.mps.is_available(): + return "mps" + else: + return "cpu" diff --git a/modules/image_restoration/real_esrgan/wrapper/__init__.py b/modules/image_restoration/real_esrgan/wrapper/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/image_restoration/real_esrgan/wrapper/real_esrganer.py b/modules/image_restoration/real_esrgan/wrapper/real_esrganer.py new file mode 100644 index 0000000..2fba524 --- /dev/null +++ b/modules/image_restoration/real_esrgan/wrapper/real_esrganer.py @@ -0,0 +1,303 @@ +import cv2 +import math +import numpy as np +import queue +import threading +import torch +from torch.nn import functional as F + + +class RealESRGANer(): + """A helper class for upsampling images with RealESRGAN. + + Args: + scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4. + model_path (str): The path to the pretrained model. It can be urls (will first download it automatically). + model (nn.Module): The defined network. Default: None. + tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop + input images into tiles, and then process each of them. Finally, they will be merged into one image. + 0 denotes for do not use tile. Default: 0. + tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10. + pre_pad (int): Pad the input images to avoid border artifacts. Default: 10. + half (float): Whether to use half precision during inference. Default: False. + """ + + def __init__(self, + scale, + model_path, + dni_weight=None, + model=None, + tile=0, + tile_pad=10, + pre_pad=10, + half=False, + device=None, + gpu_id=None): + self.scale = scale + self.tile_size = tile + self.tile_pad = tile_pad + self.pre_pad = pre_pad + self.mod_scale = None + self.half = half + + # initialize model + if gpu_id: + self.device = torch.device( + f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device + else: + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device + + loadnet = torch.load(model_path, map_location=torch.device('cpu')) + + # prefer to use params_ema + if 'params_ema' in loadnet: + keyname = 'params_ema' + else: + keyname = 'params' + model.load_state_dict(loadnet[keyname], strict=True) + + model.eval() + self.model = model.to(self.device) + if self.half: + self.model = self.model.half() + + def dni(self, net_a, net_b, dni_weight, key='params', loc='cpu'): + """Deep network interpolation. + + ``Paper: Deep Network Interpolation for Continuous Imagery Effect Transition`` + """ + net_a = torch.load(net_a, map_location=torch.device(loc)) + net_b = torch.load(net_b, map_location=torch.device(loc)) + for k, v_a in net_a[key].items(): + net_a[key][k] = dni_weight[0] * v_a + dni_weight[1] * net_b[key][k] + return net_a + + def pre_process(self, img): + """Pre-process, such as pre-pad and mod pad, so that the images can be divisible + """ + img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float() + self.img = img.unsqueeze(0).to(self.device) + if self.half: + self.img = self.img.half() + + # pre_pad + if self.pre_pad != 0: + self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect') + # mod pad for divisible borders + if self.scale == 2: + self.mod_scale = 2 + elif self.scale == 1: + self.mod_scale = 4 + if self.mod_scale is not None: + self.mod_pad_h, self.mod_pad_w = 0, 0 + _, _, h, w = self.img.size() + if (h % self.mod_scale != 0): + self.mod_pad_h = (self.mod_scale - h % self.mod_scale) + if (w % self.mod_scale != 0): + self.mod_pad_w = (self.mod_scale - w % self.mod_scale) + self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect') + + def process(self): + # model inference + self.output = self.model(self.img) + + def tile_process(self): + """It will first crop input images to tiles, and then process each tile. + Finally, all the processed tiles are merged into one images. + + Modified from: https://github.com/ata4/esrgan-launcher + """ + batch, channel, height, width = self.img.shape + output_height = height * self.scale + output_width = width * self.scale + output_shape = (batch, channel, output_height, output_width) + + # start with black image + self.output = self.img.new_zeros(output_shape) + tiles_x = math.ceil(width / self.tile_size) + tiles_y = math.ceil(height / self.tile_size) + + # loop over all tiles + for y in range(tiles_y): + for x in range(tiles_x): + # extract tile from input image + ofs_x = x * self.tile_size + ofs_y = y * self.tile_size + # input tile area on total image + input_start_x = ofs_x + input_end_x = min(ofs_x + self.tile_size, width) + input_start_y = ofs_y + input_end_y = min(ofs_y + self.tile_size, height) + + # input tile area on total image with padding + input_start_x_pad = max(input_start_x - self.tile_pad, 0) + input_end_x_pad = min(input_end_x + self.tile_pad, width) + input_start_y_pad = max(input_start_y - self.tile_pad, 0) + input_end_y_pad = min(input_end_y + self.tile_pad, height) + + # input tile dimensions + input_tile_width = input_end_x - input_start_x + input_tile_height = input_end_y - input_start_y + tile_idx = y * tiles_x + x + 1 + input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad] + + # upscale tile + try: + with torch.no_grad(): + output_tile = self.model(input_tile) + except RuntimeError as error: + print('Error', error) + print(f'\tTile {tile_idx}/{tiles_x * tiles_y}') + + # output tile area on total image + output_start_x = input_start_x * self.scale + output_end_x = input_end_x * self.scale + output_start_y = input_start_y * self.scale + output_end_y = input_end_y * self.scale + + # output tile area without padding + output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale + output_end_x_tile = output_start_x_tile + input_tile_width * self.scale + output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale + output_end_y_tile = output_start_y_tile + input_tile_height * self.scale + + # put tile into output image + self.output[:, :, output_start_y:output_end_y, + output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile, + output_start_x_tile:output_end_x_tile] + + def post_process(self): + # remove extra pad + if self.mod_scale is not None: + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale] + # remove prepad + if self.pre_pad != 0: + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale] + return self.output + + @torch.no_grad() + def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'): + if isinstance(img, str): + img = cv2.imread(img) + + h_input, w_input = img.shape[0:2] + # img: numpy + img = img.astype(np.float32) + if np.max(img) > 256: # 16-bit image + max_range = 65535 + print('\tInput is a 16-bit image') + else: + max_range = 255 + img = img / max_range + if len(img.shape) == 2: # gray image + img_mode = 'L' + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) + elif img.shape[2] == 4: # RGBA image with alpha channel + img_mode = 'RGBA' + alpha = img[:, :, 3] + img = img[:, :, 0:3] + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + if alpha_upsampler == 'realesrgan': + alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB) + else: + img_mode = 'RGB' + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + # ------------------- process image (without the alpha channel) ------------------- # + self.pre_process(img) + if self.tile_size > 0: + self.tile_process() + else: + self.process() + output_img = self.post_process() + output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0)) + if img_mode == 'L': + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) + + # ------------------- process the alpha channel if necessary ------------------- # + if img_mode == 'RGBA': + if alpha_upsampler == 'realesrgan': + self.pre_process(alpha) + if self.tile_size > 0: + self.tile_process() + else: + self.process() + output_alpha = self.post_process() + output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0)) + output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) + else: # use the cv2 resize for alpha channel + h, w = alpha.shape[0:2] + output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR) + + # merge the alpha channel + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA) + output_img[:, :, 3] = output_alpha + + # ------------------------------ return ------------------------------ # + if max_range == 65535: # 16-bit image + output = (output_img * 65535.0).round().astype(np.uint16) + else: + output = (output_img * 255.0).round().astype(np.uint8) + + if outscale is not None and outscale != float(self.scale): + output = cv2.resize( + output, ( + int(w_input * outscale), + int(h_input * outscale), + ), interpolation=cv2.INTER_LANCZOS4) + + return output, img_mode + + +class PrefetchReader(threading.Thread): + """Prefetch images. + + Args: + img_list (list[str]): A image list of image paths to be read. + num_prefetch_queue (int): Number of prefetch queue. + """ + + def __init__(self, img_list, num_prefetch_queue): + super().__init__() + self.que = queue.Queue(num_prefetch_queue) + self.img_list = img_list + + def run(self): + for img_path in self.img_list: + img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) + self.que.put(img) + + self.que.put(None) + + def __next__(self): + next_item = self.que.get() + if next_item is None: + raise StopIteration + return next_item + + def __iter__(self): + return self + + +class IOConsumer(threading.Thread): + + def __init__(self, opt, que, qid): + super().__init__() + self._queue = que + self.qid = qid + self.opt = opt + + def run(self): + while True: + msg = self._queue.get() + if isinstance(msg, str) and msg == 'quit': + break + + output = msg['output'] + save_path = msg['save_path'] + cv2.imwrite(save_path, output) + print(f'IO worker {self.qid} is done.') diff --git a/modules/image_restoration/real_esrgan/wrapper/rrdb_net.py b/modules/image_restoration/real_esrgan/wrapper/rrdb_net.py new file mode 100644 index 0000000..d0be16a --- /dev/null +++ b/modules/image_restoration/real_esrgan/wrapper/rrdb_net.py @@ -0,0 +1,182 @@ +from torch import nn as nn +import torch +from torch.nn import init as init +from torch.nn import functional as F +from torch.nn.modules.batchnorm import _BatchNorm + + +class ResidualDenseBlock(nn.Module): + """Residual Dense Block. + + Used in RRDB block in ESRGAN. + + Args: + num_feat (int): Channel number of intermediate features. + num_grow_ch (int): Channels for each growth. + """ + + def __init__(self, num_feat=64, num_grow_ch=32): + super(ResidualDenseBlock, self).__init__() + self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) + self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + # Empirically, we use 0.2 to scale the residual for better performance + return x5 * 0.2 + x + + +class RRDB(nn.Module): + """Residual in Residual Dense Block. + + Used in RRDB-Net in ESRGAN. + + Args: + num_feat (int): Channel number of intermediate features. + num_grow_ch (int): Channels for each growth. + """ + + def __init__(self, num_feat, num_grow_ch=32): + super(RRDB, self).__init__() + self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) + + def forward(self, x): + out = self.rdb1(x) + out = self.rdb2(out) + out = self.rdb3(out) + # Empirically, we use 0.2 to scale the residual for better performance + return out * 0.2 + x + + +class RRDBNet(nn.Module): + """Networks consisting of Residual in Residual Dense Block, which is used + in ESRGAN. + + ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. + + We extend ESRGAN for scale x2 and scale x1. + Note: This is one option for scale 1, scale 2 in RRDBNet. + We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size + and enlarge the channel size before feeding inputs into the main ESRGAN architecture. + + Args: + num_in_ch (int): Channel number of inputs. + num_out_ch (int): Channel number of outputs. + num_feat (int): Channel number of intermediate features. + Default: 64 + num_block (int): Block number in the trunk network. Defaults: 23 + num_grow_ch (int): Channels for each growth. Default: 32. + """ + + def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32): + super(RRDBNet, self).__init__() + self.scale = scale + if scale == 2: + num_in_ch = num_in_ch * 4 + elif scale == 1: + num_in_ch = num_in_ch * 16 + self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) + self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + # upsample + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + if self.scale == 2: + feat = pixel_unshuffle(x, scale=2) + elif self.scale == 1: + feat = pixel_unshuffle(x, scale=4) + else: + feat = x + feat = self.conv_first(feat) + body_feat = self.conv_body(self.body(feat)) + feat = feat + body_feat + # upsample + feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) + feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) + out = self.conv_last(self.lrelu(self.conv_hr(feat))) + return out + + +def make_layer(basic_block, num_basic_block, **kwarg): + """Make layers by stacking the same blocks. + + Args: + basic_block (nn.module): nn.module class for basic block. + num_basic_block (int): number of blocks. + + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + layers = [] + for _ in range(num_basic_block): + layers.append(basic_block(**kwarg)) + return nn.Sequential(*layers) + + +def pixel_unshuffle(x, scale): + """ Pixel unshuffle. + + Args: + x (Tensor): Input feature with shape (b, c, hh, hw). + scale (int): Downsample ratio. + + Returns: + Tensor: the pixel unshuffled feature. + """ + b, c, hh, hw = x.size() + out_channel = c * (scale**2) + assert hh % scale == 0 and hw % scale == 0 + h = hh // scale + w = hw // scale + x_view = x.view(b, c, h, scale, w, scale) + return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) + +@torch.no_grad() +def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): + """Initialize network weights. + + Args: + module_list (list[nn.Module] | nn.Module): Modules to be initialized. + scale (float): Scale initialized weights, especially for residual + blocks. Default: 1. + bias_fill (float): The value to fill bias. Default: 0 + kwargs (dict): Other arguments for initialization function. + """ + if not isinstance(module_list, list): + module_list = [module_list] + for module in module_list: + for m in module.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, _BatchNorm): + init.constant_(m.weight, 1) + if m.bias is not None: + m.bias.data.fill_(bias_fill) diff --git a/modules/image_restoration/real_esrgan/wrapper/srvgg_net_compact.py b/modules/image_restoration/real_esrgan/wrapper/srvgg_net_compact.py new file mode 100644 index 0000000..7882a1b --- /dev/null +++ b/modules/image_restoration/real_esrgan/wrapper/srvgg_net_compact.py @@ -0,0 +1,67 @@ +from torch import nn as nn +from torch.nn import functional as F + + +class SRVGGNetCompact(nn.Module): + """A compact VGG-style network structure for super-resolution. + + It is a compact network structure, which performs upsampling in the last layer and no convolution is + conducted on the HR feature space. + + Args: + num_in_ch (int): Channel number of inputs. Default: 3. + num_out_ch (int): Channel number of outputs. Default: 3. + num_feat (int): Channel number of intermediate features. Default: 64. + num_conv (int): Number of convolution layers in the body network. Default: 16. + upscale (int): Upsampling factor. Default: 4. + act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu. + """ + + def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'): + super(SRVGGNetCompact, self).__init__() + self.num_in_ch = num_in_ch + self.num_out_ch = num_out_ch + self.num_feat = num_feat + self.num_conv = num_conv + self.upscale = upscale + self.act_type = act_type + + self.body = nn.ModuleList() + # the first conv + self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)) + # the first activation + if act_type == 'relu': + activation = nn.ReLU(inplace=True) + elif act_type == 'prelu': + activation = nn.PReLU(num_parameters=num_feat) + elif act_type == 'leakyrelu': + activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.body.append(activation) + + # the body structure + for _ in range(num_conv): + self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1)) + # activation + if act_type == 'relu': + activation = nn.ReLU(inplace=True) + elif act_type == 'prelu': + activation = nn.PReLU(num_parameters=num_feat) + elif act_type == 'leakyrelu': + activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.body.append(activation) + + # the last conv + self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1)) + # upsample + self.upsampler = nn.PixelShuffle(upscale) + + def forward(self, x): + out = x + for i in range(0, len(self.body)): + out = self.body[i](out) + + out = self.upsampler(out) + # add the nearest upsampled image, so that the network learns the residual + base = F.interpolate(x, scale_factor=self.upscale, mode='nearest') + out += base + return out diff --git a/modules/live_portrait/live_portrait_inferencer.py b/modules/live_portrait/live_portrait_inferencer.py index 007be35..d915710 100644 --- a/modules/live_portrait/live_portrait_inferencer.py +++ b/modules/live_portrait/live_portrait_inferencer.py @@ -1,17 +1,11 @@ import logging -import os -import cv2 import time import copy import dill -import torch from ultralytics import YOLO import safetensors.torch import gradio as gr -from gradio_i18n import Translate, gettext as _ from ultralytics.utils import LOGGER as ultralytics_logger -from enum import Enum -from typing import Union, List, Dict, Tuple from modules.utils.paths import * from modules.utils.image_helper import * @@ -27,6 +21,7 @@ from modules.live_portrait.motion_extractor import MotionExtractor from modules.live_portrait.appearance_feature_extractor import AppearanceFeatureExtractor from modules.live_portrait.stitching_retargeting_network import StitchingRetargetingNetwork +from modules.image_restoration.real_esrgan.real_esrgan_inferencer import RealESRGANInferencer class LivePortraitInferencer: @@ -69,6 +64,11 @@ def __init__(self, self.psi_list = None self.d_info = None + self.resrgan_inferencer = RealESRGANInferencer( + model_dir=os.path.join(self.model_dir, "RealESRGAN"), + output_dir=self.output_dir + ) + def load_models(self, model_type: str = ModelType.HUMAN.value, progress=gr.Progress()): @@ -161,6 +161,7 @@ def edit_expression(self, sample_ratio: float = 1, sample_parts: str = SamplePart.ALL.value, crop_factor: float = 2.3, + enable_image_restoration: bool = False, src_image: Optional[str] = None, sample_image: Optional[str] = None,) -> None: if isinstance(model_type, ModelType): @@ -232,8 +233,11 @@ def edit_expression(self, out = np.clip(psi.mask_ori * crop_with_fullsize + (1 - psi.mask_ori) * psi.src_rgb, 0, 255).astype(np.uint8) temp_out_img_path, out_img_path = get_auto_incremental_file_path(TEMP_DIR, "png"), get_auto_incremental_file_path(OUTPUTS_DIR, "png") - save_image(numpy_array=crop_out, output_path=temp_out_img_path) - save_image(numpy_array=out, output_path=out_img_path) + cropped_out_img_path = save_image(numpy_array=crop_out, output_path=temp_out_img_path) + out_img_path = save_image(numpy_array=out, output_path=out_img_path) + + if enable_image_restoration: + out = self.resrgan_inferencer.restore_image(out_img_path) return out except Exception as e: @@ -244,6 +248,7 @@ def create_video(self, retargeting_eyes: float = 1, retargeting_mouth: float = 1, crop_factor: float = 2.3, + enable_image_restoration: bool = False, src_image: Optional[str] = None, driving_vid_path: Optional[str] = None, progress: gr.Progress = gr.Progress() @@ -317,11 +322,18 @@ def create_video(self, np.uint8) out_frame_path = get_auto_incremental_file_path(os.path.join(self.output_dir, "temp", "video_frames", "out"), "png") - save_image(out, out_frame_path) + out_frame_path = save_image(out, out_frame_path) + + if enable_image_restoration: + out_frame_path = self.resrgan_inferencer.restore_image(out_frame_path) progress(i/total_length, desc=f"Generating frames {i}/{total_length} ..") - video_path = create_video_from_frames(TEMP_VIDEO_OUT_FRAMES_DIR, frame_rate=vid_info.frame_rate, output_dir=os.path.join(self.output_dir, "videos")) + video_path = create_video_from_frames( + TEMP_VIDEO_OUT_FRAMES_DIR, + frame_rate=vid_info.frame_rate, + output_dir=os.path.join(self.output_dir, "videos") + ) return video_path except Exception as e: diff --git a/modules/utils/paths.py b/modules/utils/paths.py index 2f9e694..8d2fa53 100644 --- a/modules/utils/paths.py +++ b/modules/utils/paths.py @@ -2,9 +2,10 @@ import os -PROJECT_ROOT_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), "..", "..") +PROJECT_ROOT_DIR = os.path.normpath(os.path.join(os.path.abspath(os.path.dirname(__file__)), "..", "..")) MODELS_DIR = os.path.join(PROJECT_ROOT_DIR, "models") MODELS_ANIMAL_DIR = os.path.join(MODELS_DIR, "animal") +MODELS_REAL_ESRGAN_DIR = os.path.join(MODELS_DIR, "RealESRGAN") OUTPUTS_DIR = os.path.join(PROJECT_ROOT_DIR, "outputs") OUTPUTS_VIDEOS_DIR = os.path.join(OUTPUTS_DIR, "videos") TEMP_DIR = os.path.join(OUTPUTS_DIR, "temp") @@ -29,6 +30,9 @@ # Just animal detection model not the face, needs better model "yolo_v5s_animal_det": os.path.join(MODELS_ANIMAL_DIR, "yolo_v5s_animal_det.n2x") } +MODEL_REAL_ESRGAN_PATH = { + "realesr-general-x4v3": os.path.join(MODELS_REAL_ESRGAN_DIR, "realesr-general-x4v3.pth") +} MASK_TEMPLATES = os.path.join(PROJECT_ROOT_DIR, "modules", "utils", "resources", "mask_template.png") I18N_YAML_PATH = os.path.join(PROJECT_ROOT_DIR, "i18n", "translation.yaml") @@ -52,6 +56,7 @@ def init_dirs(): for dir_path in [ MODELS_DIR, MODELS_ANIMAL_DIR, + MODELS_REAL_ESRGAN_DIR, OUTPUTS_DIR, EXP_OUTPUT_DIR, TEMP_DIR, diff --git a/requirements.txt b/requirements.txt index 3e17359..30f32ee 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +# AdvancedLivePortrait --extra-index-url https://download.pytorch.org/whl/cu124 torch torchvision @@ -15,7 +16,6 @@ dill gradio gradio-i18n - # Tests # pytest # scikit-image diff --git a/tests/test_image_restoration.py b/tests/test_image_restoration.py new file mode 100644 index 0000000..5625aa6 --- /dev/null +++ b/tests/test_image_restoration.py @@ -0,0 +1,31 @@ +import os +import pytest + +from test_config import * +from modules.live_portrait.live_portrait_inferencer import LivePortraitInferencer + + +@pytest.mark.parametrize( + "input_image", + [ + TEST_IMAGE_PATH + ] +) +def test_image_restoration( + input_image: str, +): + if not os.path.exists(TEST_IMAGE_PATH): + download_image( + TEST_IMAGE_URL, + TEST_IMAGE_PATH + ) + + inferencer = LivePortraitInferencer() + + restored_output = inferencer.resrgan_inferencer.restore_image( + input_image, + overwrite=False + ) + + assert os.path.exists(restored_output) + assert are_images_different(input_image, restored_output)