From 5c84682f16402bfa184a14b821cb324eab4e756f Mon Sep 17 00:00:00 2001 From: beep-bebop <41529995+beep-bebop@users.noreply.github.com> Date: Wed, 28 Aug 2024 21:22:07 +0800 Subject: [PATCH] Customise the separator used for splicing in DataCollatorWithFlattening (#33114) * Customising the separator used for splicing in DataCollatorWithFlattening * update DataCollatorWithFlattening docs --------- Co-authored-by: weifangyuan --- src/transformers/data/data_collator.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index fdf8d1e7a96f19..7f982c49cf13ea 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -1619,20 +1619,24 @@ class DataCollatorWithFlattening(DefaultDataCollator): Data collator used for padding free approach. Does the following: - concatate the entire mini batch into single long sequence [1, total_tokens] + - uses `separator_id` to separate sequences within the concatenated `labels`, default value is -100 - no padding will be added, returns `input_ids`, `labels` and `position_ids` """ - def __init__(self, *args, return_position_ids=True, **kwargs): + def __init__(self, *args, return_position_ids=True, separator_id=-100, **kwargs): super().__init__(*args, **kwargs) self.return_position_ids = return_position_ids + self.separator_id = separator_id warnings.warn( "Using `DataCollatorWithFlattening` will flatten the entire mini batch into single long sequence." "Make sure your attention computation is able to handle it!" ) - def __call__(self, features, return_tensors=None): + def __call__(self, features, return_tensors=None, separator_id=None): if return_tensors is None: return_tensors = self.return_tensors + if separator_id is None: + separator_id = self.separator_id is_labels_provided = "labels" in features[0] ret = {"input_ids": [], "labels": []} if self.return_position_ids: @@ -1640,9 +1644,9 @@ def __call__(self, features, return_tensors=None): for idx in range(0, len(features)): ret["input_ids"] += features[idx]["input_ids"] if is_labels_provided: - ret["labels"] += [-100] + features[idx]["labels"][1:] + ret["labels"] += [separator_id] + features[idx]["labels"][1:] else: - ret["labels"] += [-100] + features[idx]["input_ids"][1:] + ret["labels"] += [separator_id] + features[idx]["input_ids"][1:] if self.return_position_ids: ret["position_ids"] += list(range(len(features[idx]["input_ids"]))) return default_data_collator([ret], return_tensors)