diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 179c5f508fe877..aec91c75559828 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -585,51 +585,84 @@ class DataCollatorForSeq2Seq: def __call__(self, features, return_tensors=None): if return_tensors is None: return_tensors = self.return_tensors - labels = [feature["labels"] for feature in features] if "labels" in features[0].keys() else None - # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the - # same length to return tensors. - no_padding = self.padding is False or self.padding == PaddingStrategy.DO_NOT_PAD - if labels is not None and not no_padding: - max_padding = self.padding == PaddingStrategy.MAX_LENGTH and self.max_length is not None - max_label_length = max(len(l) for l in labels) if not max_padding else self.max_length - if self.pad_to_multiple_of is not None: - max_label_length = ( - (max_label_length + self.pad_to_multiple_of - 1) - // self.pad_to_multiple_of - * self.pad_to_multiple_of - ) - - padding_side = self.tokenizer.padding_side - for feature in features: - remainder = [self.label_pad_token_id] * (max_label_length - len(feature["labels"])) - if isinstance(feature["labels"], list): - feature["labels"] = ( - feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"] - ) - elif padding_side == "right": - feature["labels"] = np.concatenate([feature["labels"], remainder]).astype(np.int64) - else: - feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64) - features = pad_without_fast_tokenizer_warning( + label_name = "label" if "label" in features[0].keys() else "labels" + labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None + # reconvert list[None] to None if necessary + # this might occur when we pass {..., "labels": None} + if labels is not None and all(label is None for label in labels): + labels = None + non_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features] + + # run through tokenizer without labels to ensure no side effects + batch = pad_without_fast_tokenizer_warning( self.tokenizer, - features, + non_labels_features, padding=self.padding, max_length=self.max_length, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors=return_tensors, ) + # we have to pad the labels manually as we cannot rely on `tokenizer.pad` and we need them to be of the same length to return tensors + no_padding = self.padding is False or self.padding == PaddingStrategy.DO_NOT_PAD + if labels is not None: + if no_padding: + if isinstance(features[0][label_name], list): + batch["labels"] = list(labels) + else: + batch["labels"] = [np.concatenate([label, []]) for label in labels] + else: + max_padding = self.padding == PaddingStrategy.MAX_LENGTH and self.max_length is not None + max_label_length = max(len(l) for l in labels) if not max_padding else self.max_length + if self.pad_to_multiple_of is not None: + max_label_length = ( + (max_label_length + self.pad_to_multiple_of - 1) + // self.pad_to_multiple_of + * self.pad_to_multiple_of + ) + + padding_side = self.tokenizer.padding_side + if isinstance(features[0][label_name], list): + batch["labels"] = [ + label + [self.label_pad_token_id] * (max_label_length - len(label)) + if padding_side == "right" + else [self.label_pad_token_id] * (max_label_length - len(label)) + label + for label in labels + ] + else: + batch["labels"] = [ + np.concatenate([label, [self.label_pad_token_id] * (max_label_length - len(label))]) + if padding_side == "right" + else np.concatenate([[self.label_pad_token_id] * (max_label_length - len(label)), label]) + for label in labels + ] + + # reintroduce side effects via tokenizer that return respective datatypes for the `return_tensors` argument + if batch.get("labels", None) is not None: + if return_tensors == "pt": + import torch + + batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64) + elif return_tensors == "tf": + import tensorflow as tf + + batch["labels"] = tf.constant(batch["labels"], dtype=tf.int64) + else: + batch["labels"] = np.array(batch["labels"], dtype=np.int64) + else: + batch["labels"] = None + # prepare decoder_input_ids if ( labels is not None and self.model is not None and hasattr(self.model, "prepare_decoder_input_ids_from_labels") ): - decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=features["labels"]) - features["decoder_input_ids"] = decoder_input_ids + decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=batch["labels"]) + batch["decoder_input_ids"] = decoder_input_ids - return features + return batch @dataclass diff --git a/tests/trainer/test_data_collator.py b/tests/trainer/test_data_collator.py index 8a121a9f83d073..36e1813258d1a3 100644 --- a/tests/trainer/test_data_collator.py +++ b/tests/trainer/test_data_collator.py @@ -439,6 +439,330 @@ def test_sop(self): self.assertEqual(batch["sentence_order_label"].shape, torch.Size((2,))) +@require_torch +class DataCollatorImmutabilityTest(unittest.TestCase): + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + + vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"] + self.vocab_file = os.path.join(self.tmpdirname, "vocab.txt") + with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: + vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def _turn_to_none(self, item): + """used to convert `item` to `None` type""" + return None + + def _validate_original_data_against_collated_data(self, collator, original_data, batch_data): + # we only care about side effects, the results are tested elsewhere + collator(batch_data) + + # we go through every item and convert to `primitive` datatypes if necessary + # then compares for equivalence for the original data and the data that has been passed through the collator + for original, batch in zip(original_data, batch_data): + for original_val, batch_val in zip(original.values(), batch.values()): + if isinstance(original_val, (np.ndarray, torch.Tensor)): + self.assertEqual(original_val.tolist(), batch_val.tolist()) + else: + self.assertEqual(original_val, batch_val) + + def _validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + self, collator, base_data, input_key, input_datatype, label_key, label_datatype, ignore_label=False + ): + # using the arguments to recreate the features with their respective (potentially new) datatypes + features_original = [ + {label_key: label_datatype(sample[label_key]), input_key: input_datatype(sample[input_key])} + for sample in base_data + ] + features_batch = [ + {label_key: label_datatype(sample[label_key]), input_key: input_datatype(sample[input_key])} + for sample in base_data + ] + + # some collators do not use labels, or sometimes we want to check if the collator with labels can handle such cases + if ignore_label: + for original, batch in zip(features_original, features_batch): + original.pop(label_key) + batch.pop(label_key) + + self._validate_original_data_against_collated_data( + collator=collator, original_data=features_original, batch_data=features_batch + ) + + def test_default_collator_immutability(self): + features_base_single_label = [{"label": i, "inputs": (0, 1, 2, 3, 4, 5)} for i in range(4)] + features_base_multiple_labels = [{"label": (0, 1, 2), "inputs": (0, 1, 2, 3, 4, 5)} for i in range(4)] + + for datatype_input, datatype_label in [ + (list, int), + (list, float), + (np.array, int), + (np.array, torch.tensor), + (list, self._turn_to_none), + ]: + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=default_data_collator, + base_data=features_base_single_label, + input_key="inputs", + input_datatype=datatype_input, + label_key="label", + label_datatype=datatype_label, + ) + + for datatype_input, datatype_label in [(list, list), (list, self._turn_to_none)]: + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=default_data_collator, + base_data=features_base_multiple_labels, + input_key="inputs", + input_datatype=datatype_input, + label_key="label", + label_datatype=datatype_label, + ) + + features_base_single_label_alt = [{"input_ids": (0, 1, 2, 3, 4), "label": float(i)} for i in range(4)] + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=default_data_collator, + base_data=features_base_single_label_alt, + input_key="input_ids", + input_datatype=list, + label_key="label", + label_datatype=float, + ) + + def test_with_padding_collator_immutability(self): + tokenizer = BertTokenizer(self.vocab_file) + + features_original = [{"input_ids": [0, 1, 2]}, {"input_ids": [0, 1, 2, 3, 4, 5]}] + features_batch = [{"input_ids": [0, 1, 2]}, {"input_ids": [0, 1, 2, 3, 4, 5]}] + + data_collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=10) + self._validate_original_data_against_collated_data( + collator=data_collator, original_data=features_original, batch_data=features_batch + ) + + data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) + self._validate_original_data_against_collated_data( + collator=data_collator, original_data=features_original, batch_data=features_batch + ) + + def test_for_token_classification_collator_immutability(self): + tokenizer = BertTokenizer(self.vocab_file) + + features_base = [ + {"input_ids": (0, 1, 2), "labels": (0, 1, 2)}, + {"input_ids": (0, 1, 2, 3, 4, 5), "labels": (0, 1, 2, 3, 4, 5)}, + ] + token_classification_collators = [ + DataCollatorForTokenClassification(tokenizer), + DataCollatorForTokenClassification(tokenizer, padding="max_length", max_length=10), + DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8), + DataCollatorForTokenClassification(tokenizer, label_pad_token_id=-1), + ] + + for datatype_input, datatype_label in [(list, list), (torch.tensor, torch.tensor)]: + for collator in token_classification_collators: + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=collator, + base_data=features_base, + input_key="input_ids", + input_datatype=datatype_input, + label_key="labels", + label_datatype=datatype_label, + ) + + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=token_classification_collators[-1], + base_data=features_base, + input_key="input_ids", + input_datatype=datatype_input, + label_key="labels", + label_datatype=datatype_label, + ignore_label=True, + ) + + def test_seq2seq_collator_immutability(self): + tokenizer = BertTokenizer(self.vocab_file) + + features_base = [ + {"input_ids": list(range(3)), "labels": list(range(3))}, + {"input_ids": list(range(6)), "labels": list(range(6))}, + ] + seq2seq_collators = [ + DataCollatorForSeq2Seq(tokenizer, padding=PaddingStrategy.LONGEST), + DataCollatorForSeq2Seq(tokenizer, padding=PaddingStrategy.MAX_LENGTH, max_length=7), + DataCollatorForSeq2Seq(tokenizer, padding=PaddingStrategy.LONGEST, pad_to_multiple_of=8), + DataCollatorForSeq2Seq(tokenizer, padding=PaddingStrategy.LONGEST, label_pad_token_id=-1), + ] + + for datatype_input, datatype_label in [(list, list), (torch.tensor, torch.tensor)]: + for collator in seq2seq_collators: + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=collator, + base_data=features_base, + input_key="input_ids", + input_datatype=datatype_input, + label_key="labels", + label_datatype=datatype_label, + ) + + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=seq2seq_collators[-1], + base_data=features_base, + input_key="input_ids", + input_datatype=datatype_input, + label_key="labels", + label_datatype=datatype_label, + ignore_label=True, + ) + + features_base_no_pad = [ + {"input_ids": list(range(3)), "labels": list(range(3))}, + {"input_ids": list(range(3)), "labels": list(range(3))}, + ] + seq2seq_no_padding_collator = DataCollatorForSeq2Seq(tokenizer, padding=PaddingStrategy.DO_NOT_PAD) + for datatype_input, datatype_label in [(list, list), (torch.tensor, torch.tensor)]: + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=seq2seq_no_padding_collator, + base_data=features_base_no_pad, + input_key="input_ids", + input_datatype=datatype_input, + label_key="labels", + label_datatype=datatype_label, + ) + + def test_language_modelling_collator_immutability(self): + tokenizer = BertTokenizer(self.vocab_file) + + features_base_no_pad = [ + {"input_ids": tuple(range(10)), "labels": (1,)}, + {"input_ids": tuple(range(10)), "labels": (1,)}, + ] + features_base_pad = [ + {"input_ids": tuple(range(5)), "labels": (1,)}, + {"input_ids": tuple(range(5)), "labels": (1,)}, + ] + lm_collators = [ + DataCollatorForLanguageModeling(tokenizer, mlm=False), + DataCollatorForLanguageModeling(tokenizer, mlm=False, pad_to_multiple_of=8), + DataCollatorForLanguageModeling(tokenizer), + DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8), + ] + + for datatype_input, datatype_label in [(list, list), (torch.tensor, torch.tensor)]: + for collator in lm_collators: + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=collator, + base_data=features_base_no_pad, + input_key="input_ids", + input_datatype=datatype_input, + label_key="labels", + label_datatype=datatype_label, + ignore_label=True, + ) + + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=collator, + base_data=features_base_pad, + input_key="input_ids", + input_datatype=datatype_input, + label_key="labels", + label_datatype=datatype_label, + ignore_label=True, + ) + + def test_whole_world_masking_collator_immutability(self): + tokenizer = BertTokenizer(self.vocab_file) + + features_base = [ + {"input_ids": list(range(10)), "labels": (1,)}, + {"input_ids": list(range(10)), "labels": (1,)}, + ] + whole_word_masking_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="pt") + + for datatype_input, datatype_label in [(list, list), (np.array, np.array)]: + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=whole_word_masking_collator, + base_data=features_base, + input_key="input_ids", + input_datatype=datatype_input, + label_key="labels", + label_datatype=datatype_label, + ignore_label=True, + ) + + def test_permutation_language_modelling_collator_immutability(self): + tokenizer = BertTokenizer(self.vocab_file) + + plm_collator = DataCollatorForPermutationLanguageModeling(tokenizer) + + no_pad_features_original = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}] + no_pad_features_batch = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}] + self._validate_original_data_against_collated_data( + collator=plm_collator, original_data=no_pad_features_original, batch_data=no_pad_features_batch + ) + + pad_features_original = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}] + pad_features_batch = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}] + self._validate_original_data_against_collated_data( + collator=plm_collator, original_data=pad_features_original, batch_data=pad_features_batch + ) + + def test_next_sentence_prediction_collator_immutability(self): + tokenizer = BertTokenizer(self.vocab_file) + + features_original = [ + {"input_ids": [0, 1, 2, 3, 4], "token_type_ids": [0, 1, 2, 3, 4], "next_sentence_label": i} + for i in range(2) + ] + features_batch = [ + {"input_ids": [0, 1, 2, 3, 4], "token_type_ids": [0, 1, 2, 3, 4], "next_sentence_label": i} + for i in range(2) + ] + + nsp_collator = DataCollatorForLanguageModeling(tokenizer) + self._validate_original_data_against_collated_data( + collator=nsp_collator, original_data=features_original, batch_data=features_batch + ) + + nsp_collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8) + self._validate_original_data_against_collated_data( + collator=nsp_collator, original_data=features_original, batch_data=features_batch + ) + + def test_sentence_order_prediction_collator_immutability(self): + tokenizer = BertTokenizer(self.vocab_file) + + features_original = [ + { + "input_ids": torch.tensor([0, 1, 2, 3, 4]), + "token_type_ids": torch.tensor([0, 1, 2, 3, 4]), + "sentence_order_label": i, + } + for i in range(2) + ] + features_batch = [ + { + "input_ids": torch.tensor([0, 1, 2, 3, 4]), + "token_type_ids": torch.tensor([0, 1, 2, 3, 4]), + "sentence_order_label": i, + } + for i in range(2) + ] + + sop_collator = DataCollatorForLanguageModeling(tokenizer) + self._validate_original_data_against_collated_data( + collator=sop_collator, original_data=features_original, batch_data=features_batch + ) + + sop_collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8) + self._validate_original_data_against_collated_data( + collator=sop_collator, original_data=features_original, batch_data=features_batch + ) + + @require_tf class TFDataCollatorIntegrationTest(unittest.TestCase): def setUp(self): @@ -794,6 +1118,338 @@ def test_sop(self): self.assertEqual(batch["sentence_order_label"].shape.as_list(), [2]) +@require_tf +class TFDataCollatorImmutabilityTest(unittest.TestCase): + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + + vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"] + self.vocab_file = os.path.join(self.tmpdirname, "vocab.txt") + with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: + vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def _turn_to_none(self, item): + """used to convert `item` to `None` type""" + return None + + def _validate_original_data_against_collated_data(self, collator, original_data, batch_data): + # we only care about side effects, the results are tested elsewhere + collator(batch_data) + + # we go through every item and convert to `primitive` datatypes if necessary + # then compares for equivalence for the original data and the data that has been passed through the collator + for original, batch in zip(original_data, batch_data): + for original_val, batch_val in zip(original.values(), batch.values()): + if isinstance(original_val, np.ndarray): + self.assertEqual(original_val.tolist(), batch_val.tolist()) + elif isinstance(original_val, tf.Tensor): + self.assertEqual(original_val.numpy().tolist(), batch_val.numpy().tolist()) + else: + self.assertEqual(original_val, batch_val) + + def _validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + self, collator, base_data, input_key, input_datatype, label_key, label_datatype, ignore_label=False + ): + # using the arguments to recreate the features with their respective (potentially new) datatypes + features_original = [ + {label_key: label_datatype(sample[label_key]), input_key: input_datatype(sample[input_key])} + for sample in base_data + ] + features_batch = [ + {label_key: label_datatype(sample[label_key]), input_key: input_datatype(sample[input_key])} + for sample in base_data + ] + + # some collators do not use labels, or sometimes we want to check if the collator with labels can handle such cases + if ignore_label: + for original, batch in zip(features_original, features_batch): + original.pop(label_key) + batch.pop(label_key) + + self._validate_original_data_against_collated_data( + collator=collator, original_data=features_original, batch_data=features_batch + ) + + def test_default_collator_immutability(self): + features_base_single_label = [{"label": i, "inputs": (0, 1, 2, 3, 4, 5)} for i in range(4)] + features_base_multiple_labels = [{"label": (0, 1, 2), "inputs": (0, 1, 2, 3, 4, 5)} for i in range(4)] + + for datatype_input, datatype_label in [ + (list, int), + (list, float), + (np.array, int), + (np.array, tf.constant), + (list, self._turn_to_none), + ]: + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=lambda x: default_data_collator(x, return_tensors="tf"), + base_data=features_base_single_label, + input_key="inputs", + input_datatype=datatype_input, + label_key="label", + label_datatype=datatype_label, + ) + + for datatype_input, datatype_label in [(list, list), (list, self._turn_to_none)]: + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=lambda x: default_data_collator(x, return_tensors="tf"), + base_data=features_base_multiple_labels, + input_key="inputs", + input_datatype=datatype_input, + label_key="label", + label_datatype=datatype_label, + ) + + features_base_single_label_alt = [{"input_ids": (0, 1, 2, 3, 4), "label": float(i)} for i in range(4)] + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=lambda x: default_data_collator(x, return_tensors="tf"), + base_data=features_base_single_label_alt, + input_key="input_ids", + input_datatype=list, + label_key="label", + label_datatype=float, + ) + + def test_with_padding_collator_immutability(self): + tokenizer = BertTokenizer(self.vocab_file) + + features_original = [{"input_ids": [0, 1, 2]}, {"input_ids": [0, 1, 2, 3, 4, 5]}] + features_batch = [{"input_ids": [0, 1, 2]}, {"input_ids": [0, 1, 2, 3, 4, 5]}] + + data_collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=10, return_tensors="tf") + self._validate_original_data_against_collated_data( + collator=data_collator, original_data=features_original, batch_data=features_batch + ) + + data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8, return_tensors="tf") + self._validate_original_data_against_collated_data( + collator=data_collator, original_data=features_original, batch_data=features_batch + ) + + def test_for_token_classification_collator_immutability(self): + tokenizer = BertTokenizer(self.vocab_file) + + features_base = [ + {"input_ids": (0, 1, 2), "labels": (0, 1, 2)}, + {"input_ids": (0, 1, 2, 3, 4, 5), "labels": (0, 1, 2, 3, 4, 5)}, + ] + token_classification_collators = [ + DataCollatorForTokenClassification(tokenizer, return_tensors="tf"), + DataCollatorForTokenClassification(tokenizer, padding="max_length", max_length=10, return_tensors="tf"), + DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8, return_tensors="tf"), + DataCollatorForTokenClassification(tokenizer, label_pad_token_id=-1, return_tensors="tf"), + ] + + for datatype_input, datatype_label in [(list, list)]: + for collator in token_classification_collators: + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=collator, + base_data=features_base, + input_key="input_ids", + input_datatype=datatype_input, + label_key="labels", + label_datatype=datatype_label, + ) + + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=token_classification_collators[-1], + base_data=features_base, + input_key="input_ids", + input_datatype=datatype_input, + label_key="labels", + label_datatype=datatype_label, + ignore_label=True, + ) + + def test_seq2seq_collator_immutability(self): + tokenizer = BertTokenizer(self.vocab_file) + + features_base = [ + {"input_ids": list(range(3)), "labels": list(range(3))}, + {"input_ids": list(range(6)), "labels": list(range(6))}, + ] + seq2seq_collators = [ + DataCollatorForSeq2Seq(tokenizer, padding=PaddingStrategy.LONGEST, return_tensors="tf"), + DataCollatorForSeq2Seq(tokenizer, padding=PaddingStrategy.MAX_LENGTH, max_length=7, return_tensors="tf"), + DataCollatorForSeq2Seq( + tokenizer, padding=PaddingStrategy.LONGEST, pad_to_multiple_of=8, return_tensors="tf" + ), + DataCollatorForSeq2Seq( + tokenizer, padding=PaddingStrategy.LONGEST, label_pad_token_id=-1, return_tensors="tf" + ), + ] + + for datatype_input, datatype_label in [(list, list)]: + for collator in seq2seq_collators: + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=collator, + base_data=features_base, + input_key="input_ids", + input_datatype=datatype_input, + label_key="labels", + label_datatype=datatype_label, + ) + + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=seq2seq_collators[-1], + base_data=features_base, + input_key="input_ids", + input_datatype=datatype_input, + label_key="labels", + label_datatype=datatype_label, + ignore_label=True, + ) + + features_base_no_pad = [ + {"input_ids": list(range(3)), "labels": list(range(3))}, + {"input_ids": list(range(3)), "labels": list(range(3))}, + ] + seq2seq_no_padding_collator = DataCollatorForSeq2Seq( + tokenizer, padding=PaddingStrategy.DO_NOT_PAD, return_tensors="tf" + ) + for datatype_input, datatype_label in [(list, list)]: + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=seq2seq_no_padding_collator, + base_data=features_base_no_pad, + input_key="input_ids", + input_datatype=datatype_input, + label_key="labels", + label_datatype=datatype_label, + ) + + def test_language_modelling_collator_immutability(self): + tokenizer = BertTokenizer(self.vocab_file) + + features_base_no_pad = [ + {"input_ids": tuple(range(10)), "labels": (1,)}, + {"input_ids": tuple(range(10)), "labels": (1,)}, + ] + features_base_pad = [ + {"input_ids": tuple(range(5)), "labels": (1,)}, + {"input_ids": tuple(range(5)), "labels": (1,)}, + ] + lm_collators = [ + DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors="tf"), + DataCollatorForLanguageModeling(tokenizer, mlm=False, pad_to_multiple_of=8, return_tensors="tf"), + DataCollatorForLanguageModeling(tokenizer, return_tensors="tf"), + DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8, return_tensors="tf"), + ] + + for datatype_input, datatype_label in [(list, list)]: + for collator in lm_collators: + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=collator, + base_data=features_base_no_pad, + input_key="input_ids", + input_datatype=datatype_input, + label_key="labels", + label_datatype=datatype_label, + ignore_label=True, + ) + + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=collator, + base_data=features_base_pad, + input_key="input_ids", + input_datatype=datatype_input, + label_key="labels", + label_datatype=datatype_label, + ignore_label=True, + ) + + def test_whole_world_masking_collator_immutability(self): + tokenizer = BertTokenizer(self.vocab_file) + + features_base = [ + {"input_ids": list(range(10)), "labels": (1,)}, + {"input_ids": list(range(10)), "labels": (1,)}, + ] + whole_word_masking_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="tf") + + for datatype_input, datatype_label in [(list, list), (np.array, np.array)]: + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=whole_word_masking_collator, + base_data=features_base, + input_key="input_ids", + input_datatype=datatype_input, + label_key="labels", + label_datatype=datatype_label, + ignore_label=True, + ) + + def test_permutation_language_modelling_collator_immutability(self): + tokenizer = BertTokenizer(self.vocab_file) + + plm_collator = DataCollatorForPermutationLanguageModeling(tokenizer, return_tensors="tf") + + no_pad_features_original = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}] + no_pad_features_batch = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}] + self._validate_original_data_against_collated_data( + collator=plm_collator, original_data=no_pad_features_original, batch_data=no_pad_features_batch + ) + + pad_features_original = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}] + pad_features_batch = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}] + self._validate_original_data_against_collated_data( + collator=plm_collator, original_data=pad_features_original, batch_data=pad_features_batch + ) + + def test_next_sentence_prediction_collator_immutability(self): + tokenizer = BertTokenizer(self.vocab_file) + + features_original = [ + {"input_ids": [0, 1, 2, 3, 4], "token_type_ids": [0, 1, 2, 3, 4], "next_sentence_label": i} + for i in range(2) + ] + features_batch = [ + {"input_ids": [0, 1, 2, 3, 4], "token_type_ids": [0, 1, 2, 3, 4], "next_sentence_label": i} + for i in range(2) + ] + + nsp_collator = DataCollatorForLanguageModeling(tokenizer, return_tensors="tf") + self._validate_original_data_against_collated_data( + collator=nsp_collator, original_data=features_original, batch_data=features_batch + ) + + nsp_collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8, return_tensors="tf") + self._validate_original_data_against_collated_data( + collator=nsp_collator, original_data=features_original, batch_data=features_batch + ) + + def test_sentence_order_prediction_collator_immutability(self): + tokenizer = BertTokenizer(self.vocab_file) + + features_original = [ + { + "input_ids": tf.convert_to_tensor([0, 1, 2, 3, 4]), + "token_type_ids": tf.convert_to_tensor([0, 1, 2, 3, 4]), + "sentence_order_label": i, + } + for i in range(2) + ] + features_batch = [ + { + "input_ids": tf.convert_to_tensor([0, 1, 2, 3, 4]), + "token_type_ids": tf.convert_to_tensor([0, 1, 2, 3, 4]), + "sentence_order_label": i, + } + for i in range(2) + ] + + sop_collator = DataCollatorForLanguageModeling(tokenizer, return_tensors="tf") + self._validate_original_data_against_collated_data( + collator=sop_collator, original_data=features_original, batch_data=features_batch + ) + + sop_collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8, return_tensors="tf") + self._validate_original_data_against_collated_data( + collator=sop_collator, original_data=features_original, batch_data=features_batch + ) + + class NumpyDataCollatorIntegrationTest(unittest.TestCase): def setUp(self): self.tmpdirname = tempfile.mkdtemp() @@ -1137,3 +1793,332 @@ def test_sop(self): self.assertEqual(batch["token_type_ids"].shape, (2, 8)) self.assertEqual(batch["labels"].shape, (2, 8)) self.assertEqual(batch["sentence_order_label"].shape, (2,)) + + +class NumpyDataCollatorImmutabilityTest(unittest.TestCase): + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + + vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"] + self.vocab_file = os.path.join(self.tmpdirname, "vocab.txt") + with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: + vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def _turn_to_none(self, item): + """used to convert `item` to `None` type""" + return None + + def _validate_original_data_against_collated_data(self, collator, original_data, batch_data): + # we only care about side effects, the results are tested elsewhere + collator(batch_data) + + # we go through every item and convert to `primitive` datatypes if necessary + # then compares for equivalence for the original data and the data that has been passed through the collator + for original, batch in zip(original_data, batch_data): + for original_val, batch_val in zip(original.values(), batch.values()): + if isinstance(original_val, np.ndarray): + self.assertEqual(original_val.tolist(), batch_val.tolist()) + else: + self.assertEqual(original_val, batch_val) + + def _validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + self, collator, base_data, input_key, input_datatype, label_key, label_datatype, ignore_label=False + ): + # using the arguments to recreate the features with their respective (potentially new) datatypes + features_original = [ + {label_key: label_datatype(sample[label_key]), input_key: input_datatype(sample[input_key])} + for sample in base_data + ] + features_batch = [ + {label_key: label_datatype(sample[label_key]), input_key: input_datatype(sample[input_key])} + for sample in base_data + ] + + # some collators do not use labels, or sometimes we want to check if the collator with labels can handle such cases + if ignore_label: + for original, batch in zip(features_original, features_batch): + original.pop(label_key) + batch.pop(label_key) + + self._validate_original_data_against_collated_data( + collator=collator, original_data=features_original, batch_data=features_batch + ) + + def test_default_collator_immutability(self): + features_base_single_label = [{"label": i, "inputs": (0, 1, 2, 3, 4, 5)} for i in range(4)] + features_base_multiple_labels = [{"label": (0, 1, 2), "inputs": (0, 1, 2, 3, 4, 5)} for i in range(4)] + + for datatype_input, datatype_label in [ + (list, int), + (list, float), + (np.array, int), + (np.array, np.array), + (list, self._turn_to_none), + ]: + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=lambda x: default_data_collator(x, return_tensors="np"), + base_data=features_base_single_label, + input_key="inputs", + input_datatype=datatype_input, + label_key="label", + label_datatype=datatype_label, + ) + + for datatype_input, datatype_label in [(list, list), (list, self._turn_to_none)]: + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=lambda x: default_data_collator(x, return_tensors="np"), + base_data=features_base_multiple_labels, + input_key="inputs", + input_datatype=datatype_input, + label_key="label", + label_datatype=datatype_label, + ) + + features_base_single_label_alt = [{"input_ids": (0, 1, 2, 3, 4), "label": float(i)} for i in range(4)] + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=lambda x: default_data_collator(x, return_tensors="np"), + base_data=features_base_single_label_alt, + input_key="input_ids", + input_datatype=list, + label_key="label", + label_datatype=float, + ) + + def test_with_padding_collator_immutability(self): + tokenizer = BertTokenizer(self.vocab_file) + + features_original = [{"input_ids": [0, 1, 2]}, {"input_ids": [0, 1, 2, 3, 4, 5]}] + features_batch = [{"input_ids": [0, 1, 2]}, {"input_ids": [0, 1, 2, 3, 4, 5]}] + + data_collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=10, return_tensors="np") + self._validate_original_data_against_collated_data( + collator=data_collator, original_data=features_original, batch_data=features_batch + ) + + data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8, return_tensors="np") + self._validate_original_data_against_collated_data( + collator=data_collator, original_data=features_original, batch_data=features_batch + ) + + def test_for_token_classification_collator_immutability(self): + tokenizer = BertTokenizer(self.vocab_file) + + features_base = [ + {"input_ids": (0, 1, 2), "labels": (0, 1, 2)}, + {"input_ids": (0, 1, 2, 3, 4, 5), "labels": (0, 1, 2, 3, 4, 5)}, + ] + token_classification_collators = [ + DataCollatorForTokenClassification(tokenizer, return_tensors="np"), + DataCollatorForTokenClassification(tokenizer, padding="max_length", max_length=10, return_tensors="np"), + DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8, return_tensors="np"), + DataCollatorForTokenClassification(tokenizer, label_pad_token_id=-1, return_tensors="np"), + ] + + for datatype_input, datatype_label in [(list, list)]: + for collator in token_classification_collators: + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=collator, + base_data=features_base, + input_key="input_ids", + input_datatype=datatype_input, + label_key="labels", + label_datatype=datatype_label, + ) + + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=token_classification_collators[-1], + base_data=features_base, + input_key="input_ids", + input_datatype=datatype_input, + label_key="labels", + label_datatype=datatype_label, + ignore_label=True, + ) + + def test_seq2seq_collator_immutability(self): + tokenizer = BertTokenizer(self.vocab_file) + + features_base = [ + {"input_ids": list(range(3)), "labels": list(range(3))}, + {"input_ids": list(range(6)), "labels": list(range(6))}, + ] + seq2seq_collators = [ + DataCollatorForSeq2Seq(tokenizer, padding=PaddingStrategy.LONGEST, return_tensors="np"), + DataCollatorForSeq2Seq(tokenizer, padding=PaddingStrategy.MAX_LENGTH, max_length=7, return_tensors="np"), + DataCollatorForSeq2Seq( + tokenizer, padding=PaddingStrategy.LONGEST, pad_to_multiple_of=8, return_tensors="np" + ), + DataCollatorForSeq2Seq( + tokenizer, padding=PaddingStrategy.LONGEST, label_pad_token_id=-1, return_tensors="np" + ), + ] + + for datatype_input, datatype_label in [(list, list)]: + for collator in seq2seq_collators: + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=collator, + base_data=features_base, + input_key="input_ids", + input_datatype=datatype_input, + label_key="labels", + label_datatype=datatype_label, + ) + + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=seq2seq_collators[-1], + base_data=features_base, + input_key="input_ids", + input_datatype=datatype_input, + label_key="labels", + label_datatype=datatype_label, + ignore_label=True, + ) + + features_base_no_pad = [ + {"input_ids": list(range(3)), "labels": list(range(3))}, + {"input_ids": list(range(3)), "labels": list(range(3))}, + ] + seq2seq_no_padding_collator = DataCollatorForSeq2Seq( + tokenizer, padding=PaddingStrategy.DO_NOT_PAD, return_tensors="np" + ) + for datatype_input, datatype_label in [(list, list)]: + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=seq2seq_no_padding_collator, + base_data=features_base_no_pad, + input_key="input_ids", + input_datatype=datatype_input, + label_key="labels", + label_datatype=datatype_label, + ) + + def test_language_modelling_collator_immutability(self): + tokenizer = BertTokenizer(self.vocab_file) + + features_base_no_pad = [ + {"input_ids": tuple(range(10)), "labels": (1,)}, + {"input_ids": tuple(range(10)), "labels": (1,)}, + ] + features_base_pad = [ + {"input_ids": tuple(range(5)), "labels": (1,)}, + {"input_ids": tuple(range(5)), "labels": (1,)}, + ] + lm_collators = [ + DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors="np"), + DataCollatorForLanguageModeling(tokenizer, mlm=False, pad_to_multiple_of=8, return_tensors="np"), + DataCollatorForLanguageModeling(tokenizer, return_tensors="np"), + DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8, return_tensors="np"), + ] + + for datatype_input, datatype_label in [(list, list)]: + for collator in lm_collators: + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=collator, + base_data=features_base_no_pad, + input_key="input_ids", + input_datatype=datatype_input, + label_key="labels", + label_datatype=datatype_label, + ignore_label=True, + ) + + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=collator, + base_data=features_base_pad, + input_key="input_ids", + input_datatype=datatype_input, + label_key="labels", + label_datatype=datatype_label, + ignore_label=True, + ) + + def test_whole_world_masking_collator_immutability(self): + tokenizer = BertTokenizer(self.vocab_file) + + features_base = [ + {"input_ids": list(range(10)), "labels": (1,)}, + {"input_ids": list(range(10)), "labels": (1,)}, + ] + whole_word_masking_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="np") + + for datatype_input, datatype_label in [(list, list), (np.array, np.array)]: + self._validate_original_data_against_collated_data_on_specified_keys_and_datatypes( + collator=whole_word_masking_collator, + base_data=features_base, + input_key="input_ids", + input_datatype=datatype_input, + label_key="labels", + label_datatype=datatype_label, + ignore_label=True, + ) + + def test_permutation_language_modelling_collator_immutability(self): + tokenizer = BertTokenizer(self.vocab_file) + + plm_collator = DataCollatorForPermutationLanguageModeling(tokenizer, return_tensors="np") + + no_pad_features_original = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}] + no_pad_features_batch = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}] + self._validate_original_data_against_collated_data( + collator=plm_collator, original_data=no_pad_features_original, batch_data=no_pad_features_batch + ) + + pad_features_original = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}] + pad_features_batch = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}] + self._validate_original_data_against_collated_data( + collator=plm_collator, original_data=pad_features_original, batch_data=pad_features_batch + ) + + def test_next_sentence_prediction_collator_immutability(self): + tokenizer = BertTokenizer(self.vocab_file) + + features_original = [ + {"input_ids": [0, 1, 2, 3, 4], "token_type_ids": [0, 1, 2, 3, 4], "next_sentence_label": i} + for i in range(2) + ] + features_batch = [ + {"input_ids": [0, 1, 2, 3, 4], "token_type_ids": [0, 1, 2, 3, 4], "next_sentence_label": i} + for i in range(2) + ] + + nsp_collator = DataCollatorForLanguageModeling(tokenizer, return_tensors="np") + self._validate_original_data_against_collated_data( + collator=nsp_collator, original_data=features_original, batch_data=features_batch + ) + + nsp_collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8, return_tensors="np") + self._validate_original_data_against_collated_data( + collator=nsp_collator, original_data=features_original, batch_data=features_batch + ) + + def test_sentence_order_prediction_collator_immutability(self): + tokenizer = BertTokenizer(self.vocab_file) + + features_original = [ + { + "input_ids": np.array([0, 1, 2, 3, 4]), + "token_type_ids": np.array([0, 1, 2, 3, 4]), + "sentence_order_label": i, + } + for i in range(2) + ] + features_batch = [ + { + "input_ids": np.array([0, 1, 2, 3, 4]), + "token_type_ids": np.array([0, 1, 2, 3, 4]), + "sentence_order_label": i, + } + for i in range(2) + ] + + sop_collator = DataCollatorForLanguageModeling(tokenizer, return_tensors="np") + self._validate_original_data_against_collated_data( + collator=sop_collator, original_data=features_original, batch_data=features_batch + ) + + sop_collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8, return_tensors="np") + self._validate_original_data_against_collated_data( + collator=sop_collator, original_data=features_original, batch_data=features_batch + )