Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
HolyWu committed Jul 25, 2021
1 parent 8ae6737 commit 9992f9b
Show file tree
Hide file tree
Showing 12 changed files with 1,392 additions and 2 deletions.
2 changes: 2 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Auto detect text files and perform LF normalization
* text=auto
13 changes: 11 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
Expand Down Expand Up @@ -50,6 +49,7 @@ coverage.xml
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
Expand All @@ -72,6 +72,7 @@ instance/
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
Expand All @@ -82,7 +83,9 @@ profile_default/
ipython_config.py

# pyenv
.python-version
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
Expand Down Expand Up @@ -127,3 +130,9 @@ dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/
27 changes: 27 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# DPIR
DPIR: Deep Plug-and-Play Image Restoration

Ported from https://github.com/cszn/DPIR


## Dependencies
- [NumPy](https://numpy.org/install)
- [PyTorch](https://pytorch.org/get-started), preferably with CUDA. Note that `torchvision` and `torchaudio` are not required and hence can be omitted from the command.
- [VapourSynth](http://www.vapoursynth.com/)


## Installation
```
pip install --upgrade vsdpir
python -m vsdpir
```


## Usage
```python
from vsdpir import DPIR

ret = DPIR(clip)
```

See `__init__.py` for the description of the parameters.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"
27 changes: 27 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
[metadata]
name = vsdpir
version = 1.0.0
author = HolyWu
description = DPIR function for VapourSynth
long_description = file: README.md
long_description_content_type = text/markdown
url = https://github.com/HolyWu/vs-dpir
classifiers =
License :: OSI Approved :: MIT License
Operating System :: OS Independent
Programming Language :: Python :: 3
Programming Language :: Python :: 3 :: Only
Topic :: Multimedia :: Video

[options]
zip_safe = False
packages = vsdpir
python_requires = >=3.6
install_requires =
numpy
requests
torch
tqdm

[options.package_data]
vsdpir = *.pth
105 changes: 105 additions & 0 deletions vsdpir/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import numpy as np
import os.path
import torch
import vapoursynth as vs
from . import utils_model
from .network_unet import UNetRes as net


def DPIR(clip: vs.VideoNode, strength: float=None, task: str='denoise', device_type: str='cuda', device_index: int=0) -> vs.VideoNode:
'''
DPIR: Deep Plug-and-Play Image Restoration
Parameters:
clip: Clip to process. Only planar format with float sample type of 32 bit depth is supported.
strength: Strength for deblocking or denoising. Must be greater than 0. Defaults to 50.0 for 'deblock' task, 5.0 for 'denoise' task.
task: Task to perform. Must be 'deblock' or 'denoise'.
device_type: Device type on which the tensor is allocated. Must be 'cuda' or 'cpu'.
device_index: Device ordinal for the device type.
'''
if not isinstance(clip, vs.VideoNode):
raise vs.Error('DPIR: This is not a clip')

if clip.format.id != vs.RGBS:
raise vs.Error('DPIR: Only RGBS format is supported')

if strength is not None and strength <= 0:
raise vs.Error('DPIR: strength must be greater than 0')

task = task.lower()
device_type = device_type.lower()

if task not in ['deblock', 'denoise']:
raise vs.Error("DPIR: task must be 'deblock' or 'denoise'")

if device_type not in ['cuda', 'cpu']:
raise vs.Error("DPIR: device_type must be 'cuda' or 'cpu'")

if device_type == 'cuda' and not torch.cuda.is_available():
raise vs.Error('DPIR: CUDA is not available')

if task == 'deblock':
if strength is None:
strength = 50.0
strength /= 100
model_name = 'drunet_deblocking_color.pth'
else:
if strength is None:
strength = 5.0
strength /= 255
model_name = 'drunet_color.pth'

model_path = os.path.join(os.path.dirname(__file__), model_name)

device = torch.device(device_type, device_index)

model = net(in_nc=4, out_nc=3, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode='strideconv', upsample_mode='convtranspose')
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
for _, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)

torch.cuda.empty_cache()

def deblock(n: int, f: vs.VideoFrame) -> vs.VideoFrame:
img_L = frame_to_tensor(f)
noise_level = torch.FloatTensor([strength])
noise_level_map = torch.ones((1, 1, img_L.shape[2], img_L.shape[3])).mul_(noise_level).float()
img_L = torch.cat((img_L, noise_level_map), dim=1)
img_L = img_L.to(device)

img_E = model(img_L)

return tensor_to_frame(img_E, f)

def denoise(n: int, f: vs.VideoFrame) -> vs.VideoFrame:
img_L = frame_to_tensor(f)
img_L = torch.cat((img_L, torch.FloatTensor([strength]).repeat(1, 1, img_L.shape[2], img_L.shape[3])), dim=1)
img_L = img_L.to(device)

if img_L.size(2) // 8 == 0 and img_L.size(3) // 8 == 0:
img_E = model(img_L)
else:
img_E = utils_model.test_mode(model, img_L, refield=64, mode=5)

return tensor_to_frame(img_E, f)

return clip.std.ModifyFrame(clips=clip, selector=eval(task))


def frame_to_tensor(f: vs.VideoFrame) -> torch.Tensor:
arr = np.stack([np.asarray(f.get_read_array(plane)) for plane in range(f.format.num_planes)])
return torch.from_numpy(arr).float().unsqueeze(0)


def tensor_to_frame(t: torch.Tensor, f: vs.VideoFrame) -> vs.VideoFrame:
arr = t.data.squeeze().float().cpu().numpy()
fout = f.copy()
for plane in range(fout.format.num_planes):
np.copyto(np.asarray(fout.get_write_array(plane)), arr[plane, ...])
return fout
16 changes: 16 additions & 0 deletions vsdpir/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import os.path
import requests
from tqdm import tqdm

def download_model(url) -> None:
filename = url.split('/')[-1]
r = requests.get(url, stream=True)
with open(os.path.join(os.path.dirname(__file__), filename), 'wb') as f:
with tqdm(unit='B', unit_scale=True, unit_divisor=1024, miniters=1, desc=filename, total=int(r.headers.get('content-length', 0))) as pbar:
for chunk in r.iter_content(chunk_size=4096):
f.write(chunk)
pbar.update(len(chunk))

if __name__ == '__main__':
download_model('https://github.com/HolyWu/vs-dpir/releases/download/model/drunet_color.pth')
download_model('https://github.com/HolyWu/vs-dpir/releases/download/model/drunet_deblocking_color.pth')
Loading

0 comments on commit 9992f9b

Please sign in to comment.