Skip to content

Commit

Permalink
fix(config): Remove the default installation onnxruntime-directml res…
Browse files Browse the repository at this point in the history
…triction on Windows.
  • Loading branch information
SWHL committed May 15, 2024
1 parent 8523a2e commit 7a73148
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 18 deletions.
46 changes: 31 additions & 15 deletions python/rapidocr_onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
# @Author: SWHL
# @Contact: [email protected]
import argparse
import importlib
import math
import os
import platform
import random
import traceback
import warnings
from io import BytesIO
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import platform

import cv2
import numpy as np
Expand Down Expand Up @@ -67,10 +68,9 @@ def _get_ep_list(self) -> List[Tuple[str, str]]:
cpu_provider_opts = {
"arena_extend_strategy": "kSameAsRequested",
}

EP_list = [(CPU_EP, cpu_provider_opts)]

use_cuda = (
self.use_cuda = (
self.cfg_use_cuda and get_device() == "GPU" and CUDA_EP in had_providers
)
cuda_provider_opts = {
Expand All @@ -79,36 +79,52 @@ def _get_ep_list(self) -> List[Tuple[str, str]]:
"cudnn_conv_algo_search": "EXHAUSTIVE",
"do_copy_in_default_stream": True,
}

# check windows 10 or above
# use_directml = platform and DIRECTML_EP in had_providers and self.cfg_use_cuda("use_dml")
use_directml = platform.system() == "Windows" and platform.release().split(".")[0] >= "10" and DIRECTML_EP in had_providers and self.cfg_use_dml
if use_cuda:
self.use_directml = (
platform.system() == "Windows"
and platform.release().split(".")[0] >= "10"
and DIRECTML_EP in had_providers
and self.cfg_use_dml
)
if self.use_cuda:
EP_list.insert(0, (CUDA_EP, cuda_provider_opts))
elif use_directml:
print("Windows 10 or above detected, try to use DirectML as primary provider")
directml_options = cuda_provider_opts if use_cuda else cpu_provider_opts
elif self.use_directml:
self._verfiy_dml()
print(
"Windows 10 or above detected, try to use DirectML as primary provider"
)
directml_options = (
cuda_provider_opts if self.use_cuda else cpu_provider_opts
)
EP_list.insert(0, (DIRECTML_EP, directml_options))
return EP_list

def _verfiy_dml(self):
try:
importlib.import_module("onnxruntime-directml")
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"If there are other onnxruntime packages installed, please use pip uninstall onnxruntime to uninstall them first. \nThen install the package using DirectML through pip install onnxruntime-directml."
) from exc

def _verify_providers(self) -> None:
session_providers = self.session.get_providers()
if self.cfg_use_cuda and CUDA_EP not in session_providers:

if self.use_cuda and session_providers[0] != CUDA_EP:
warnings.warn(
f"{CUDA_EP} is not avaiable for current env, the inference part is automatically shifted to be executed under {CPU_EP}.\n"
"Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, "
"you can check their relations from the offical web site: "
"https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html",
RuntimeWarning,
)

use_directml = platform.system() == "Windows" and platform.release().split(".")[0] >= "10" and self.cfg_use_dml
if use_directml and session_providers[0] != DIRECTML_EP:

if self.use_directml and session_providers[0] != DIRECTML_EP:
warnings.warn(
"DirectML is not available for the current environment, the inference part is automatically shifted to be executed under other EP.\n"
)


def __call__(self, input_content: np.ndarray) -> np.ndarray:
input_dict = dict(zip(self.get_input_names(), [input_content]))
try:
Expand Down
4 changes: 1 addition & 3 deletions python/requirements_ort.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,4 @@ six>=1.15.0
Shapely>=1.7.1
PyYAML
Pillow
# install the onnxruntime-directml if on windows platform, notice that the onnxruntime-directml is conflict with onnxruntime, we can only install one of them
onnxruntime-directml;platform_system=="Windows"
onnxruntime>=1.7.0;platform_system!="Windows"
onnxruntime>=1.7.0

0 comments on commit 7a73148

Please sign in to comment.