diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 57532c0c711b85..7672df0b9a0e46 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -52,6 +52,7 @@ find_pruneable_heads_and_indices, id_tensor_storage, is_torch_greater_or_equal_than_1_13, + is_torch_greater_or_equal_than_2_4, prune_conv1d_layer, prune_layer, prune_linear_layer, @@ -5005,6 +5006,8 @@ def tensor_parallel(self, device_mesh): device_mesh (`torch.distributed.DeviceMesh`): The device mesh to use for tensor parallelism. """ + if not is_torch_greater_or_equal_than_2_4: + raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.") # Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module. # No op if `_tp_plan` attribute does not exist under the module. diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index a595f8bc9e1af6..6757f72350ba29 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -20,11 +20,6 @@ from packaging import version from safetensors.torch import storage_ptr, storage_size from torch import nn -from torch.distributed.tensor import Replicate -from torch.distributed.tensor.parallel import ( - ColwiseParallel, - RowwiseParallel, -) from .utils import is_torch_xla_available, logging @@ -44,6 +39,14 @@ is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12") +if is_torch_greater_or_equal_than_2_4: + from torch.distributed.tensor import Replicate + from torch.distributed.tensor.parallel import ( + ColwiseParallel, + RowwiseParallel, + ) + + def softmax_backward_data(parent, grad_output, output, dim, self): """ A function that calls the internal `_softmax_backward_data` PyTorch method and that adjusts the arguments according