diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index e84c9d0ef3ce2e..01a5896a2907e6 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -691,11 +691,17 @@ class DataCollatorForLanguageModeling(DataCollatorMixin): tokens and the value to predict for the masked token. mlm_probability (`float`, *optional*, defaults to 0.15): The probability with which to (randomly) mask tokens in the input, when `mlm` is set to `True`. + mask_replace_prob (`float`, *optional*, defaults to 0.8): + The probability with which masked tokens are replaced by the tokenizer's mask token (e.g., `[MASK]`). + Defaults to 0.8, meaning 80% of the masked tokens will be replaced with `[MASK]`. + Only works when `mlm` is set to `True`. + random_replace_prob (`float`, *optional*, defaults to 0.1): + The probability with which masked tokens are replaced by random tokens from the tokenizer's vocabulary. + Defaults to 0.1, meaning 10% of the masked tokens will be replaced with random tokens. The remaining + masked tokens (1 - mask_replace_prob - random_replace_prob) are left unchanged. + Only works when `mlm` is set to `True`. pad_to_multiple_of (`int`, *optional*): - If set will pad the sequence to a multiple of the provided value. - - This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= - 7.0 (Volta). + If set, will pad the sequence to a multiple of the provided value. return_tensors (`str`): The type of Tensor to return. Allowable values are "np", "pt" and "tf". @@ -705,11 +711,36 @@ class DataCollatorForLanguageModeling(DataCollatorMixin): BatchEncoding, with the `"special_tokens_mask"` key, as returned by a [`PreTrainedTokenizer`] or a [`PreTrainedTokenizerFast`] with the argument `return_special_tokens_mask=True`. - """ + + + 1. Default Behavior: + - `mask_replace_prob=0.8`, `random_replace_prob=0.1`. + - Expect 80% of masked tokens replaced with `[MASK]`, 10% replaced with random tokens, and 10% left unchanged. + + 2. All masked tokens replaced by `[MASK]`: + - `mask_replace_prob=1.0`, `random_replace_prob=0.0`. + - Expect all masked tokens to be replaced with `[MASK]`. No tokens are left unchanged or replaced with random tokens. + + 3. No `[MASK]` replacement, only random tokens: + - `mask_replace_prob=0.0`, `random_replace_prob=1.0`. + - Expect all masked tokens to be replaced with random tokens. No `[MASK]` replacements or unchanged tokens. + + 4. Balanced replacement: + - `mask_replace_prob=0.5`, `random_replace_prob=0.4`. + - Expect 50% of masked tokens replaced with `[MASK]`, 40% replaced with random tokens, and 10% left unchanged. + + Note: + The sum of `mask_replace_prob` and `random_replace_prob` must not exceed 1. If their sum is less than 1, the + remaining proportion will consist of masked tokens left unchanged. + + + """ tokenizer: PreTrainedTokenizerBase mlm: bool = True mlm_probability: float = 0.15 + mask_replace_prob: float = 0.8 + random_replace_prob: float = 0.1 pad_to_multiple_of: Optional[int] = None tf_experimental_compile: bool = False return_tensors: str = "pt" @@ -720,6 +751,15 @@ def __post_init__(self): "This tokenizer does not have a mask token which is necessary for masked language modeling. " "You should pass `mlm=False` to train on causal language modeling instead." ) + if self.mlm_probability < 0 or self.mlm_probability > 1: + raise ValueError("mlm_probability should be between 0 and 1.") + if self.mask_replace_prob + self.random_replace_prob > 1: + raise ValueError("The sum of mask_replace_prob and random_replace_prob should not exceed 1") + if self.mask_replace_prob < 0 or self.mask_replace_prob > 1: + raise ValueError("mask_replace_prob should be between 0 and 1.") + if self.random_replace_prob < 0 or self.random_replace_prob > 1: + raise ValueError("random_replace_prob should be between 0 and 1.") + if self.tf_experimental_compile: import tensorflow as tf @@ -749,18 +789,28 @@ def tf_mask_tokens( # Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens labels = tf.where(masked_indices, inputs, -100) - # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) - indices_replaced = self.tf_bernoulli(input_shape, 0.8) & masked_indices + # mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) + indices_replaced = self.tf_bernoulli(input_shape, self.mask_replace_prob) & masked_indices inputs = tf.where(indices_replaced, mask_token_id, inputs) - # 10% of the time, we replace masked input tokens with random word - indices_random = self.tf_bernoulli(input_shape, 0.5) & masked_indices & ~indices_replaced + if self.mask_replace_prob == 1 or self.random_replace_prob == 0: + return inputs, labels + + remaining_prob = 1 - self.mask_replace_prob + # scaling the random_replace_prob to the remaining probability for example if + # mask_replace_prob = 0.8 and random_replace_prob = 0.1, + # then random_replace_prob_scaled = 0.1 / 0.2 = 0.5 + random_replace_prob_scaled = self.random_replace_prob / remaining_prob + # random_replace_prob% of the time, we replace masked input tokens with random word + indices_random = ( + self.tf_bernoulli(input_shape, random_replace_prob_scaled) & masked_indices & ~indices_replaced + ) random_words = tf.random.uniform(input_shape, maxval=vocab_size, dtype=inputs.dtype) inputs = tf.where(indices_random, random_words, inputs) - # The rest of the time (10% of the time) we keep the masked input tokens unchanged + # The rest of the time ((1-random_replace_prob-mask_replace_prob)% of the time) we keep the masked input tokens unchanged return inputs, labels def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: @@ -849,16 +899,29 @@ def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = No masked_indices = torch.bernoulli(probability_matrix).bool() labels[~masked_indices] = -100 # We only compute loss on masked tokens - # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) - indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices + # mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) + indices_replaced = torch.bernoulli(torch.full(labels.shape, self.mask_replace_prob)).bool() & masked_indices inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) - # 10% of the time, we replace masked input tokens with random word - indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced + if self.mask_replace_prob == 1 or self.random_replace_prob == 0: + return inputs, labels + + remaining_prob = 1 - self.mask_replace_prob + # scaling the random_replace_prob to the remaining probability for example if + # mask_replace_prob = 0.8 and random_replace_prob = 0.1, + # then random_replace_prob_scaled = 0.1 / 0.2 = 0.5 + random_replace_prob_scaled = self.random_replace_prob / remaining_prob + + # random_replace_prob% of the time, we replace masked input tokens with random word + indices_random = ( + torch.bernoulli(torch.full(labels.shape, random_replace_prob_scaled)).bool() + & masked_indices + & ~indices_replaced + ) random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long) inputs[indices_random] = random_words[indices_random] - # The rest of the time (10% of the time) we keep the masked input tokens unchanged + # The rest of the time ((1-random_replace_prob-mask_replace_prob)% of the time) we keep the masked input tokens unchanged return inputs, labels def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: @@ -905,14 +968,24 @@ def numpy_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = No masked_indices = np.random.binomial(1, probability_matrix, size=probability_matrix.shape).astype(bool) labels[~masked_indices] = -100 # We only compute loss on masked tokens - # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) - indices_replaced = np.random.binomial(1, 0.8, size=labels.shape).astype(bool) & masked_indices + # mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) + indices_replaced = ( + np.random.binomial(1, self.mask_replace_prob, size=labels.shape).astype(bool) & masked_indices + ) inputs[indices_replaced] = self.tokenizer.mask_token_id - # 10% of the time, we replace masked input tokens with random word - # indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced + if self.mask_replace_prob == 1 or self.random_replace_prob == 0: + return inputs, labels + + remaining_prob = 1 - self.mask_replace_prob + # scaling the random_replace_prob to the remaining probability for example if + # mask_replace_prob = 0.8 and random_replace_prob = 0.1, + # then random_replace_prob_scaled = 0.1 / 0.2 = 0.5 + random_replace_prob_scaled = self.random_replace_prob / remaining_prob indices_random = ( - np.random.binomial(1, 0.5, size=labels.shape).astype(bool) & masked_indices & ~indices_replaced + np.random.binomial(1, random_replace_prob_scaled, size=labels.shape).astype(bool) + & masked_indices + & ~indices_replaced ) random_words = np.random.randint( low=0, high=len(self.tokenizer), size=np.count_nonzero(indices_random), dtype=np.int64 diff --git a/tests/trainer/test_data_collator.py b/tests/trainer/test_data_collator.py index 70870be7718bee..c3e9b5a3badf21 100644 --- a/tests/trainer/test_data_collator.py +++ b/tests/trainer/test_data_collator.py @@ -1020,6 +1020,52 @@ def _test_no_pad_and_pad(self, no_pad_features, pad_features): self.assertTrue(tf.reduce_any(masked_tokens)) # self.assertTrue(all(x == -100 for x in batch["labels"].numpy()[~masked_tokens.numpy()].tolist())) + def test_probability_sum_error(self): + """Test that the sum of mask_replace_prob and random_replace_prob exceeding 1 raises an error.""" + tokenizer = BertTokenizer(self.vocab_file) + with self.assertRaises(ValueError): + DataCollatorForLanguageModeling(tokenizer=tokenizer, mask_replace_prob=0.9, random_replace_prob=0.2) + + def test_all_mask_replacement(self): + """Test behavior when mask_replace_prob=1.""" + tokenizer = BertTokenizer(self.vocab_file) + + # pytorch call + collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, mask_replace_prob=1, random_replace_prob=0, return_tensors="pt" + ) + + inputs = torch.tensor([0, 1, 2, 3, 4, 5]) + features = [{"input_ids": inputs} for _ in range(8)] + batch = collator(features) + + # confirm that every token is either the original token or [MASK] + self.assertTrue(torch.all((batch["input_ids"] == inputs) | (batch["input_ids"] == tokenizer.mask_token_id))) + + # tf call + collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, mask_replace_prob=1, random_replace_prob=0, return_tensors="tf" + ) + inputs = tf.constant([0, 1, 2, 3, 4, 5]) + features = [{"input_ids": inputs} for _ in range(8)] + batch = collator(features) + + # confirm that every token is either the original token or [MASK] + self.assertTrue( + tf.reduce_all((batch["input_ids"] == inputs) | (batch["input_ids"] == tokenizer.mask_token_id)) + ) + + # numpy call + collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, mask_replace_prob=1, random_replace_prob=0, return_tensors="np" + ) + inputs = np.array([0, 1, 2, 3, 4, 5]) + features = [{"input_ids": inputs} for _ in range(8)] + batch = collator(features) + + # confirm that every token is either the original token or [MASK] + self.assertTrue(np.all((batch["input_ids"] == inputs) | (batch["input_ids"] == tokenizer.mask_token_id))) + def test_data_collator_for_language_modeling(self): no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}] pad_features = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}]