Skip to content

Commit

Permalink
mvp
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed May 13, 2024
1 parent 9430635 commit c681fee
Show file tree
Hide file tree
Showing 7 changed files with 252 additions and 200 deletions.
81 changes: 47 additions & 34 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,20 +1276,13 @@ def _prepare_generation_config(
generation_config = self.generation_config

# `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config`
# will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled.
if is_torchdynamo_compiling():
model_kwargs = kwargs
generate_attributes_in_kwargs = [
key for key, value in kwargs.items() if getattr(generation_config, key, None) != value
]
if len(generate_attributes_in_kwargs) > 0:
raise ValueError(
"`torch.compile` exception: all generation configuration attributes must be passed within a "
f"`generation_config` instance passed to `generate` (found: {generate_attributes_in_kwargs})."
)
else:
# will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled -- an
# exception will be raised in `_validate_model_kwargs`
if not is_torchdynamo_compiling():
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs)
else:
model_kwargs = kwargs

return generation_config, model_kwargs

Expand All @@ -1305,11 +1298,14 @@ def _get_initial_cache_position(self, input_ids, model_kwargs):
past_length = model_kwargs["past_key_values"].get_seq_length()
else:
past_length = model_kwargs["past_key_values"][0][0].shape[2]

# `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange` from
# `past_length` to `past_length + input_length`
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0)
else:
cur_len = input_ids.shape[-1]
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0)
model_kwargs["cache_position"] = cache_position + past_length - 1
return model_kwargs

def _get_static_cache(self, max_batch_size: int, max_cache_len: int) -> StaticCache:
Expand Down Expand Up @@ -1389,11 +1385,12 @@ def _tensor_or_none(token, device=None):
raise ValueError(
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
)
if eos_token_id is not None and (torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any()):
logger.warning(
f"`eos_token_id` should consist of positive integers, but is {eos_token_id}. Your generation will not "
"stop until the maximum length is reached. Depending on other flags, it may even crash."
)
if not is_torchdynamo_compiling(): # Checks that depend on tensor-dependent control flow
if eos_token_id is not None and (torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any()):
logger.warning(
f"`eos_token_id` should consist of positive integers, but is {eos_token_id}. Your generation will "
"not stop until the maximum length is reached. Depending on other flags, it may even crash."
)

# Update generation config with the updated special tokens tensors
generation_config.bos_token_id = bos_token_id
Expand Down Expand Up @@ -1863,23 +1860,35 @@ def typeerror():

return result

def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool:
def _has_unfinished_sequences(
self,
this_peer_finished: bool,
synced_gpus: bool,
device: torch.device,
cur_len: Optional[int] = None,
max_length: Optional[int] = None,
) -> bool:
"""
Returns whether there are still unfinished sequences in the device. The existence of unfinished sequences is
fed through `this_peer_finished`. ZeRO stage 3-friendly.
"""
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
# torch.compile does not support data-dependent control flow. This is a workaround to allow torch.compile,
# although we lose the ability to stop when all sequences return an EOS token (and other stopping criteria)
if is_torchdynamo_compiling():
return cur_len < max_length
else:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
return False
elif this_peer_finished:
return False
elif this_peer_finished:
return False
return True
return True

@torch.no_grad()
def _contrastive_search(
Expand Down Expand Up @@ -2319,6 +2328,7 @@ def _sample(
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate
max_length = generation_config.max_length
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
do_sample = generation_config.do_sample
if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
Expand All @@ -2342,12 +2352,14 @@ def _sample(
)

# keep track of which sequences are already finished
batch_size = input_ids.shape[0]
batch_size, cur_len = input_ids.shape
this_peer_finished = False
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)

while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
while self._has_unfinished_sequences(
this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
):
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

Expand Down Expand Up @@ -2412,6 +2424,7 @@ def _sample(

unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0
cur_len += 1

if streamer is not None:
streamer.end()
Expand Down
74 changes: 41 additions & 33 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
Expand Down Expand Up @@ -1164,39 +1165,46 @@ def prepare_inputs_for_generation(
use_cache=True,
**kwargs,
):
past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None
else None
)
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None

# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.

# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
# with torch.compile, we only support the basic case where we assume `input_ids` has all tokens (prompt + generated tokens)
if is_torchdynamo_compiling():
input_ids = input_ids[:, cache_position]

else:
past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None
else None
)
cache_length = (
past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None

# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.

# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]

position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
Expand Down
74 changes: 41 additions & 33 deletions src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
Expand Down Expand Up @@ -1429,39 +1430,46 @@ def prepare_inputs_for_generation(
use_cache=True,
**kwargs,
):
past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None
else None
)
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None

# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.

# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
# with torch.compile, we only support the basic case where we assume `input_ids` has all tokens (prompt + generated tokens)
if is_torchdynamo_compiling():
input_ids = input_ids[:, cache_position]

else:
past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None
else None
)
cache_length = (
past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None

# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.

# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]

position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
Expand Down
Loading

0 comments on commit c681fee

Please sign in to comment.