Skip to content

Commit

Permalink
scripts/vsmlrt.py: improve backend specification
Browse files Browse the repository at this point in the history
  • Loading branch information
WolframRhodium committed Dec 16, 2021
1 parent a661365 commit be75f95
Showing 1 changed file with 38 additions and 12 deletions.
50 changes: 38 additions & 12 deletions scripts/vsmlrt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
__version__ = "3.0.0"

from dataclasses import dataclass
import enum
import math
import os.path
Expand All @@ -7,8 +10,22 @@
from vapoursynth import core


def Version() -> str:
return "3.0.0"
class Backend:
@dataclass(frozen=True)
class ORT_CPU:
num_streams: int = 1
verbosity: int = 2

@dataclass(frozen=True)
class ORT_CUDA:
device_id: int = 0
cudnn_benchmark: bool = True
num_streams: int = 1
verbosity: int = 2

@dataclass(frozen=True)
class OV_CPU:
pass


def calcSize(width: int, tiles: int, overlap: int) -> int:
Expand All @@ -34,10 +51,7 @@ def Waifu2x(
tilesize: typing.Optional[typing.Union[int, typing.Tuple[int, int]]] = None,
overlap: typing.Optional[typing.Union[int, typing.Tuple[int, int]]] = None,
model: typing.Literal[0, 1, 2, 3, 4, 5, 6] = 6,
backend: typing.Literal["ort-cpu", "ort-cuda", "ov-cpu"] = "ort-cpu",
# parameters for "ort-cuda"
device_id: int = 0,
cudnn_benchmark: bool = True
backend: typing.Union[Backend.OV_CPU, Backend.ORT_CPU, Backend.ORT_CUDA] = Backend.OV_CPU()
) -> vs.VideoNode:

funcName = "vsmlrt.Waifu2x"
Expand Down Expand Up @@ -96,6 +110,13 @@ def Waifu2x(
elif clip.format.id != vs.RGBS:
raise ValueError(f'{funcName}: input should be of RGBS format')

if backend is Backend.ORT_CPU:
backend = Backend.ORT_CPU()
elif backend is Backend.ORT_CUDA:
backend = Backend.ORT_CUDA()
elif backend is Backend.OV_CPU:
backend = Backend.OV_CPU()

folder_path = os.path.join("waifu2x", tuple(Waifu2xModel.__members__)[model])

if model in (0, 1, 2):
Expand Down Expand Up @@ -130,20 +151,25 @@ def Waifu2x(
filter_param_a=0, filter_param_b=0.75
)

if backend == "ort-cpu":
if isinstance(backend, Backend.ORT_CPU):
clip = core.ort.Model(
clip, network_path,
overlap=(overlap_w, overlap_h), tilesize=(tile_w, tile_h),
provider="CPU", builtin=1
provider="CPU", builtin=1,
num_streams=backend.num_streams,
verbosity=backend.verbosity
)
elif backend == "ort-cuda":
elif isinstance(backend, Backend.ORT_CUDA):
clip = core.ort.Model(
clip, network_path,
overlap=(overlap_w, overlap_h), tilesize=(tile_w, tile_h),
provider="CUDA", device_id=device_id, cudnn_benchmark=cudnn_benchmark,
builtin=1
provider="CUDA", builtin=1,
device_id=backend.device_id,
num_streams=backend.num_streams,
verbosity=backend.verbosity,
cudnn_benchmark=backend.cudnn_benchmark
)
elif backend == "ov-cpu":
elif isinstance(backend, Backend.OV_CPU):
clip = core.ov.Model(
clip, network_path,
overlap=(overlap_w, overlap_h), tilesize=(tile_w, tile_h),
Expand Down

0 comments on commit be75f95

Please sign in to comment.