Skip to content

Commit

Permalink
Customise the separator used for splicing in DataCollatorWithFlatteni…
Browse files Browse the repository at this point in the history
…ng (#33114)

* Customising the separator used for splicing in DataCollatorWithFlattening

* update DataCollatorWithFlattening docs

---------

Co-authored-by: weifangyuan <[email protected]>
  • Loading branch information
beep-bebop and weifangyuan authored Aug 28, 2024
1 parent f4c86d0 commit 5c84682
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/transformers/data/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1619,30 +1619,34 @@ class DataCollatorWithFlattening(DefaultDataCollator):
Data collator used for padding free approach. Does the following:
- concatate the entire mini batch into single long sequence [1, total_tokens]
- uses `separator_id` to separate sequences within the concatenated `labels`, default value is -100
- no padding will be added, returns `input_ids`, `labels` and `position_ids`
"""

def __init__(self, *args, return_position_ids=True, **kwargs):
def __init__(self, *args, return_position_ids=True, separator_id=-100, **kwargs):
super().__init__(*args, **kwargs)
self.return_position_ids = return_position_ids
self.separator_id = separator_id
warnings.warn(
"Using `DataCollatorWithFlattening` will flatten the entire mini batch into single long sequence."
"Make sure your attention computation is able to handle it!"
)

def __call__(self, features, return_tensors=None):
def __call__(self, features, return_tensors=None, separator_id=None):
if return_tensors is None:
return_tensors = self.return_tensors
if separator_id is None:
separator_id = self.separator_id
is_labels_provided = "labels" in features[0]
ret = {"input_ids": [], "labels": []}
if self.return_position_ids:
ret.update({"position_ids": []})
for idx in range(0, len(features)):
ret["input_ids"] += features[idx]["input_ids"]
if is_labels_provided:
ret["labels"] += [-100] + features[idx]["labels"][1:]
ret["labels"] += [separator_id] + features[idx]["labels"][1:]
else:
ret["labels"] += [-100] + features[idx]["input_ids"][1:]
ret["labels"] += [separator_id] + features[idx]["input_ids"][1:]
if self.return_position_ids:
ret["position_ids"] += list(range(len(features[idx]["input_ids"])))
return default_data_collator([ret], return_tensors)

0 comments on commit 5c84682

Please sign in to comment.