Skip to content

Commit

Permalink
support sequence parallel (#810)
Browse files Browse the repository at this point in the history
  • Loading branch information
HIT-cwh authored Apr 26, 2024
1 parent 18addcb commit fbf37a4
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 4 deletions.
32 changes: 32 additions & 0 deletions swift/llm/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
import numpy as np
import torch
import torch.distributed as dist
from modelscope import BitsAndBytesConfig, GenerationConfig
from transformers import IntervalStrategy
from transformers.integrations import is_deepspeed_zero3_enabled
Expand All @@ -26,6 +27,17 @@
print_example, set_generation_config, sort_by_max_length,
stat_dataset)

SUPPORT_XTUNER = False

try:
from xtuner.parallel.sequence import *
# datasets is required in Xtuner
from datasets import Dataset
from xtuner.dataset.huggingface import pack_dataset
SUPPORT_XTUNER = True
except ImportError:
pass

logger = get_logger()


Expand Down Expand Up @@ -196,6 +208,25 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
dataset_info['train_dataset'] = stat_dataset(train_dataset)
if val_dataset is not None:
dataset_info['val_dataset'] = stat_dataset(val_dataset)
if args.pack_to_max_length:
assert SUPPORT_XTUNER, \
('Please install XTuner first to pack dataset to `max_length`.'
'`pip install -U \'xtuner[deepspeed]\'`')
if dist.get_rank() == 0:
ds = [i[0] for i in train_dataset.data]
train_dataset = Dataset.from_list(ds)
train_dataset = pack_dataset(
train_dataset,
max_length=args.max_length,
use_varlen_attn=False,
shuffle_before_pack=True,
map_num_proc=16)
objects = [train_dataset]
train_dataset.save_to_disk('alpaca_pack')
else:
objects = [None]
dist.broadcast_object_list(objects, src=0)
train_dataset = objects[0]
else:
dataset_info = None
td0, tkwargs0 = template.encode(train_dataset[0])
Expand Down Expand Up @@ -236,6 +267,7 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
trainer_kwargs['check_model'] = False

trainer = Seq2SeqTrainer(
sequence_parallel_size=args.sequence_parallel_size,
model=model,
args=training_args,
data_collator=data_collator,
Expand Down
12 changes: 12 additions & 0 deletions swift/llm/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@
from .utils import (SftArguments, find_all_linears, find_embedding, find_ln,
is_adapter)

SUPPORT_XTUNER = False

try:
from xtuner.model.modules.dispatch import dispatch_modules
from xtuner.parallel.sequence import *
SUPPORT_XTUNER = True
except ImportError:
pass

logger = get_logger()


Expand Down Expand Up @@ -199,6 +208,9 @@ def prepare_model(model, args: SftArguments):
model.load_state_dict(state_dict, False)
# release memory
del state_dict
if SUPPORT_XTUNER:
dispatch_modules(model)
logger.info('Dispatch modules for sequence parallel.')
else:
raise ValueError(f'args.sft_type: {args.sft_type}')

Expand Down
4 changes: 4 additions & 0 deletions swift/llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,10 @@ class SftArguments(ArgumentsBase):
# fsdp config file
fsdp_config: Optional[str] = None

# xtuner config
sequence_parallel_size: int = 1
pack_to_max_length: bool = False

def handle_dataset_mixture(self, train_dataset: HfDataset) -> None:
if train_dataset is None:
return train_dataset
Expand Down
45 changes: 45 additions & 0 deletions swift/llm/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,17 @@
from swift.torchacc_utils import pad_and_split_batch
from swift.utils import get_dist_setting, use_torchacc

SUPPORT_XTUNER = False

try:
from xtuner.parallel.sequence import (pad_for_sequence_parallel,
split_for_sequence_parallel,
get_sequence_parallel_group,
get_sequence_parallel_world_size)
SUPPORT_XTUNER = True
except ImportError:
pass

DEFAULT_SYSTEM = 'You are a helpful assistant.'
History = List[Union[Tuple[str, str], List[str]]]

Expand Down Expand Up @@ -421,6 +432,31 @@ def _concat_tokenizer_kwargs(
assert len(old_tokenizer_kwargs) == 0
return curr_tokenizer_kwargs

def _pad_and_split_for_sequence_parallel(self, tokenizer, input_ids,
labels, position_ids,
attention_mask, loss_scale):
input_ids = pad_for_sequence_parallel(
input_ids, padding_value=tokenizer.pad_token_id, dim=-1)
labels = pad_for_sequence_parallel(labels, padding_value=-100, dim=-1)
position_ids = pad_for_sequence_parallel(
position_ids, padding_value=0, dim=-1)
attention_mask = pad_for_sequence_parallel(
attention_mask, padding_value=0, dim=-1)

sp_group = get_sequence_parallel_group()
input_ids = split_for_sequence_parallel(
input_ids, dim=1, sp_group=sp_group)
labels = split_for_sequence_parallel(labels, dim=1, sp_group=sp_group)
position_ids = split_for_sequence_parallel(
position_ids, dim=1, sp_group=sp_group)
if loss_scale is not None:
loss_scale = pad_for_sequence_parallel(
loss_scale, padding_value=0., dim=-1)
loss_scale = split_for_sequence_parallel(
loss_scale, dim=1, sp_group=sp_group)

return input_ids, labels, position_ids, attention_mask, loss_scale

def data_collator(self,
batch: List[Dict[str, Any]],
padding_to: Optional[int] = None) -> Dict[str, Any]:
Expand Down Expand Up @@ -470,10 +506,19 @@ def data_collator(self,
padding_to, input_ids, attention_mask, labels, loss_scale,
self.max_length, self.tokenizer, rank, world_size)

bs, seq_len = input_ids.shape
position_ids = torch.arange(seq_len).unsqueeze(0).long().repeat(bs, 1)

if get_sequence_parallel_world_size() > 1:
input_ids, labels, position_ids, attention_mask, loss_scale = \
self._pad_and_split_for_sequence_parallel(
tokenizer, input_ids, labels, position_ids, attention_mask, loss_scale)

res = {
'input_ids': input_ids,
'attention_mask': attention_mask,
'labels': labels,
'position_ids': position_ids,
}
if loss_scale is not None:
res['loss_scale'] = loss_scale
Expand Down
118 changes: 114 additions & 4 deletions swift/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.distributed as dist
from peft import PeftModel
from torch import Tensor, nn
from torch.nn import CrossEntropyLoss
Expand All @@ -14,7 +15,8 @@
from transformers.modeling_utils import unwrap_model
from transformers.models.auto.modeling_auto import \
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.utils import is_peft_available
from transformers.trainer_utils import seed_worker
from transformers.utils import is_peft_available, is_torch_xla_available

from swift.torchacc_utils import (ta_eval_dataloader, ta_test_dataloader,
ta_train_dataloader)
Expand All @@ -28,14 +30,30 @@
except ImportError:
from transformers.deepspeed import is_deepspeed_zero3_enabled

if is_torch_xla_available():
import torch_xla.core.xla_model as xm

SUPPORT_XTUNER = False

try:
from xtuner.parallel.sequence import (init_sequence_parallel,
SequenceParallelSampler,
reduce_sequence_parallel_loss,
get_sequence_parallel_world_size,
get_sequence_parallel_group)
from mmengine.device.utils import get_max_cuda_memory
SUPPORT_XTUNER = True
except ImportError:
pass


class Trainer(PushToMsHubMixin, SwiftMixin, HfTrainer):
pass


class Seq2SeqTrainer(PushToMsHubMixin, SwiftMixin, HfSeq2SeqTrainer):

def __init__(self, *args, **kwargs):
def __init__(self, sequence_parallel_size=1, *args, **kwargs):
super().__init__(*args, **kwargs)
# performance
self.perf: Dict[str, Any] = {
Expand All @@ -49,6 +67,9 @@ def __init__(self, *args, **kwargs):
self.model, 'get_trainable_parameters') else None,
}
self._acc = torch.tensor(0.).to(self.args.device)
if SUPPORT_XTUNER:
self.sequence_parallel_size = sequence_parallel_size
init_sequence_parallel(sequence_parallel_size)

def train(self, *args, **kwargs) -> torch.Tensor:
res = super().train(*args, **kwargs)
Expand Down Expand Up @@ -205,6 +226,7 @@ def compute_scaled_loss(self, labels: torch.Tensor,
return loss.mean()

def compute_loss(self, model, inputs, return_outputs=None):
assert 'labels' in inputs
if not hasattr(self, '_custom_metrics'):
self._custom_metrics = {}

Expand Down Expand Up @@ -240,9 +262,17 @@ def compute_loss(self, model, inputs, return_outputs=None):
else:
loss = outputs['loss'] if isinstance(outputs, dict) else outputs[0]

preds = outputs.logits.argmax(dim=2)[..., :-1]
if labels is None:
labels = inputs['labels']

if SUPPORT_XTUNER:
# reduce loss for logging correctly
num_tokens = (labels != -100).sum()
loss = reduce_sequence_parallel_loss(loss, num_tokens,
get_sequence_parallel_group())

preds = outputs.logits.argmax(dim=2)[..., :-1]

labels = labels[..., 1:]
masks = labels != -100
acc_strategy = getattr(self.args, 'acc_strategy', 'token')
Expand All @@ -266,10 +296,90 @@ def compute_loss(self, model, inputs, return_outputs=None):
'acc'] + acc / self.args.gradient_accumulation_steps
return (loss, outputs) if return_outputs else loss

# Support logging cuda memory usage
# hacky: Override Trainer's private method
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch,
ignore_keys_for_eval):
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
if is_torch_xla_available():
xm.mark_step()

logs: Dict[str, float] = {}

# all_gather + mean() to get average loss over all processes
tr_loss_scalar = self._nested_gather(tr_loss).mean().item()

# reset tr_loss to zero
tr_loss -= tr_loss

logs['loss'] = round(
tr_loss_scalar /
(self.state.global_step - self._globalstep_last_logged), 4)
if grad_norm is not None:
logs['grad_norm'] = grad_norm.detach().item() if isinstance(
grad_norm, torch.Tensor) else grad_norm
logs['learning_rate'] = self._get_learning_rate()
logs['memory'] = get_max_cuda_memory()

self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step
self.store_flos()

self.log(logs)

metrics = None
if self.control.should_evaluate:
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
self._report_to_hp_search(trial, self.state.global_step, metrics)

# Run delayed LR scheduler now that metrics are populated
if isinstance(self.lr_scheduler,
torch.optim.lr_scheduler.ReduceLROnPlateau):
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith('eval_'):
metric_to_check = f'eval_{metric_to_check}'
self.lr_scheduler.step(metrics[metric_to_check])

if self.control.should_save:
self._save_checkpoint(model, trial, metrics=metrics)
self.control = self.callback_handler.on_save(
self.args, self.state, self.control)

def get_train_dataloader(self):

if not use_torchacc():
return super().get_train_dataloader()
# modified from HFTrainer.get_train_dataloader
# RandomSampler -> SequenceParallelSampler
if trainer.is_datasets_available():
import datasets
if self.train_dataset is None:
raise ValueError('Trainer: training requires a train_dataset.')

train_dataset = self.train_dataset
data_collator = self.data_collator
if trainer.is_datasets_available() and isinstance(
train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(
train_dataset, description='training')
else:
data_collator = self._get_collator_with_removed_columns(
data_collator, description='training')

dataloader_params = {
'batch_size': self._train_batch_size,
'collate_fn': data_collator,
'num_workers': self.args.dataloader_num_workers,
'pin_memory': self.args.dataloader_pin_memory,
'persistent_workers': self.args.dataloader_persistent_workers,
}

if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params['sampler'] = SequenceParallelSampler(
train_dataset, seed=1024)
dataloader_params['drop_last'] = self.args.dataloader_drop_last
dataloader_params['worker_init_fn'] = seed_worker

return DataLoader(train_dataset, **dataloader_params)

else:
if trainer.is_datasets_available():
Expand Down

0 comments on commit fbf37a4

Please sign in to comment.