Skip to content

Commit

Permalink
scripts/vsmlrt.py: add supprt for cugan model
Browse files Browse the repository at this point in the history
  • Loading branch information
WolframRhodium committed Jan 27, 2022
1 parent 390d0e5 commit 4c3693d
Showing 1 changed file with 88 additions and 2 deletions.
90 changes: 88 additions & 2 deletions scripts/vsmlrt.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
__version__ = "3.4.2"
__version__ = "3.5.0"

__all__ = [
"Backend",
"Waifu2x", "Waifu2xModel",
"DPIR", "DPIRModel",
"RealESRGANv2", "RealESRGANv2Model"
"RealESRGANv2", "RealESRGANv2Model",
"CUGAN"
]

import copy
Expand Down Expand Up @@ -397,6 +398,91 @@ def RealESRGANv2(
return clip


def CUGAN(
clip: vs.VideoNode,
noise: typing.Literal[-1, 0, 1, 2, 3] = -1,
scale: typing.Literal[2, 3, 4] = 2,
tiles: typing.Optional[typing.Union[int, typing.Tuple[int, int]]] = None,
tilesize: typing.Optional[typing.Union[int, typing.Tuple[int, int]]] = None,
overlap: typing.Optional[typing.Union[int, typing.Tuple[int, int]]] = None,
backend: backendT = Backend.OV_CPU(),
preprocess: bool = True
) -> vs.VideoNode:

func_name = "vsmlrt.CUGAN"

if not isinstance(clip, vs.VideoNode):
raise TypeError(f'{func_name}: "clip" must be a clip!')

if clip.format.sample_type != vs.FLOAT or clip.format.bits_per_sample != 32:
raise ValueError(f"{func_name}: only constant format 32 bit float input supported")

if not isinstance(noise, int) or noise not in range(-1, 4):
raise ValueError(f'{func_name}: "noise" must be -1, 0, 1, 2, or 3')

if not isinstance(scale, int) or scale not in (2, 3, 4):
raise ValueError(f'{func_name}: "scale" must be 2, 3 or 4')

if scale != 2 and noise in [1, 2]:
raise ValueError(
f'{func_name}: "scale={scale}" model'
f' does not support noise reduction level {noise}'
)

if clip.format.id != vs.RGBS:
raise ValueError(f'{func_name}: "clip" must be of RGBS format')

if overlap is None:
overlap_w = overlap_h = 4
elif isinstance(overlap, int):
overlap_w = overlap_h = overlap
else:
overlap_w, overlap_h = overlap

multiple = 2

width, height = clip.width, clip.height

(tile_w, tile_h), (overlap_w, overlap_h) = calc_tilesize(
tiles=tiles, tilesize=tilesize,
width=clip.width, height=clip.height,
multiple=multiple,
overlap_w=overlap_w, overlap_h=overlap_h
)

if tile_w % multiple != 0 or tile_h % multiple != 0:
raise ValueError(
f'{func_name}: tile size must be divisible by {multiple} ({tile_w}, {tile_h})'
)

channels = 3

backend = init_backend(
backend=backend,
channels=channels,
trt_max_shapes=(tile_w, tile_h)
)

folder_path = os.path.join(models_path, "cugan")

if noise == -1:
model_name = f"up{scale}x-latest-no-denoise.onnx"
elif noise == 0:
model_name = f"up{scale}x-latest-conservative.onnx"
else:
model_name = f"up{scale}x-latest-denoise{noise}x.onnx"

network_path = os.path.join(folder_path, model_name)

clip = inference(
clips=[clip], network_path=network_path,
overlap=(overlap_w, overlap_h), tilesize=(tile_w, tile_h),
backend=backend
)

return clip


def get_engine_path(
network_path: str,
opt_shapes: typing.Tuple[int, int],
Expand Down

0 comments on commit 4c3693d

Please sign in to comment.