Skip to content

Commit

Permalink
scripts/vsmlrt.py: add flexible_inference() for arbitrary number of…
Browse files Browse the repository at this point in the history
… output planes; fix typo
  • Loading branch information
WolframRhodium committed May 13, 2024
1 parent 6cfc179 commit 0581970
Showing 1 changed file with 128 additions and 30 deletions.
158 changes: 128 additions & 30 deletions scripts/vsmlrt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "3.20.13"
__version__ = "3.21.0"

__all__ = [
"Backend", "BackendV2",
Expand Down Expand Up @@ -1705,19 +1705,24 @@ def ArtCNN(
else:
clip = core.std.Expr(clip, ["", "x 0.5 +"])

clip = inference_with_fallback(
clips=[clip], network_path=network_path,
overlap=(overlap_w, overlap_h), tilesize=(tile_w, tile_h),
backend=backend
)
clip_u, clip_v = flexible_inference_with_fallback(
clips=[clip], network_path=network_path,
overlap=(overlap_w, overlap_h), tilesize=(tile_w, tile_h),
backend=backend
)

if model in range(4, 6):
clip = core.std.ShufflePlanes(clip, [0, 1, 2], vs.YUV)
clip = core.std.ShufflePlanes([clip, clip_u, clip_v], [0, 0, 0], vs.YUV)

if clip.format.bits_per_sample == 16:
clip = core.akarin.Expr(clip, ["", "x 0.5 -"])
else:
clip = core.std.Expr(clip, ["", "x 0.5 -"])
else:
clip = inference_with_fallback(
clips=[clip], network_path=network_path,
overlap=(overlap_w, overlap_h), tilesize=(tile_w, tile_h),
backend=backend
)

return clip

Expand Down Expand Up @@ -2259,8 +2264,9 @@ def _inference(
tilesize: typing.Tuple[int, int],
backend: backendT,
path_is_serialization: bool = False,
input_name: str = "input"
) -> vs.VideoNode:
input_name: str = "input",
flexible_output_prop: typing.Optional[str] = None
) -> typing.Union[vs.VideoNode, typing.Dict[str, typing.Any]]:

if not path_is_serialization:
network_path = typing.cast(str, network_path)
Expand All @@ -2274,18 +2280,19 @@ def _inference(
)

if isinstance(backend, Backend.ORT_CPU):
clip = core.ort.Model(
ret = core.ort.Model(
clips, network_path,
overlap=overlap, tilesize=tilesize,
provider="CPU", builtin=False,
num_streams=backend.num_streams,
verbosity=backend.verbosity,
fp16=backend.fp16,
path_is_serialization=path_is_serialization,
fp16_blacklist_ops=backend.fp16_blacklist_ops
fp16_blacklist_ops=backend.fp16_blacklist_ops,
flexible_output_prop=flexible_output_prop,
)
elif isinstance(backend, Backend.ORT_DML):
clip = core.ort.Model(
ret = core.ort.Model(
clips, network_path,
overlap=overlap, tilesize=tilesize,
provider="DML", builtin=False,
Expand All @@ -2294,7 +2301,8 @@ def _inference(
verbosity=backend.verbosity,
fp16=backend.fp16,
path_is_serialization=path_is_serialization,
fp16_blacklist_ops=backend.fp16_blacklist_ops
fp16_blacklist_ops=backend.fp16_blacklist_ops,
flexible_output_prop=flexible_output_prop,
)
elif isinstance(backend, Backend.ORT_CUDA):
kwargs = dict()
Expand All @@ -2310,7 +2318,7 @@ def _inference(
kwargs["output_format"] = backend.output_format
kwargs["tf32"] = backend.tf32

clip = core.ort.Model(
ret = core.ort.Model(
clips, network_path,
overlap=overlap, tilesize=tilesize,
provider="CUDA", builtin=False,
Expand All @@ -2322,6 +2330,7 @@ def _inference(
path_is_serialization=path_is_serialization,
use_cuda_graph=backend.use_cuda_graph,
fp16_blacklist_ops=backend.fp16_blacklist_ops,
flexible_output_prop=flexible_output_prop,
**kwargs
)
elif isinstance(backend, Backend.OV_CPU):
Expand Down Expand Up @@ -2349,14 +2358,15 @@ def _inference(
ENFORCE_BF16="YES" if backend.bf16 else "NO"
)

clip = core.ov.Model(
ret = core.ov.Model(
clips, network_path,
overlap=overlap, tilesize=tilesize,
device="CPU", builtin=False,
fp16=False, # use ov's internal quantization
config=config,
path_is_serialization=path_is_serialization,
fp16_blacklist_ops=backend.fp16_blacklist_ops # disabled since fp16 = False
fp16_blacklist_ops=backend.fp16_blacklist_ops, # disabled since fp16 = False
flexible_output_prop=flexible_output_prop,
)
elif isinstance(backend, Backend.OV_GPU):
version = tuple(map(int, core.ov.Version().get("openvino_version", b"0.0.0").split(b'-')[0].split(b'.')))
Expand All @@ -2376,14 +2386,15 @@ def _inference(
GPU_THROUGHPUT_STREAMS=backend.num_streams
)

clip = core.ov.Model(
ret = core.ov.Model(
clips, network_path,
overlap=overlap, tilesize=tilesize,
device=f"GPU.{backend.device_id}", builtin=False,
fp16=False, # use ov's internal quantization
config=config,
path_is_serialization=path_is_serialization,
fp16_blacklist_ops=backend.fp16_blacklist_ops
fp16_blacklist_ops=backend.fp16_blacklist_ops,
flexible_output_prop=flexible_output_prop,
)
elif isinstance(backend, Backend.TRT):
if path_is_serialization:
Expand Down Expand Up @@ -2428,24 +2439,26 @@ def _inference(
custom_args=backend.custom_args,
engine_folder=backend.engine_folder,
)
clip = core.trt.Model(
ret = core.trt.Model(
clips, engine_path,
overlap=overlap,
tilesize=tilesize,
device_id=backend.device_id,
use_cuda_graph=backend.use_cuda_graph,
num_streams=backend.num_streams,
verbosity=4 if backend.verbose else 2
verbosity=4 if backend.verbose else 2,
flexible_output_prop=flexible_output_prop,
)
elif isinstance(backend, Backend.NCNN_VK):
clip = core.ncnn.Model(
ret = core.ncnn.Model(
clips, network_path,
overlap=overlap, tilesize=tilesize,
device_id=backend.device_id,
num_streams=backend.num_streams,
builtin=False,
fp16=backend.fp16,
path_is_serialization=path_is_serialization,
flexible_output_prop=flexible_output_prop,
)
elif isinstance(backend, Backend.MIGX):
if path_is_serialization:
Expand All @@ -2470,24 +2483,26 @@ def _inference(
custom_env=backend.custom_env,
custom_args=backend.custom_args
)
clip = core.migx.Model(
ret = core.migx.Model(
clips, mxr_path,
overlap=overlap,
tilesize=tilesize,
device_id=backend.device_id
device_id=backend.device_id,
flexible_output_prop=flexible_output_prop,
)
elif isinstance(backend, Backend.OV_NPU):
clip = core.ov.Model(
ret = core.ov.Model(
clips, network_path,
overlap=overlap, tilesize=tilesize,
device="NPU", builtin=False,
fp16=False, # use ov's internal quantization
path_is_serialization=path_is_serialization,
flexible_output_prop=flexible_output_prop,
)
else:
raise TypeError(f'unknown backend {backend}')

return clip
return ret


def inference_with_fallback(
Expand All @@ -2501,7 +2516,7 @@ def inference_with_fallback(
) -> vs.VideoNode:

try:
return _inference(
ret = _inference(
clips=clips, network_path=network_path,
overlap=overlap, tilesize=tilesize,
backend=backend,
Expand All @@ -2514,7 +2529,7 @@ def inference_with_fallback(
logger = logging.getLogger("vsmlrt")
logger.warning(f'"{backend}" fails, trying fallback backend "{fallback_backend}"')

return _inference(
ret = _inference(
clips=clips, network_path=network_path,
overlap=overlap, tilesize=tilesize,
backend=fallback_backend,
Expand All @@ -2524,6 +2539,8 @@ def inference_with_fallback(
else:
raise e

return typing.cast(vs.VideoNode, ret)


def inference(
clips: typing.Union[vs.VideoNode, typing.List[vs.VideoNode]],
Expand All @@ -2535,7 +2552,6 @@ def inference(
) -> vs.VideoNode:

if isinstance(clips, vs.VideoNode):
clips = typing.cast(vs.VideoNode, clips)
clips = [clips]

if tilesize is None:
Expand All @@ -2557,6 +2573,88 @@ def inference(
)


def flexible_inference_with_fallback(
clips: typing.List[vs.VideoNode],
network_path: typing.Union[bytes, str],
overlap: typing.Tuple[int, int],
tilesize: typing.Tuple[int, int],
backend: backendT,
path_is_serialization: bool = False,
input_name: str = "input",
flexible_output_prop: str = "vsmlrt_flexible"
) -> typing.List[vs.VideoNode]:

try:
ret = _inference(
clips=clips, network_path=network_path,
overlap=overlap, tilesize=tilesize,
backend=backend,
path_is_serialization=path_is_serialization,
input_name=input_name,
flexible_output_prop=flexible_output_prop
)
except Exception as e:
if fallback_backend is not None:
import logging
logger = logging.getLogger("vsmlrt")
logger.warning(f'"{backend}" fails, trying fallback backend "{fallback_backend}"')

ret = _inference(
clips=clips, network_path=network_path,
overlap=overlap, tilesize=tilesize,
backend=fallback_backend,
path_is_serialization=path_is_serialization,
input_name=input_name,
flexible_output_prop=flexible_output_prop
)
else:
raise e

ret = typing.cast(typing.Dict[str, typing.Any], ret)
clip = ret["clip"]
num_planes = ret["num_planes"]

planes = [
clip.std.PropToClip(prop=f"{flexible_output_prop}{i}")
for i in range(num_planes)
]

return planes


def flexible_inference(
clips: typing.Union[vs.VideoNode, typing.List[vs.VideoNode]],
network_path: str,
overlap: typing.Tuple[int, int] = (0, 0),
tilesize: typing.Optional[typing.Tuple[int, int]] = None,
backend: backendT = Backend.OV_CPU(),
input_name: typing.Optional[str] = "input",
flexible_output_prop: str = "vsmlrt_flexible"
) -> typing.List[vs.VideoNode]:

if isinstance(clips, vs.VideoNode):
clips = [clips]

if tilesize is None:
tilesize = (clips[0].width, clips[0].height)

backend = init_backend(backend=backend, trt_opt_shapes=tilesize)

if input_name is None:
input_name = get_input_name(network_path)

return flexible_inference_with_fallback(
clips=clips,
network_path=network_path,
overlap=overlap,
tilesize=tilesize,
backend=backend,
path_is_serialization=False,
input_name=input_name,
flexible_output_prop=flexible_output_prop
)


def get_input_name(network_path: str) -> str:
import onnx
model = onnx.load(network_path)
Expand Down Expand Up @@ -2709,7 +2807,7 @@ def MIGX(*,

return Backend.MIGX(
fp16=fp16,
opt_shapes=opt_shapes
opt_shapes=opt_shapes,
**kwargs
)

Expand Down

0 comments on commit 0581970

Please sign in to comment.