diff --git a/swift/llm/rlhf.py b/swift/llm/rlhf.py index a2a0c03d7..f6ad596f0 100644 --- a/swift/llm/rlhf.py +++ b/swift/llm/rlhf.py @@ -24,6 +24,7 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]: logger.info(f'args: {args}') seed_everything(args.seed) training_args = args.training_args + streaming = args.streaming if is_torch_npu_available(): print(f'device_count: {torch.npu.device_count()}') else: @@ -183,12 +184,14 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]: template: Template = get_template( args.template_type, tokenizer, args.system, args.max_length, args.truncation_strategy, model=model) - if not template.support_multi_round and 'history' in train_dataset[0]: - logger.info( - 'The current template does not support multi-turn dialogue. The chatml template is used by default. \ -You can also use the --model_type parameter to specify the template.') - template: Template = get_template( - 'chatml', tokenizer, args.system, args.max_length, args.truncation_strategy, model=model) + if not template.support_multi_round: + first_data = train_dataset[0] if not streaming else next(iter(train_dataset)) + if 'history' in first_data: + logger.info( + 'The current template does not support multi-turn dialogue. The chatml template is used by default. \ + You can also use the --model_type parameter to specify the template.') + template: Template = get_template( + 'chatml', tokenizer, args.system, args.max_length, args.truncation_strategy, model=model) args.system = template.default_system logger.info(f'system: {args.system}') diff --git a/swift/llm/sft.py b/swift/llm/sft.py index 0d4dd3a4d..38b7769ce 100644 --- a/swift/llm/sft.py +++ b/swift/llm/sft.py @@ -292,6 +292,8 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]: args.truncation_strategy, model=model, **template_kwargs) + if streaming: + template.encode = partial(template.encode, streaming=streaming) args.system = template.default_system logger.info(f'system: {args.system}') logger.info(f'args.lazy_tokenize: {args.lazy_tokenize}') @@ -326,7 +328,7 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]: raise AttributeError('Failed to access dataset attributes,train_dataset is None. This might be because:\n' '(1) The dataset contains None for input or labels;\n' "(2) The 'max_length' setting is too short causing data truncation.") - td0, tkwargs0 = train_dataset.data[0] if not streaming else next(iter(train_dataset.data)) + td0, tkwargs0 = train_dataset.data[0] if not streaming else next(iter(train_dataset)), {} print_example(td0, tokenizer, tkwargs0) dataset_info['train_dataset'] = stat_dataset(train_dataset) if not streaming else None if val_dataset is not None: @@ -400,7 +402,8 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]: last_model_checkpoint = getattr(trainer.state, 'last_model_checkpoint', None) logger.info(f'last_model_checkpoint: {last_model_checkpoint}') logger.info(f'best_model_checkpoint: {trainer.state.best_model_checkpoint}') - train_time = get_time_info(trainer.state.log_history, len(train_dataset)) + if not streaming: + train_time = get_time_info(trainer.state.log_history, len(train_dataset)) # Visualization if is_master() and not use_torchacc(): if 'tensorboard' in args.training_args.report_to: @@ -412,7 +415,6 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]: trainer.push_to_hub() run_info = { 'memory': trainer.perf['memory'], - 'train_time': train_time, 'last_model_checkpoint': last_model_checkpoint, 'best_model_checkpoint': trainer.state.best_model_checkpoint, 'best_metric': trainer.state.best_metric, @@ -421,6 +423,8 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]: 'model_info': model_info, 'dataset_info': dataset_info, } + if not streaming: + run_info.update({'train_time': train_time}) for key in ['gen_time', 'gen_len']: if trainer.perf[key] != 0: run_info[key] = trainer.perf[key] diff --git a/swift/llm/utils/__init__.py b/swift/llm/utils/__init__.py index 887f7a1b8..a70425175 100644 --- a/swift/llm/utils/__init__.py +++ b/swift/llm/utils/__init__.py @@ -43,9 +43,9 @@ try: if is_lmdeploy_available(): from .lmdeploy_utils import ( + prepare_lmdeploy_engine_template, lmdeploy_context, LmdeployGenerationConfig, - LmdeployGenerationConfig, get_lmdeploy_engine, inference_stream_lmdeploy, inference_lmdeploy, diff --git a/swift/llm/utils/argument.py b/swift/llm/utils/argument.py index 85f3e9398..74e38bcb8 100644 --- a/swift/llm/utils/argument.py +++ b/swift/llm/utils/argument.py @@ -389,7 +389,7 @@ def _handle_dataset_compat(self: Union['SftArguments', 'InferArguments'], train_ train_idxs = random_state.permutation(train_dataset_sample) train_dataset = train_dataset.select(train_idxs) else: - train_dataset = train_dataset.shuffle(seed=self.dataset_seed, buffer_size=self.streaming_buffer_size) + # train_dataset = train_dataset.shuffle(seed=self.dataset_seed, buffer_size=self.streaming_buffer_size) train_dataset = train_dataset.take(train_dataset_sample) if val_dataset_sample is None: @@ -400,33 +400,30 @@ def _handle_dataset_compat(self: Union['SftArguments', 'InferArguments'], train_ val_idxs = random_state.permutation(val_dataset_sample) val_dataset = val_dataset.select(val_idxs) elif streaming: - val_dataset.shuffle(seed=self.dataset_seed, buffer_size=self.streaming_buffer_size) + # val_dataset.shuffle(seed=self.dataset_seed, buffer_size=self.streaming_buffer_size) val_dataset = val_dataset.take(val_dataset_sample) if (train_dataset is None or not hasattr(self, 'train_dataset_mix_ratio') or self.train_dataset_mix_ratio <= 0 or len(self.train_dataset_mix_ds) == 0): return train_dataset, val_dataset - if streaming: - logger.warning('`train_dataset_mix_ds` is not supported in streaming mode.') + mix_dataset_sample = int(len(train_dataset) * self.train_dataset_mix_ratio) + logger.info(f'train_dataset_mix_ds: {self.train_dataset_mix_ds}') + logger.info(f'len(train_dataset): {len(train_dataset)}, mix_dataset_sample: {mix_dataset_sample}') + mixed_dataset = get_dataset( + self.train_dataset_mix_ds, + 0.0, + random_state, + check_dataset_strategy=self.check_dataset_strategy, + streaming=streaming)[0] + if len(mixed_dataset) < mix_dataset_sample: + logger.warn(f'The length of dataset used for mixin: {self.train_dataset_mix_ds} are ' + 'lesser than the ratio required by the `train_dataset_mix_ratio` ' + f'argument: {self.train_dataset_mix_ratio}. ' + f'the actual ratio is: {len(mixed_dataset) / len(train_dataset):.6}.') else: - mix_dataset_sample = int(len(train_dataset) * self.train_dataset_mix_ratio) - logger.info(f'train_dataset_mix_ds: {self.train_dataset_mix_ds}') - logger.info(f'len(train_dataset): {len(train_dataset)}, mix_dataset_sample: {mix_dataset_sample}') - mixed_dataset = get_dataset( - self.train_dataset_mix_ds, - 0.0, - random_state, - check_dataset_strategy=self.check_dataset_strategy, - streaming=streaming)[0] - if len(mixed_dataset) < mix_dataset_sample: - logger.warn(f'The length of dataset used for mixin: {self.train_dataset_mix_ds} are ' - 'lesser than the ratio required by the `train_dataset_mix_ratio` ' - f'argument: {self.train_dataset_mix_ratio}. ' - f'the actual ratio is: {len(mixed_dataset) / len(train_dataset):.6}.') - else: - mixed_dataset = sample_dataset(mixed_dataset, mix_dataset_sample, random_state) - train_dataset = concatenate_datasets([train_dataset, mixed_dataset]) + mixed_dataset = sample_dataset(mixed_dataset, mix_dataset_sample, random_state) + train_dataset = concatenate_datasets([train_dataset, mixed_dataset]) return train_dataset, val_dataset def prepare_template(self: Union['SftArguments', 'InferArguments']): @@ -570,6 +567,10 @@ def _handle_streaming_args(self: Union['SftArguments', 'InferArguments']) -> Non self.lazy_tokenize = False logger.info('lazy_tokenize set to False in streaming dataset') + if hasattr(self, 'train_dataset_mix_ratio') and self.train_dataset_mix_ratio > 0: + logger.warning('train_dataset_mix_ratio is not supported for streaming dataset, set to 0') + self.train_dataset_mix_ratio = 0 + if self.dataset_test_ratio > 0: logger.warning('Since the length of streaming data cannot be estimated,' 'set dataset_test_ratio to 0. You can manually set val_dataset_sample.') @@ -578,6 +579,9 @@ def _handle_streaming_args(self: Union['SftArguments', 'InferArguments']) -> Non if self.train_dataset_sample > 0 or self.val_dataset_sample: logger.warning('The final data size in streaming data may be smaller than train_dataset_sample') + if self.max_steps == -1: + raise ValueError('Please specify `max_steps` in streaming mode.') + @dataclass class SftArguments(ArgumentsBase): @@ -1065,7 +1069,9 @@ def __post_init__(self) -> None: self.lazy_tokenize = template_info.get('lazy_tokenize', False) logger.info(f'Setting args.lazy_tokenize: {self.lazy_tokenize}') if self.dataloader_num_workers is None: - if 'dataloader_num_workers' in template_info: + if self.streaming: + self.dataloader_num_workers = 0 + elif 'dataloader_num_workers' in template_info: self.dataloader_num_workers = template_info['dataloader_num_workers'] elif platform.system() == 'Windows': self.dataloader_num_workers = 0 diff --git a/swift/llm/utils/dataset.py b/swift/llm/utils/dataset.py index 1ed6eba1c..cceefba25 100644 --- a/swift/llm/utils/dataset.py +++ b/swift/llm/utils/dataset.py @@ -422,8 +422,8 @@ def _post_preprocess( if dataset_test_ratio == 1: train_dataset, val_dataset = None, train_dataset if dataset_sample > 0: - train_dataset = train_dataset.shuffle( - seed=get_seed(random_state), buffer_size=16384) # TODO: set buffer_size + # train_dataset = train_dataset.shuffle( + # seed=get_seed(random_state), buffer_size=16384) # TODO: set buffer_size train_dataset = train_dataset.take(dataset_sample) res = [] diff --git a/swift/llm/utils/template.py b/swift/llm/utils/template.py index 115e00716..1f9d47a1f 100644 --- a/swift/llm/utils/template.py +++ b/swift/llm/utils/template.py @@ -439,13 +439,16 @@ def preprocess(self, example): self._preprocess_media(example) return example - def encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def encode(self, example: Dict[str, Any], streaming: bool = False) -> Tuple[Dict[str, Any], Dict[str, Any]]: example = self.preprocess(example) _encode = self._encode if self._is_lmdeploy: assert self.is_multimodal is not None, 'Please use the get_model_tokenizer function.' _encode = MethodType(Template._encode, self) - return _encode(example) + res = _encode(example) + if streaming: + res = res[0] + return res async def prepare_lmdeploy_inputs(self, inputs: Dict[str, Any]) -> None: images = inputs.pop('images', None) or [] diff --git a/swift/llm/utils/utils.py b/swift/llm/utils/utils.py index b4cc7d1b3..6b4ee7fe9 100644 --- a/swift/llm/utils/utils.py +++ b/swift/llm/utils/utils.py @@ -5,6 +5,7 @@ import os import shutil import time +from contextlib import nullcontext from copy import deepcopy from functools import partial, wraps from queue import Empty, Queue @@ -81,7 +82,9 @@ def download_dataset(model_id: str, files: List[str], force_download: bool = Fal @wraps(_old_msdataset_load) def _msdataset_ddp_load(*args, **kwargs): - with safe_ddp_context(): + streaming = kwargs.get('streaming', False) + context = nullcontext() if streaming else safe_ddp_context() + with context: dataset = _old_msdataset_load(*args, **kwargs) return dataset