From 9fab561b9cba87324cf14ae2412cff4a461d642d Mon Sep 17 00:00:00 2001 From: Huazhong Ji Date: Thu, 8 Feb 2024 01:27:01 +0800 Subject: [PATCH] Add npu device for pipeline (#28885) add npu device for pipeline Co-authored-by: unit_test --- src/transformers/pipelines/base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index bfa8e2262ec8d4..9f30665e590d7d 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -41,6 +41,7 @@ is_tf_available, is_torch_available, is_torch_cuda_available, + is_torch_npu_available, is_torch_xpu_available, logging, ) @@ -852,6 +853,8 @@ def __init__( self.device = torch.device("cpu") elif is_torch_cuda_available(): self.device = torch.device(f"cuda:{device}") + elif is_torch_npu_available(): + self.device = torch.device(f"npu:{device}") elif is_torch_xpu_available(check_device=True): self.device = torch.device(f"xpu:{device}") else: