diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index bfa8e2262ec8d4..9f30665e590d7d 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -41,6 +41,7 @@ is_tf_available, is_torch_available, is_torch_cuda_available, + is_torch_npu_available, is_torch_xpu_available, logging, ) @@ -852,6 +853,8 @@ def __init__( self.device = torch.device("cpu") elif is_torch_cuda_available(): self.device = torch.device(f"cuda:{device}") + elif is_torch_npu_available(): + self.device = torch.device(f"npu:{device}") elif is_torch_xpu_available(check_device=True): self.device = torch.device(f"xpu:{device}") else: