Skip to content

Commit

Permalink
Merge branch 'add-Dinov2-with-registers-attempt-2' of https://github.…
Browse files Browse the repository at this point in the history
…com/BernardZach/transformers into add-Dinov2-with-registers-attempt-2
  • Loading branch information
BernardZach committed Dec 5, 2024
2 parents 07121a8 + 5a62171 commit 437a371
Show file tree
Hide file tree
Showing 9 changed files with 636 additions and 21 deletions.
2 changes: 2 additions & 0 deletions docs/source/en/generation_strategies.md
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,8 @@ just like in multinomial sampling. However, in assisted decoding, reducing the t
['Alice and Bob, a couple of friends of mine, who are both in the same office as']
```

We recommend to install `scikit-learn` library to enhance the candidate generation strategy and achieve additional speedup.

#### Universal Assisted Decoding

Universal Assisted Decoding (UAD) adds support for main and assistant models with different tokenizers.
Expand Down
57 changes: 57 additions & 0 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@
import numpy as np
import torch

from ..utils import is_sklearn_available


if is_sklearn_available():
from sklearn.metrics import roc_curve

from ..cache_utils import DynamicCache
from ..pytorch_utils import isin_mps_friendly
from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor
Expand Down Expand Up @@ -180,6 +186,14 @@ def __init__(
# We need to roll back the cache in assisted generation, only DynamicCache is supported
self.generation_config.cache_implementation = None

if (
is_sklearn_available()
and self.assistant_model.generation_config.assistant_confidence_threshold
and type(self) is AssistedCandidateGenerator
):
self.probs = []
self.matches = []

def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""
Fetches the candidates to be tried for the current input.
Expand Down Expand Up @@ -230,6 +244,17 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
# 3. Update variables for the next round of candidate generation
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values

if (
is_sklearn_available()
and self.assistant_model.generation_config.assistant_confidence_threshold
and type(self) is AssistedCandidateGenerator
):
scores_tensor = torch.cat(assistant_output.scores, dim=0)
scores_softmax = torch.softmax(scores_tensor, dim=-1)
ids = assistant_output.sequences[-1, -len(assistant_output.scores) :]
p = scores_softmax[range(len(ids)), ids]
self.probs.extend(p.tolist())

# 4. Prepare variables for output
candidate_logits = torch.stack(assistant_output.scores, dim=1)
candidate_ids = assistant_output.sequences
Expand Down Expand Up @@ -261,6 +286,38 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F
else:
self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0)

# The assistant's confidence threshold is adjusted throughout the speculative iterations to reduce the number of unnecessary draft and target forward passes. The costs are estimated based on the ROC curve, which considers the probability of the draft token and its match with the target. A cost of 25% is assigned to false positives and 75% to false negatives.
# This adaptation is not compatible with UAG, as it relies on the number of matched tokens based on the draft vocabulary, which is unavailable in UAG.
if (
is_sklearn_available()
and self.assistant_model.generation_config.assistant_confidence_threshold
and type(self) is AssistedCandidateGenerator
):
# update self.matches
self.matches.extend([1] * num_matches)
if len(self.probs) > len(self.matches):
self.matches.append(0)

# update self.probs
excess_length = len(self.probs) - len(self.matches)
if excess_length > 0:
del self.probs[-excess_length:]

if (
len(self.probs) > 5 and {0, 1}.issubset(self.matches)
): # require at least 5 samples to calculate the ROC curve and at least one positive and one negative sample
fpr, tpr, thresholds = roc_curve(self.matches, self.probs)
fnr = 1 - tpr

# Calculate the cost for each threshold
costs = fpr + 3 * fnr

# Find the threshold that minimizes the cost
optimal_threshold_index = np.argmin(costs)
best_threshold = thresholds[optimal_threshold_index]

self.assistant_model.generation_config.assistant_confidence_threshold = best_threshold


class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
"""
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,9 @@ class GenerationConfig(PushToHubMixin):
assistant_confidence_threshold (`float`, *optional*, defaults to 0.4):
The confidence threshold for the assistant model. If the assistant model's confidence in its prediction for the current token is lower
than this threshold, the assistant model stops the current token generation iteration, even if the number of _speculative tokens_
(defined by `num_assistant_tokens`) is not yet reached. It is an unsupervised version of the dynamic speculation lookahead
(defined by `num_assistant_tokens`) is not yet reached. The assistant's confidence threshold is adjusted throughout the speculative iterations to reduce the number of unnecessary draft and target forward passes, biased towards avoiding false negatives.
`assistant_confidence_threshold` value is persistent over multiple generation calls with the same assistant model.
It is an unsupervised version of the dynamic speculation lookahead
from Dynamic Speculation Lookahead Accelerates Speculative Decoding of Large Language Models <https://arxiv.org/abs/2405.04304>.
prompt_lookup_num_tokens (`int`, *optional*):
The number of tokens to be output as candidate tokens.
Expand Down
Loading

0 comments on commit 437a371

Please sign in to comment.