From 0be998bc7f9393c1672a7ff93fbccd15d5c39f3c Mon Sep 17 00:00:00 2001 From: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> Date: Wed, 3 Jul 2024 11:12:51 +0200 Subject: [PATCH] Requires for torch.tensor before casting (#31755) --- src/transformers/utils/generic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 80232898ce4707..01c5ede34ae83e 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -762,7 +762,7 @@ def torch_int(x): import torch - return x.to(torch.int64) if torch.jit.is_tracing() else int(x) + return x.to(torch.int64) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else int(x) def torch_float(x): @@ -774,7 +774,7 @@ def torch_float(x): import torch - return x.to(torch.float32) if torch.jit.is_tracing() else int(x) + return x.to(torch.float32) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else int(x) def filter_out_non_signature_kwargs(extra: Optional[list] = None):