diff --git a/scripts/vsmlrt.py b/scripts/vsmlrt.py index 428c68f..6073e06 100644 --- a/scripts/vsmlrt.py +++ b/scripts/vsmlrt.py @@ -7,6 +7,7 @@ "RealESRGANv2", "RealESRGANv2Model" ] +import copy from dataclasses import dataclass, field import enum import math @@ -42,13 +43,13 @@ def get_plugins_path() -> str: class Backend: - @dataclass(frozen=True) + @dataclass(frozen=False) class ORT_CPU: num_streams: int = 1 verbosity: int = 2 fp16: bool = False - @dataclass(frozen=True) + @dataclass(frozen=False) class ORT_CUDA: device_id: int = 0 cudnn_benchmark: bool = True @@ -56,11 +57,11 @@ class ORT_CUDA: verbosity: int = 2 fp16: bool = False - @dataclass(frozen=True) + @dataclass(frozen=False) class OV_CPU: fp16: bool = False - @dataclass + @dataclass(frozen=False) class TRT: max_shapes: typing.Optional[typing.Tuple[int, int]] = None opt_shapes: typing.Optional[typing.Tuple[int, int]] = None @@ -568,6 +569,8 @@ def init_backend( elif backend is Backend.TRT: # type: ignore backend = Backend.TRT() + backend = copy.deepcopy(backend) + if isinstance(backend, Backend.TRT): backend._channels = channels