diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 80232898ce4707..01c5ede34ae83e 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -762,7 +762,7 @@ def torch_int(x): import torch - return x.to(torch.int64) if torch.jit.is_tracing() else int(x) + return x.to(torch.int64) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else int(x) def torch_float(x): @@ -774,7 +774,7 @@ def torch_float(x): import torch - return x.to(torch.float32) if torch.jit.is_tracing() else int(x) + return x.to(torch.float32) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else int(x) def filter_out_non_signature_kwargs(extra: Optional[list] = None):