Skip to content

Commit

Permalink
Explicitly specify cuda device for onnxruntime
Browse files Browse the repository at this point in the history
  • Loading branch information
BreezeWhite committed Nov 16, 2024
1 parent dd00ca1 commit 76894b4
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion oemer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,17 @@ def inference(
output_shape = model.output_shape
else:
import onnxruntime as rt
import torch

onnx_path = os.path.join(model_path, "model.onnx")
metadata = pickle.load(open(os.path.join(model_path, "metadata.pkl"), "rb"))
if sys.platform == "darwin":
providers = ["CoreMLExecutionProvider", "CPUExecutionProvider"]
else:
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
providers = [
("CUDAExecutionProvider", {"device_id": 0}),
"CPUExecutionProvider",
]
sess = rt.InferenceSession(onnx_path, providers=providers)
output_names = metadata["output_names"]
input_shape = metadata["input_shape"]
Expand Down

0 comments on commit 76894b4

Please sign in to comment.