Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
hjh0119 committed Aug 5, 2024
1 parent 458168f commit 274b4d3
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 37 deletions.
15 changes: 9 additions & 6 deletions swift/llm/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]:
logger.info(f'args: {args}')
seed_everything(args.seed)
training_args = args.training_args
streaming = args.streaming
if is_torch_npu_available():
print(f'device_count: {torch.npu.device_count()}')
else:
Expand Down Expand Up @@ -183,12 +184,14 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]:

template: Template = get_template(
args.template_type, tokenizer, args.system, args.max_length, args.truncation_strategy, model=model)
if not template.support_multi_round and 'history' in train_dataset[0]:
logger.info(
'The current template does not support multi-turn dialogue. The chatml template is used by default. \
You can also use the --model_type parameter to specify the template.')
template: Template = get_template(
'chatml', tokenizer, args.system, args.max_length, args.truncation_strategy, model=model)
if not template.support_multi_round:
first_data = train_dataset[0] if not streaming else next(iter(train_dataset))
if 'history' in first_data:
logger.info(
'The current template does not support multi-turn dialogue. The chatml template is used by default. \
You can also use the --model_type parameter to specify the template.')
template: Template = get_template(
'chatml', tokenizer, args.system, args.max_length, args.truncation_strategy, model=model)
args.system = template.default_system
logger.info(f'system: {args.system}')

Expand Down
10 changes: 7 additions & 3 deletions swift/llm/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,8 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]:
args.truncation_strategy,
model=model,
**template_kwargs)
if streaming:
template.encode = partial(template.encode, streaming=streaming)
args.system = template.default_system
logger.info(f'system: {args.system}')
logger.info(f'args.lazy_tokenize: {args.lazy_tokenize}')
Expand Down Expand Up @@ -326,7 +328,7 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]:
raise AttributeError('Failed to access dataset attributes,train_dataset is None. This might be because:\n'
'(1) The dataset contains None for input or labels;\n'
"(2) The 'max_length' setting is too short causing data truncation.")
td0, tkwargs0 = train_dataset.data[0] if not streaming else next(iter(train_dataset.data))
td0, tkwargs0 = train_dataset.data[0] if not streaming else next(iter(train_dataset)), {}
print_example(td0, tokenizer, tkwargs0)
dataset_info['train_dataset'] = stat_dataset(train_dataset) if not streaming else None
if val_dataset is not None:
Expand Down Expand Up @@ -400,7 +402,8 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]:
last_model_checkpoint = getattr(trainer.state, 'last_model_checkpoint', None)
logger.info(f'last_model_checkpoint: {last_model_checkpoint}')
logger.info(f'best_model_checkpoint: {trainer.state.best_model_checkpoint}')
train_time = get_time_info(trainer.state.log_history, len(train_dataset))
if not streaming:
train_time = get_time_info(trainer.state.log_history, len(train_dataset))
# Visualization
if is_master() and not use_torchacc():
if 'tensorboard' in args.training_args.report_to:
Expand All @@ -412,7 +415,6 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]:
trainer.push_to_hub()
run_info = {
'memory': trainer.perf['memory'],
'train_time': train_time,
'last_model_checkpoint': last_model_checkpoint,
'best_model_checkpoint': trainer.state.best_model_checkpoint,
'best_metric': trainer.state.best_metric,
Expand All @@ -421,6 +423,8 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]:
'model_info': model_info,
'dataset_info': dataset_info,
}
if not streaming:
run_info.update({'train_time': train_time})
for key in ['gen_time', 'gen_len']:
if trainer.perf[key] != 0:
run_info[key] = trainer.perf[key]
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@
try:
if is_lmdeploy_available():
from .lmdeploy_utils import (
prepare_lmdeploy_engine_template,
lmdeploy_context,
LmdeployGenerationConfig,
LmdeployGenerationConfig,
get_lmdeploy_engine,
inference_stream_lmdeploy,
inference_lmdeploy,
Expand Down
50 changes: 28 additions & 22 deletions swift/llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def _handle_dataset_compat(self: Union['SftArguments', 'InferArguments'], train_
train_idxs = random_state.permutation(train_dataset_sample)
train_dataset = train_dataset.select(train_idxs)
else:
train_dataset = train_dataset.shuffle(seed=self.dataset_seed, buffer_size=self.streaming_buffer_size)
# train_dataset = train_dataset.shuffle(seed=self.dataset_seed, buffer_size=self.streaming_buffer_size)
train_dataset = train_dataset.take(train_dataset_sample)

if val_dataset_sample is None:
Expand All @@ -400,33 +400,30 @@ def _handle_dataset_compat(self: Union['SftArguments', 'InferArguments'], train_
val_idxs = random_state.permutation(val_dataset_sample)
val_dataset = val_dataset.select(val_idxs)
elif streaming:
val_dataset.shuffle(seed=self.dataset_seed, buffer_size=self.streaming_buffer_size)
# val_dataset.shuffle(seed=self.dataset_seed, buffer_size=self.streaming_buffer_size)
val_dataset = val_dataset.take(val_dataset_sample)

if (train_dataset is None or not hasattr(self, 'train_dataset_mix_ratio') or self.train_dataset_mix_ratio <= 0
or len(self.train_dataset_mix_ds) == 0):
return train_dataset, val_dataset

if streaming:
logger.warning('`train_dataset_mix_ds` is not supported in streaming mode.')
mix_dataset_sample = int(len(train_dataset) * self.train_dataset_mix_ratio)
logger.info(f'train_dataset_mix_ds: {self.train_dataset_mix_ds}')
logger.info(f'len(train_dataset): {len(train_dataset)}, mix_dataset_sample: {mix_dataset_sample}')
mixed_dataset = get_dataset(
self.train_dataset_mix_ds,
0.0,
random_state,
check_dataset_strategy=self.check_dataset_strategy,
streaming=streaming)[0]
if len(mixed_dataset) < mix_dataset_sample:
logger.warn(f'The length of dataset used for mixin: {self.train_dataset_mix_ds} are '
'lesser than the ratio required by the `train_dataset_mix_ratio` '
f'argument: {self.train_dataset_mix_ratio}. '
f'the actual ratio is: {len(mixed_dataset) / len(train_dataset):.6}.')
else:
mix_dataset_sample = int(len(train_dataset) * self.train_dataset_mix_ratio)
logger.info(f'train_dataset_mix_ds: {self.train_dataset_mix_ds}')
logger.info(f'len(train_dataset): {len(train_dataset)}, mix_dataset_sample: {mix_dataset_sample}')
mixed_dataset = get_dataset(
self.train_dataset_mix_ds,
0.0,
random_state,
check_dataset_strategy=self.check_dataset_strategy,
streaming=streaming)[0]
if len(mixed_dataset) < mix_dataset_sample:
logger.warn(f'The length of dataset used for mixin: {self.train_dataset_mix_ds} are '
'lesser than the ratio required by the `train_dataset_mix_ratio` '
f'argument: {self.train_dataset_mix_ratio}. '
f'the actual ratio is: {len(mixed_dataset) / len(train_dataset):.6}.')
else:
mixed_dataset = sample_dataset(mixed_dataset, mix_dataset_sample, random_state)
train_dataset = concatenate_datasets([train_dataset, mixed_dataset])
mixed_dataset = sample_dataset(mixed_dataset, mix_dataset_sample, random_state)
train_dataset = concatenate_datasets([train_dataset, mixed_dataset])
return train_dataset, val_dataset

def prepare_template(self: Union['SftArguments', 'InferArguments']):
Expand Down Expand Up @@ -570,6 +567,10 @@ def _handle_streaming_args(self: Union['SftArguments', 'InferArguments']) -> Non
self.lazy_tokenize = False
logger.info('lazy_tokenize set to False in streaming dataset')

if hasattr(self, 'train_dataset_mix_ratio') and self.train_dataset_mix_ratio > 0:
logger.warning('train_dataset_mix_ratio is not supported for streaming dataset, set to 0')
self.train_dataset_mix_ratio = 0

if self.dataset_test_ratio > 0:
logger.warning('Since the length of streaming data cannot be estimated,'
'set dataset_test_ratio to 0. You can manually set val_dataset_sample.')
Expand All @@ -578,6 +579,9 @@ def _handle_streaming_args(self: Union['SftArguments', 'InferArguments']) -> Non
if self.train_dataset_sample > 0 or self.val_dataset_sample:
logger.warning('The final data size in streaming data may be smaller than train_dataset_sample')

if self.max_steps == -1:
raise ValueError('Please specify `max_steps` in streaming mode.')


@dataclass
class SftArguments(ArgumentsBase):
Expand Down Expand Up @@ -1065,7 +1069,9 @@ def __post_init__(self) -> None:
self.lazy_tokenize = template_info.get('lazy_tokenize', False)
logger.info(f'Setting args.lazy_tokenize: {self.lazy_tokenize}')
if self.dataloader_num_workers is None:
if 'dataloader_num_workers' in template_info:
if self.streaming:
self.dataloader_num_workers = 0
elif 'dataloader_num_workers' in template_info:
self.dataloader_num_workers = template_info['dataloader_num_workers']
elif platform.system() == 'Windows':
self.dataloader_num_workers = 0
Expand Down
4 changes: 2 additions & 2 deletions swift/llm/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,8 @@ def _post_preprocess(
if dataset_test_ratio == 1:
train_dataset, val_dataset = None, train_dataset
if dataset_sample > 0:
train_dataset = train_dataset.shuffle(
seed=get_seed(random_state), buffer_size=16384) # TODO: set buffer_size
# train_dataset = train_dataset.shuffle(
# seed=get_seed(random_state), buffer_size=16384) # TODO: set buffer_size
train_dataset = train_dataset.take(dataset_sample)

res = []
Expand Down
7 changes: 5 additions & 2 deletions swift/llm/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,13 +439,16 @@ def preprocess(self, example):
self._preprocess_media(example)
return example

def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
def encode(self, example: Dict[str, Any], streaming: bool = False) -> Tuple[Dict[str, Any], Dict[str, Any]]:
example = self.preprocess(example)
_encode = self._encode
if self._is_lmdeploy:
assert self.is_multimodal is not None, 'Please use the get_model_tokenizer function.'
_encode = MethodType(Template._encode, self)
return _encode(example)
res = _encode(example)
if streaming:
res = res[0]
return res

async def prepare_lmdeploy_inputs(self, inputs: Dict[str, Any]) -> None:
images = inputs.pop('images', None) or []
Expand Down
5 changes: 4 additions & 1 deletion swift/llm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import shutil
import time
from contextlib import nullcontext
from copy import deepcopy
from functools import partial, wraps
from queue import Empty, Queue
Expand Down Expand Up @@ -81,7 +82,9 @@ def download_dataset(model_id: str, files: List[str], force_download: bool = Fal

@wraps(_old_msdataset_load)
def _msdataset_ddp_load(*args, **kwargs):
with safe_ddp_context():
streaming = kwargs.get('streaming', False)
context = nullcontext() if streaming else safe_ddp_context()
with context:
dataset = _old_msdataset_load(*args, **kwargs)
return dataset

Expand Down

0 comments on commit 274b4d3

Please sign in to comment.