Skip to content

Commit

Permalink
FIX [Generation] Fix some issues when running the MaxLength criteri…
Browse files Browse the repository at this point in the history
…a on CPU (huggingface#29317)

fix the bitwise or issue
  • Loading branch information
younesbelkada authored Mar 5, 2024
1 parent e947683 commit 81c8191
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/transformers/generation/stopping_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa
f"maximum length ({self.max_position_embeddings}). Depending on the model, you may observe "
"exceptions, performance degradation, or nothing at all."
)
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device)
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool)


class MaxNewTokensCriteria(StoppingCriteria):
Expand Down Expand Up @@ -103,7 +103,7 @@ def __init__(self, start_length: int, max_new_tokens: int):
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
is_done = input_ids.shape[-1] >= self.max_length
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device)
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool)


class MaxTimeCriteria(StoppingCriteria):
Expand All @@ -126,7 +126,7 @@ def __init__(self, max_time: float, initial_timestamp: Optional[float] = None):
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
is_done = time.time() - self.initial_timestamp > self.max_time
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device)
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool)


class StoppingCriteriaList(list):
Expand Down

0 comments on commit 81c8191

Please sign in to comment.