Skip to content

Commit

Permalink
Fix import error from deprecation in transformers (#1415)
Browse files Browse the repository at this point in the history
* Fix import error from  deprecation in transformers

* Fix import path
  • Loading branch information
lewtun authored Mar 11, 2024
1 parent 4d862da commit 7630f87
Showing 1 changed file with 37 additions and 1 deletion.
38 changes: 37 additions & 1 deletion trl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from transformers import top_k_top_p_filtering
from transformers.generation import TopKLogitsWarper, TopPLogitsWarper

from .import_utils import is_npu_available, is_xpu_available

Expand All @@ -36,6 +36,42 @@
WANDB_PADDING = -1


def top_k_top_p_filtering(
logits: torch.FloatTensor,
top_k: int = 0,
top_p: float = 1.0,
filter_value: float = -float("Inf"),
min_tokens_to_keep: int = 1,
) -> torch.FloatTensor:
"""
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering.
Args:
logits: logits distribution shape (batch size, vocabulary size)
top_k (`int`, *optional*, defaults to 0):
If > 0, only keep the top k tokens with highest probability (top-k filtering)
top_p (`float`, *optional*, defaults to 1.0):
If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus
filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimumber of tokens we keep per batch example in the output.
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""

if top_k > 0:
logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
None, logits
)

if 0 <= top_p <= 1.0:
logits = TopPLogitsWarper(top_p=top_p, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
None, logits
)

return logits


def flatten_dict(nested: Dict, sep: str = "/") -> Dict:
"""Flatten dictionary and concatenate nested keys with separator."""

Expand Down

0 comments on commit 7630f87

Please sign in to comment.