diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..dfe0770 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +# Auto detect text files and perform LF normalization +* text=auto diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a81c8ee --- /dev/null +++ b/.gitignore @@ -0,0 +1,138 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..de92ae9 --- /dev/null +++ b/README.md @@ -0,0 +1,29 @@ +# BasicVSR++ +BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment + +Ported from https://github.com/open-mmlab/mmediting + + +## Dependencies +- [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit-archive), required by `mmcv-full` to compile CUDA ops. Install the same version as in `PyTorch`. +- [mmcv-full](https://github.com/open-mmlab/mmcv#installation) +- [NumPy](https://numpy.org/install) +- [PyTorch](https://pytorch.org/get-started), preferably with CUDA. Note that `torchaudio` is not required and hence can be omitted from the command. +- [VapourSynth](http://www.vapoursynth.com/) + + +## Installation +``` +pip install --upgrade vsbasicvsrpp +python -m vsbasicvsrpp +``` + + +## Usage +```python +from vsbasicvsrpp import BasicVSRPP + +ret = BasicVSRPP(clip) +``` + +See `__init__.py` for the description of the parameters. diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..9787c3b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..7510471 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,26 @@ +[metadata] +name = vsbasicvsrpp +version = 1.0.0 +author = HolyWu +description = BasicVSR++ function for VapourSynth +long_description = file: README.md +long_description_content_type = text/markdown +url = https://github.com/HolyWu/vs-basicvsrpp +classifiers = + License :: OSI Approved :: Apache Software License + Operating System :: OS Independent + Programming Language :: Python :: 3 + Programming Language :: Python :: 3 :: Only + Topic :: Multimedia :: Video + +[options] +zip_safe = False +packages = find: +python_requires = >=3.7 +install_requires = + mmcv-full + numpy + torch + +[options.package_data] +* = *.pth diff --git a/vsbasicvsrpp/__init__.py b/vsbasicvsrpp/__init__.py new file mode 100644 index 0000000..bf50490 --- /dev/null +++ b/vsbasicvsrpp/__init__.py @@ -0,0 +1,207 @@ +import math +import mmcv +import numpy as np +import os +import torch +import vapoursynth as vs +from .basicvsr import BasicVSR +from .basicvsr_pp import BasicVSRPlusPlus +from .builder import build_model + + +def BasicVSRPP(clip: vs.VideoNode, model: int=1, interval: int=30, tile_x: int=0, tile_y: int=0, tile_pad: int=16, + device_type: str='cuda', device_index: int=0, fp16: bool=False) -> vs.VideoNode: + ''' + BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment + + Support either x4 upsampling (for model 0-2) or same size output (for model 3-5). + For model 0-2, input resolution must be at least 64. + For model 3-5, input resolution must be at least 256 and mod-4. + + Parameters: + clip: Clip to process. Only planar format with float sample type of 32 bit depth is supported. + + model: Model to use. + 0 = REDS + 1 = Vimeo-90K (BI) + 2 = Vimeo-90K (BD) + 3 = NTIRE 2021 Quality enhancement of heavily compressed videos Challenge - Track 1 + 4 = NTIRE 2021 Quality enhancement of heavily compressed videos Challenge - Track 2 + 5 = NTIRE 2021 Quality enhancement of heavily compressed videos Challenge - Track 3 + + interval: Interval size. + + tile_x, tile_y: Tile width and height respectively, 0 for no tiling. + It's recommended that the input's width and height is divisible by the tile's width and height respectively. + Set it to the maximum value that your GPU supports to reduce its impact on the output. + + tile_pad: Tile padding. + + device_type: Device type on which the tensor is allocated. Must be 'cuda' or 'cpu'. + + device_index: Device ordinal for the device type. + + fp16: fp16 mode for faster and more lightweight inference on cards with Tensor Cores. + ''' + if not isinstance(clip, vs.VideoNode): + raise vs.Error('BasicVSR++: this is not a clip') + + if clip.format.id != vs.RGBS: + raise vs.Error('BasicVSR++: only RGBS format is supported') + + if model not in [0, 1, 2, 3, 4, 5]: + raise vs.Error('BasicVSR++: model must be 0, 1, 2, 3, 4, or 5') + + if interval < 1: + raise vs.Error('BasicVSR++: interval must be at least 1') + + device_type = device_type.lower() + + if device_type not in ['cuda', 'cpu']: + raise vs.Error("BasicVSR++: device_type must be 'cuda' or 'cpu'") + + if device_type == 'cuda' and not torch.cuda.is_available(): + raise vs.Error('BasicVSR++: CUDA is not available') + + device = torch.device(device_type, device_index) + if device_type == 'cuda': + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + + if model == 0: + model_name = 'basicvsr_plusplus_reds4.pth' + elif model == 1: + model_name = 'basicvsr_plusplus_vimeo90k_bi.pth' + elif model == 2: + model_name = 'basicvsr_plusplus_vimeo90k_bd.pth' + elif model == 3: + model_name = 'basicvsr_plusplus_ntire_decompress_track1.pth' + elif model == 4: + model_name = 'basicvsr_plusplus_ntire_decompress_track2.pth' + else: + model_name = 'basicvsr_plusplus_ntire_decompress_track3.pth' + model_path = os.path.join(os.path.dirname(__file__), model_name) + + if model < 3: + config_name = 'config012.py' + scale = 4 + else: + config_name = 'config345.py' + scale = 1 + + config = mmcv.Config.fromfile(os.path.join(os.path.dirname(__file__), config_name)) + + model = build_model(config.model) + mmcv.runner.load_checkpoint(model, model_path, strict=True) + model.to(device) + model.eval() + if fp16: + model.half() + + cache = {} + + def basicvsrpp(n: int, f: vs.VideoFrame) -> vs.VideoFrame: + nonlocal cache + + if str(n) not in cache.keys(): + cache.clear() + + imgs = [frame_to_tensor(f[0])] + for i in range(1, interval): + if (n + i) >= clip.num_frames: + break + imgs.append(frame_to_tensor(clip.get_frame(n + i))) + + imgs = torch.stack(imgs) + imgs = imgs.unsqueeze(0).to(device) + if fp16: + imgs = imgs.half() + + with torch.no_grad(): + if tile_x > 0 and tile_y > 0: + output = tile_process(imgs, scale, tile_x, tile_y, tile_pad, model) + else: + output = model(imgs) + + output = output.squeeze(0).detach().cpu().numpy() + for i in range(output.shape[0]): + cache[str(n + i)] = output[i, :, :, :] + + del imgs + torch.cuda.empty_cache() + + return ndarray_to_frame(cache[str(n)], f[1]) + + new_clip = clip.std.BlankClip(width=clip.width * scale, height=clip.height * scale) + return new_clip.std.ModifyFrame(clips=[clip, new_clip], selector=basicvsrpp) + + +def frame_to_tensor(f: vs.VideoFrame) -> torch.Tensor: + arr = np.stack([np.asarray(f.get_read_array(plane)) for plane in range(f.format.num_planes)]) + return torch.from_numpy(arr) + + +def ndarray_to_frame(arr: np.ndarray, f: vs.VideoFrame) -> vs.VideoFrame: + fout = f.copy() + for plane in range(fout.format.num_planes): + np.copyto(np.asarray(fout.get_write_array(plane)), arr[plane, :, :]) + return fout + + +def tile_process(img: torch.Tensor, scale: int, tile_x: int, tile_y: int, tile_pad: int, model: BasicVSR) -> torch.Tensor: + batch, num, channel, height, width = img.shape + output_height = height * scale + output_width = width * scale + output_shape = (batch, num, channel, output_height, output_width) + + # start with black image + output = img.new_zeros(output_shape) + + tiles_x = math.ceil(width / tile_x) + tiles_y = math.ceil(height / tile_y) + + # loop over all tiles + for y in range(tiles_y): + for x in range(tiles_x): + # extract tile from input image + ofs_x = x * tile_x + ofs_y = y * tile_y + + # input tile area on total image + input_start_x = ofs_x + input_end_x = min(ofs_x + tile_x, width) + input_start_y = ofs_y + input_end_y = min(ofs_y + tile_y, height) + + # input tile area on total image with padding + input_start_x_pad = max(input_start_x - tile_pad, 0) + input_end_x_pad = min(input_end_x + tile_pad, width) + input_start_y_pad = max(input_start_y - tile_pad, 0) + input_end_y_pad = min(input_end_y + tile_pad, height) + + # input tile dimensions + input_tile_width = input_end_x - input_start_x + input_tile_height = input_end_y - input_start_y + + input_tile = img[:, :, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad] + + # upscale tile + output_tile = model(input_tile) + + # output tile area on total image + output_start_x = input_start_x * scale + output_end_x = input_end_x * scale + output_start_y = input_start_y * scale + output_end_y = input_end_y * scale + + # output tile area without padding + output_start_x_tile = (input_start_x - input_start_x_pad) * scale + output_end_x_tile = output_start_x_tile + input_tile_width * scale + output_start_y_tile = (input_start_y - input_start_y_pad) * scale + output_end_y_tile = output_start_y_tile + input_tile_height * scale + + # put tile into output image + output[:, :, :, output_start_y:output_end_y, output_start_x:output_end_x] = \ + output_tile[:, :, :, output_start_y_tile:output_end_y_tile, output_start_x_tile:output_end_x_tile] + + return output diff --git a/vsbasicvsrpp/__main__.py b/vsbasicvsrpp/__main__.py new file mode 100644 index 0000000..c151ac4 --- /dev/null +++ b/vsbasicvsrpp/__main__.py @@ -0,0 +1,21 @@ +import os +import requests +from tqdm import tqdm + +def download_model(url: str) -> None: + filename = url.split('/')[-1] + r = requests.get(url, stream=True) + with open(os.path.join(os.path.dirname(__file__), filename), 'wb') as f: + with tqdm(unit='B', unit_scale=True, unit_divisor=1024, miniters=1, desc=filename, total=int(r.headers.get('content-length', 0))) as pbar: + for chunk in r.iter_content(chunk_size=4096): + f.write(chunk) + pbar.update(len(chunk)) + +if __name__ == '__main__': + download_model('https://github.com/HolyWu/vs-basicvsrpp/releases/download/model/basicvsr_plusplus_ntire_decompress_track1.pth') + download_model('https://github.com/HolyWu/vs-basicvsrpp/releases/download/model/basicvsr_plusplus_ntire_decompress_track2.pth') + download_model('https://github.com/HolyWu/vs-basicvsrpp/releases/download/model/basicvsr_plusplus_ntire_decompress_track3.pth') + download_model('https://github.com/HolyWu/vs-basicvsrpp/releases/download/model/basicvsr_plusplus_reds4.pth') + download_model('https://github.com/HolyWu/vs-basicvsrpp/releases/download/model/basicvsr_plusplus_vimeo90k_bd.pth') + download_model('https://github.com/HolyWu/vs-basicvsrpp/releases/download/model/basicvsr_plusplus_vimeo90k_bi.pth') + download_model('https://github.com/HolyWu/vs-basicvsrpp/releases/download/model/spynet.pth') diff --git a/vsbasicvsrpp/basicvsr.py b/vsbasicvsrpp/basicvsr.py new file mode 100644 index 0000000..4662f54 --- /dev/null +++ b/vsbasicvsrpp/basicvsr.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + +from .builder import build_backbone +from .registry import MODELS + + +@MODELS.register_module() +class BasicVSR(nn.Module): + """BasicVSR model for video super-resolution. + + Note that this model is used for IconVSR. + + Paper: + BasicVSR: The Search for Essential Components in Video Super-Resolution + and Beyond, CVPR, 2021 + + Args: + generator (dict): Config for the generator structure. + """ + + def __init__(self, generator): + super().__init__() + + # generator + self.generator = build_backbone(generator) + + # count training steps + self.register_buffer('step_counter', torch.zeros(1)) + + def forward(self, lq): + """Testing forward function. + + Args: + lq (Tensor): LQ Tensor with shape (n, t, c, h, w). + + Returns: + dict: Output results. + """ + with torch.no_grad(): + output = self.generator(lq) + + return output diff --git a/vsbasicvsrpp/basicvsr_net.py b/vsbasicvsrpp/basicvsr_net.py new file mode 100644 index 0000000..1d05fc6 --- /dev/null +++ b/vsbasicvsrpp/basicvsr_net.py @@ -0,0 +1,245 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmcv.runner import load_checkpoint + +from .flow_warp import flow_warp +from .sr_backbone_utils import ResidualBlockNoBN, make_layer + + +class ResidualBlocksWithInputConv(nn.Module): + """Residual blocks with a convolution in front. + + Args: + in_channels (int): Number of input channels of the first conv. + out_channels (int): Number of channels of the residual blocks. + Default: 64. + num_blocks (int): Number of residual blocks. Default: 30. + """ + + def __init__(self, in_channels, out_channels=64, num_blocks=30): + super().__init__() + + main = [] + + # a convolution used to match the channels of the residual blocks + main.append(nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=True)) + main.append(nn.LeakyReLU(negative_slope=0.1, inplace=True)) + + # residual blocks + main.append( + make_layer( + ResidualBlockNoBN, num_blocks, mid_channels=out_channels)) + + self.main = nn.Sequential(*main) + + def forward(self, feat): + """ + Forward function for ResidualBlocksWithInputConv. + + Args: + feat (Tensor): Input feature with shape (n, in_channels, h, w) + + Returns: + Tensor: Output feature with shape (n, out_channels, h, w) + """ + return self.main(feat) + + +class SPyNet(nn.Module): + """SPyNet network structure. + + The difference to the SPyNet in [tof.py] is that + 1. more SPyNetBasicModule is used in this version, and + 2. no batch normalization is used in this version. + + Paper: + Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017 + + Args: + pretrained (str): path for pre-trained SPyNet. Default: None. + """ + + def __init__(self, pretrained): + super().__init__() + + self.basic_module = nn.ModuleList( + [SPyNetBasicModule() for _ in range(6)]) + + if isinstance(pretrained, str): + load_checkpoint(self, pretrained, strict=True) + elif pretrained is not None: + raise TypeError('[pretrained] should be str or None, ' + f'but got {type(pretrained)}.') + + self.register_buffer( + 'mean', + torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + self.register_buffer( + 'std', + torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + def compute_flow(self, ref, supp): + """Compute flow from ref to supp. + + Note that in this function, the images are already resized to a + multiple of 32. + + Args: + ref (Tensor): Reference image with shape of (n, 3, h, w). + supp (Tensor): Supporting image with shape of (n, 3, h, w). + + Returns: + Tensor: Estimated optical flow: (n, 2, h, w). + """ + n, _, h, w = ref.size() + + # normalize the input images + ref = [(ref - self.mean) / self.std] + supp = [(supp - self.mean) / self.std] + + # generate downsampled frames + for level in range(5): + ref.append( + F.avg_pool2d( + input=ref[-1], + kernel_size=2, + stride=2, + count_include_pad=False)) + supp.append( + F.avg_pool2d( + input=supp[-1], + kernel_size=2, + stride=2, + count_include_pad=False)) + ref = ref[::-1] + supp = supp[::-1] + + # flow computation + flow = ref[0].new_zeros(n, 2, h // 32, w // 32) + for level in range(len(ref)): + if level == 0: + flow_up = flow + else: + flow_up = F.interpolate( + input=flow, + scale_factor=2, + mode='bilinear', + align_corners=True) * 2.0 + + # add the residue to the upsampled flow + flow = flow_up + self.basic_module[level]( + torch.cat([ + ref[level], + flow_warp( + supp[level], + flow_up.permute(0, 2, 3, 1), + padding_mode='border'), flow_up + ], 1)) + + return flow + + def forward(self, ref, supp): + """Forward function of SPyNet. + + This function computes the optical flow from ref to supp. + + Args: + ref (Tensor): Reference image with shape of (n, 3, h, w). + supp (Tensor): Supporting image with shape of (n, 3, h, w). + + Returns: + Tensor: Estimated optical flow: (n, 2, h, w). + """ + + # upsize to a multiple of 32 + h, w = ref.shape[2:4] + w_up = w if (w % 32) == 0 else 32 * (w // 32 + 1) + h_up = h if (h % 32) == 0 else 32 * (h // 32 + 1) + ref = F.interpolate( + input=ref, size=(h_up, w_up), mode='bilinear', align_corners=False) + supp = F.interpolate( + input=supp, + size=(h_up, w_up), + mode='bilinear', + align_corners=False) + + # compute flow, and resize back to the original resolution + flow = F.interpolate( + input=self.compute_flow(ref, supp), + size=(h, w), + mode='bilinear', + align_corners=False) + + # adjust the flow values + flow[:, 0, :, :] *= float(w) / float(w_up) + flow[:, 1, :, :] *= float(h) / float(h_up) + + return flow + + +class SPyNetBasicModule(nn.Module): + """Basic Module for SPyNet. + + Paper: + Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017 + """ + + def __init__(self): + super().__init__() + + self.basic_module = nn.Sequential( + ConvModule( + in_channels=8, + out_channels=32, + kernel_size=7, + stride=1, + padding=3, + norm_cfg=None, + act_cfg=dict(type='ReLU')), + ConvModule( + in_channels=32, + out_channels=64, + kernel_size=7, + stride=1, + padding=3, + norm_cfg=None, + act_cfg=dict(type='ReLU')), + ConvModule( + in_channels=64, + out_channels=32, + kernel_size=7, + stride=1, + padding=3, + norm_cfg=None, + act_cfg=dict(type='ReLU')), + ConvModule( + in_channels=32, + out_channels=16, + kernel_size=7, + stride=1, + padding=3, + norm_cfg=None, + act_cfg=dict(type='ReLU')), + ConvModule( + in_channels=16, + out_channels=2, + kernel_size=7, + stride=1, + padding=3, + norm_cfg=None, + act_cfg=None)) + + def forward(self, tensor_input): + """ + Args: + tensor_input (Tensor): Input tensor with shape (b, 8, h, w). + 8 channels contain: + [reference image (3), neighbor image (3), initial flow (2)]. + + Returns: + Tensor: Refined flow with shape (b, 2, h, w) + """ + return self.basic_module(tensor_input) diff --git a/vsbasicvsrpp/basicvsr_plusplus_ntire_decompress_track1.pth b/vsbasicvsrpp/basicvsr_plusplus_ntire_decompress_track1.pth new file mode 100644 index 0000000..e69de29 diff --git a/vsbasicvsrpp/basicvsr_plusplus_ntire_decompress_track2.pth b/vsbasicvsrpp/basicvsr_plusplus_ntire_decompress_track2.pth new file mode 100644 index 0000000..e69de29 diff --git a/vsbasicvsrpp/basicvsr_plusplus_ntire_decompress_track3.pth b/vsbasicvsrpp/basicvsr_plusplus_ntire_decompress_track3.pth new file mode 100644 index 0000000..e69de29 diff --git a/vsbasicvsrpp/basicvsr_plusplus_reds4.pth b/vsbasicvsrpp/basicvsr_plusplus_reds4.pth new file mode 100644 index 0000000..e69de29 diff --git a/vsbasicvsrpp/basicvsr_plusplus_vimeo90k_bd.pth b/vsbasicvsrpp/basicvsr_plusplus_vimeo90k_bd.pth new file mode 100644 index 0000000..e69de29 diff --git a/vsbasicvsrpp/basicvsr_plusplus_vimeo90k_bi.pth b/vsbasicvsrpp/basicvsr_plusplus_vimeo90k_bi.pth new file mode 100644 index 0000000..e69de29 diff --git a/vsbasicvsrpp/basicvsr_pp.py b/vsbasicvsrpp/basicvsr_pp.py new file mode 100644 index 0000000..a952ad4 --- /dev/null +++ b/vsbasicvsrpp/basicvsr_pp.py @@ -0,0 +1,401 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import constant_init +from mmcv.ops import ModulatedDeformConv2d, modulated_deform_conv2d + +from .basicvsr_net import ResidualBlocksWithInputConv, SPyNet +from .flow_warp import flow_warp +from .registry import BACKBONES +from .upsample import PixelShufflePack + + +@BACKBONES.register_module() +class BasicVSRPlusPlus(nn.Module): + """BasicVSR++ network structure. + + Support either x4 upsampling or same size output. Since DCN is used in this + model, it can only be used with CUDA enabled. If CUDA is not enabled, + feature alignment will be skipped. + + Paper: + BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation + and Alignment + + Args: + mid_channels (int, optional): Channel number of the intermediate + features. Default: 64. + num_blocks (int, optional): The number of residual blocks in each + propagation branch. Default: 7. + max_residue_magnitude (int): The maximum magnitude of the offset + residue (Eq. 6 in paper). Default: 10. + is_low_res_input (bool, optional): Whether the input is low-resolution + or not. If False, the output resolution is equal to the input + resolution. Default: True. + spynet_pretrained (str, optional): Pre-trained model path of SPyNet. + Default: None. + cpu_cache_length (int, optional): When the length of sequence is larger + than this value, the intermediate features are sent to CPU. This + saves GPU memory, but slows down the inference speed. You can + increase this number if you have a GPU with large memory. + Default: 100. + """ + + def __init__(self, + mid_channels=64, + num_blocks=7, + max_residue_magnitude=10, + is_low_res_input=True, + spynet_pretrained=None, + cpu_cache_length=100): + + super().__init__() + self.mid_channels = mid_channels + self.is_low_res_input = is_low_res_input + self.cpu_cache_length = cpu_cache_length + + # optical flow + self.spynet = SPyNet(pretrained=spynet_pretrained) + + # feature extraction module + if is_low_res_input: + self.feat_extract = ResidualBlocksWithInputConv(3, mid_channels, 5) + else: + self.feat_extract = nn.Sequential( + nn.Conv2d(3, mid_channels, 3, 2, 1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(mid_channels, mid_channels, 3, 2, 1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + ResidualBlocksWithInputConv(mid_channels, mid_channels, 5)) + + # propagation branches + self.deform_align = nn.ModuleDict() + self.backbone = nn.ModuleDict() + modules = ['backward_1', 'forward_1', 'backward_2', 'forward_2'] + for i, module in enumerate(modules): + if torch.cuda.is_available(): + self.deform_align[module] = SecondOrderDeformableAlignment( + 2 * mid_channels, + mid_channels, + 3, + padding=1, + deform_groups=16, + max_residue_magnitude=max_residue_magnitude) + self.backbone[module] = ResidualBlocksWithInputConv( + (2 + i) * mid_channels, mid_channels, num_blocks) + + # upsampling module + self.reconstruction = ResidualBlocksWithInputConv( + 5 * mid_channels, mid_channels, 5) + self.upsample1 = PixelShufflePack( + mid_channels, mid_channels, 2, upsample_kernel=3) + self.upsample2 = PixelShufflePack( + mid_channels, 64, 2, upsample_kernel=3) + self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1) + self.conv_last = nn.Conv2d(64, 3, 3, 1, 1) + self.img_upsample = nn.Upsample( + scale_factor=4, mode='bilinear', align_corners=False) + + # activation function + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + if len(self.deform_align) > 0: + self.is_with_alignment = True + else: + self.is_with_alignment = False + warnings.warn( + 'Deformable alignment module is not added. ' + 'Probably your CUDA is not configured correctly. DCN can only ' + 'be used with CUDA enabled. Alignment is skipped now.') + + def compute_flow(self, lqs): + """Compute optical flow using SPyNet for feature alignment. + + Note that if the input is an mirror-extended sequence, 'flows_forward' + is not needed, since it is equal to 'flows_backward.flip(1)'. + + Args: + lqs (tensor): Input low quality (LQ) sequence with + shape (n, t, c, h, w). + + Return: + tuple(Tensor): Optical flow. 'flows_forward' corresponds to the + flows used for forward-time propagation (current to previous). + 'flows_backward' corresponds to the flows used for + backward-time propagation (current to next). + """ + + n, t, c, h, w = lqs.size() + lqs_1 = lqs[:, :-1, :, :, :].reshape(-1, c, h, w) + lqs_2 = lqs[:, 1:, :, :, :].reshape(-1, c, h, w) + + flows_backward = self.spynet(lqs_1, lqs_2).view(n, t - 1, 2, h, w) + flows_forward = self.spynet(lqs_2, lqs_1).view(n, t - 1, 2, h, w) + + if self.cpu_cache: + flows_backward = flows_backward.cpu() + flows_forward = flows_forward.cpu() + + return flows_forward, flows_backward + + def propagate(self, feats, flows, module_name): + """Propagate the latent features throughout the sequence. + + Args: + feats dict(list[tensor]): Features from previous branches. Each + component is a list of tensors with shape (n, c, h, w). + flows (tensor): Optical flows with shape (n, t - 1, 2, h, w). + module_name (str): The name of the propgation branches. Can either + be 'backward_1', 'forward_1', 'backward_2', 'forward_2'. + + Return: + dict(list[tensor]): A dictionary containing all the propgated + features. Each key in the dictionary corresponds to a + propagation branch, which is represented by a list of tensors. + """ + + n, t, _, h, w = flows.size() + + frame_idx = range(0, t + 1) + flow_idx = range(-1, t) + mapping_idx = list(range(0, len(feats['spatial']))) + mapping_idx += mapping_idx[::-1] + + if 'backward' in module_name: + frame_idx = frame_idx[::-1] + flow_idx = frame_idx + + feat_prop = flows.new_zeros(n, self.mid_channels, h, w) + for i, idx in enumerate(frame_idx): + feat_current = feats['spatial'][mapping_idx[idx]] + if self.cpu_cache: + feat_current = feat_current.cuda() + feat_prop = feat_prop.cuda() + # second-order deformable alignment + if i > 0 and self.is_with_alignment: + flow_n1 = flows[:, flow_idx[i], :, :, :] + if self.cpu_cache: + flow_n1 = flow_n1.cuda() + + cond_n1 = flow_warp(feat_prop, flow_n1.permute(0, 2, 3, 1)) + + # initialize second-order features + feat_n2 = torch.zeros_like(feat_prop) + flow_n2 = torch.zeros_like(flow_n1) + cond_n2 = torch.zeros_like(cond_n1) + + if i > 1: # second-order features + feat_n2 = feats[module_name][-2] + if self.cpu_cache: + feat_n2 = feat_n2.cuda() + + flow_n2 = flows[:, flow_idx[i - 1], :, :, :] + if self.cpu_cache: + flow_n2 = flow_n2.cuda() + + flow_n2 = flow_n1 + flow_warp(flow_n2, + flow_n1.permute(0, 2, 3, 1)) + cond_n2 = flow_warp(feat_n2, flow_n2.permute(0, 2, 3, 1)) + + # flow-guided deformable convolution + cond = torch.cat([cond_n1, feat_current, cond_n2], dim=1) + feat_prop = torch.cat([feat_prop, feat_n2], dim=1) + feat_prop = self.deform_align[module_name](feat_prop, cond, + flow_n1, flow_n2) + + # concatenate and residual blocks + feat = [feat_current] + [ + feats[k][idx] + for k in feats if k not in ['spatial', module_name] + ] + [feat_prop] + if self.cpu_cache: + feat = [f.cuda() for f in feat] + + feat = torch.cat(feat, dim=1) + feat_prop = feat_prop + self.backbone[module_name](feat) + feats[module_name].append(feat_prop) + + if self.cpu_cache: + feats[module_name][-1] = feats[module_name][-1].cpu() + torch.cuda.empty_cache() + + if 'backward' in module_name: + feats[module_name] = feats[module_name][::-1] + + return feats + + def upsample(self, lqs, feats): + """Compute the output image given the features. + + Args: + lqs (tensor): Input low quality (LQ) sequence with + shape (n, t, c, h, w). + feats (dict): The features from the propgation branches. + + Returns: + Tensor: Output HR sequence with shape (n, t, c, 4h, 4w). + + """ + + outputs = [] + num_outputs = len(feats['spatial']) + + mapping_idx = list(range(0, num_outputs)) + mapping_idx += mapping_idx[::-1] + + for i in range(0, lqs.size(1)): + hr = [feats[k].pop(0) for k in feats if k != 'spatial'] + hr.insert(0, feats['spatial'][mapping_idx[i]]) + hr = torch.cat(hr, dim=1) + if self.cpu_cache: + hr = hr.cuda() + + hr = self.reconstruction(hr) + hr = self.lrelu(self.upsample1(hr)) + hr = self.lrelu(self.upsample2(hr)) + hr = self.lrelu(self.conv_hr(hr)) + hr = self.conv_last(hr) + if self.is_low_res_input: + hr += self.img_upsample(lqs[:, i, :, :, :]) + else: + hr += lqs[:, i, :, :, :] + + if self.cpu_cache: + hr = hr.cpu() + torch.cuda.empty_cache() + + outputs.append(hr) + + return torch.stack(outputs, dim=1) + + def forward(self, lqs): + """Forward function for BasicVSR++. + + Args: + lqs (tensor): Input low quality (LQ) sequence with + shape (n, t, c, h, w). + + Returns: + Tensor: Output HR sequence with shape (n, t, c, 4h, 4w). + """ + + n, t, c, h, w = lqs.size() + + # whether to cache the features in CPU + self.cpu_cache = False # True if t > self.cpu_cache_length else False + + if self.is_low_res_input: + lqs_downsample = lqs.clone() + else: + lqs_downsample = F.interpolate( + lqs.view(-1, c, h, w), scale_factor=0.25, + mode='bicubic').view(n, t, c, h // 4, w // 4) + + feats = {} + # compute spatial features + if self.cpu_cache: + feats['spatial'] = [] + for i in range(0, t): + feat = self.feat_extract(lqs[:, i, :, :, :]).cpu() + feats['spatial'].append(feat) + torch.cuda.empty_cache() + else: + feats_ = self.feat_extract(lqs.view(-1, c, h, w)) + h, w = feats_.shape[2:] + feats_ = feats_.view(n, t, -1, h, w) + feats['spatial'] = [feats_[:, i, :, :, :] for i in range(0, t)] + + # compute optical flow using the low-res inputs + assert lqs_downsample.size(3) >= 64 and lqs_downsample.size(4) >= 64, ( + 'The height and width of low-res inputs must be at least 64, ' + f'but got {h} and {w}.') + flows_forward, flows_backward = self.compute_flow(lqs_downsample) + + # feature propgation + for iter_ in [1, 2]: + for direction in ['backward', 'forward']: + module = f'{direction}_{iter_}' + + feats[module] = [] + + if direction == 'backward': + flows = flows_backward + elif flows_forward is not None: + flows = flows_forward + else: + flows = flows_backward.flip(1) + + feats = self.propagate(feats, flows, module) + if self.cpu_cache: + del flows + torch.cuda.empty_cache() + + return self.upsample(lqs, feats) + + +class SecondOrderDeformableAlignment(ModulatedDeformConv2d): + """Second-order deformable alignment module. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + max_residue_magnitude (int): The maximum magnitude of the offset + residue (Eq. 6 in paper). Default: 10. + + """ + + def __init__(self, *args, **kwargs): + self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10) + + super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Sequential( + nn.Conv2d(3 * self.out_channels + 4, self.out_channels, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(self.out_channels, 27 * self.deform_groups, 3, 1, 1), + ) + + self.init_offset() + + def init_offset(self): + constant_init(self.conv_offset[-1], val=0, bias=0) + + def forward(self, x, extra_feat, flow_1, flow_2): + extra_feat = torch.cat([extra_feat, flow_1, flow_2], dim=1) + out = self.conv_offset(extra_feat) + o1, o2, mask = torch.chunk(out, 3, dim=1) + + # offset + offset = self.max_residue_magnitude * torch.tanh( + torch.cat((o1, o2), dim=1)) + offset_1, offset_2 = torch.chunk(offset, 2, dim=1) + offset_1 = offset_1 + flow_1.flip(1).repeat(1, + offset_1.size(1) // 2, 1, + 1) + offset_2 = offset_2 + flow_2.flip(1).repeat(1, + offset_2.size(1) // 2, 1, + 1) + offset = torch.cat([offset_1, offset_2], dim=1) + + # mask + mask = torch.sigmoid(mask) + + return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias, + self.stride, self.padding, + self.dilation, self.groups, + self.deform_groups) diff --git a/vsbasicvsrpp/builder.py b/vsbasicvsrpp/builder.py new file mode 100644 index 0000000..1256dee --- /dev/null +++ b/vsbasicvsrpp/builder.py @@ -0,0 +1,58 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv import build_from_cfg + +from .registry import BACKBONES, COMPONENTS, LOSSES, MODELS + + +def build(cfg, registry, default_args=None): + """Build module function. + + Args: + cfg (dict): Configuration for building modules. + registry (obj): ``registry`` object. + default_args (dict, optional): Default arguments. Defaults to None. + """ + if isinstance(cfg, list): + modules = [ + build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg + ] + return nn.Sequential(*modules) + + return build_from_cfg(cfg, registry, default_args) + + +def build_backbone(cfg): + """Build backbone. + + Args: + cfg (dict): Configuration for building backbone. + """ + return build(cfg, BACKBONES) + + +def build_component(cfg): + """Build component. + + Args: + cfg (dict): Configuration for building component. + """ + return build(cfg, COMPONENTS) + + +def build_loss(cfg): + """Build loss. + + Args: + cfg (dict): Configuration for building loss. + """ + return build(cfg, LOSSES) + + +def build_model(cfg): + """Build model. + + Args: + cfg (dict): Configuration for building model. + """ + return build(cfg, MODELS) diff --git a/vsbasicvsrpp/config012.py b/vsbasicvsrpp/config012.py new file mode 100644 index 0000000..ace90bc --- /dev/null +++ b/vsbasicvsrpp/config012.py @@ -0,0 +1,14 @@ +import importlib.resources + +with importlib.resources.path('vsbasicvsrpp', 'spynet.pth') as p: + spynet_path = str(p) + +# model settings +model = dict( + type='BasicVSR', + generator=dict( + type='BasicVSRPlusPlus', + mid_channels=64, + num_blocks=7, + is_low_res_input=True, + spynet_pretrained=spynet_path)) diff --git a/vsbasicvsrpp/config345.py b/vsbasicvsrpp/config345.py new file mode 100644 index 0000000..a6a4979 --- /dev/null +++ b/vsbasicvsrpp/config345.py @@ -0,0 +1,14 @@ +import importlib.resources + +with importlib.resources.path('vsbasicvsrpp', 'spynet.pth') as p: + spynet_path = str(p) + +# model settings +model = dict( + type='BasicVSR', + generator=dict( + type='BasicVSRPlusPlus', + mid_channels=128, + num_blocks=25, + is_low_res_input=False, + spynet_pretrained=spynet_path)) diff --git a/vsbasicvsrpp/flow_warp.py b/vsbasicvsrpp/flow_warp.py new file mode 100644 index 0000000..b38f4bc --- /dev/null +++ b/vsbasicvsrpp/flow_warp.py @@ -0,0 +1,47 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn.functional as F + + +def flow_warp(x, + flow, + interpolation='bilinear', + padding_mode='zeros', + align_corners=True): + """Warp an image or a feature map with optical flow. + + Args: + x (Tensor): Tensor with size (n, c, h, w). + flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is + a two-channel, denoting the width and height relative offsets. + Note that the values are not normalized to [-1, 1]. + interpolation (str): Interpolation mode: 'nearest' or 'bilinear'. + Default: 'bilinear'. + padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'. + Default: 'zeros'. + align_corners (bool): Whether align corners. Default: True. + + Returns: + Tensor: Warped image or feature map. + """ + if x.size()[-2:] != flow.size()[1:3]: + raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and ' + f'flow ({flow.size()[1:3]}) are not the same.') + _, _, h, w = x.size() + # create mesh grid + grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w)) + grid = torch.stack((grid_x, grid_y), 2).type_as(x) # (w, h, 2) + grid.requires_grad = False + + grid_flow = grid + flow + # scale grid_flow to [-1,1] + grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0 + grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0 + grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3) + output = F.grid_sample( + x, + grid_flow, + mode=interpolation, + padding_mode=padding_mode, + align_corners=align_corners) + return output diff --git a/vsbasicvsrpp/registry.py b/vsbasicvsrpp/registry.py new file mode 100644 index 0000000..41fae9d --- /dev/null +++ b/vsbasicvsrpp/registry.py @@ -0,0 +1,7 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.utils import Registry + +MODELS = Registry('model') +BACKBONES = Registry('backbone') +COMPONENTS = Registry('component') +LOSSES = Registry('loss') diff --git a/vsbasicvsrpp/spynet.pth b/vsbasicvsrpp/spynet.pth new file mode 100644 index 0000000..e69de29 diff --git a/vsbasicvsrpp/sr_backbone_utils.py b/vsbasicvsrpp/sr_backbone_utils.py new file mode 100644 index 0000000..b4b0aad --- /dev/null +++ b/vsbasicvsrpp/sr_backbone_utils.py @@ -0,0 +1,97 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmcv.cnn import constant_init, kaiming_init +from mmcv.utils.parrots_wrapper import _BatchNorm + + +def default_init_weights(module, scale=1): + """Initialize network weights. + + Args: + modules (nn.Module): Modules to be initialized. + scale (float): Scale initialized weights, especially for residual + blocks. + """ + for m in module.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m, a=0, mode='fan_in', bias=0) + m.weight.data *= scale + elif isinstance(m, nn.Linear): + kaiming_init(m, a=0, mode='fan_in', bias=0) + m.weight.data *= scale + elif isinstance(m, _BatchNorm): + constant_init(m.weight, val=1, bias=0) + + +def make_layer(block, num_blocks, **kwarg): + """Make layers by stacking the same blocks. + + Args: + block (nn.module): nn.module class for basic block. + num_blocks (int): number of blocks. + + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + layers = [] + for _ in range(num_blocks): + layers.append(block(**kwarg)) + return nn.Sequential(*layers) + + +class ResidualBlockNoBN(nn.Module): + """Residual block without BN. + + It has a style of: + + :: + + ---Conv-ReLU-Conv-+- + |________________| + + Args: + mid_channels (int): Channel number of intermediate features. + Default: 64. + res_scale (float): Used to scale the residual before addition. + Default: 1.0. + """ + + def __init__(self, mid_channels=64, res_scale=1.0): + super().__init__() + self.res_scale = res_scale + self.conv1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1, bias=True) + self.conv2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1, bias=True) + + self.relu = nn.ReLU(inplace=True) + + # if res_scale < 1.0, use the default initialization, as in EDSR. + # if res_scale = 1.0, use scaled kaiming_init, as in MSRResNet. + if res_scale == 1.0: + self.init_weights() + + def init_weights(self): + """Initialize weights for ResidualBlockNoBN. + + Initialization methods like `kaiming_init` are for VGG-style + modules. For modules with residual paths, using smaller std is + better for stability and performance. We empirically use 0.1. + See more details in "ESRGAN: Enhanced Super-Resolution Generative + Adversarial Networks" + """ + + for m in [self.conv1, self.conv2]: + default_init_weights(m, 0.1) + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + + identity = x + out = self.conv2(self.relu(self.conv1(x))) + return identity + out * self.res_scale diff --git a/vsbasicvsrpp/upsample.py b/vsbasicvsrpp/upsample.py new file mode 100644 index 0000000..f39ec1a --- /dev/null +++ b/vsbasicvsrpp/upsample.py @@ -0,0 +1,51 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.nn.functional as F + +from .sr_backbone_utils import default_init_weights + + +class PixelShufflePack(nn.Module): + """ Pixel Shuffle upsample layer. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + scale_factor (int): Upsample ratio. + upsample_kernel (int): Kernel size of Conv layer to expand channels. + + Returns: + Upsampled feature map. + """ + + def __init__(self, in_channels, out_channels, scale_factor, + upsample_kernel): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.scale_factor = scale_factor + self.upsample_kernel = upsample_kernel + self.upsample_conv = nn.Conv2d( + self.in_channels, + self.out_channels * scale_factor * scale_factor, + self.upsample_kernel, + padding=(self.upsample_kernel - 1) // 2) + self.init_weights() + + def init_weights(self): + """Initialize weights for PixelShufflePack. + """ + default_init_weights(self, 1) + + def forward(self, x): + """Forward function for PixelShufflePack. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + x = self.upsample_conv(x) + x = F.pixel_shuffle(x, self.scale_factor) + return x