Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance DataCollatorForLanguageModeling with Configurable Token Replacement Probabilities #35251

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
113 changes: 93 additions & 20 deletions src/transformers/data/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,11 +691,17 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
tokens and the value to predict for the masked token.
mlm_probability (`float`, *optional*, defaults to 0.15):
The probability with which to (randomly) mask tokens in the input, when `mlm` is set to `True`.
mask_replace_prob (`float`, *optional*, defaults to 0.8):
The probability with which masked tokens are replaced by the tokenizer's mask token (e.g., `[MASK]`).
Defaults to 0.8, meaning 80% of the masked tokens will be replaced with `[MASK]`.
Only works when `mlm` is set to `True`.
random_replace_prob (`float`, *optional*, defaults to 0.1):
The probability with which masked tokens are replaced by random tokens from the tokenizer's vocabulary.
Defaults to 0.1, meaning 10% of the masked tokens will be replaced with random tokens. The remaining
masked tokens (1 - mask_replace_prob - random_replace_prob) are left unchanged.
Only works when `mlm` is set to `True`.
pad_to_multiple_of (`int`, *optional*):
If set will pad the sequence to a multiple of the provided value.

This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.0 (Volta).
If set, will pad the sequence to a multiple of the provided value.
return_tensors (`str`):
The type of Tensor to return. Allowable values are "np", "pt" and "tf".

Expand All @@ -705,11 +711,36 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
BatchEncoding, with the `"special_tokens_mask"` key, as returned by a [`PreTrainedTokenizer`] or a
[`PreTrainedTokenizerFast`] with the argument `return_special_tokens_mask=True`.

</Tip>"""
<Example Options and Expectations>

1. Default Behavior:
- `mask_replace_prob=0.8`, `random_replace_prob=0.1`.
- Expect 80% of masked tokens replaced with `[MASK]`, 10% replaced with random tokens, and 10% left unchanged.

2. All masked tokens replaced by `[MASK]`:
- `mask_replace_prob=1.0`, `random_replace_prob=0.0`.
- Expect all masked tokens to be replaced with `[MASK]`. No tokens are left unchanged or replaced with random tokens.

3. No `[MASK]` replacement, only random tokens:
- `mask_replace_prob=0.0`, `random_replace_prob=1.0`.
- Expect all masked tokens to be replaced with random tokens. No `[MASK]` replacements or unchanged tokens.

4. Balanced replacement:
- `mask_replace_prob=0.5`, `random_replace_prob=0.4`.
- Expect 50% of masked tokens replaced with `[MASK]`, 40% replaced with random tokens, and 10% left unchanged.

Note:
The sum of `mask_replace_prob` and `random_replace_prob` must not exceed 1. If their sum is less than 1, the
remaining proportion will consist of masked tokens left unchanged.

</Tip>
"""

tokenizer: PreTrainedTokenizerBase
mlm: bool = True
mlm_probability: float = 0.15
mask_replace_prob: float = 0.8
random_replace_prob: float = 0.1
pad_to_multiple_of: Optional[int] = None
tf_experimental_compile: bool = False
return_tensors: str = "pt"
Expand All @@ -720,6 +751,15 @@ def __post_init__(self):
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
"You should pass `mlm=False` to train on causal language modeling instead."
)
if self.mlm_probability < 0 or self.mlm_probability > 1:
raise ValueError("mlm_probability should be between 0 and 1.")
if self.mask_replace_prob + self.random_replace_prob > 1:
raise ValueError("The sum of mask_replace_prob and random_replace_prob should not exceed 1")
if self.mask_replace_prob < 0 or self.mask_replace_prob > 1:
raise ValueError("mask_replace_prob should be between 0 and 1.")
if self.random_replace_prob < 0 or self.random_replace_prob > 1:
raise ValueError("random_replace_prob should be between 0 and 1.")

if self.tf_experimental_compile:
import tensorflow as tf

Expand Down Expand Up @@ -749,18 +789,28 @@ def tf_mask_tokens(
# Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens
labels = tf.where(masked_indices, inputs, -100)

# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = self.tf_bernoulli(input_shape, 0.8) & masked_indices
# mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = self.tf_bernoulli(input_shape, self.mask_replace_prob) & masked_indices

inputs = tf.where(indices_replaced, mask_token_id, inputs)

# 10% of the time, we replace masked input tokens with random word
indices_random = self.tf_bernoulli(input_shape, 0.5) & masked_indices & ~indices_replaced
if self.mask_replace_prob == 1 or self.random_replace_prob == 0:
return inputs, labels

remaining_prob = 1 - self.mask_replace_prob
# scaling the random_replace_prob to the remaining probability for example if
# mask_replace_prob = 0.8 and random_replace_prob = 0.1,
# then random_replace_prob_scaled = 0.1 / 0.2 = 0.5
random_replace_prob_scaled = self.random_replace_prob / remaining_prob
# random_replace_prob% of the time, we replace masked input tokens with random word
indices_random = (
self.tf_bernoulli(input_shape, random_replace_prob_scaled) & masked_indices & ~indices_replaced
)
random_words = tf.random.uniform(input_shape, maxval=vocab_size, dtype=inputs.dtype)

inputs = tf.where(indices_random, random_words, inputs)

# The rest of the time (10% of the time) we keep the masked input tokens unchanged
# The rest of the time ((1-random_replace_prob-mask_replace_prob)% of the time) we keep the masked input tokens unchanged
return inputs, labels

def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
Expand Down Expand Up @@ -849,16 +899,29 @@ def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = No
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens

# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
# mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, self.mask_replace_prob)).bool() & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
if self.mask_replace_prob == 1 or self.random_replace_prob == 0:
return inputs, labels

remaining_prob = 1 - self.mask_replace_prob
# scaling the random_replace_prob to the remaining probability for example if
# mask_replace_prob = 0.8 and random_replace_prob = 0.1,
# then random_replace_prob_scaled = 0.1 / 0.2 = 0.5
random_replace_prob_scaled = self.random_replace_prob / remaining_prob

# random_replace_prob% of the time, we replace masked input tokens with random word
indices_random = (
torch.bernoulli(torch.full(labels.shape, random_replace_prob_scaled)).bool()
& masked_indices
& ~indices_replaced
)
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]

# The rest of the time (10% of the time) we keep the masked input tokens unchanged
# The rest of the time ((1-random_replace_prob-mask_replace_prob)% of the time) we keep the masked input tokens unchanged
return inputs, labels

def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
Expand Down Expand Up @@ -905,14 +968,24 @@ def numpy_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = No
masked_indices = np.random.binomial(1, probability_matrix, size=probability_matrix.shape).astype(bool)
labels[~masked_indices] = -100 # We only compute loss on masked tokens

# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = np.random.binomial(1, 0.8, size=labels.shape).astype(bool) & masked_indices
# mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = (
np.random.binomial(1, self.mask_replace_prob, size=labels.shape).astype(bool) & masked_indices
)
inputs[indices_replaced] = self.tokenizer.mask_token_id

# 10% of the time, we replace masked input tokens with random word
# indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
if self.mask_replace_prob == 1 or self.random_replace_prob == 0:
return inputs, labels

remaining_prob = 1 - self.mask_replace_prob
# scaling the random_replace_prob to the remaining probability for example if
# mask_replace_prob = 0.8 and random_replace_prob = 0.1,
# then random_replace_prob_scaled = 0.1 / 0.2 = 0.5
random_replace_prob_scaled = self.random_replace_prob / remaining_prob
indices_random = (
np.random.binomial(1, 0.5, size=labels.shape).astype(bool) & masked_indices & ~indices_replaced
np.random.binomial(1, random_replace_prob_scaled, size=labels.shape).astype(bool)
& masked_indices
& ~indices_replaced
)
random_words = np.random.randint(
low=0, high=len(self.tokenizer), size=np.count_nonzero(indices_random), dtype=np.int64
Expand Down
46 changes: 46 additions & 0 deletions tests/trainer/test_data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,52 @@ def _test_no_pad_and_pad(self, no_pad_features, pad_features):
self.assertTrue(tf.reduce_any(masked_tokens))
# self.assertTrue(all(x == -100 for x in batch["labels"].numpy()[~masked_tokens.numpy()].tolist()))

def test_probability_sum_error(self):
"""Test that the sum of mask_replace_prob and random_replace_prob exceeding 1 raises an error."""
tokenizer = BertTokenizer(self.vocab_file)
with self.assertRaises(ValueError):
DataCollatorForLanguageModeling(tokenizer=tokenizer, mask_replace_prob=0.9, random_replace_prob=0.2)

def test_all_mask_replacement(self):
"""Test behavior when mask_replace_prob=1."""
tokenizer = BertTokenizer(self.vocab_file)

# pytorch call
collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mask_replace_prob=1, random_replace_prob=0, return_tensors="pt"
)

inputs = torch.tensor([0, 1, 2, 3, 4, 5])
features = [{"input_ids": inputs} for _ in range(8)]
batch = collator(features)

# confirm that every token is either the original token or [MASK]
self.assertTrue(torch.all((batch["input_ids"] == inputs) | (batch["input_ids"] == tokenizer.mask_token_id)))

# tf call
collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mask_replace_prob=1, random_replace_prob=0, return_tensors="tf"
)
inputs = tf.constant([0, 1, 2, 3, 4, 5])
features = [{"input_ids": inputs} for _ in range(8)]
batch = collator(features)

# confirm that every token is either the original token or [MASK]
self.assertTrue(
tf.reduce_all((batch["input_ids"] == inputs) | (batch["input_ids"] == tokenizer.mask_token_id))
)

# numpy call
collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mask_replace_prob=1, random_replace_prob=0, return_tensors="np"
)
inputs = np.array([0, 1, 2, 3, 4, 5])
features = [{"input_ids": inputs} for _ in range(8)]
batch = collator(features)

# confirm that every token is either the original token or [MASK]
self.assertTrue(np.all((batch["input_ids"] == inputs) | (batch["input_ids"] == tokenizer.mask_token_id)))

def test_data_collator_for_language_modeling(self):
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
pad_features = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}]
Expand Down