From b7b4549e02ed3360ccb2e3d428f6abb368346a4d Mon Sep 17 00:00:00 2001 From: WolframRhodium Date: Sat, 1 Jan 2022 16:24:57 +0800 Subject: [PATCH] scripts/vsmlrt.py: support mutable backend object --- scripts/vsmlrt.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/scripts/vsmlrt.py b/scripts/vsmlrt.py index 428c68f..6073e06 100644 --- a/scripts/vsmlrt.py +++ b/scripts/vsmlrt.py @@ -7,6 +7,7 @@ "RealESRGANv2", "RealESRGANv2Model" ] +import copy from dataclasses import dataclass, field import enum import math @@ -42,13 +43,13 @@ def get_plugins_path() -> str: class Backend: - @dataclass(frozen=True) + @dataclass(frozen=False) class ORT_CPU: num_streams: int = 1 verbosity: int = 2 fp16: bool = False - @dataclass(frozen=True) + @dataclass(frozen=False) class ORT_CUDA: device_id: int = 0 cudnn_benchmark: bool = True @@ -56,11 +57,11 @@ class ORT_CUDA: verbosity: int = 2 fp16: bool = False - @dataclass(frozen=True) + @dataclass(frozen=False) class OV_CPU: fp16: bool = False - @dataclass + @dataclass(frozen=False) class TRT: max_shapes: typing.Optional[typing.Tuple[int, int]] = None opt_shapes: typing.Optional[typing.Tuple[int, int]] = None @@ -568,6 +569,8 @@ def init_backend( elif backend is Backend.TRT: # type: ignore backend = Backend.TRT() + backend = copy.deepcopy(backend) + if isinstance(backend, Backend.TRT): backend._channels = channels