Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
hjh0119 committed Aug 5, 2024
1 parent 274b4d3 commit 6ba584e
Show file tree
Hide file tree
Showing 6 changed files with 289 additions and 166 deletions.
2 changes: 2 additions & 0 deletions swift/llm/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]:
trainer_kwargs['is_vision'] = args.is_vision
model.config.model_type += '_' # add suffix to avoid checks in hfDPOTrainer

trainer_kwargs['streaming'] = streaming

trainer = trainer_cls(
model=model,
train_dataset=train_dataset,
Expand Down
73 changes: 38 additions & 35 deletions swift/llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def _handle_dataset_compat(self: Union['SftArguments', 'InferArguments'], train_
train_idxs = random_state.permutation(train_dataset_sample)
train_dataset = train_dataset.select(train_idxs)
else:
# train_dataset = train_dataset.shuffle(seed=self.dataset_seed, buffer_size=self.streaming_buffer_size)
train_dataset = train_dataset.shuffle(seed=self.dataset_seed, buffer_size=self.streaming_buffer_size)
train_dataset = train_dataset.take(train_dataset_sample)

if val_dataset_sample is None:
Expand All @@ -400,7 +400,7 @@ def _handle_dataset_compat(self: Union['SftArguments', 'InferArguments'], train_
val_idxs = random_state.permutation(val_dataset_sample)
val_dataset = val_dataset.select(val_idxs)
elif streaming:
# val_dataset.shuffle(seed=self.dataset_seed, buffer_size=self.streaming_buffer_size)
val_dataset.shuffle(seed=self.dataset_seed, buffer_size=self.streaming_buffer_size)
val_dataset = val_dataset.take(val_dataset_sample)

if (train_dataset is None or not hasattr(self, 'train_dataset_mix_ratio') or self.train_dataset_mix_ratio <= 0
Expand Down Expand Up @@ -552,36 +552,6 @@ def load_from_ckpt_dir(self, is_sft: bool = False) -> None:
if self.val_dataset is None:
self.val_dataset = []

def _handle_streaming_args(self: Union['SftArguments', 'InferArguments']) -> None:
if not self.streaming:
return
if hasattr(self, 'packing') and self.packing:
self.packing = False
logger.warning('Packing is not supported for streaming dataset, set to False')

if hasattr(self, 'test_oom_error') and self.test_oom_error:
self.test_oom_error = False
logger.warning('test_oom_error is not supported for streaming dataset, set to False')

if hasattr(self, 'lazy_tokenize') and self.lazy_tokenize:
self.lazy_tokenize = False
logger.info('lazy_tokenize set to False in streaming dataset')

if hasattr(self, 'train_dataset_mix_ratio') and self.train_dataset_mix_ratio > 0:
logger.warning('train_dataset_mix_ratio is not supported for streaming dataset, set to 0')
self.train_dataset_mix_ratio = 0

if self.dataset_test_ratio > 0:
logger.warning('Since the length of streaming data cannot be estimated,'
'set dataset_test_ratio to 0. You can manually set val_dataset_sample.')
self.dataset_test_ratio = 0

if self.train_dataset_sample > 0 or self.val_dataset_sample:
logger.warning('The final data size in streaming data may be smaller than train_dataset_sample')

if self.max_steps == -1:
raise ValueError('Please specify `max_steps` in streaming mode.')


@dataclass
class SftArguments(ArgumentsBase):
Expand Down Expand Up @@ -1068,10 +1038,9 @@ def __post_init__(self) -> None:
if self.lazy_tokenize is None:
self.lazy_tokenize = template_info.get('lazy_tokenize', False)
logger.info(f'Setting args.lazy_tokenize: {self.lazy_tokenize}')
self._handle_streaming_args()
if self.dataloader_num_workers is None:
if self.streaming:
self.dataloader_num_workers = 0
elif 'dataloader_num_workers' in template_info:
if 'dataloader_num_workers' in template_info:
self.dataloader_num_workers = template_info['dataloader_num_workers']
elif platform.system() == 'Windows':
self.dataloader_num_workers = 0
Expand Down Expand Up @@ -1224,6 +1193,40 @@ def _handle_pai_compat(self) -> None:
self.add_output_dir_suffix = False
logger.info(f'Setting args.add_output_dir_suffix: {self.add_output_dir_suffix}')

def _handle_streaming_args(self) -> None:
if not self.streaming:
return
if self.packing:
self.packing = False
logger.warning('Packing is not supported for streaming dataset, set to False')

if self.test_oom_error:
self.test_oom_error = False
logger.warning('test_oom_error is not supported for streaming dataset, set to False')

if self.lazy_tokenize:
self.lazy_tokenize = False
logger.info('lazy_tokenize set to False in streaming dataset')

if self.train_dataset_mix_ratio > 0:
logger.warning('train_dataset_mix_ratio is not supported for streaming dataset, set to 0')
self.train_dataset_mix_ratio = 0

if self.dataset_test_ratio > 0:
logger.warning('Since the length of streaming data cannot be estimated,'
'set dataset_test_ratio to 0. You can manually set val_dataset_sample.')
self.dataset_test_ratio = 0

if self.train_dataset_sample > 0 or self.val_dataset_sample:
logger.warning('The final data size in streaming data may be smaller than train_dataset_sample')

if self.max_steps == -1:
raise ValueError('Please specify `max_steps` in streaming mode.')

if self.dataloader_num_workers is None or self.dataloader_num_workers > 0:
logger.info('dataloader_num_workers is not supported in streaming mode, set to 0')
self.dataloader_num_workers = 0


@dataclass
class InferArguments(ArgumentsBase):
Expand Down
Loading

0 comments on commit 6ba584e

Please sign in to comment.