Skip to content

Commit

Permalink
Fix: use torch.from_numpy() to create tensors for np.ndarrays (#33201)
Browse files Browse the repository at this point in the history
use torch.from_numpy for np.ndarrays
  • Loading branch information
shinyano authored Sep 2, 2024
1 parent 2895224 commit cff06aa
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 8 deletions.
7 changes: 5 additions & 2 deletions src/transformers/agents/agent_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(self, value):
elif isinstance(value, torch.Tensor):
self._tensor = value
elif isinstance(value, np.ndarray):
self._tensor = torch.tensor(value)
self._tensor = torch.from_numpy(value)
else:
raise TypeError(f"Unsupported type for {self.__class__.__name__}: {type(value)}")

Expand Down Expand Up @@ -192,7 +192,10 @@ def __init__(self, value, samplerate=16_000):
self._tensor = value
elif isinstance(value, tuple):
self.samplerate = value[0]
self._tensor = torch.tensor(value[1])
if isinstance(value[1], np.ndarray):
self._tensor = torch.from_numpy(value[1])
else:
self._tensor = torch.tensor(value[1])
else:
raise ValueError(f"Unsupported audio type: {type(value)}")

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/agents/document_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def encode(self, document: "Image", question: str):
if isinstance(document, str):
img = Image.open(document).convert("RGB")
img_array = np.array(img).transpose(2, 0, 1)
document = torch.tensor(img_array)
document = torch.from_numpy(img_array)
pixel_values = self.pre_processor(document, return_tensors="pt").pixel_values

return {"decoder_input_ids": decoder_input_ids, "pixel_values": pixel_values}
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/data/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def torch_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any
if isinstance(v, torch.Tensor):
batch[k] = torch.stack([f[k] for f in features])
elif isinstance(v, np.ndarray):
batch[k] = torch.tensor(np.stack([f[k] for f in features]))
batch[k] = torch.from_numpy(np.stack([f[k] for f in features]))
else:
batch[k] = torch.tensor([f[k] for f in features])

Expand Down
5 changes: 4 additions & 1 deletion src/transformers/feature_extraction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,10 @@ def as_tensor(value):
and isinstance(value[0][0], np.ndarray)
):
value = np.array(value)
return torch.tensor(value)
if isinstance(value, np.ndarray):
return torch.from_numpy(value)
else:
return torch.tensor(value)

is_tensor = torch.is_tensor
elif tensor_type == TensorType.JAX:
Expand Down
10 changes: 8 additions & 2 deletions src/transformers/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,9 +579,15 @@ def normalize(self, image, mean, std, rescale=False):
import torch

if not isinstance(mean, torch.Tensor):
mean = torch.tensor(mean)
if isinstance(mean, np.ndarray):
mean = torch.from_numpy(mean)
else:
mean = torch.tensor(mean)
if not isinstance(std, torch.Tensor):
std = torch.tensor(std)
if isinstance(std, np.ndarray):
std = torch.from_numpy(std)
else:
std = torch.tensor(std)

if image.ndim == 3 and image.shape[0] in [1, 3]:
return (image - mean[:, None, None]) / std[:, None, None]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,7 @@ def convert_to_tensors(

def as_tensor(value, dtype=None):
if isinstance(value, list) and isinstance(value[0], np.ndarray):
return torch.tensor(np.array(value))
return torch.from_numpy(np.array(value))
return torch.tensor(value)

elif tensor_type == TensorType.JAX:
Expand Down

0 comments on commit cff06aa

Please sign in to comment.