From cff06aac6fad28019930be03f5d467055bf62177 Mon Sep 17 00:00:00 2001 From: ZM <12236590+shinyano@users.noreply.github.com> Date: Tue, 3 Sep 2024 00:45:55 +0800 Subject: [PATCH] Fix: use `torch.from_numpy()` to create tensors for np.ndarrays (#33201) use torch.from_numpy for np.ndarrays --- src/transformers/agents/agent_types.py | 7 +++++-- src/transformers/agents/document_question_answering.py | 2 +- src/transformers/data/data_collator.py | 2 +- src/transformers/feature_extraction_utils.py | 5 ++++- src/transformers/image_utils.py | 10 ++++++++-- src/transformers/tokenization_utils_base.py | 2 +- 6 files changed, 20 insertions(+), 8 deletions(-) diff --git a/src/transformers/agents/agent_types.py b/src/transformers/agents/agent_types.py index 114b6de01c3333..4a36eaaee05122 100644 --- a/src/transformers/agents/agent_types.py +++ b/src/transformers/agents/agent_types.py @@ -105,7 +105,7 @@ def __init__(self, value): elif isinstance(value, torch.Tensor): self._tensor = value elif isinstance(value, np.ndarray): - self._tensor = torch.tensor(value) + self._tensor = torch.from_numpy(value) else: raise TypeError(f"Unsupported type for {self.__class__.__name__}: {type(value)}") @@ -192,7 +192,10 @@ def __init__(self, value, samplerate=16_000): self._tensor = value elif isinstance(value, tuple): self.samplerate = value[0] - self._tensor = torch.tensor(value[1]) + if isinstance(value[1], np.ndarray): + self._tensor = torch.from_numpy(value[1]) + else: + self._tensor = torch.tensor(value[1]) else: raise ValueError(f"Unsupported audio type: {type(value)}") diff --git a/src/transformers/agents/document_question_answering.py b/src/transformers/agents/document_question_answering.py index 061dac199fc5b5..030120ac6c7f1e 100644 --- a/src/transformers/agents/document_question_answering.py +++ b/src/transformers/agents/document_question_answering.py @@ -60,7 +60,7 @@ def encode(self, document: "Image", question: str): if isinstance(document, str): img = Image.open(document).convert("RGB") img_array = np.array(img).transpose(2, 0, 1) - document = torch.tensor(img_array) + document = torch.from_numpy(img_array) pixel_values = self.pre_processor(document, return_tensors="pt").pixel_values return {"decoder_input_ids": decoder_input_ids, "pixel_values": pixel_values} diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 7f982c49cf13ea..696cedf47d98a0 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -153,7 +153,7 @@ def torch_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any if isinstance(v, torch.Tensor): batch[k] = torch.stack([f[k] for f in features]) elif isinstance(v, np.ndarray): - batch[k] = torch.tensor(np.stack([f[k] for f in features])) + batch[k] = torch.from_numpy(np.stack([f[k] for f in features])) else: batch[k] = torch.tensor([f[k] for f in features]) diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index cda7808f34853e..3590d9da98870b 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -146,7 +146,10 @@ def as_tensor(value): and isinstance(value[0][0], np.ndarray) ): value = np.array(value) - return torch.tensor(value) + if isinstance(value, np.ndarray): + return torch.from_numpy(value) + else: + return torch.tensor(value) is_tensor = torch.is_tensor elif tensor_type == TensorType.JAX: diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index 4449b602491ad9..1a70ef05638379 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -579,9 +579,15 @@ def normalize(self, image, mean, std, rescale=False): import torch if not isinstance(mean, torch.Tensor): - mean = torch.tensor(mean) + if isinstance(mean, np.ndarray): + mean = torch.from_numpy(mean) + else: + mean = torch.tensor(mean) if not isinstance(std, torch.Tensor): - std = torch.tensor(std) + if isinstance(std, np.ndarray): + std = torch.from_numpy(std) + else: + std = torch.tensor(std) if image.ndim == 3 and image.shape[0] in [1, 3]: return (image - mean[:, None, None]) / std[:, None, None] diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 608c6516668b29..dc0af00cedeb3b 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -721,7 +721,7 @@ def convert_to_tensors( def as_tensor(value, dtype=None): if isinstance(value, list) and isinstance(value[0], np.ndarray): - return torch.tensor(np.array(value)) + return torch.from_numpy(np.array(value)) return torch.tensor(value) elif tensor_type == TensorType.JAX: