Skip to content

Commit

Permalink
remove bloom caching
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Sep 18, 2023
1 parent 6d8acb4 commit 52c1745
Showing 1 changed file with 13 additions and 99 deletions.
112 changes: 13 additions & 99 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,8 +687,6 @@ def __init__(
"To export your model, simply set `export=True`."
)



@add_start_docstrings_to_model_forward(
CAUSALLM_ONNX_MODEL_DOCSTRING.format("batch_size, sequence_length")
+ TEXT_GENERATION_EXAMPLE.format(
Expand All @@ -707,7 +705,6 @@ def forward(
use_cache_branch: None = None,
**kwargs,
) -> CausalLMOutputWithPast:

# adding use_cache_branch in the signature here is just a hack for IO Binding
use_torch = isinstance(input_ids, torch.Tensor)
self.raise_on_numpy_input_io_binding(use_torch)
Expand Down Expand Up @@ -917,8 +914,6 @@ def _from_pretrained(
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
**kwargs,
) -> "ORTModelForCausalLM":


model_path = Path(model_id)

# We do not implement the logic for use_cache=False, use_merged=True
Expand Down Expand Up @@ -1041,10 +1036,16 @@ def _from_pretrained(
# Since v1.7.0 decoder with past models have fixed sequence length of 1
# To keep these models compatible we set this dimension to dynamic
onnx_model = onnx.load(model_cache_path)
input_dims = {node.name: [dim.dim_value or dim.dim_param for dim in node.type.tensor_type.shape.dim] for node in onnx_model.graph.input}
input_dims = {
node.name: [dim.dim_value or dim.dim_param for dim in node.type.tensor_type.shape.dim]
for node in onnx_model.graph.input
}
if input_dims["input_ids"][1] == 1:
input_dims["input_ids"][1] = "sequence_length"
output_dims = {node.name: [dim.dim_value or dim.dim_param for dim in node.type.tensor_type.shape.dim] for node in onnx_model.graph.output}
output_dims = {
node.name: [dim.dim_value or dim.dim_param for dim in node.type.tensor_type.shape.dim]
for node in onnx_model.graph.output
}
output_dims["logits"][1] = "sequence_length"
static_model = onnx.load(model_cache_path)
updated_model = update_model_dims.update_inputs_outputs_dims(static_model, input_dims, output_dims)
Expand All @@ -1063,10 +1064,9 @@ def _from_pretrained(
use_io_binding=use_io_binding,
model_save_dir=model_save_dir,
preprocessors=preprocessors,
use_cache=use_cache
use_cache=use_cache,
)


@classmethod
def _from_transformers(
cls,
Expand Down Expand Up @@ -1134,7 +1134,6 @@ def _from_transformers(
file_name=file_name,
)


# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
Expand All @@ -1150,13 +1149,6 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)


# TODO : rm !!!!
# `past_key_values` may be in the stardard format (e.g. in contrastive search), converts to bloom's format if needed
if past_key_values is not None and self.config.model_type == "bloom":
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
past_key_values = self._convert_to_bloom_cache(past_key_values)

return {
"input_ids": input_ids,
"past_key_values": past_key_values,
Expand All @@ -1165,90 +1157,12 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
"attention_mask": attention_mask,
}


# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
def _reorder_cache(
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called.
This is required to match `past_key_values` with the correct beam_idx at every generation step.
"""
if self.config.model_type == "bloom":
return self._reorder_cache_bloom(past_key_values, beam_idx)

# from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past_key_values
)

# TODO: remove
# Copied from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache
def _reorder_cache_bloom(
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called for bloom architecture.
This is required to match `past_key_values` with the correct beam_idx at every generation step.
"""
standardized_past = self._convert_to_standard_cache(past_key_values, batch_size=len(beam_idx))

# Get a copy of `beam_idx` on all the devices where we need those indices.
device_to_beam_idx = {
past_state.device: beam_idx.to(past_state.device)
for layer_past in past_key_values
for past_state in layer_past
}
reordered_past = tuple(
(
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
)
for layer_past in standardized_past
)
return self._convert_to_bloom_cache(reordered_past)

# Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._convert_to_bloom_cache
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
@staticmethod
def _convert_to_bloom_cache(past_key_value: Tuple[Tuple[torch.Tensor]]) -> Tuple[Tuple[torch.Tensor]]:
"""
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
"""
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
batch_size_times_num_heads = batch_size * num_heads
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
return tuple(
(
layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
)
for layer_past in past_key_value
)

# Adapted from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._convert_to_standard_cache
def _convert_to_standard_cache(
self, past_key_value: Tuple[Tuple[torch.Tensor]], batch_size: int
) -> Tuple[Tuple[torch.Tensor]]:
"""
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size, num_heads, ...]))
"""
if self.config.model_type != "bloom":
return past_key_value

batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
num_heads = batch_size_times_num_heads // batch_size
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
return tuple(
(
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
)
for layer_past in past_key_value
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)

def can_generate(self):
Expand Down

0 comments on commit 52c1745

Please sign in to comment.