From 7a731483a5ed70ffd1658798c61f9cb0c18ba1d8 Mon Sep 17 00:00:00 2001 From: SWHL Date: Wed, 15 May 2024 22:24:38 +0800 Subject: [PATCH] fix(config): Remove the default installation onnxruntime-directml restriction on Windows. --- python/rapidocr_onnxruntime/utils.py | 46 +++++++++++++++++++--------- python/requirements_ort.txt | 4 +-- 2 files changed, 32 insertions(+), 18 deletions(-) diff --git a/python/rapidocr_onnxruntime/utils.py b/python/rapidocr_onnxruntime/utils.py index 3a345f4d5..7030d7f77 100644 --- a/python/rapidocr_onnxruntime/utils.py +++ b/python/rapidocr_onnxruntime/utils.py @@ -2,15 +2,16 @@ # @Author: SWHL # @Contact: liekkaskono@163.com import argparse +import importlib import math import os +import platform import random import traceback import warnings from io import BytesIO from pathlib import Path from typing import Dict, List, Optional, Tuple, Union -import platform import cv2 import numpy as np @@ -67,10 +68,9 @@ def _get_ep_list(self) -> List[Tuple[str, str]]: cpu_provider_opts = { "arena_extend_strategy": "kSameAsRequested", } - EP_list = [(CPU_EP, cpu_provider_opts)] - use_cuda = ( + self.use_cuda = ( self.cfg_use_cuda and get_device() == "GPU" and CUDA_EP in had_providers ) cuda_provider_opts = { @@ -79,21 +79,39 @@ def _get_ep_list(self) -> List[Tuple[str, str]]: "cudnn_conv_algo_search": "EXHAUSTIVE", "do_copy_in_default_stream": True, } + # check windows 10 or above - # use_directml = platform and DIRECTML_EP in had_providers and self.cfg_use_cuda("use_dml") - use_directml = platform.system() == "Windows" and platform.release().split(".")[0] >= "10" and DIRECTML_EP in had_providers and self.cfg_use_dml - if use_cuda: + self.use_directml = ( + platform.system() == "Windows" + and platform.release().split(".")[0] >= "10" + and DIRECTML_EP in had_providers + and self.cfg_use_dml + ) + if self.use_cuda: EP_list.insert(0, (CUDA_EP, cuda_provider_opts)) - elif use_directml: - print("Windows 10 or above detected, try to use DirectML as primary provider") - directml_options = cuda_provider_opts if use_cuda else cpu_provider_opts + elif self.use_directml: + self._verfiy_dml() + print( + "Windows 10 or above detected, try to use DirectML as primary provider" + ) + directml_options = ( + cuda_provider_opts if self.use_cuda else cpu_provider_opts + ) EP_list.insert(0, (DIRECTML_EP, directml_options)) return EP_list + def _verfiy_dml(self): + try: + importlib.import_module("onnxruntime-directml") + except ModuleNotFoundError as exc: + raise ModuleNotFoundError( + "If there are other onnxruntime packages installed, please use pip uninstall onnxruntime to uninstall them first. \nThen install the package using DirectML through pip install onnxruntime-directml." + ) from exc + def _verify_providers(self) -> None: session_providers = self.session.get_providers() - - if self.cfg_use_cuda and CUDA_EP not in session_providers: + + if self.use_cuda and session_providers[0] != CUDA_EP: warnings.warn( f"{CUDA_EP} is not avaiable for current env, the inference part is automatically shifted to be executed under {CPU_EP}.\n" "Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, " @@ -101,14 +119,12 @@ def _verify_providers(self) -> None: "https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html", RuntimeWarning, ) - - use_directml = platform.system() == "Windows" and platform.release().split(".")[0] >= "10" and self.cfg_use_dml - if use_directml and session_providers[0] != DIRECTML_EP: + + if self.use_directml and session_providers[0] != DIRECTML_EP: warnings.warn( "DirectML is not available for the current environment, the inference part is automatically shifted to be executed under other EP.\n" ) - def __call__(self, input_content: np.ndarray) -> np.ndarray: input_dict = dict(zip(self.get_input_names(), [input_content])) try: diff --git a/python/requirements_ort.txt b/python/requirements_ort.txt index 9b629440c..5cc0e39d9 100644 --- a/python/requirements_ort.txt +++ b/python/requirements_ort.txt @@ -5,6 +5,4 @@ six>=1.15.0 Shapely>=1.7.1 PyYAML Pillow -# install the onnxruntime-directml if on windows platform, notice that the onnxruntime-directml is conflict with onnxruntime, we can only install one of them -onnxruntime-directml;platform_system=="Windows" -onnxruntime>=1.7.0;platform_system!="Windows" \ No newline at end of file +onnxruntime>=1.7.0 \ No newline at end of file