Skip to content

Commit

Permalink
DataCollatorForLanguageModeling class was updated with new parameters…
Browse files Browse the repository at this point in the history
… that provides more control over the token masking and relacing
  • Loading branch information
mahdibaghbanzadeh committed Dec 12, 2024
1 parent a691ccb commit f26418c
Showing 1 changed file with 82 additions and 20 deletions.
102 changes: 82 additions & 20 deletions src/transformers/data/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,11 +691,15 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
tokens and the value to predict for the masked token.
mlm_probability (`float`, *optional*, defaults to 0.15):
The probability with which to (randomly) mask tokens in the input, when `mlm` is set to `True`.
mask_replace_prob (`float`, *optional*, defaults to 0.8):
The probability with which masked tokens are replaced by the tokenizer's mask token (e.g., `[MASK]`).
Defaults to 0.8, meaning 80% of the masked tokens will be replaced with `[MASK]`.
random_replace_prob (`float`, *optional*, defaults to 0.1):
The probability with which masked tokens are replaced by random tokens from the tokenizer's vocabulary.
Defaults to 0.1, meaning 10% of the masked tokens will be replaced with random tokens. The remaining
masked tokens (1 - mask_replace_prob - random_replace_prob) are left unchanged.
pad_to_multiple_of (`int`, *optional*):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.0 (Volta).
If set, will pad the sequence to a multiple of the provided value.
return_tensors (`str`):
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
Expand All @@ -705,11 +709,37 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
BatchEncoding, with the `"special_tokens_mask"` key, as returned by a [`PreTrainedTokenizer`] or a
[`PreTrainedTokenizerFast`] with the argument `return_special_tokens_mask=True`.
</Tip>"""
<Example Options and Expectations>
1. Default Behavior:
- `mask_replace_prob=0.8`, `random_replace_prob=0.1`.
- Expect 80% of masked tokens replaced with `[MASK]`, 10% replaced with random tokens, and 10% left unchanged.
2. All masked tokens replaced by `[MASK]`:
- `mask_replace_prob=1.0`, `random_replace_prob=0.0`.
- Expect all masked tokens to be replaced with `[MASK]`. No tokens are left unchanged or replaced with random tokens.
3. No `[MASK]` replacement, only random tokens:
- `mask_replace_prob=0.0`, `random_replace_prob=1.0`.
- Expect all masked tokens to be replaced with random tokens. No `[MASK]` replacements or unchanged tokens.
4. Balanced replacement:
- `mask_replace_prob=0.5`, `random_replace_prob=0.4`.
- Expect 50% of masked tokens replaced with `[MASK]`, 40% replaced with random tokens, and 10% left unchanged.
Note:
The sum of `mask_replace_prob` and `random_replace_prob` must not exceed 1. If their sum is less than 1, the
remaining proportion will consist of masked tokens left unchanged.
</Tip>
"""


tokenizer: PreTrainedTokenizerBase
mlm: bool = True
mlm_probability: float = 0.15
mask_replace_prob: float = 0.8
random_replace_prob: float = 0.1
pad_to_multiple_of: Optional[int] = None
tf_experimental_compile: bool = False
return_tensors: str = "pt"
Expand All @@ -720,6 +750,15 @@ def __post_init__(self):
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
"You should pass `mlm=False` to train on causal language modeling instead."
)
if self.mlm_probability < 0 or self.mlm_probability > 1:
raise ValueError("mlm_probability should be between 0 and 1.")
if self.mask_replace_prob + self.random_replace_prob > 1:
raise ValueError("The sum of mask_replace_prob and random_replace_prob should not exceed 1")
if self.mask_replace_prob < 0 or self.mask_replace_prob > 1:
raise ValueError("mask_replace_prob should be between 0 and 1.")
if self.random_replace_prob < 0 or self.random_replace_prob > 1:
raise ValueError("random_replace_prob should be between 0 and 1.")

if self.tf_experimental_compile:
import tensorflow as tf

Expand Down Expand Up @@ -749,18 +788,26 @@ def tf_mask_tokens(
# Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens
labels = tf.where(masked_indices, inputs, -100)

# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = self.tf_bernoulli(input_shape, 0.8) & masked_indices
# mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = self.tf_bernoulli(input_shape, self.mask_replace_prob) & masked_indices

inputs = tf.where(indices_replaced, mask_token_id, inputs)

# 10% of the time, we replace masked input tokens with random word
indices_random = self.tf_bernoulli(input_shape, 0.5) & masked_indices & ~indices_replaced
if self.mask_replace_prob == 1 or self.random_replace_prob == 0:
return inputs, labels

remaining_prob = 1 - self.mask_replace_prob
# scaling the random_replace_prob to the remaining probability for example if
# mask_replace_prob = 0.8 and random_replace_prob = 0.1,
# then random_replace_prob_scaled = 0.1 / 0.2 = 0.5
random_replace_prob_scaled = self.random_replace_prob / remaining_prob
# random_replace_prob% of the time, we replace masked input tokens with random word
indices_random = self.tf_bernoulli(input_shape, random_replace_prob_scaled) & masked_indices & ~indices_replaced
random_words = tf.random.uniform(input_shape, maxval=vocab_size, dtype=inputs.dtype)

inputs = tf.where(indices_random, random_words, inputs)

# The rest of the time (10% of the time) we keep the masked input tokens unchanged
# The rest of the time ((1-random_replace_prob-mask_replace_prob)% of the time) we keep the masked input tokens unchanged
return inputs, labels

def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
Expand Down Expand Up @@ -849,16 +896,25 @@ def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = No
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens

# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
# mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, self.mask_replace_prob)).bool() & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
if self.mask_replace_prob == 1 or self.random_replace_prob == 0:
return inputs, labels

remaining_prob = 1 - self.mask_replace_prob
# scaling the random_replace_prob to the remaining probability for example if
# mask_replace_prob = 0.8 and random_replace_prob = 0.1,
# then random_replace_prob_scaled = 0.1 / 0.2 = 0.5
random_replace_prob_scaled = self.random_replace_prob / remaining_prob

# random_replace_prob% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, random_replace_prob_scaled)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]

# The rest of the time (10% of the time) we keep the masked input tokens unchanged
# The rest of the time ((1-random_replace_prob-mask_replace_prob)% of the time) we keep the masked input tokens unchanged
return inputs, labels

def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
Expand Down Expand Up @@ -905,14 +961,20 @@ def numpy_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = No
masked_indices = np.random.binomial(1, probability_matrix, size=probability_matrix.shape).astype(bool)
labels[~masked_indices] = -100 # We only compute loss on masked tokens

# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = np.random.binomial(1, 0.8, size=labels.shape).astype(bool) & masked_indices
# mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = np.random.binomial(1, self.mask_replace_prob, size=labels.shape).astype(bool) & masked_indices
inputs[indices_replaced] = self.tokenizer.mask_token_id

# 10% of the time, we replace masked input tokens with random word
# indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
if self.mask_replace_prob == 1 or self.random_replace_prob == 0:
return inputs, labels

remaining_prob = 1 - self.mask_replace_prob
# scaling the random_replace_prob to the remaining probability for example if
# mask_replace_prob = 0.8 and random_replace_prob = 0.1,
# then random_replace_prob_scaled = 0.1 / 0.2 = 0.5
random_replace_prob_scaled = self.random_replace_prob / remaining_prob
indices_random = (
np.random.binomial(1, 0.5, size=labels.shape).astype(bool) & masked_indices & ~indices_replaced
np.random.binomial(1, random_replace_prob_scaled, size=labels.shape).astype(bool) & masked_indices & ~indices_replaced
)
random_words = np.random.randint(
low=0, high=len(self.tokenizer), size=np.count_nonzero(indices_random), dtype=np.int64
Expand Down

0 comments on commit f26418c

Please sign in to comment.