diff --git a/flair/__init__.py b/flair/__init__.py index 6dfdead37..ac9770d34 100644 --- a/flair/__init__.py +++ b/flair/__init__.py @@ -24,12 +24,12 @@ """ # Get the device from the environment variable -flair_device = os.environ.get("FLAIR_DEVICE") +device_id = os.environ.get("FLAIR_DEVICE") # global variable: device -if torch.cuda.is_available() and flair_device != "cpu": +if torch.cuda.is_available() and device_id != "cpu": # No need for correctness checks, torch is doing it - device = torch.device(f"cuda:{flair_device}") if flair_device else torch.device("cuda:0") + device = torch.device(f"cuda:{device_id}") if device_id else torch.device("cuda:0") else: device = torch.device("cpu")