diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..dfe0770 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +# Auto detect text files and perform LF normalization +* text=auto diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..68bc17f --- /dev/null +++ b/.gitignore @@ -0,0 +1,160 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/README.md b/README.md index 2c77fe9..764e56b 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,32 @@ -# vs-gmfss_fortuna -GMFSS_Fortuna function for VapourSynth +# GMFSS_Fortuna +The All-In-One GMFSS: Dedicated for Anime Video Frame Interpolation, based on https://github.com/98mxr/GMFSS_Fortuna. + + +## Dependencies +- [CuPy](https://docs.cupy.dev/en/stable/install.html) +- [NumPy](https://numpy.org/install) +- [PyTorch](https://pytorch.org/get-started) 1.13.1 +- [VapourSynth](http://www.vapoursynth.com/) R55+ + +`trt` requires additional runtime libraries: +- [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit) 11.7 +- [cuDNN](https://developer.nvidia.com/cudnn) 8.6 +- [TensorRT](https://developer.nvidia.com/tensorrt) 8.6.0.12 + +For ease of installation on Windows, you can download the 7z file on [Releases](https://github.com/HolyWu/vs-gmfss_fortuna/releases) which contains required runtime libraries and Python wheel file. Unzip the file to the location that you chose. Add `/bin` to your system `PATH`. Additionally, add `` to `CUDA_PATH` environment variable because CuPy relies on it. Finally pip install the Python wheel file. + + +## Installation +``` +pip install -U vsgmfss-fortuna +``` + + +## Usage +```python +from vsgmfss_fortuna import gmfss_fortuna + +ret = gmfss_fortuna(clip) +``` + +See `__init__.py` for the description of the parameters. diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..c3014b2 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,31 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "vsgmfss_fortuna" +version = "1.0.0" +description = "GMFSS_Fortuna function for VapourSynth" +readme = "README.md" +requires-python = ">=3.10" +license = {file = "LICENSE"} +authors = [{name = "HolyWu", email = "holywu@gmail.com"}] +keywords = ["GMFSS_Fortuna", "VapourSynth"] +classifiers = [ + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3.10", + "Topic :: Multimedia :: Video" +] +dependencies = [ + "cupy-cuda11x>=12.0.0", + "numpy>=1.24.2", + "tensorrt>=8.6.0", + "torch>=1.13.1", + "torch-tensorrt-fx-only>=1.4.0.dev0", + "VapourSynth>=55" +] + +[project.urls] +"Homepage" = "https://github.com/HolyWu/vs-gmfss_fortuna" +"Bug Tracker" = "https://github.com/HolyWu/vs-gmfss_fortuna/issues" diff --git a/vsgmfss_fortuna/FeatureNet.py b/vsgmfss_fortuna/FeatureNet.py new file mode 100644 index 0000000..b6efaf4 --- /dev/null +++ b/vsgmfss_fortuna/FeatureNet.py @@ -0,0 +1,34 @@ +import torch.nn as nn + +from .util import MyPReLU + + +class FeatureNet(nn.Module): + """The quadratic model""" + def __init__(self): + super(FeatureNet, self).__init__() + self.block1 = nn.Sequential( + MyPReLU(), + nn.Conv2d(3, 64, 3, 2, 1), + MyPReLU(), + nn.Conv2d(64, 64, 3, 1, 1), + ) + self.block2 = nn.Sequential( + MyPReLU(), + nn.Conv2d(64, 128, 3, 2, 1), + MyPReLU(), + nn.Conv2d(128, 128, 3, 1, 1), + ) + self.block3 = nn.Sequential( + MyPReLU(), + nn.Conv2d(128, 192, 3, 2, 1), + MyPReLU(), + nn.Conv2d(192, 192, 3, 1, 1), + ) + + def forward(self, x): + x1 = self.block1(x) + x2 = self.block2(x1) + x3 = self.block3(x2) + + return x1, x2, x3 diff --git a/vsgmfss_fortuna/FusionNet_b.py b/vsgmfss_fortuna/FusionNet_b.py new file mode 100644 index 0000000..5c3448e --- /dev/null +++ b/vsgmfss_fortuna/FusionNet_b.py @@ -0,0 +1,148 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .util import MyPixelShuffle, MyPReLU + + +# Residual Block +def ResidualBlock(in_channels, out_channels, stride=1): + return torch.nn.Sequential( + MyPReLU(), + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=True), + MyPReLU(), + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=True) + ) + + +# downsample block +def DownsampleBlock(in_channels, out_channels, stride=2): + return torch.nn.Sequential( + MyPReLU(), + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=True), + MyPReLU(), + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True) + ) + + +# upsample block +def UpsampleBlock(in_channels, out_channels, stride=2): + return torch.nn.Sequential( + MyPReLU(), + nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=stride, padding=1, bias=True), + MyPReLU(), + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True) + ) + + +class PixelShuffleBlcok(nn.Module): + def __init__(self, in_feat, num_feat, num_out_ch): + super(PixelShuffleBlcok, self).__init__() + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(in_feat, num_feat, 3, 1, 1), + MyPReLU() + ) + self.upsample = nn.Sequential( + nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1), + MyPixelShuffle(2) + ) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + def forward(self, x): + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + return x + +# grid network +class GridNet(nn.Module): + def __init__(self, in_channels=12, in_channels1=128, in_channels2=256, in_channels3=384, out_channels=3): + super(GridNet, self).__init__() + + self.residual_model_head = ResidualBlock(in_channels, 64) + self.residual_model_head1 = ResidualBlock(in_channels1, 64) + self.residual_model_head2 = ResidualBlock(in_channels2, 128) + self.residual_model_head3 = ResidualBlock(in_channels3, 192) + + self.residual_model_01=ResidualBlock(64, 64) + #self.residual_model_02=ResidualBlock(64, 64) + #self.residual_model_03=ResidualBlock(64, 64) + self.residual_model_04=ResidualBlock(64, 64) + self.residual_model_05=ResidualBlock(64, 64) + self.residual_model_tail=PixelShuffleBlcok(64, 64, out_channels) + + + self.residual_model_11=ResidualBlock(128, 128) + #self.residual_model_12=ResidualBlock(128, 128) + #self.residual_model_13=ResidualBlock(128, 128) + self.residual_model_14=ResidualBlock(128, 128) + self.residual_model_15=ResidualBlock(128, 128) + + self.residual_model_21=ResidualBlock(192, 192) + #self.residual_model_22=ResidualBlock(192, 192) + #self.residual_model_23=ResidualBlock(192, 192) + self.residual_model_24=ResidualBlock(192, 192) + self.residual_model_25=ResidualBlock(192, 192) + + # + + self.downsample_model_10=DownsampleBlock(64, 128) + self.downsample_model_20=DownsampleBlock(128, 192) + + self.downsample_model_11=DownsampleBlock(64, 128) + self.downsample_model_21=DownsampleBlock(128, 192) + + #self.downsample_model_12=DownsampleBlock(64, 128) + #self.downsample_model_22=DownsampleBlock(128, 192) + + # + + #self.upsample_model_03=UpsampleBlock(128, 64) + #self.upsample_model_13=UpsampleBlock(192, 128) + + self.upsample_model_04=UpsampleBlock(128, 64) + self.upsample_model_14=UpsampleBlock(192, 128) + + self.upsample_model_05=UpsampleBlock(128, 64) + self.upsample_model_15=UpsampleBlock(192, 128) + + def forward(self, x, x1, x2, x3): + X00=self.residual_model_head(x) + self.residual_model_head1(x1) #--- 182 ~ 185 + # X10 = self.residual_model_head1(x1) + + X01=self.residual_model_01(X00) + X00#--- 208 ~ 211 ,AddBackward1213 + + X10=self.downsample_model_10(X00) + self.residual_model_head2(x2) #--- 186 ~ 189 + X20=self.downsample_model_20(X10) + self.residual_model_head3(x3) #--- 190 ~ 193 + + residual_11=self.residual_model_11(X10) + X10 #201 ~ 204 , sum AddBackward1206 + downsample_11=self.downsample_model_11(X01) #214 ~ 217 + X11=residual_11 + downsample_11 #--- AddBackward1218 + + residual_21=self.residual_model_21(X20) + X20 #194 ~ 197 , sum AddBackward1199 + downsample_21=self.downsample_model_21(X11) #219 ~ 222 + X21=residual_21 + downsample_21 # AddBackward1223 + + + X24=self.residual_model_24(X21) + X21 #--- 224 ~ 227 , AddBackward1229 + X25=self.residual_model_25(X24) + X24 #--- 230 ~ 233 , AddBackward1235 + + + upsample_14=self.upsample_model_14(X24) #242 ~ 246 + residual_14=self.residual_model_14(X11) + X11 #248 ~ 251, AddBackward1253 + X14=upsample_14 + residual_14 #--- AddBackward1254 + + upsample_04=self.upsample_model_04(X14) #268 ~ 272 + residual_04=self.residual_model_04(X01) + X01 #274 ~ 277, AddBackward1279 + X04=upsample_04 + residual_04 #--- AddBackward1280 + + upsample_15=self.upsample_model_15(X25) #236 ~ 240 + residual_15=self.residual_model_15(X14) + X14 #255 ~ 258, AddBackward1260 + X15=upsample_15 + residual_15 # AddBackward1261 + + upsample_05=self.upsample_model_05(X15) # 262 ~ 266 + residual_05=self.residual_model_05(X04) + X04 #281 ~ 284,AddBackward1286 + X05=upsample_05 + residual_05 # AddBackward1287 + + X_tail=self.residual_model_tail(X05) #288 ~ 291 + + return X_tail diff --git a/vsgmfss_fortuna/FusionNet_u.py b/vsgmfss_fortuna/FusionNet_u.py new file mode 100644 index 0000000..ec0be7e --- /dev/null +++ b/vsgmfss_fortuna/FusionNet_u.py @@ -0,0 +1,148 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .util import MyPixelShuffle, MyPReLU + + +# Residual Block +def ResidualBlock(in_channels, out_channels, stride=1): + return torch.nn.Sequential( + MyPReLU(), + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=True), + MyPReLU(), + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=True) + ) + + +# downsample block +def DownsampleBlock(in_channels, out_channels, stride=2): + return torch.nn.Sequential( + MyPReLU(), + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=True), + MyPReLU(), + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True) + ) + + +# upsample block +def UpsampleBlock(in_channels, out_channels, stride=2): + return torch.nn.Sequential( + MyPReLU(), + nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=stride, padding=1, bias=True), + MyPReLU(), + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True) + ) + + +class PixelShuffleBlcok(nn.Module): + def __init__(self, in_feat, num_feat, num_out_ch): + super(PixelShuffleBlcok, self).__init__() + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(in_feat, num_feat, 3, 1, 1), + MyPReLU() + ) + self.upsample = nn.Sequential( + nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1), + MyPixelShuffle(2) + ) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + def forward(self, x): + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + return x + +# grid network +class GridNet(nn.Module): + def __init__(self, in_channels=9, in_channels1=128, in_channels2=256, in_channels3=384, out_channels=3): + super(GridNet, self).__init__() + + self.residual_model_head0 = ResidualBlock(in_channels, 64) + self.residual_model_head1 = ResidualBlock(in_channels1, 64) + self.residual_model_head2 = ResidualBlock(in_channels2, 128) + self.residual_model_head3 = ResidualBlock(in_channels3, 192) + + self.residual_model_01=ResidualBlock(64, 64) + #self.residual_model_02=ResidualBlock(64, 64) + #self.residual_model_03=ResidualBlock(64, 64) + self.residual_model_04=ResidualBlock(64, 64) + self.residual_model_05=ResidualBlock(64, 64) + self.residual_model_tail=PixelShuffleBlcok(64, 64, out_channels) + + + self.residual_model_11=ResidualBlock(128, 128) + #self.residual_model_12=ResidualBlock(128, 128) + #self.residual_model_13=ResidualBlock(128, 128) + self.residual_model_14=ResidualBlock(128, 128) + self.residual_model_15=ResidualBlock(128, 128) + + self.residual_model_21=ResidualBlock(192, 192) + #self.residual_model_22=ResidualBlock(192, 192) + #self.residual_model_23=ResidualBlock(192, 192) + self.residual_model_24=ResidualBlock(192, 192) + self.residual_model_25=ResidualBlock(192, 192) + + # + + self.downsample_model_10=DownsampleBlock(64, 128) + self.downsample_model_20=DownsampleBlock(128, 192) + + self.downsample_model_11=DownsampleBlock(64, 128) + self.downsample_model_21=DownsampleBlock(128, 192) + + #self.downsample_model_12=DownsampleBlock(64, 128) + #self.downsample_model_22=DownsampleBlock(128, 192) + + # + + #self.upsample_model_03=UpsampleBlock(128, 64) + #self.upsample_model_13=UpsampleBlock(192, 128) + + self.upsample_model_04=UpsampleBlock(128, 64) + self.upsample_model_14=UpsampleBlock(192, 128) + + self.upsample_model_05=UpsampleBlock(128, 64) + self.upsample_model_15=UpsampleBlock(192, 128) + + def forward(self, x, x1, x2, x3): + X00=self.residual_model_head0(x) + self.residual_model_head1(x1) #--- 182 ~ 185 + # X10 = self.residual_model_head1(x1) + + X01=self.residual_model_01(X00) + X00#--- 208 ~ 211 ,AddBackward1213 + + X10=self.downsample_model_10(X00) + self.residual_model_head2(x2) #--- 186 ~ 189 + X20=self.downsample_model_20(X10) + self.residual_model_head3(x3) #--- 190 ~ 193 + + residual_11=self.residual_model_11(X10) + X10 #201 ~ 204 , sum AddBackward1206 + downsample_11=self.downsample_model_11(X01) #214 ~ 217 + X11=residual_11 + downsample_11 #--- AddBackward1218 + + residual_21=self.residual_model_21(X20) + X20 #194 ~ 197 , sum AddBackward1199 + downsample_21=self.downsample_model_21(X11) #219 ~ 222 + X21=residual_21 + downsample_21 # AddBackward1223 + + + X24=self.residual_model_24(X21) + X21 #--- 224 ~ 227 , AddBackward1229 + X25=self.residual_model_25(X24) + X24 #--- 230 ~ 233 , AddBackward1235 + + + upsample_14=self.upsample_model_14(X24) #242 ~ 246 + residual_14=self.residual_model_14(X11) + X11 #248 ~ 251, AddBackward1253 + X14=upsample_14 + residual_14 #--- AddBackward1254 + + upsample_04=self.upsample_model_04(X14) #268 ~ 272 + residual_04=self.residual_model_04(X01) + X01 #274 ~ 277, AddBackward1279 + X04=upsample_04 + residual_04 #--- AddBackward1280 + + upsample_15=self.upsample_model_15(X25) #236 ~ 240 + residual_15=self.residual_model_15(X14) + X14 #255 ~ 258, AddBackward1260 + X15=upsample_15 + residual_15 # AddBackward1261 + + upsample_05=self.upsample_model_05(X15) # 262 ~ 266 + residual_05=self.residual_model_05(X04) + X04 #281 ~ 284,AddBackward1286 + X05=upsample_05 + residual_05 # AddBackward1287 + + X_tail=self.residual_model_tail(X05) #288 ~ 291 + + return X_tail diff --git a/vsgmfss_fortuna/GMFSS.py b/vsgmfss_fortuna/GMFSS.py new file mode 100644 index 0000000..4a25312 --- /dev/null +++ b/vsgmfss_fortuna/GMFSS.py @@ -0,0 +1,99 @@ +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .FeatureNet import FeatureNet +from .gmflow.gmflow import GMFlow +from .IFNet_HDv3 import IFNet +from .MetricNet import MetricNet +from .softsplat import softsplat as warp + +torch.fx.wrap('warp') + + +class GMFSS(nn.Module): + def __init__(self, model_dir, model_type, scale, ensemble): + super(GMFSS, self).__init__() + if model_type == 'base': + from .FusionNet_b import GridNet + else: + from .FusionNet_u import GridNet + self.ifnet = IFNet(ensemble) + self.ifnet.load_state_dict(torch.load(os.path.join(model_dir, 'rife.pkl'), map_location='cpu')) + self.flownet = GMFlow() + self.metricnet = MetricNet() + self.feat_ext = FeatureNet() + self.fusionnet = GridNet() + self.flownet.load_state_dict(torch.load(os.path.join(model_dir, 'flownet.pkl'), map_location='cpu')) + self.metricnet.load_state_dict(torch.load(os.path.join(model_dir, f'metric_{model_type}.pkl'), map_location='cpu')) + self.feat_ext.load_state_dict(torch.load(os.path.join(model_dir, f'feat_{model_type}.pkl'), map_location='cpu')) + self.fusionnet.load_state_dict(torch.load(os.path.join(model_dir, f'fusionnet_{model_type}.pkl'), map_location='cpu')) + self.model_type = model_type + self.scale = scale + + def reuse(self, img0, img1): + feat11, feat12, feat13 = self.feat_ext(img0) + feat21, feat22, feat23 = self.feat_ext(img1) + feat_ext0 = [feat11, feat12, feat13] + feat_ext1 = [feat21, feat22, feat23] + + img0 = F.interpolate(img0, scale_factor = 0.5, mode="bilinear") + img1 = F.interpolate(img1, scale_factor = 0.5, mode="bilinear") + + if self.scale != 1.0: + imgf0 = F.interpolate(img0, scale_factor = self.scale, mode="bilinear") + imgf1 = F.interpolate(img1, scale_factor = self.scale, mode="bilinear") + else: + imgf0 = img0 + imgf1 = img1 + flow01 = self.flownet(imgf0, imgf1) + flow10 = self.flownet(imgf1, imgf0) + if self.scale != 1.0: + flow01 = F.interpolate(flow01, scale_factor = 1. / self.scale, mode="bilinear") / self.scale + flow10 = F.interpolate(flow10, scale_factor = 1. / self.scale, mode="bilinear") / self.scale + + metric0, metric1 = self.metricnet(img0, img1, flow01, flow10) + + return flow01, flow10, metric0, metric1, feat_ext0, feat_ext1 + + def forward(self, img0, img1, timestep): + reuse_things = self.reuse(img0, img1) + flow01, metric0, feat11, feat12, feat13 = reuse_things[0], reuse_things[2], reuse_things[4][0], reuse_things[4][1], reuse_things[4][2] + flow10, metric1, feat21, feat22, feat23 = reuse_things[1], reuse_things[3], reuse_things[5][0], reuse_things[5][1], reuse_things[5][2] + + F1t = timestep * flow01 + F2t = (1-timestep) * flow10 + + Z1t = timestep * metric0 + Z2t = (1-timestep) * metric1 + + img0 = F.interpolate(img0, scale_factor = 0.5, mode="bilinear") + I1t = warp(img0, F1t, Z1t, strMode='soft') + img1 = F.interpolate(img1, scale_factor = 0.5, mode="bilinear") + I2t = warp(img1, F2t, Z2t, strMode='soft') + + if self.model_type == 'union': + rife = self.ifnet(img0, img1, timestep) + + feat1t1 = warp(feat11, F1t, Z1t, strMode='soft') + feat2t1 = warp(feat21, F2t, Z2t, strMode='soft') + + F1td = F.interpolate(F1t, scale_factor = 0.5, mode="bilinear") * 0.5 + Z1d = F.interpolate(Z1t, scale_factor = 0.5, mode="bilinear") + feat1t2 = warp(feat12, F1td, Z1d, strMode='soft') + F2td = F.interpolate(F2t, scale_factor = 0.5, mode="bilinear") * 0.5 + Z2d = F.interpolate(Z2t, scale_factor = 0.5, mode="bilinear") + feat2t2 = warp(feat22, F2td, Z2d, strMode='soft') + + F1tdd = F.interpolate(F1t, scale_factor = 0.25, mode="bilinear") * 0.25 + Z1dd = F.interpolate(Z1t, scale_factor = 0.25, mode="bilinear") + feat1t3 = warp(feat13, F1tdd, Z1dd, strMode='soft') + F2tdd = F.interpolate(F2t, scale_factor = 0.25, mode="bilinear") * 0.25 + Z2dd = F.interpolate(Z2t, scale_factor = 0.25, mode="bilinear") + feat2t3 = warp(feat23, F2tdd, Z2dd, strMode='soft') + + out = self.fusionnet(torch.cat([img0, I1t, I2t, img1] if self.model_type == 'base' else [I1t, rife, I2t], dim=1), torch.cat([feat1t1, feat2t1], dim=1), torch.cat([feat1t2, feat2t2], dim=1), torch.cat([feat1t3, feat2t3], dim=1)) + + return torch.clamp(out, 0, 1) diff --git a/vsgmfss_fortuna/IFNet_HDv3.py b/vsgmfss_fortuna/IFNet_HDv3.py new file mode 100644 index 0000000..cf663e1 --- /dev/null +++ b/vsgmfss_fortuna/IFNet_HDv3.py @@ -0,0 +1,96 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .util import MyPixelShuffle +from .warplayer import warp + +torch.fx.wrap('warp') + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + nn.LeakyReLU(0.2, True) + ) + +class ResConv(nn.Module): + def __init__(self, c, dilation=1): + super(ResConv, self).__init__() + self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) + self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x): + return self.relu(self.conv(x) * self.beta + x) + +class IFBlock(nn.Module): + def __init__(self, in_planes, c=64): + super(IFBlock, self).__init__() + self.conv0 = nn.Sequential( + conv(in_planes, c//2, 3, 2, 1), + conv(c//2, c, 3, 2, 1), + ) + self.convblock = nn.Sequential( + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ) + self.lastconv = nn.Sequential( + nn.ConvTranspose2d(c, 4*6, 4, 2, 1), + MyPixelShuffle(2) + ) + + def forward(self, x, flow=None, scale=1): + x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear") + if flow is not None: + flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear") * 1. / scale + x = torch.cat((x, flow), 1) + feat = self.conv0(x) + feat = self.convblock(feat) + tmp = self.lastconv(feat) + tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") + flow = tmp[:, :4] * scale + mask = tmp[:, 4:5] + return flow, mask + +class IFNet(nn.Module): + def __init__(self, ensemble=False): + super(IFNet, self).__init__() + self.block0 = IFBlock(7, c=192) + self.block1 = IFBlock(8+4, c=128) + self.block2 = IFBlock(8+4, c=96) + self.block3 = IFBlock(8+4, c=64) + self.scale_list = [8, 4, 2, 1] + self.ensemble = ensemble + + def forward(self, img0, img1, timestep): + timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3]) + flow = None + block = [self.block0, self.block1, self.block2, self.block3] + for i in range(4): + if flow is None: + flow, mask = block[i](torch.cat((img0[:, :3], img1[:, :3], timestep), 1), None, scale=self.scale_list[i]) + if self.ensemble: + f1, m1 = block[i](torch.cat((img1[:, :3], img0[:, :3], 1-timestep), 1), None, scale=self.scale_list[i]) + flow = (flow + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2 + mask = (mask + (-m1)) / 2 + else: + f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], timestep, mask), 1), flow, scale=self.scale_list[i]) + if self.ensemble: + f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], 1-timestep, -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=self.scale_list[i]) + f0 = (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2 + m0 = (m0 + (-m1)) / 2 + flow = flow + f0 + mask = mask + m0 + warped_img0 = warp(img0, flow[:, :2]) + warped_img1 = warp(img1, flow[:, 2:4]) + mask = torch.sigmoid(mask) + merged = warped_img0 * mask + warped_img1 * (1 - mask) + return merged diff --git a/vsgmfss_fortuna/MetricNet.py b/vsgmfss_fortuna/MetricNet.py new file mode 100644 index 0000000..e2a5838 --- /dev/null +++ b/vsgmfss_fortuna/MetricNet.py @@ -0,0 +1,70 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .gmflow.geometry import forward_backward_consistency_check +from .util import MyPReLU + +torch.fx.wrap('backwarp') +torch.fx.wrap('forward_backward_consistency_check') + +backwarp_tenGrid = {} + + +def backwarp(tenIn, tenflow): + if str(tenflow.shape) not in backwarp_tenGrid: + tenHor = torch.linspace(start=-1.0, end=1.0, steps=tenflow.shape[3], dtype=tenflow.dtype, device=tenflow.device).view(1, 1, 1, -1).repeat(1, 1, tenflow.shape[2], 1) + tenVer = torch.linspace(start=-1.0, end=1.0, steps=tenflow.shape[2], dtype=tenflow.dtype, device=tenflow.device).view(1, 1, -1, 1).repeat(1, 1, 1, tenflow.shape[3]) + + backwarp_tenGrid[str(tenflow.shape)] = torch.cat([tenHor, tenVer], 1) + # end + + tenflow = torch.cat([tenflow[:, 0:1, :, :] / ((tenIn.shape[3] - 1.0) / 2.0), tenflow[:, 1:2, :, :] / ((tenIn.shape[2] - 1.0) / 2.0)], 1) + + return torch.nn.functional.grid_sample(input=tenIn, grid=(backwarp_tenGrid[str(tenflow.shape)] + tenflow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=True) + + +class MetricNet(nn.Module): + def __init__(self): + super(MetricNet, self).__init__() + self.metric_in = nn.Conv2d(14, 64, 3, 1, 1) + self.metric_net1 = nn.Sequential( + MyPReLU(), + nn.Conv2d(64, 64, 3, 1, 1) + ) + self.metric_net2 = nn.Sequential( + MyPReLU(), + nn.Conv2d(64, 64, 3, 1, 1) + ) + self.metric_net3 = nn.Sequential( + MyPReLU(), + nn.Conv2d(64, 64, 3, 1, 1) + ) + self.metric_out = nn.Sequential( + MyPReLU(), + nn.Conv2d(64, 2, 3, 1, 1) + ) + + def forward(self, img0, img1, flow01, flow10): + metric0 = F.l1_loss(img0, backwarp(img1, flow01), reduction='none').mean([1], True) + metric1 = F.l1_loss(img1, backwarp(img0, flow10), reduction='none').mean([1], True) + + fwd_occ, bwd_occ = forward_backward_consistency_check(flow01, flow10) + + flow01 = torch.cat([flow01[:, 0:1, :, :] / ((flow01.shape[3] - 1.0) / 2.0), flow01[:, 1:2, :, :] / ((flow01.shape[2] - 1.0) / 2.0)], 1) + flow10 = torch.cat([flow10[:, 0:1, :, :] / ((flow10.shape[3] - 1.0) / 2.0), flow10[:, 1:2, :, :] / ((flow10.shape[2] - 1.0) / 2.0)], 1) + + img = torch.cat((img0, img1), 1) + metric = torch.cat((-metric0, -metric1), 1) + flow = torch.cat((flow01, flow10), 1) + occ = torch.cat((fwd_occ.unsqueeze(1), bwd_occ.unsqueeze(1)), 1) + + feat = self.metric_in(torch.cat((img, metric, flow, occ), 1)) + feat = self.metric_net1(feat) + feat + feat = self.metric_net2(feat) + feat + feat = self.metric_net3(feat) + feat + metric = self.metric_out(feat) + + metric = torch.tanh(metric) * 10 + + return metric[:, :1], metric[:, 1:2] diff --git a/vsgmfss_fortuna/__init__.py b/vsgmfss_fortuna/__init__.py new file mode 100644 index 0000000..c6d1a1f --- /dev/null +++ b/vsgmfss_fortuna/__init__.py @@ -0,0 +1,249 @@ +from __future__ import annotations + +import math +import os +from fractions import Fraction +from threading import Lock + +import numpy as np +import tensorrt +import torch +import torch.nn.functional as F +import vapoursynth as vs +from torch.nn import InstanceNorm2d +from torch_tensorrt.fx import LowerSetting +from torch_tensorrt.fx.lower import Lowerer +from torch_tensorrt.fx.utils import LowerPrecision + +from .gmflow.transformer import FeatureTransformer +from .GMFSS import GMFSS + +__version__ = "1.0.0" + +os.environ["CUDA_MODULE_LOADING"] = "LAZY" + +model_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") + + +@torch.inference_mode() +def gmfss_fortuna( + clip: vs.VideoNode, + device_index: int | None = None, + num_streams: int = 1, + trt: bool = False, + trt_max_workspace_size: int = 1 << 30, + trt_cache_path: str = model_dir, + model: int = 0, + factor_num: int = 2, + factor_den: int = 1, + fps_num: int | None = None, + fps_den: int | None = None, + scale: float = 1.0, + ensemble: bool = False, + sc: bool = True, + sc_threshold: float | None = None, +) -> vs.VideoNode: + """The All-In-One GMFSS: Dedicated for Anime Video Frame Interpolation + + :param clip: Clip to process. Only RGBH and RGBS formats are supported. + RGBH performs inference in FP16 mode while RGBS performs inference in FP32 mode. + :param device_index: Device ordinal of the GPU. + :param num_streams: Number of CUDA streams to enqueue the kernels. + :param trt: Use TensorRT for high-performance inference. + :param trt_max_workspace_size: Maximum workspace size for TensorRT engine. + :param trt_cache_path: Path for TensorRT engine file. Engine will be cached when it's built for the first + time. Note each engine is created for specific settings such as model path/name, + precision, workspace etc, and specific GPUs and it's not portable. + :param model: Model to use. + 0 = base model + 1 = union model + :param factor_num: Numerator of factor for target frame rate. + For example `factor_num=5, factor_den=2` will multiply the frame rate by 2.5. + :param factor_den: Denominator of factor for target frame rate. + :param fps_num: Numerator of target frame rate. Override `factor_num` and `factor_den` if specified. + :param fps_den: Denominator of target frame rate. + :param scale: Control the process resolution for optical flow model. Try scale=0.5 for 4K video. + Must be 0.25, 0.5, 1.0, 2.0, or 4.0. + :param ensemble: Smooth predictions in areas where the estimation is uncertain. + :param sc: Avoid interpolating frames over scene changes. + :param sc_threshold: Threshold for scene change detection. Must be between 0.0 and 1.0. + Leave it None if the clip already has _SceneChangeNext properly set. + """ + if not isinstance(clip, vs.VideoNode): + raise vs.Error("gmfss_fortuna: this is not a clip") + + if clip.format.id not in [vs.RGBH, vs.RGBS]: + raise vs.Error("gmfss_fortuna: only RGBH and RGBS formats are supported") + + if clip.num_frames < 2: + raise vs.Error("gmfss_fortuna: clip's number of frames must be at least 2") + + if not torch.cuda.is_available(): + raise vs.Error("gmfss_fortuna: CUDA is not available") + + if num_streams < 1: + raise vs.Error("gmfss_fortuna: num_streams must be at least 1") + + if num_streams > vs.core.num_threads: + raise vs.Error("gmfss_fortuna: setting num_streams greater than `core.num_threads` is useless") + + if model not in range(2): + raise vs.Error("gmfss_fortuna: model must be 0 or 1") + + if factor_num < 1: + raise vs.Error("gmfss_fortuna: factor_num must be at least 1") + + if factor_den < 1: + raise vs.Error("gmfss_fortuna: factor_den must be at least 1") + + if fps_num is not None and fps_num < 1: + raise vs.Error("gmfss_fortuna: fps_num must be at least 1") + + if fps_den is not None and fps_den < 1: + raise vs.Error("gmfss_fortuna: fps_den must be at least 1") + + if fps_num is not None and fps_den is not None and clip.fps == 0: + raise vs.Error( + "gmfss_fortuna: clip does not have a valid frame rate and hence fps_num and fps_den cannot be used" + ) + + if scale not in [0.25, 0.5, 1.0, 2.0, 4.0]: + raise vs.Error("gmfss_fortuna: scale must be 0.25, 0.5, 1.0, 2.0, or 4.0") + + torch.set_float32_matmul_precision("high") + + fp16 = clip.format.bits_per_sample == 16 + dtype = torch.half if fp16 else torch.float + + device = torch.device("cuda", device_index) + + stream = [torch.cuda.Stream(device=device) for _ in range(num_streams)] + stream_lock = [Lock() for _ in range(num_streams)] + + model_type = "base" if model == 0 else "union" + + module = GMFSS(model_dir, model_type, scale, ensemble) + module.eval().to(device, memory_format=torch.channels_last) + if fp16: + module.half() + + tmp = max(64, int(64 / scale)) + pw = math.ceil(clip.width / tmp) * tmp + ph = math.ceil(clip.height / tmp) * tmp + + if trt: + device_name = torch.cuda.get_device_name(device) + trt_version = tensorrt.__version__ + dimensions = f"{pw}x{ph}" + precision = "fp16" if fp16 else "fp32" + trt_engine_path = os.path.join( + os.path.realpath(trt_cache_path), + ( + f"gmfss_fortuna-{model_type}" + + f"_{device_name}" + + f"_trt-{trt_version}" + + f"_{dimensions}" + + f"_{precision}" + + f"_workspace-{trt_max_workspace_size}" + + f"_scale-{scale}" + + f"_ensemble-{ensemble}" + + ".pt" + ), + ) + + if not os.path.isfile(trt_engine_path): + lower_setting = LowerSetting( + lower_precision=LowerPrecision.FP16 if fp16 else LowerPrecision.FP32, + min_acc_module_size=1, + leaf_module_list={FeatureTransformer, InstanceNorm2d}, + max_workspace_size=trt_max_workspace_size, + dynamic_batch=False, + tactic_sources=1 << int(tensorrt.TacticSource.EDGE_MASK_CONVOLUTIONS) + | 1 << int(tensorrt.TacticSource.JIT_CONVOLUTIONS), + ) + lowerer = Lowerer.create(lower_setting=lower_setting) + module = lowerer( + module, + [ + torch.zeros((1, 3, ph, pw), dtype=dtype, device=device).to(memory_format=torch.channels_last), + torch.zeros((1, 3, ph, pw), dtype=dtype, device=device).to(memory_format=torch.channels_last), + torch.zeros((1,), dtype=dtype, device=device), + ], + ) + torch.save(module, trt_engine_path) + + del module + torch.cuda.empty_cache() + module = [torch.load(trt_engine_path) for _ in range(num_streams)] + + if fps_num is not None and fps_den is not None: + factor = Fraction(fps_num, fps_den) / clip.fps + factor_num, factor_den = factor.as_integer_ratio() + + if sc_threshold is not None: + clip = sc_detect(clip, sc_threshold) + + index = -1 + index_lock = Lock() + + @torch.inference_mode() + def inference(n: int, f: list[vs.VideoFrame]) -> vs.VideoFrame: + remainder = n * factor_den % factor_num + + if remainder == 0 or (sc and f[0].props.get("_SceneChangeNext")): + return f[0] + + nonlocal index + with index_lock: + index = (index + 1) % num_streams + local_index = index + + with stream_lock[local_index], torch.cuda.stream(stream[local_index]): + img0 = frame_to_tensor(f[0], device) + img1 = frame_to_tensor(f[1], device) + img0 = F.interpolate(img0, (ph, pw), mode="bilinear") + img1 = F.interpolate(img1, (ph, pw), mode="bilinear") + + timestep = torch.tensor([remainder / factor_num], dtype=dtype, device=device) + + if trt: + output = module[local_index](img0, img1, timestep) + else: + output = module(img0, img1, timestep) + + output = F.interpolate(output, (clip.height, clip.width), mode="bilinear") + return tensor_to_frame(output, f[0].copy()) + + clip0 = vs.core.std.Interleave([clip] * factor_num) + if factor_den > 1: + clip0 = clip0.std.SelectEvery(cycle=factor_den, offsets=0) + + clip1 = clip.std.DuplicateFrames(frames=clip.num_frames - 1).std.Trim(first=1) + clip1 = vs.core.std.Interleave([clip1] * factor_num) + if factor_den > 1: + clip1 = clip1.std.SelectEvery(cycle=factor_den, offsets=0) + + return clip0.std.FrameEval(lambda n: clip0.std.ModifyFrame([clip0, clip1], inference), clip_src=[clip0, clip1]) + + +def sc_detect(clip: vs.VideoNode, threshold: float) -> vs.VideoNode: + def copy_property(n: int, f: list[vs.VideoFrame]) -> vs.VideoFrame: + fout = f[0].copy() + fout.props["_SceneChangePrev"] = f[1].props["_SceneChangePrev"] + fout.props["_SceneChangeNext"] = f[1].props["_SceneChangeNext"] + return fout + + sc_clip = clip.resize.Bicubic(format=vs.GRAY8, matrix_s="709").misc.SCDetect(threshold) + return clip.std.FrameEval(lambda n: clip.std.ModifyFrame([clip, sc_clip], copy_property), clip_src=[clip, sc_clip]) + + +def frame_to_tensor(frame: vs.VideoFrame, device: torch.device) -> torch.Tensor: + array = np.stack([np.asarray(frame[plane]) for plane in range(frame.format.num_planes)]) + return torch.from_numpy(array).unsqueeze(0).to(device, memory_format=torch.channels_last).clamp(0.0, 1.0) + + +def tensor_to_frame(tensor: torch.Tensor, frame: vs.VideoFrame) -> vs.VideoFrame: + array = tensor.squeeze(0).detach().cpu().numpy() + for plane in range(frame.format.num_planes): + np.copyto(np.asarray(frame[plane]), array[plane, :, :]) + return frame diff --git a/vsgmfss_fortuna/gmflow/__init__.py b/vsgmfss_fortuna/gmflow/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vsgmfss_fortuna/gmflow/backbone.py b/vsgmfss_fortuna/gmflow/backbone.py new file mode 100644 index 0000000..a30942e --- /dev/null +++ b/vsgmfss_fortuna/gmflow/backbone.py @@ -0,0 +1,117 @@ +import torch.nn as nn + +from .trident_conv import MultiScaleTridentConv + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_layer=nn.InstanceNorm2d, stride=1, dilation=1, + ): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, + dilation=dilation, padding=dilation, stride=stride, bias=False) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, + dilation=dilation, padding=dilation, bias=False) + self.relu = nn.ReLU(inplace=True) + + self.norm1 = norm_layer(planes) + self.norm2 = norm_layer(planes) + if not stride == 1 or in_planes != planes: + self.norm3 = norm_layer(planes) + + if stride == 1 and in_planes == planes: + self.downsample = None + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class CNNEncoder(nn.Module): + def __init__(self, output_dim=128, + norm_layer=nn.InstanceNorm2d, + num_output_scales=1, + **kwargs, + ): + super(CNNEncoder, self).__init__() + self.num_branch = num_output_scales + + feature_dims = [64, 96, 128] + + self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False) # 1/2 + self.norm1 = norm_layer(feature_dims[0]) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = feature_dims[0] + self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer) # 1/2 + self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer) # 1/4 + + # highest resolution 1/4 or 1/8 + stride = 2 if num_output_scales == 1 else 1 + self.layer3 = self._make_layer(feature_dims[2], stride=stride, + norm_layer=norm_layer, + ) # 1/4 or 1/8 + + self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0) + + if self.num_branch > 1: + if self.num_branch == 4: + strides = (1, 2, 4, 8) + elif self.num_branch == 3: + strides = (1, 2, 4) + elif self.num_branch == 2: + strides = (1, 2) + else: + raise ValueError + + self.trident_conv = MultiScaleTridentConv(output_dim, output_dim, + kernel_size=3, + strides=strides, + paddings=1, + num_branch=self.num_branch, + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d): + layer1 = ResidualBlock(self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation) + layer2 = ResidualBlock(dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation) + + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) # 1/2 + x = self.layer2(x) # 1/4 + x = self.layer3(x) # 1/8 or 1/4 + + x = self.conv2(x) + + if self.num_branch > 1: + out = self.trident_conv([x] * self.num_branch) # high to low res + else: + out = [x] + + return out diff --git a/vsgmfss_fortuna/gmflow/geometry.py b/vsgmfss_fortuna/gmflow/geometry.py new file mode 100644 index 0000000..f76c137 --- /dev/null +++ b/vsgmfss_fortuna/gmflow/geometry.py @@ -0,0 +1,96 @@ +import torch +import torch.nn.functional as F + + +def coords_grid(b, h, w, homogeneous=False, device=None): + y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij') # [H, W] + + stacks = [x, y] + + if homogeneous: + ones = torch.ones_like(x) # [H, W] + stacks.append(ones) + + grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] + + grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] + + if device is not None: + grid = grid.to(device) + + return grid + + +def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): + assert device is not None + + x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device), + torch.linspace(h_min, h_max, len_h, device=device)], + indexing='ij') + grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2] + + return grid + + +def normalize_coords(coords, h, w): + # coords: [B, H, W, 2] + c = torch.tensor([(w - 1) / 2., (h - 1) / 2.], dtype=coords.dtype, device=coords.device) + return (coords - c) / c # [-1, 1] + + +def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False): + # img: [B, C, H, W] + # sample_coords: [B, 2, H, W] in image scale + if sample_coords.size(1) != 2: # [B, H, W, 2] + sample_coords = sample_coords.permute(0, 3, 1, 2) + + b, _, h, w = sample_coords.shape + + # Normalize to [-1, 1] + x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1 + y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1 + + grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2] + + img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True) + + if return_mask: + mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W] + + return img, mask + + return img + + +def flow_warp(feature, flow, mask=False, padding_mode='zeros'): + b, c, h, w = feature.size() + assert flow.size(1) == 2 + + grid = coords_grid(b, h, w).to(flow) + flow # [B, 2, H, W] + + return bilinear_sample(feature, grid, padding_mode=padding_mode, + return_mask=mask) + + +def forward_backward_consistency_check(fwd_flow, bwd_flow, + alpha=0.01, + beta=0.5 + ): + # fwd_flow, bwd_flow: [B, 2, H, W] + # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) + assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 + assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 + flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] + + warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W] + warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W] + + diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W] + diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) + + threshold = alpha * flow_mag + beta + + fwd_occ = (diff_fwd > threshold).to(fwd_flow.dtype) # [B, H, W] + bwd_occ = (diff_bwd > threshold).to(bwd_flow.dtype) + + return fwd_occ, bwd_occ diff --git a/vsgmfss_fortuna/gmflow/gmflow.py b/vsgmfss_fortuna/gmflow/gmflow.py new file mode 100644 index 0000000..89d33ce --- /dev/null +++ b/vsgmfss_fortuna/gmflow/gmflow.py @@ -0,0 +1,168 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .backbone import CNNEncoder +from .geometry import flow_warp +from .matching import global_correlation_softmax, local_correlation_softmax +from .transformer import FeatureFlowAttention, FeatureTransformer +from .utils import feature_add_position, normalize_img + +torch.fx.wrap('feature_add_position') +torch.fx.wrap('flow_warp') +torch.fx.wrap('global_correlation_softmax') +torch.fx.wrap('local_correlation_softmax') +torch.fx.wrap('normalize_img') + + +class GMFlow(nn.Module): + def __init__(self, + num_scales=2, + upsample_factor=4, + feature_channels=128, + attention_type='swin', + num_transformer_layers=6, + ffn_dim_expansion=4, + num_head=1, + **kwargs, + ): + super(GMFlow, self).__init__() + + self.num_scales = num_scales + self.feature_channels = feature_channels + self.upsample_factor = upsample_factor + self.attention_type = attention_type + self.num_transformer_layers = num_transformer_layers + + # CNN backbone + self.backbone = CNNEncoder(output_dim=feature_channels, num_output_scales=num_scales) + + # Transformer + self.transformer = FeatureTransformer(num_layers=num_transformer_layers, + d_model=feature_channels, + nhead=num_head, + attention_type=attention_type, + ffn_dim_expansion=ffn_dim_expansion, + ) + + # flow propagation with self-attn + self.feature_flow_attn = FeatureFlowAttention(in_channels=feature_channels) + + # convex upsampling: concat feature0 and flow as input + self.upsampler = nn.Sequential(nn.Conv2d(2 + feature_channels, 256, 3, 1, 1), + nn.ReLU(inplace=True), + nn.Conv2d(256, upsample_factor ** 2 * 9, 1, 1, 0)) + + def extract_feature(self, img0, img1): + concat = torch.cat((img0, img1), dim=0) # [2B, C, H, W] + features = self.backbone(concat) # list of [2B, C, H, W], resolution from high to low + + # reverse: resolution from low to high + features = features[::-1] + + feature0, feature1 = [], [] + + for i in range(len(features)): + feature = features[i] + chunks = torch.chunk(feature, 2, 0) # tuple + feature0.append(chunks[0]) + feature1.append(chunks[1]) + + return feature0, feature1 + + def upsample_flow(self, flow, feature, bilinear=False, upsample_factor=8, + ): + if bilinear: + up_flow = F.interpolate(flow, scale_factor=upsample_factor, + mode='bilinear', align_corners=True) * upsample_factor + + else: + # convex upsampling + concat = torch.cat((flow, feature), dim=1) + + mask = self.upsampler(concat) + b, flow_channel, h, w = flow.shape + mask = mask.view(b, 1, 9, self.upsample_factor, self.upsample_factor, h, w) # [B, 1, 9, K, K, H, W] + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(self.upsample_factor * flow, [3, 3], padding=1) + up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w) # [B, 2, 9, 1, 1, H, W] + + up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W] + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W] + up_flow = up_flow.reshape(b, flow_channel, self.upsample_factor * h, + self.upsample_factor * w) # [B, 2, K*H, K*W] + + return up_flow + + def forward(self, img0, img1, + attn_splits_list=[2, 8], + corr_radius_list=[-1, 4], + prop_radius_list=[-1, 1], + pred_bidir_flow=False, + **kwargs, + ): + + img0, img1 = normalize_img(img0, img1) # [B, 3, H, W] + + # resolution low to high + feature0_list, feature1_list = self.extract_feature(img0, img1) # list of features + + flow = None + + assert len(attn_splits_list) == len(corr_radius_list) == len(prop_radius_list) == self.num_scales + + for scale_idx in range(self.num_scales): + feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx] + + if pred_bidir_flow and scale_idx > 0: + # predicting bidirectional flow with refinement + feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat((feature1, feature0), dim=0) + + upsample_factor = self.upsample_factor * (2 ** (self.num_scales - 1 - scale_idx)) + + if scale_idx > 0: + flow = F.interpolate(flow, scale_factor=2, mode='bilinear', align_corners=True) * 2 + + if flow is not None: + flow = flow.detach() + feature1 = flow_warp(feature1, flow) # [B, C, H, W] + + attn_splits = attn_splits_list[scale_idx] + corr_radius = corr_radius_list[scale_idx] + prop_radius = prop_radius_list[scale_idx] + + # add position to features + feature0, feature1 = feature_add_position(feature0, feature1, attn_splits, self.feature_channels) + + # Transformer + feature0, feature1 = self.transformer(feature0, feature1, attn_num_splits=attn_splits) + + # correlation and softmax + if corr_radius == -1: # global matching + flow_pred = global_correlation_softmax(feature0, feature1, pred_bidir_flow)[0] + else: # local matching + flow_pred = local_correlation_softmax(feature0, feature1, corr_radius)[0] + + # flow or residual flow + flow = flow + flow_pred if flow is not None else flow_pred + + # upsample to the original resolution for supervison + if self.training: # only need to upsample intermediate flow predictions at training time + flow_bilinear = self.upsample_flow(flow, None, bilinear=True, upsample_factor=upsample_factor) + + # flow propagation with self-attn + if pred_bidir_flow and scale_idx == 0: + feature0 = torch.cat((feature0, feature1), dim=0) # [2*B, C, H, W] for propagation + flow = self.feature_flow_attn(feature0, flow.detach(), + local_window_attn=prop_radius > 0, + local_window_radius=prop_radius) + + # bilinear upsampling at training time except the last one + if self.training and scale_idx < self.num_scales - 1: + flow_up = self.upsample_flow(flow, feature0, bilinear=True, upsample_factor=upsample_factor) + + if scale_idx == self.num_scales - 1: + flow_up = self.upsample_flow(flow, feature0) + + return flow_up diff --git a/vsgmfss_fortuna/gmflow/matching.py b/vsgmfss_fortuna/gmflow/matching.py new file mode 100644 index 0000000..a69fc5a --- /dev/null +++ b/vsgmfss_fortuna/gmflow/matching.py @@ -0,0 +1,83 @@ +import torch +import torch.nn.functional as F + +from .geometry import coords_grid, generate_window_grid, normalize_coords + + +def global_correlation_softmax(feature0, feature1, + pred_bidir_flow=False, + ): + # global correlation + b, c, h, w = feature0.shape + feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C] + feature1 = feature1.view(b, c, -1) # [B, C, H*W] + + correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (c ** 0.5) # [B, H, W, H, W] + + # flow from softmax + init_grid = coords_grid(b, h, w).to(correlation) # [B, 2, H, W] + grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] + + correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W] + + if pred_bidir_flow: + correlation = torch.cat((correlation, correlation.permute(0, 2, 1)), dim=0) # [2*B, H*W, H*W] + init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, 2, H, W] + grid = grid.repeat(2, 1, 1) # [2*B, H*W, 2] + b = b * 2 + + prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W] + + correspondence = torch.matmul(prob, grid).view(b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W] + + # when predicting bidirectional flow, flow is the concatenation of forward flow and backward flow + flow = correspondence - init_grid + + return flow, prob + + +def local_correlation_softmax(feature0, feature1, local_radius, + padding_mode='zeros', + ): + b, c, h, w = feature0.size() + coords_init = coords_grid(b, h, w).to(feature0) # [B, 2, H, W] + coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] + + local_h = 2 * local_radius + 1 + local_w = 2 * local_radius + 1 + + window_grid = generate_window_grid(-local_radius, local_radius, + -local_radius, local_radius, + local_h, local_w, device=feature0.device).to(feature0.dtype) # [2R+1, 2R+1, 2] + window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2] + sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1)^2, 2] + + sample_coords_softmax = sample_coords + + # exclude coords that are out of image space + valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [B, H*W, (2R+1)^2] + valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [B, H*W, (2R+1)^2] + + valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax + + # normalize coordinates to [-1, 1] + sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1] + window_feature = F.grid_sample(feature1, sample_coords_norm, + padding_mode=padding_mode, align_corners=True + ).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)^2] + feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C] + + corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)^2] + + # mask invalid locations + corr[~valid] = -1e9 if feature0.dtype == torch.float else -1e4 + + prob = F.softmax(corr, -1) # [B, H*W, (2R+1)^2] + + correspondence = torch.matmul(prob.unsqueeze(-2), sample_coords_softmax).squeeze(-2).view( + b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W] + + flow = correspondence - coords_init + match_prob = prob + + return flow, match_prob diff --git a/vsgmfss_fortuna/gmflow/position.py b/vsgmfss_fortuna/gmflow/position.py new file mode 100644 index 0000000..a84b8b8 --- /dev/null +++ b/vsgmfss_fortuna/gmflow/position.py @@ -0,0 +1,47 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py + +import math + +import torch +import torch.nn as nn + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x): + # x = tensor_list.tensors # [B, C, H, W] + # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0 + b, c, h, w = x.size() + mask = x.new_ones((b, h, w), dtype=torch.float32) # [B, H, W] + y_embed = mask.cumsum(1) + x_embed = mask.cumsum(2) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2).to(x.dtype) + return pos diff --git a/vsgmfss_fortuna/gmflow/transformer.py b/vsgmfss_fortuna/gmflow/transformer.py new file mode 100644 index 0000000..bd11074 --- /dev/null +++ b/vsgmfss_fortuna/gmflow/transformer.py @@ -0,0 +1,409 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils import merge_splits, split_feature + + +def single_head_full_attention(q, k, v): + # q, k, v: [B, L, C] + assert q.dim() == k.dim() == v.dim() == 3 + + scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** .5) # [B, L, L] + attn = torch.softmax(scores, dim=2) # [B, L, L] + out = torch.matmul(attn, v) # [B, L, C] + + return out + + +def generate_shift_window_attn_mask(input_resolution, window_size_h, window_size_w, + shift_size_h, shift_size_w, device=torch.device('cuda')): + # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + # calculate attention mask for SW-MSA + h, w = input_resolution + img_mask = torch.zeros((1, h, w, 1), device=device) # 1 H W 1 + h_slices = (slice(0, -window_size_h), + slice(-window_size_h, -shift_size_h), + slice(-shift_size_h, None)) + w_slices = (slice(0, -window_size_w), + slice(-window_size_w, -shift_size_w), + slice(-shift_size_w, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = split_feature(img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True) + + mask_windows = mask_windows.view(-1, window_size_h * window_size_w) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + +def single_head_split_window_attention(q, k, v, + num_splits=1, + with_shift=False, + h=None, + w=None, + attn_mask=None, + ): + # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + # q, k, v: [B, L, C] + assert q.dim() == k.dim() == v.dim() == 3 + + assert h is not None and w is not None + assert q.size(1) == h * w + + b, _, c = q.size() + + b_new = b * num_splits * num_splits + + window_size_h = h // num_splits + window_size_w = w // num_splits + + q = q.view(b, h, w, c) # [B, H, W, C] + k = k.view(b, h, w, c) + v = v.view(b, h, w, c) + + scale_factor = c ** 0.5 + + if with_shift: + assert attn_mask is not None # compute once + shift_size_h = window_size_h // 2 + shift_size_w = window_size_w // 2 + + q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + + q = split_feature(q, num_splits=num_splits, channel_last=True) # [B*K*K, H/K, W/K, C] + k = split_feature(k, num_splits=num_splits, channel_last=True) + v = split_feature(v, num_splits=num_splits, channel_last=True) + + scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1) + ) / scale_factor # [B*K*K, H/K*W/K, H/K*W/K] + + if with_shift: + scores += attn_mask.repeat(b, 1, 1) + + attn = torch.softmax(scores, dim=-1) + + out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C] + + out = merge_splits(out.view(b_new, h // num_splits, w // num_splits, c), + num_splits=num_splits, channel_last=True) # [B, H, W, C] + + # shift back + if with_shift: + out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2)) + + out = out.view(b, -1, c) + + return out + + +class TransformerLayer(nn.Module): + def __init__(self, + d_model=256, + nhead=1, + attention_type='swin', + no_ffn=False, + ffn_dim_expansion=4, + with_shift=False, + **kwargs, + ): + super(TransformerLayer, self).__init__() + + self.dim = d_model + self.nhead = nhead + self.attention_type = attention_type + self.no_ffn = no_ffn + + self.with_shift = with_shift + + # multi-head attention + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, d_model, bias=False) + + self.merge = nn.Linear(d_model, d_model, bias=False) + + self.norm1 = nn.LayerNorm(d_model) + + # no ffn after self-attn, with ffn after cross-attn + if not self.no_ffn: + in_channels = d_model * 2 + self.mlp = nn.Sequential( + nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False), + nn.GELU(), + nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False), + ) + + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, source, target, + height=None, + width=None, + shifted_window_attn_mask=None, + attn_num_splits=None, + **kwargs, + ): + # source, target: [B, L, C] + query, key, value = source, target, target + + # single-head attention + query = self.q_proj(query) # [B, L, C] + key = self.k_proj(key) # [B, L, C] + value = self.v_proj(value) # [B, L, C] + + if self.attention_type == 'swin' and attn_num_splits > 1: + if self.nhead > 1: + # we observe that multihead attention slows down the speed and increases the memory consumption + # without bringing obvious performance gains and thus the implementation is removed + raise NotImplementedError + else: + message = single_head_split_window_attention(query, key, value, + num_splits=attn_num_splits, + with_shift=self.with_shift, + h=height, + w=width, + attn_mask=shifted_window_attn_mask, + ) + else: + message = single_head_full_attention(query, key, value) # [B, L, C] + + message = self.merge(message) # [B, L, C] + message = self.norm1(message) + + if not self.no_ffn: + message = self.mlp(torch.cat([source, message], dim=-1)) + message = self.norm2(message) + + return source + message + + +class TransformerBlock(nn.Module): + """self attention + cross attention + FFN""" + + def __init__(self, + d_model=256, + nhead=1, + attention_type='swin', + ffn_dim_expansion=4, + with_shift=False, + **kwargs, + ): + super(TransformerBlock, self).__init__() + + self.self_attn = TransformerLayer(d_model=d_model, + nhead=nhead, + attention_type=attention_type, + no_ffn=True, + ffn_dim_expansion=ffn_dim_expansion, + with_shift=with_shift, + ) + + self.cross_attn_ffn = TransformerLayer(d_model=d_model, + nhead=nhead, + attention_type=attention_type, + ffn_dim_expansion=ffn_dim_expansion, + with_shift=with_shift, + ) + + def forward(self, source, target, + height=None, + width=None, + shifted_window_attn_mask=None, + attn_num_splits=None, + **kwargs, + ): + # source, target: [B, L, C] + + # self attention + source = self.self_attn(source, source, + height=height, + width=width, + shifted_window_attn_mask=shifted_window_attn_mask, + attn_num_splits=attn_num_splits, + ) + + # cross attention and ffn + source = self.cross_attn_ffn(source, target, + height=height, + width=width, + shifted_window_attn_mask=shifted_window_attn_mask, + attn_num_splits=attn_num_splits, + ) + + return source + + +class FeatureTransformer(nn.Module): + def __init__(self, + num_layers=6, + d_model=128, + nhead=1, + attention_type='swin', + ffn_dim_expansion=4, + **kwargs, + ): + super(FeatureTransformer, self).__init__() + + self.attention_type = attention_type + + self.d_model = d_model + self.nhead = nhead + + self.layers = nn.ModuleList([ + TransformerBlock(d_model=d_model, + nhead=nhead, + attention_type=attention_type, + ffn_dim_expansion=ffn_dim_expansion, + with_shift=True if attention_type == 'swin' and i % 2 == 1 else False, + ) + for i in range(num_layers)]) + + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feature0, feature1, + attn_num_splits=None, + **kwargs, + ): + + b, c, h, w = feature0.shape + assert self.d_model == c + + feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C] + feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C] + + if self.attention_type == 'swin' and attn_num_splits > 1: + # global and refine use different number of splits + window_size_h = h // attn_num_splits + window_size_w = w // attn_num_splits + + # compute attn mask once + shifted_window_attn_mask = generate_shift_window_attn_mask( + input_resolution=(h, w), + window_size_h=window_size_h, + window_size_w=window_size_w, + shift_size_h=window_size_h // 2, + shift_size_w=window_size_w // 2, + device=feature0.device, + ) # [K*K, H/K*W/K, H/K*W/K] + else: + shifted_window_attn_mask = None + + # concat feature0 and feature1 in batch dimension to compute in parallel + concat0 = torch.cat((feature0, feature1), dim=0) # [2B, H*W, C] + concat1 = torch.cat((feature1, feature0), dim=0) # [2B, H*W, C] + + for layer in self.layers: + concat0 = layer(concat0, concat1, + height=h, + width=w, + shifted_window_attn_mask=shifted_window_attn_mask, + attn_num_splits=attn_num_splits, + ) + + # update feature1 + concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0) + + feature0, feature1 = concat0.chunk(chunks=2, dim=0) # [B, H*W, C] + + # reshape back + feature0 = feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W] + feature1 = feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W] + + return feature0, feature1 + + +class FeatureFlowAttention(nn.Module): + """ + flow propagation with self-attention on feature + query: feature0, key: feature0, value: flow + """ + + def __init__(self, in_channels, + **kwargs, + ): + super(FeatureFlowAttention, self).__init__() + + self.q_proj = nn.Linear(in_channels, in_channels) + self.k_proj = nn.Linear(in_channels, in_channels) + + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feature0, flow, + local_window_attn=False, + local_window_radius=1, + **kwargs, + ): + # q, k: feature [B, C, H, W], v: flow [B, 2, H, W] + if local_window_attn: + return self.forward_local_window_attn(feature0, flow, + local_window_radius=local_window_radius) + + b, c, h, w = feature0.size() + + query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C] + + # a note: the ``correct'' implementation should be: + # ``query = self.q_proj(query), key = self.k_proj(query)'' + # this problem is observed while cleaning up the code + # however, this doesn't affect the performance since the projection is a linear operation, + # thus the two projection matrices for key can be merged + # so I just leave it as is in order to not re-train all models :) + query = self.q_proj(query) # [B, H*W, C] + key = self.k_proj(query) # [B, H*W, C] + + value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2] + + scores = torch.matmul(query, key.permute(0, 2, 1)) / (c ** 0.5) # [B, H*W, H*W] + prob = torch.softmax(scores, dim=-1) + + out = torch.matmul(prob, value) # [B, H*W, 2] + out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W] + + return out + + def forward_local_window_attn(self, feature0, flow, + local_window_radius=1, + ): + # assert flow.size(1) == 2 + assert local_window_radius > 0 + + b, c, h, w = feature0.size() + + feature0_reshape = self.q_proj(feature0.view(b, c, -1).permute(0, 2, 1) + ).reshape(b * h * w, 1, c) # [B*H*W, 1, C] + + kernel_size = 2 * local_window_radius + 1 + + feature0_proj = self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1)).permute(0, 2, 1).reshape(b, c, h, w) + + feature0_window = F.unfold(feature0_proj, kernel_size=kernel_size, + padding=local_window_radius) # [B, C*(2R+1)^2), H*W] + + feature0_window = feature0_window.view(b, c, kernel_size ** 2, h, w).permute( + 0, 3, 4, 1, 2).reshape(b * h * w, c, kernel_size ** 2) # [B*H*W, C, (2R+1)^2] + + flow_window = F.unfold(flow, kernel_size=kernel_size, + padding=local_window_radius) # [B, 2*(2R+1)^2), H*W] + + flow_window = flow_window.view(b, 2, kernel_size ** 2, h, w).permute( + 0, 3, 4, 2, 1).reshape(b * h * w, kernel_size ** 2, 2) # [B*H*W, (2R+1)^2, 2] + + scores = torch.matmul(feature0_reshape, feature0_window) / (c ** 0.5) # [B*H*W, 1, (2R+1)^2] + + prob = torch.softmax(scores, dim=-1) + + out = torch.matmul(prob, flow_window).view(b, h, w, 2).permute(0, 3, 1, 2).contiguous() # [B, 2, H, W] + + return out diff --git a/vsgmfss_fortuna/gmflow/trident_conv.py b/vsgmfss_fortuna/gmflow/trident_conv.py new file mode 100644 index 0000000..166b1b9 --- /dev/null +++ b/vsgmfss_fortuna/gmflow/trident_conv.py @@ -0,0 +1,90 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.modules.utils import _pair + + +class MultiScaleTridentConv(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + strides=1, + paddings=0, + dilations=1, + dilation=1, + groups=1, + num_branch=1, + test_branch_idx=-1, + bias=False, + norm=None, + activation=None, + ): + super(MultiScaleTridentConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.num_branch = num_branch + self.stride = _pair(stride) + self.groups = groups + self.with_bias = bias + self.dilation = _pair(dilation) + if isinstance(paddings, int): + paddings = [paddings] * self.num_branch + if isinstance(dilations, int): + dilations = [dilations] * self.num_branch + if isinstance(strides, int): + strides = [strides] * self.num_branch + self.paddings = [_pair(padding) for padding in paddings] + self.dilations = [_pair(dilation) for dilation in dilations] + self.strides = [_pair(stride) for stride in strides] + self.test_branch_idx = test_branch_idx + self.norm = norm + self.activation = activation + + assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1 + + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) + ) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.bias = None + + nn.init.kaiming_uniform_(self.weight, nonlinearity="relu") + if self.bias is not None: + nn.init.constant_(self.bias, 0) + + def forward(self, inputs): + num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1 + assert len(inputs) == num_branch + + if self.training or self.test_branch_idx == -1: + outputs = [ + F.conv2d(input, self.weight, self.bias, stride, padding, self.dilation, self.groups) + for input, stride, padding in zip(inputs, self.strides, self.paddings) + ] + else: + outputs = [ + F.conv2d( + inputs[0], + self.weight, + self.bias, + self.strides[self.test_branch_idx] if self.test_branch_idx == -1 else self.strides[-1], + self.paddings[self.test_branch_idx] if self.test_branch_idx == -1 else self.paddings[-1], + self.dilation, + self.groups, + ) + ] + + if self.norm is not None: + outputs = [self.norm(x) for x in outputs] + if self.activation is not None: + outputs = [self.activation(x) for x in outputs] + return outputs diff --git a/vsgmfss_fortuna/gmflow/utils.py b/vsgmfss_fortuna/gmflow/utils.py new file mode 100644 index 0000000..4dd861e --- /dev/null +++ b/vsgmfss_fortuna/gmflow/utils.py @@ -0,0 +1,87 @@ +import torch + +from .position import PositionEmbeddingSine + + +def split_feature(feature, + num_splits=2, + channel_last=False, + ): + if channel_last: # [B, H, W, C] + b, h, w, c = feature.size() + assert h % num_splits == 0 and w % num_splits == 0 + + b_new = b * num_splits * num_splits + h_new = h // num_splits + w_new = w // num_splits + + feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c + ).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c) # [B*K*K, H/K, W/K, C] + else: # [B, C, H, W] + b, c, h, w = feature.size() + assert h % num_splits == 0 and w % num_splits == 0 + + b_new = b * num_splits * num_splits + h_new = h // num_splits + w_new = w // num_splits + + feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits + ).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new) # [B*K*K, C, H/K, W/K] + + return feature + + +def merge_splits(splits, + num_splits=2, + channel_last=False, + ): + if channel_last: # [B*K*K, H/K, W/K, C] + b, h, w, c = splits.size() + new_b = b // num_splits // num_splits + + splits = splits.view(new_b, num_splits, num_splits, h, w, c) + merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view( + new_b, num_splits * h, num_splits * w, c) # [B, H, W, C] + else: # [B*K*K, C, H/K, W/K] + b, c, h, w = splits.size() + new_b = b // num_splits // num_splits + + splits = splits.view(new_b, num_splits, num_splits, c, h, w) + merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view( + new_b, c, num_splits * h, num_splits * w) # [B, C, H, W] + + return merge + + +def normalize_img(img0, img1): + # loaded images are in [0, 255] + # normalize by ImageNet mean and std + mean = torch.tensor([0.485, 0.456, 0.406], dtype=img1.dtype, device=img1.device).view(1, 3, 1, 1) + std = torch.tensor([0.229, 0.224, 0.225], dtype=img1.dtype, device=img1.device).view(1, 3, 1, 1) + img0 = (img0 - mean) / std + img1 = (img1 - mean) / std + + return img0, img1 + + +def feature_add_position(feature0, feature1, attn_splits, feature_channels): + pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2) + + if attn_splits > 1: # add position in splited window + feature0_splits = split_feature(feature0, num_splits=attn_splits) + feature1_splits = split_feature(feature1, num_splits=attn_splits) + + position = pos_enc(feature0_splits) + + feature0_splits = feature0_splits + position + feature1_splits = feature1_splits + position + + feature0 = merge_splits(feature0_splits, num_splits=attn_splits) + feature1 = merge_splits(feature1_splits, num_splits=attn_splits) + else: + position = pos_enc(feature0) + + feature0 = feature0 + position + feature1 = feature1 + position + + return feature0, feature1 diff --git a/vsgmfss_fortuna/models/feat_base.pkl b/vsgmfss_fortuna/models/feat_base.pkl new file mode 100644 index 0000000..3808555 Binary files /dev/null and b/vsgmfss_fortuna/models/feat_base.pkl differ diff --git a/vsgmfss_fortuna/models/feat_union.pkl b/vsgmfss_fortuna/models/feat_union.pkl new file mode 100644 index 0000000..5d7b84a Binary files /dev/null and b/vsgmfss_fortuna/models/feat_union.pkl differ diff --git a/vsgmfss_fortuna/models/flownet.pkl b/vsgmfss_fortuna/models/flownet.pkl new file mode 100644 index 0000000..c25668b Binary files /dev/null and b/vsgmfss_fortuna/models/flownet.pkl differ diff --git a/vsgmfss_fortuna/models/fusionnet_base.pkl b/vsgmfss_fortuna/models/fusionnet_base.pkl new file mode 100644 index 0000000..fc424ba Binary files /dev/null and b/vsgmfss_fortuna/models/fusionnet_base.pkl differ diff --git a/vsgmfss_fortuna/models/fusionnet_union.pkl b/vsgmfss_fortuna/models/fusionnet_union.pkl new file mode 100644 index 0000000..6e9d468 Binary files /dev/null and b/vsgmfss_fortuna/models/fusionnet_union.pkl differ diff --git a/vsgmfss_fortuna/models/metric_base.pkl b/vsgmfss_fortuna/models/metric_base.pkl new file mode 100644 index 0000000..37e2089 Binary files /dev/null and b/vsgmfss_fortuna/models/metric_base.pkl differ diff --git a/vsgmfss_fortuna/models/metric_union.pkl b/vsgmfss_fortuna/models/metric_union.pkl new file mode 100644 index 0000000..8de672f Binary files /dev/null and b/vsgmfss_fortuna/models/metric_union.pkl differ diff --git a/vsgmfss_fortuna/models/rife.pkl b/vsgmfss_fortuna/models/rife.pkl new file mode 100644 index 0000000..0da5211 Binary files /dev/null and b/vsgmfss_fortuna/models/rife.pkl differ diff --git a/vsgmfss_fortuna/softsplat.py b/vsgmfss_fortuna/softsplat.py new file mode 100644 index 0000000..fbaf58e --- /dev/null +++ b/vsgmfss_fortuna/softsplat.py @@ -0,0 +1,534 @@ +#!/usr/bin/env python + +import collections +import os +import re +import typing + +import cupy +import torch + +########################################################## + + +objCudacache = {} + + +def cuda_int32(intIn:int): + return cupy.int32(intIn) +# end + + +def cuda_float32(fltIn:float): + return cupy.float32(fltIn) +# end + + +def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict): + if 'device' not in objCudacache: + objCudacache['device'] = torch.cuda.get_device_name() + # end + + strKey = strFunction + + for strVariable in objVariables: + objValue = objVariables[strVariable] + + strKey += strVariable + + if objValue is None: + continue + + elif type(objValue) == int: + strKey += str(objValue) + + elif type(objValue) == float: + strKey += str(objValue) + + elif type(objValue) == bool: + strKey += str(objValue) + + elif type(objValue) == str: + strKey += objValue + + elif type(objValue) == torch.Tensor: + strKey += str(objValue.dtype) + strKey += str(objValue.shape) + strKey += str(objValue.stride()) + + elif True: + print(strVariable, type(objValue)) + assert(False) + + # end + # end + + strKey += objCudacache['device'] + + if strKey not in objCudacache: + for strVariable in objVariables: + objValue = objVariables[strVariable] + + if objValue is None: + continue + + elif type(objValue) == int: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == float: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == bool: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == str: + strKernel = strKernel.replace('{{' + strVariable + '}}', objValue) + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8: + strKernel = strKernel.replace('{{type}}', 'unsigned char') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16: + strKernel = strKernel.replace('{{type}}', 'half') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32: + strKernel = strKernel.replace('{{type}}', 'float') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64: + strKernel = strKernel.replace('{{type}}', 'double') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32: + strKernel = strKernel.replace('{{type}}', 'int') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64: + strKernel = strKernel.replace('{{type}}', 'long') + + elif type(objValue) == torch.Tensor: + print(strVariable, objValue.dtype) + assert(False) + + elif True: + print(strVariable, type(objValue)) + assert(False) + + # end + # end + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + # end + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item())) + # end + + while True: + objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == '(' else 0 + intParentheses -= 1 if strKernel[intStop] == ')' else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(',') + + assert(intArgs == len(strArgs) - 1) + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')') + # end + + strKernel = strKernel.replace('OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', '(' + str.join('+', strIndex) + ')') + # end + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()', strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == '(' else 0 + intParentheses -= 1 if strKernel[intStop] == ')' else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(',') + + assert(intArgs == len(strArgs) - 1) + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')') + # end + + strKernel = strKernel.replace('VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', strTensor + '[' + str.join('+', strIndex) + ']') + # end + + objCudacache[strKey] = { + 'strFunction': strFunction, + 'strKernel': strKernel + } + # end + + return strKey +# end + + +@cupy.memoize(for_each_device=True) +def cuda_launch(strKey:str): + if 'CUDA_HOME' not in os.environ: + os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path() + # end + + return cupy.cuda.compile_with_cache(objCudacache[strKey]['strKernel'], tuple(['-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include'])).get_function(objCudacache[strKey]['strFunction']) +# end + + +########################################################## + + +def softsplat(tenIn:torch.Tensor, tenFlow:torch.Tensor, tenMetric:torch.Tensor, strMode:str): + assert(strMode.split('-')[0] in ['sum', 'avg', 'linear', 'soft']) + + if strMode == 'sum': assert(tenMetric is None) + if strMode == 'avg': assert(tenMetric is None) + if strMode.split('-')[0] == 'linear': assert(tenMetric is not None) + if strMode.split('-')[0] == 'soft': assert(tenMetric is not None) + + orig_dtype = tenIn.dtype + tenIn = tenIn.float() + tenFlow = tenFlow.float() + tenMetric = tenMetric.float() + + if strMode == 'avg': + tenIn = torch.cat([tenIn, tenIn.new_ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]])], 1) + + elif strMode.split('-')[0] == 'linear': + tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1) + + elif strMode.split('-')[0] == 'soft': + tenIn = torch.cat([tenIn * tenMetric.exp(), tenMetric.exp()], 1) + + # end + + tenOut = softsplat_func.apply(tenIn, tenFlow) + + if strMode.split('-')[0] in ['avg', 'linear', 'soft']: + tenNormalize = tenOut[:, -1:, :, :] + + if len(strMode.split('-')) == 1: + tenNormalize = tenNormalize + 0.0000001 + + elif strMode.split('-')[1] == 'addeps': + tenNormalize = tenNormalize + 0.0000001 + + elif strMode.split('-')[1] == 'zeroeps': + tenNormalize[tenNormalize == 0.0] = 1.0 + + elif strMode.split('-')[1] == 'clipeps': + tenNormalize = tenNormalize.clip(0.0000001, None) + + # end + + tenOut = tenOut[:, :-1, :, :] / tenNormalize + # end + + return tenOut.to(orig_dtype) +# end + + +class softsplat_func(torch.autograd.Function): + @staticmethod + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward(self, tenIn, tenFlow): + tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) + + if tenIn.is_cuda == True: + cuda_launch(cuda_kernel('softsplat_out', ''' + extern "C" __global__ void __launch_bounds__(512) softsplat_out( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + {{type}}* __restrict__ tenOut + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) / SIZE_1(tenOut) ) % SIZE_0(tenOut); + const int intC = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_1(tenOut); + const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut); + const int intX = ( intIndex ) % SIZE_3(tenOut); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + {{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX); + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); + {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); + {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); + {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn * fltNorthwest); + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn * fltNortheast); + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn * fltSouthwest); + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn * fltSoutheast); + } + } } + ''', { + 'tenIn': tenIn, + 'tenFlow': tenFlow, + 'tenOut': tenOut + }))( + grid=tuple([int((tenOut.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenOut.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOut.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + + elif tenIn.is_cuda != True: + assert(False) + + # end + + self.save_for_backward(tenIn, tenFlow) + + return tenOut + # end + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(self, tenOutgrad): + tenIn, tenFlow = self.saved_tensors + + tenOutgrad = tenOutgrad.contiguous(); assert(tenOutgrad.is_cuda == True) + + tenIngrad = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) if self.needs_input_grad[0] == True else None + tenFlowgrad = tenFlow.new_zeros([tenFlow.shape[0], tenFlow.shape[1], tenFlow.shape[2], tenFlow.shape[3]]) if self.needs_input_grad[1] == True else None + + if tenIngrad is not None: + cuda_launch(cuda_kernel('softsplat_ingrad', ''' + extern "C" __global__ void __launch_bounds__(512) softsplat_ingrad( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenIngrad, + {{type}}* __restrict__ tenFlowgrad + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad); + const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) ) % SIZE_1(tenIngrad); + const int intY = ( intIndex / SIZE_3(tenIngrad) ) % SIZE_2(tenIngrad); + const int intX = ( intIndex ) % SIZE_3(tenIngrad); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltIngrad = 0.0f; + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); + {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); + {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); + {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest; + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNortheastY, intNortheastX) * fltNortheast; + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest; + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast; + } + + tenIngrad[intIndex] = fltIngrad; + } } + ''', { + 'tenIn': tenIn, + 'tenFlow': tenFlow, + 'tenOutgrad': tenOutgrad, + 'tenIngrad': tenIngrad, + 'tenFlowgrad': tenFlowgrad + }))( + grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenIngrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), tenIngrad.data_ptr(), None], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + # end + + if tenFlowgrad is not None: + cuda_launch(cuda_kernel('softsplat_flowgrad', ''' + extern "C" __global__ void __launch_bounds__(512) softsplat_flowgrad( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenIngrad, + {{type}}* __restrict__ tenFlowgrad + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) / SIZE_1(tenFlowgrad) ) % SIZE_0(tenFlowgrad); + const int intC = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) ) % SIZE_1(tenFlowgrad); + const int intY = ( intIndex / SIZE_3(tenFlowgrad) ) % SIZE_2(tenFlowgrad); + const int intX = ( intIndex ) % SIZE_3(tenFlowgrad); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltFlowgrad = 0.0f; + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = 0.0f; + {{type}} fltNortheast = 0.0f; + {{type}} fltSouthwest = 0.0f; + {{type}} fltSoutheast = 0.0f; + + if (intC == 0) { + fltNorthwest = (({{type}}) (-1.0f)) * (({{type}}) (intSoutheastY) - fltY); + fltNortheast = (({{type}}) (+1.0f)) * (({{type}}) (intSouthwestY) - fltY); + fltSouthwest = (({{type}}) (-1.0f)) * (fltY - ({{type}}) (intNortheastY)); + fltSoutheast = (({{type}}) (+1.0f)) * (fltY - ({{type}}) (intNorthwestY)); + + } else if (intC == 1) { + fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (-1.0f)); + fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (-1.0f)); + fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (({{type}}) (+1.0f)); + fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (({{type}}) (+1.0f)); + + } + + for (int intChannel = 0; intChannel < SIZE_1(tenOutgrad); intChannel += 1) { + {{type}} fltIn = VALUE_4(tenIn, intN, intChannel, intY, intX); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNorthwestY, intNorthwestX) * fltIn * fltNorthwest; + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNortheastY, intNortheastX) * fltIn * fltNortheast; + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSouthwestY, intSouthwestX) * fltIn * fltSouthwest; + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSoutheastY, intSoutheastX) * fltIn * fltSoutheast; + } + } + + tenFlowgrad[intIndex] = fltFlowgrad; + } } + ''', { + 'tenIn': tenIn, + 'tenFlow': tenFlow, + 'tenOutgrad': tenOutgrad, + 'tenIngrad': tenIngrad, + 'tenFlowgrad': tenFlowgrad + }))( + grid=tuple([int((tenFlowgrad.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenFlowgrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), None, tenFlowgrad.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + # end + + return tenIngrad, tenFlowgrad + # end +# end diff --git a/vsgmfss_fortuna/util.py b/vsgmfss_fortuna/util.py new file mode 100644 index 0000000..f43da96 --- /dev/null +++ b/vsgmfss_fortuna/util.py @@ -0,0 +1,27 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.parameter import Parameter + + +class MyPixelShuffle(nn.Module): + def __init__(self, upscale_factor): + super(MyPixelShuffle, self).__init__() + self.upscale_factor = upscale_factor + + def forward(self, input): + b, c, hh, hw = input.size() + out_channel = c // (self.upscale_factor**2) + h = hh * self.upscale_factor + w = hw * self.upscale_factor + x_view = input.view(b, out_channel, self.upscale_factor, self.upscale_factor, hh, hw) + return x_view.permute(0, 1, 4, 2, 5, 3).reshape(b, out_channel, h, w) + + +class MyPReLU(nn.Module): + def __init__(self, num_parameters=1, init=0.25): + super(MyPReLU, self).__init__() + self.weight = Parameter(torch.empty(num_parameters).fill_(init)) + + def forward(self, input): + return F.relu(input) - self.weight.reshape(1, -1, 1, 1) * F.relu(-input) diff --git a/vsgmfss_fortuna/warplayer.py b/vsgmfss_fortuna/warplayer.py new file mode 100644 index 0000000..3566f92 --- /dev/null +++ b/vsgmfss_fortuna/warplayer.py @@ -0,0 +1,24 @@ +import torch + +backwarp_tenGrid = {} + + +def warp(tenInput, tenFlow): + orig_dtype = tenInput.dtype + tenInput = tenInput.float() + tenFlow = tenFlow.float() + + k = (str(tenFlow.device), str(tenFlow.size())) + if k not in backwarp_tenGrid: + tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], dtype=torch.float, device=tenFlow.device).view( + 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) + tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], dtype=torch.float, device=tenFlow.device).view( + 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) + backwarp_tenGrid[k] = torch.cat( + [tenHorizontal, tenVertical], 1) + + tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), + tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1) + + g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) + return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True).to(orig_dtype)