Skip to content

Commit

Permalink
CI / generate: batch size computation compatible with all models (#29671
Browse files Browse the repository at this point in the history
)
  • Loading branch information
gante authored Mar 18, 2024
1 parent 00c1d87 commit bf3dfd1
Showing 1 changed file with 12 additions and 20 deletions.
32 changes: 12 additions & 20 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1949,11 +1949,9 @@ def _contrastive_search(
)

# keep track of which sequences are already finished
batch_size, cur_len = (
model_kwargs["attention_mask"].shape
if model_kwargs.get("attention_mask", None) is not None
else input_ids.shape
)
batch_size, cur_len = input_ids.shape
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

Expand Down Expand Up @@ -2398,12 +2396,10 @@ def _greedy_search(
)

# keep track of which sequences are already finished
batch_size, cur_len = input_ids.shape
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
this_peer_finished = False
batch_size, cur_len = (
model_kwargs["attention_mask"].shape
if model_kwargs.get("attention_mask", None) is not None
else input_ids.shape
)
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

Expand Down Expand Up @@ -2686,12 +2682,10 @@ def _sample(
)

# keep track of which sequences are already finished
batch_size, cur_len = input_ids.shape
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
this_peer_finished = False
batch_size, cur_len = (
model_kwargs["attention_mask"].shape
if model_kwargs.get("attention_mask", None) is not None
else input_ids.shape
)
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

Expand Down Expand Up @@ -4461,11 +4455,9 @@ def _assisted_decoding(
)

# keep track of which sequences are already finished
batch_size, cur_len = batch_size, cur_len = (
model_kwargs["attention_mask"].shape
if model_kwargs.get("attention_mask", None) is not None
else input_ids.shape
)
batch_size, cur_len = input_ids.shape
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

Expand Down

0 comments on commit bf3dfd1

Please sign in to comment.