diff --git a/src/transformers/commands/env.py b/src/transformers/commands/env.py index da9ca6660be1d5..80d8b05e04e0a3 100644 --- a/src/transformers/commands/env.py +++ b/src/transformers/commands/env.py @@ -133,13 +133,14 @@ def run(self): "JaxLib version": f"{jaxlib_version}", "Using distributed or parallel set-up in script?": "", } - if pt_cuda_available: - info["Using GPU in script?"] = "" - info["GPU type"] = torch.cuda.get_device_name() - elif pt_npu_available: - info["Using NPU in script?"] = "" - info["NPU type"] = torch.npu.get_device_name() - info["CANN version"] = torch.version.cann + if is_torch_available(): + if pt_cuda_available: + info["Using GPU in script?"] = "" + info["GPU type"] = torch.cuda.get_device_name() + elif pt_npu_available: + info["Using NPU in script?"] = "" + info["NPU type"] = torch.npu.get_device_name() + info["CANN version"] = torch.version.cann print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n") print(self.format_dict(info))