Skip to content

Commit

Permalink
backend: fix device filter
Browse files Browse the repository at this point in the history
  • Loading branch information
Nuullll committed Nov 24, 2024
1 parent b2e9203 commit f3d24c7
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
31 changes: 29 additions & 2 deletions service/device_detect.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,38 @@
import subprocess
import re

def get_devices():
result = subprocess.run(["wmic", "path", "win32_VideoController", "get", "name,pnpdeviceid"], capture_output=True, text=True)
lines = result.stdout.split("\n")
devices = {} # {name: device_id}
for line in lines:
# vendor:8086 is an Intel device
if "VEN_8086" in line:
parts = line.split(" ")
name = parts[0].strip()
for part in parts[1:]:
if "VEN_8086" in part:
# regex match "VEN_8086&DEV_{device_id}"
if match := re.search(r"VEN_8086\&DEV_([0-9A-F]+)", part):
device_id = match.group(1)
devices[name] = device_id
break

return devices

devices = get_devices()

def is_supported(name):
return "arc" in name.lower() or (name in devices and devices[name].lower() == "e20b")

import torch

Check failure on line 28 in service/device_detect.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E402)

service/device_detect.py:28:1: E402 Module level import not at top of file
import intel_extension_for_pytorch as ipex # noqa: F401

Check failure on line 29 in service/device_detect.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E402)

service/device_detect.py:29:1: E402 Module level import not at top of file

# filter out non-Arc devices
# filter out unsupported devices
supported_ids = []
for i in range(torch.xpu.device_count()):
props = torch.xpu.get_device_properties(i)
if "arc" in props.name.lower():
if is_supported(props.name):
supported_ids.append(str(i))

print(",".join(supported_ids))
1 change: 1 addition & 0 deletions service/web_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
os.environ["ONEAPI_DEVICE_SELECTOR"] = f"*:{supported_ids}"
print(f"Set ONEAPI_DEVICE_SELECTOR={os.environ['ONEAPI_DEVICE_SELECTOR']}")
except: # noqa: E722
print("Warning: Device detection failed, using all devices")
pass

# Credit to https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14186
Expand Down

0 comments on commit f3d24c7

Please sign in to comment.