From f074b983beb40573758eaac103433e911cca28a1 Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 8 Mar 2024 15:30:14 -0500 Subject: [PATCH 1/2] feat: allow cpu device even if gpu available respect if `FLAIR_DEVICE="cpu"` even on system with cuda available --- flair/__init__.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/flair/__init__.py b/flair/__init__.py index ee32069b1..6dfdead37 100644 --- a/flair/__init__.py +++ b/flair/__init__.py @@ -18,17 +18,18 @@ device: torch.device """Flair is using a single device for everything. You can set this device by overwriting this variable. -This value will be automatically set to the first found GPU if available and to CPU otherwise. -You can choose a specific GPU, by setting the `FLAIR_DEVICE` environment variable to its index. +The device will be automatically set to the first available GPU if a GPU is present and the 'FLAIR_DEVICE' environment +variable is not set to 'cpu', otherwise it will default to the CPU, and a specific GPU can be chosen by setting the 'FLAIR_DEVICE' +environment variable to its index. """ +# Get the device from the environment variable +flair_device = os.environ.get("FLAIR_DEVICE") # global variable: device -if torch.cuda.is_available(): - device_id = os.environ.get("FLAIR_DEVICE") - +if torch.cuda.is_available() and flair_device != "cpu": # No need for correctness checks, torch is doing it - device = torch.device(f"cuda:{device_id}") if device_id else torch.device("cuda:0") + device = torch.device(f"cuda:{flair_device}") if flair_device else torch.device("cuda:0") else: device = torch.device("cpu") From cc1d938e0fd0b36a6aba15e3903f63882beee10a Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 8 Mar 2024 15:33:37 -0500 Subject: [PATCH 2/2] fix: prefer `device_id` name --- flair/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flair/__init__.py b/flair/__init__.py index 6dfdead37..ac9770d34 100644 --- a/flair/__init__.py +++ b/flair/__init__.py @@ -24,12 +24,12 @@ """ # Get the device from the environment variable -flair_device = os.environ.get("FLAIR_DEVICE") +device_id = os.environ.get("FLAIR_DEVICE") # global variable: device -if torch.cuda.is_available() and flair_device != "cpu": +if torch.cuda.is_available() and device_id != "cpu": # No need for correctness checks, torch is doing it - device = torch.device(f"cuda:{flair_device}") if flair_device else torch.device("cuda:0") + device = torch.device(f"cuda:{device_id}") if device_id else torch.device("cuda:0") else: device = torch.device("cpu")