Skip to content

Commit

Permalink
scripts/vsmlrt.py: support mutable backend object
Browse files Browse the repository at this point in the history
  • Loading branch information
WolframRhodium committed Jan 1, 2022
1 parent c3c056e commit b7b4549
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions scripts/vsmlrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"RealESRGANv2", "RealESRGANv2Model"
]

import copy
from dataclasses import dataclass, field
import enum
import math
Expand Down Expand Up @@ -42,25 +43,25 @@ def get_plugins_path() -> str:


class Backend:
@dataclass(frozen=True)
@dataclass(frozen=False)
class ORT_CPU:
num_streams: int = 1
verbosity: int = 2
fp16: bool = False

@dataclass(frozen=True)
@dataclass(frozen=False)
class ORT_CUDA:
device_id: int = 0
cudnn_benchmark: bool = True
num_streams: int = 1
verbosity: int = 2
fp16: bool = False

@dataclass(frozen=True)
@dataclass(frozen=False)
class OV_CPU:
fp16: bool = False

@dataclass
@dataclass(frozen=False)
class TRT:
max_shapes: typing.Optional[typing.Tuple[int, int]] = None
opt_shapes: typing.Optional[typing.Tuple[int, int]] = None
Expand Down Expand Up @@ -568,6 +569,8 @@ def init_backend(
elif backend is Backend.TRT: # type: ignore
backend = Backend.TRT()

backend = copy.deepcopy(backend)

if isinstance(backend, Backend.TRT):
backend._channels = channels

Expand Down

0 comments on commit b7b4549

Please sign in to comment.