Skip to content

Commit

Permalink
[TPU] Support PyTorch/XLA FSDP via SPMD (#28949)
Browse files Browse the repository at this point in the history
* Initial commit

* Add guards for the global mesh

* Address more comments

* Move the dataloader into integrations/tpu.py

* Fix linters

* Make karg more explicitly

* Remove the move device logic

* Fix the CI

* Fix linters

* Re-enable checkpointing
  • Loading branch information
alanwaketan authored Feb 14, 2024
1 parent 0199a48 commit 5f06053
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 7 deletions.
36 changes: 36 additions & 0 deletions src/transformers/integrations/tpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from torch.utils.data import DataLoader

from ..utils import is_torch_tpu_available


def tpu_spmd_dataloader(dataloader: DataLoader):
if is_torch_tpu_available():
import torch_xla.distributed.parallel_loader as pl

assert isinstance(
dataloader, pl.MpDeviceLoader
), "The dataloader must be a `torch_xla.distributed.parallel_loader.MpDeviceLoader`."

# This is to support PyTorch/XLA FSDP via SPMD.
# Here we shard the input data's 0th dim across the fsdp axis.
import torch_xla.distributed.spmd as xs

sharding_spec = xs.ShardingSpec(xs.get_global_mesh(), ("fsdp", None))
dataloader._parallel_loader_kwargs["input_sharding"] = sharding_spec
return dataloader
else:
return dataloader
65 changes: 58 additions & 7 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from .debug_utils import DebugOption, DebugUnderflowOverflow
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
from .integrations.tpu import tpu_spmd_dataloader
from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
Expand Down Expand Up @@ -170,6 +171,8 @@
if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr


if is_sagemaker_mp_enabled():
Expand Down Expand Up @@ -635,6 +638,13 @@ def __init__(
if args.torch_compile and not is_torch_compile_available():
raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.")

self.is_fsdp_xla_v2_enabled = args.fsdp_config["xla_fsdp_v2"]
if self.is_fsdp_xla_v2_enabled:
# Prepare the SPMD mesh that is going to be used by the data loader and the FSDPv2 wrapper.
# Tensor axis is just a placeholder where it will not be used in FSDPv2.
num_devices = xr.global_runtime_device_count()
xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor")))

def _activate_neftune(self, model):
r"""
Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper:
Expand Down Expand Up @@ -1385,6 +1395,11 @@ def _wrap_model(self, model, training=True, dataloader=None):
size_based_auto_wrap_policy,
transformer_auto_wrap_policy,
)

if self.is_fsdp_xla_v2_enabled:
from torch_xla.experimental.spmd_fully_sharded_data_parallel import (
SpmdFullyShardedDataParallel as FSDPv2,
)
except ImportError:
raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
auto_wrap_policy = None
Expand Down Expand Up @@ -1416,15 +1431,40 @@ def _wrap_model(self, model, training=True, dataloader=None):
if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
# Apply gradient checkpointing to auto-wrapped sub-modules if specified
def auto_wrapper_callable(m, *args, **kwargs):
return FSDP(checkpoint_module(m), *args, **kwargs)
target_cls = FSDP if not self.is_fsdp_xla_v2_enabled else FSDPv2
return target_cls(checkpoint_module(m), *args, **kwargs)

# Wrap the base model with an outer FSDP wrapper
self.model = model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
auto_wrapper_callable=auto_wrapper_callable,
**fsdp_kwargs,
)
if self.is_fsdp_xla_v2_enabled:

def shard_output(output, mesh):
from .modeling_outputs import CausalLMOutputWithPast

real_output = None
if isinstance(output, torch.Tensor):
real_output = output
elif isinstance(output, tuple):
real_output = output[0]
elif isinstance(output, CausalLMOutputWithPast):
real_output = output.logits

if real_output is None:
raise ValueError("Something went wrong, the output of the model shouldn't be `None`")
xs.mark_sharding(real_output, mesh, ("fsdp", None, None))

self.model = model = FSDPv2(
model,
shard_output=shard_output,
auto_wrap_policy=auto_wrap_policy,
auto_wrapper_callable=auto_wrapper_callable,
)
else:
self.model = model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
auto_wrapper_callable=auto_wrapper_callable,
**fsdp_kwargs,
)

# Patch `xm.optimizer_step` should not reduce gradients in this case,
# as FSDP does not need gradient reduction over sharded parameters.
Expand Down Expand Up @@ -1593,6 +1633,8 @@ def _inner_training_loop(
logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
# Data loader and number of training steps
train_dataloader = self.get_train_dataloader()
if self.is_fsdp_xla_v2_enabled:
train_dataloader = tpu_spmd_dataloader(train_dataloader)

# Setting up training control variables:
# number of training epochs: num_train_epochs
Expand Down Expand Up @@ -1962,6 +2004,11 @@ def _inner_training_loop(
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)

if self.control.should_epoch_stop or self.control.should_training_stop:
# PyTorch/XLA relies on the data loader to insert the mark_step for
# each step. Since we are breaking the loop early, we need to manually
# insert the mark_step here.
if is_torch_tpu_available():
xm.mark_step()
break
if step < 0:
logger.warning(
Expand Down Expand Up @@ -2945,6 +2992,7 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa

def _save_tpu(self, output_dir: Optional[str] = None):
output_dir = output_dir if output_dir is not None else self.args.output_dir

logger.info(f"Saving model checkpoint to {output_dir}")
model = self.model
model.to("cpu")
Expand Down Expand Up @@ -3143,6 +3191,9 @@ def evaluate(
self._memory_tracker.start()

eval_dataloader = self.get_eval_dataloader(eval_dataset)
if self.is_fsdp_xla_v2_enabled:
eval_dataloader = tpu_spmd_dataloader(eval_dataloader)

start_time = time.time()

eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
Expand Down
1 change: 1 addition & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1684,6 +1684,7 @@ def __post_init__(self):
):
raise ValueError("`min_num_params` and `transformer_layer_cls_to_wrap` are mutually exclusive.")
self.fsdp_config["xla"] = self.fsdp_config.get("xla", False)
self.fsdp_config["xla_fsdp_v2"] = self.fsdp_config.get("xla_fsdp_v2", False)
self.fsdp_config["xla_fsdp_grad_ckpt"] = self.fsdp_config.get("xla_fsdp_grad_ckpt", False)
if self.fsdp_config["xla"]:
if len(self.fsdp) > 0:
Expand Down

0 comments on commit 5f06053

Please sign in to comment.