Skip to content

Commit

Permalink
Fix decoder-only generation with embeddings for gpt2
Browse files Browse the repository at this point in the history
Co-authored-by: Taylor Jackle Spriggs <[email protected]>
Co-authored-by: Soila Kavulya <[email protected]>
  • Loading branch information
tjs-intel and skavulya committed Sep 3, 2024
1 parent a1a92c9 commit 5c7f159
Show file tree
Hide file tree
Showing 24 changed files with 206 additions and 50 deletions.
137 changes: 111 additions & 26 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,10 @@ def get_final_stopping_criteria(x):
if isinstance(x, bool):
return x
elif torch.is_tensor(x):
return all(x)
if x.dim() > 0:
return all(x)
else:
return x
else:
raise TypeError(f"The stopping criteria should be either a boolean or a torch.tensor but got {type(x)}.")

Expand Down Expand Up @@ -297,6 +300,7 @@ def _expand_dict_for_generation(dict_to_expand):
key != "token_idx"
and key != "decoder_input_ids"
and key != "cache_position"
and key != "inputs_embeds_offset"
and dict_to_expand[key] is not None
and isinstance(dict_to_expand[key], torch.Tensor)
):
Expand Down Expand Up @@ -914,14 +918,44 @@ def generate(
# only pad if bucket_size < -1. If we are bucketing (bucket_size > 0), then that is taken care in greedy_search()
if not is_greedy_or_beam_and_bucket:
# token_idx is the current index in the generation process, it is incremented each time a new token is generated
token_idx = inputs_tensor.shape[-1]
model_kwargs["token_idx"] = torch.tensor(token_idx, device=inputs_tensor.device)
model_kwargs["token_idx_cpu"] = token_idx
token_idx = inputs_tensor.shape[1]
if generation_config.max_new_tokens is None:
generation_config.max_new_tokens = generation_config.max_length - token_idx
inputs_tensor = torch.nn.functional.pad(
inputs_tensor, (0, generation_config.max_new_tokens), value=generation_config.pad_token_id
)
if "inputs_embeds" in model_kwargs:
if "input_ids" in model_kwargs:
inputs_embeds_offset = (
model_kwargs["input_ids"].shape[1] - model_kwargs["inputs_embeds"].shape[1]
)
else:
inputs_embeds_offset = -model_kwargs["inputs_embeds"].shape[1]

model_kwargs["inputs_embeds_offset"] = torch.tensor(
inputs_embeds_offset, device=inputs_tensor.device
)
model_kwargs["inputs_embeds"] = torch.nn.functional.pad(
model_kwargs["inputs_embeds"],
(0, 0, 0, generation_config.max_new_tokens),
value=generation_config.pad_token_id,
)

if model_input_name == "inputs_embeds":
inputs_tensor = torch.nn.functional.pad(
inputs_tensor,
(0, 0, 0, generation_config.max_new_tokens),
value=generation_config.pad_token_id,
)
model_kwargs["input_ids"] = torch.nn.functional.pad(
model_kwargs["input_ids"],
(0, generation_config.max_new_tokens),
value=generation_config.pad_token_id,
)
else:
inputs_tensor = torch.nn.functional.pad(
inputs_tensor, (0, generation_config.max_new_tokens), value=generation_config.pad_token_id
)
model_kwargs["token_idx"] = torch.tensor(token_idx, device=inputs_tensor.device)
model_kwargs["token_idx_cpu"] = token_idx

for other_inputs in ["attention_mask", "token_type_ids"]:
if model_kwargs.get(other_inputs) is not None:
model_kwargs[other_inputs] = torch.nn.functional.pad(
Expand Down Expand Up @@ -1612,8 +1646,6 @@ def _contrastive_search(

# keep track of which sequences are already finished
batch_size, cur_len = input_ids.shape
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]

if not ignore_eos:
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
Expand All @@ -1640,7 +1672,10 @@ def _contrastive_search(
top_k_ids = None
if token_idx is not None:
# Update cur_len in case of static shapes
cur_len = token_idx.item()
if "inputs_embeds_offset" in model_kwargs:
cur_len = (token_idx + model_kwargs["inputs_embeds_offset"]).item()
else:
cur_len = token_idx.item()

time_to_first_token_done = False
model_kwargs["pad_done"] = False
Expand Down Expand Up @@ -1756,8 +1791,13 @@ def _contrastive_search(
pad_amount = input_ids.shape[-1] - top_k_ids.shape[-1]
top_k_ids = torch.nn.functional.pad(top_k_ids, (0, pad_amount), value=pad_token_id)

idx = (
token_idx + model_kwargs["inputs_embeds_offset"] - 1
if "inputs_embeds_offset" in model_kwargs
else token_idx - 1
)
top_k_probs, top_k_prob_ids = torch.topk(next_probs, dim=-1, k=top_k)
top_k_ids[:, :, token_idx - 1] = top_k_prob_ids
top_k_ids[:, :, idx] = top_k_prob_ids
else:
top_k_probs, top_k_ids = torch.topk(next_probs, dim=-1, k=top_k)

Expand Down Expand Up @@ -1883,8 +1923,14 @@ def _contrastive_search(
# the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores
# (model confidence minus degeneration penalty); (6) decoder hidden_states
top_k_indices = torch.arange(len(top_k_ids), device=input_ids.device)

if token_idx is not None:
next_tokens = top_k_ids[top_k_indices, selected_idx, token_idx - 1]
idx = (
token_idx + model_kwargs["inputs_embeds_offset"] - 1
if "inputs_embeds_offset" in model_kwargs
else token_idx - 1
)
next_tokens = top_k_ids[top_k_indices, selected_idx, idx]
else:
next_tokens = top_k_ids[top_k_indices, selected_idx]
next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k))
Expand Down Expand Up @@ -1976,9 +2022,12 @@ def _contrastive_search(
# update generated ids, model inputs, and length for next step
if token_idx is not None:
# Use token_idx-1 since token index is incremented twice in first iteration
input_ids.index_copy_(
1, token_idx - 1, next_tokens.unsqueeze(-1) if next_tokens.dim() == 1 else next_tokens
idx = (
token_idx + model_kwargs["inputs_embeds_offset"] - 1
if "inputs_embeds_offset" in model_kwargs
else token_idx - 1
)
input_ids.index_copy_(1, idx, next_tokens.unsqueeze(-1) if next_tokens.dim() == 1 else next_tokens)
else:
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)

Expand Down Expand Up @@ -2217,7 +2266,10 @@ def _sample(
token_idx = model_kwargs.get("token_idx", None)
if token_idx is not None:
# Update cur_len in case of static shapes
cur_len = token_idx.item()
if "inputs_embeds_offset" in model_kwargs:
cur_len = (token_idx + model_kwargs["inputs_embeds_offset"]).item()
else:
cur_len = token_idx.item()

time_to_first_token_done = False
model_kwargs["pad_done"] = False
Expand Down Expand Up @@ -2319,9 +2371,12 @@ def _sample(
next_tokens = next_tokens.to(input_ids.dtype)

if token_idx is not None:
input_ids.index_copy_(
1, token_idx, next_tokens.unsqueeze(-1) if next_tokens.dim() == 1 else next_tokens
idx = (
token_idx + model_kwargs["inputs_embeds_offset"]
if "inputs_embeds_offset" in model_kwargs
else token_idx
)
input_ids.index_copy_(1, idx, next_tokens.unsqueeze(-1) if next_tokens.dim() == 1 else next_tokens)
else:
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)

Expand Down Expand Up @@ -2514,7 +2569,11 @@ def _beam_search(
token_idx = model_kwargs.get("token_idx", None)
if token_idx is not None:
# Update cur_len in case of static shapes
cur_len = token_idx.item()
if "inputs_embeds_offset" in model_kwargs:
cur_len = (token_idx + model_kwargs["inputs_embeds_offset"]).item()
else:
cur_len = token_idx.item()

model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

if num_beams * batch_size != batch_beam_size:
Expand Down Expand Up @@ -2729,7 +2788,12 @@ def expand_if_needed(tensor, new_size, value, dim=-1):
) # (batch_size * num_beams, vocab_size)

if token_idx is not None:
next_token_scores_processed = logits_processor(input_ids[:, :token_idx], next_token_scores)
idx = (
token_idx + model_kwargs["inputs_embeds_offset"]
if "inputs_embeds_offset" in model_kwargs
else token_idx
)
next_token_scores_processed = logits_processor(input_ids[:, :idx], next_token_scores)
else:
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
if do_sample:
Expand Down Expand Up @@ -2836,8 +2900,13 @@ def expand_if_needed(tensor, new_size, value, dim=-1):

if token_idx is not None:
input_ids = torch.index_select(input_ids, 0, beam_idx)
idx = (
token_idx + model_kwargs["inputs_embeds_offset"]
if "inputs_embeds_offset" in model_kwargs
else token_idx
)
input_ids.index_copy_(
1, token_idx, beam_next_tokens.unsqueeze(-1) if beam_next_tokens.dim() == 1 else beam_next_tokens
1, idx, beam_next_tokens.unsqueeze(-1) if beam_next_tokens.dim() == 1 else beam_next_tokens
)
else:
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
Expand Down Expand Up @@ -3091,12 +3160,15 @@ def _constrained_beam_search(
num_beams = constrained_beam_scorer.num_beams

batch_beam_size, cur_len = input_ids.shape
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]

token_idx = model_kwargs.get("token_idx", None)
if token_idx is not None:
# Update cur_len in case of static shapes
cur_len = token_idx.item()
if "inputs_embeds_offset" in model_kwargs:
cur_len = (token_idx + model_kwargs["inputs_embeds_offset"]).item()
else:
cur_len = token_idx.item()

model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)

if num_beams * batch_size != batch_beam_size:
Expand Down Expand Up @@ -3129,7 +3201,11 @@ def _constrained_beam_search(

this_peer_finished = False

decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
if token_idx is not None:
# Update cur_len in case of static shapes
decoder_prompt_len = cur_len
else:
decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder

hb_profer = HabanaProfile(
warmup=profiling_warmup_steps, active=profiling_steps, record_shapes=profiling_record_shapes
Expand Down Expand Up @@ -3234,8 +3310,14 @@ def _constrained_beam_search(

if token_idx is not None:
input_ids = input_ids[beam_idx, :]
idx = (
token_idx + model_kwargs["inputs_embeds_offset"]
if "inputs_embeds_offset" in model_kwargs
else token_idx
)

input_ids.index_copy_(
1, token_idx, beam_next_tokens.unsqueeze(-1) if beam_next_tokens.dim() == 1 else beam_next_tokens
1, idx, beam_next_tokens.unsqueeze(-1) if beam_next_tokens.dim() == 1 else beam_next_tokens
)
else:
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
Expand Down Expand Up @@ -3437,7 +3519,10 @@ def _assisted_decoding(

if token_idx is not None:
# Update cur_len in case of static shapes
cur_len = token_idx.item()
if "inputs_embeds_offset" in model_kwargs:
cur_len = (token_idx + model_kwargs["inputs_embeds_offset"]).item()
else:
cur_len = token_idx.item()
else:
cur_len = input_ids.shape[-1]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,10 @@ def gaudi_BlipTextLMHead_prepare_inputs_for_generation(
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
if token_idx is not None:
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
idx = token_idx - 1
if "inputs_embeds_offset" in model_kwargs:
idx = idx + model_kwargs["inputs_embeds_offset"]
input_ids = torch.index_select(input_ids, 1, idx)
else:
past_length = past_key_values[0][0].shape[2]

Expand Down
5 changes: 4 additions & 1 deletion optimum/habana/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,10 @@ def prepare_inputs_for_generation(
if token_idx is None:
input_ids = input_ids[:, -1].unsqueeze(-1)
else:
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
idx = token_idx - 1
if "inputs_embeds_offset" in kwargs:
idx = idx + kwargs["inputs_embeds_offset"]
input_ids = torch.index_select(input_ids, 1, idx)

# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,11 @@ def prepare_inputs_for_generation(
# Omit tokens covered by past_key_values
if past_key_values:
if token_idx is not None:
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
idx = token_idx - 1
if "inputs_embeds_offset" in kwargs:
idx = idx + kwargs["inputs_embeds_offset"]
input_ids = torch.index_select(input_ids, 1, idx)

if token_type_ids is not None:
token_type_ids = torch.index_select(token_type_ids, 1, token_idx - 1)
else:
Expand Down
5 changes: 4 additions & 1 deletion optimum/habana/transformers/models/decilm/modeling_decilm.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,10 @@ def prepare_inputs_for_generation(
past_length = 0
if past_key_values is not None:
if token_idx is not None:
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
idx = token_idx - 1
if "inputs_embeds_offset" in kwargs:
idx = idx + kwargs["inputs_embeds_offset"]
input_ids = torch.index_select(input_ids, 1, idx)
else:
if isinstance(past_key_values, Cache):
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
Expand Down
5 changes: 4 additions & 1 deletion optimum/habana/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,7 +983,10 @@ def prepare_inputs_for_generation(
bucket_internal = kwargs.get("bucket_internal")
if past_key_values is not None:
if token_idx is not None:
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
idx = token_idx - 1
if "inputs_embeds_offset" in kwargs:
idx = idx + kwargs["inputs_embeds_offset"]
input_ids = torch.index_select(input_ids, 1, idx)
else:
past_length = past_key_values[0][0].shape[2]

Expand Down
5 changes: 4 additions & 1 deletion optimum/habana/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,10 @@ def prepare_inputs_for_generation(
input_ids = input_ids[:, cache_position]
else:
# past_length += token_idx
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
idx = token_idx - 1
if "inputs_embeds_offset" in kwargs:
idx = idx + kwargs["inputs_embeds_offset"]
input_ids = torch.index_select(input_ids, 1, idx)

if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
Expand Down
10 changes: 8 additions & 2 deletions optimum/habana/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,10 @@ def prepare_inputs_for_generation(
# Omit tokens covered by past_key_values
if past_key_values:
if token_idx is not None:
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
idx = token_idx - 1
if "inputs_embeds_offset" in kwargs:
idx = idx + kwargs["inputs_embeds_offset"]
input_ids = torch.index_select(input_ids, 1, idx)
else:
past_length = past_key_values[0][0].shape[2]

Expand Down Expand Up @@ -609,7 +612,10 @@ def prepare_inputs_for_generation(
# Omit tokens covered by past_key_values
if past_key_values:
if token_idx is not None:
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
idx = token_idx - 1
if "inputs_embeds_offset" in kwargs:
idx = idx + kwargs["inputs_embeds_offset"]
input_ids = torch.index_select(input_ids, 1, idx)
else:
past_length = past_key_values[0][0].shape[2]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,10 @@ def prepare_inputs_for_generation(
# Omit tokens covered by past_key_values
if past_key_values:
if token_idx is not None:
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
idx = token_idx - 1
if "inputs_embeds_offset" in kwargs:
idx = idx + kwargs["inputs_embeds_offset"]
input_ids = torch.index_select(input_ids, 1, idx)
if token_type_ids is not None:
token_type_ids = torch.index_select(token_type_ids, 1, token_idx - 1)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,10 @@ def prepare_inputs_for_generation(
# cut decoder_input_ids if past is used
if past_key_values is not None:
if token_idx is not None:
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
idx = token_idx - 1
if "inputs_embeds_offset" in kwargs:
idx = idx + kwargs["inputs_embeds_offset"]
input_ids = torch.index_select(input_ids, 1, idx)
else:
past_length = past_key_values[0][0].shape[2]

Expand Down
5 changes: 4 additions & 1 deletion optimum/habana/transformers/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,10 @@ def prepare_inputs_for_generation(
# Omit tokens covered by past_key_values
if past_key_values:
if token_idx is not None:
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
idx = token_idx - 1
if "inputs_embeds_offset" in kwargs:
idx = idx + kwargs["inputs_embeds_offset"]
input_ids = torch.index_select(input_ids, 1, idx)
else:
past_length = past_key_values[0][0].shape[2]

Expand Down
Loading

0 comments on commit 5c7f159

Please sign in to comment.