Skip to content

Commit

Permalink
without io binding refacto
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jun 3, 2024
1 parent 7a0757a commit 949743e
Show file tree
Hide file tree
Showing 3 changed files with 284 additions and 380 deletions.
72 changes: 28 additions & 44 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
if check_if_transformers_greater("4.25.0"):
from transformers.generation import GenerationMixin
else:
from transformers.generation_utils import GenerationMixin
from transformers.generation_utils import GenerationMixin # type: ignore # noqa: F401


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -139,15 +139,16 @@ def __init__(

self.num_pkv = 2
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
self.key_value_input_names = [key for key in self.inputs_names if (".key" in key) or (".value" in key)]
self.key_value_input_names = [key for key in self.input_names if (".key" in key) or (".value" in key)]
self.key_value_output_names = [key for key in self.output_names if (".key" in key) or (".value" in key)]
self.use_cache = len(self.key_value_input_names) > 0

if generation_config is None:
generation_config = GenerationConfig.from_model_config(config)

self.generation_config = generation_config
self.onnx_paths = [self.model_path]
self.use_merged = "use_cache_branch" in self.inputs_names
self.use_merged = "use_cache_branch" in self.input_names
self.model_type = self.config.model_type

self.use_fp16 = False
Expand All @@ -160,7 +161,7 @@ def __init__(

# Reference: https://github.com/huggingface/optimum/pull/1381
model_type = config.model_type.replace("_", "-")
if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and "position_ids" not in self.inputs_names:
if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and "position_ids" not in self.input_names:
logger.warning(
f"ORTModelForCausalLM loaded a legacy ONNX model with no position_ids input, although this input is required for batched generation for the architecture {model_type}. "
"We strongly encourage to re-export the model with optimum>=1.14 for position_ids and batched inference support."
Expand Down Expand Up @@ -202,7 +203,6 @@ def forward(
use_torch = isinstance(input_ids, torch.Tensor)
self.raise_on_numpy_input_io_binding(use_torch)

inputs = {}
known_output_shapes = {}
use_cache_branch = None
loss = None
Expand All @@ -226,10 +226,10 @@ def forward(
# I suspect the reason is the contiguous python list that messes something up?
model_inputs = [input_ids.contiguous()]

if "attention_mask" in self.inputs_names:
if "attention_mask" in self.input_names:
model_inputs.append(attention_mask)

if "position_ids" in self.inputs_names:
if "position_ids" in self.input_names:
if position_ids is None:
raise ValueError("position_ids was not passed but is a required input for this ONNX model.")
model_inputs.append(position_ids.contiguous())
Expand All @@ -240,12 +240,11 @@ def forward(
if use_cache_branch is not None:
model_inputs.append(use_cache_branch)

if "labels" in self.inputs_names:
if "labels" in self.input_names:
model_inputs.append(labels)
known_output_shapes.update({"loss": []})

io_binding, output_shapes, output_buffers = self._prepare_io_binding(
self.model,
io_binding, output_shapes, output_buffers = self.prepare_io_binding(
*model_inputs,
known_output_shapes=known_output_shapes,
ordered_input_names=self._ordered_input_names,
Expand All @@ -259,53 +258,38 @@ def forward(
io_binding.synchronize_outputs()

if self.use_cache:
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer(2)
past_key_values = ()
for name in self.key_value_output_names:
past_key_values += (output_buffers[name].view(output_shapes[name]),)
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer(2 for the self-attention)
past_key_values = tuple(
output_buffers[name].view(output_shapes[name]) for name in self.key_value_output_names
)

logits = output_buffers["logits"].view(output_shapes["logits"])

if "loss" in self.output_names:
loss = output_buffers["loss"].view(output_shapes["loss"])
else:
inputs["input_ids"] = input_ids.cpu().detach().numpy() if use_torch else input_ids

if "attention_mask" in self.inputs_names:
inputs["attention_mask"] = attention_mask.cpu().detach().numpy() if use_torch else attention_mask

if "labels" in self.inputs_names:
inputs["labels"] = labels.cpu().detach().numpy() if use_torch else labels

if "position_ids" in self.inputs_names:
if position_ids is None:
raise ValueError("position_ids was not passed but is a required input for this ONNX model.")
inputs["position_ids"] = position_ids.cpu().detach().numpy() if use_torch else position_ids

# Add the past_key_values to the decoder inputs
if past_key_values is not None:
for input_name, past_key_value in zip(self.key_value_input_names, past_key_values):
inputs[input_name] = past_key_value.cpu().detach().numpy() if use_torch else past_key_value
model_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache_branch": use_cache_branch,
"labels": labels,
}

if use_cache_branch is not None:
inputs["use_cache_branch"] = use_cache_branch.cpu().detach().numpy() if use_torch else use_cache_branch
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_outputs = self.model.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)

outputs = self.model.run(None, inputs)
logits = model_outputs.get("logits")
loss = model_outputs.get("loss", None)

if self.use_cache:
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 for the self-attention)
past_key_values = tuple(
torch.from_numpy(outputs[self.output_names[key]]).to(self.device)
for key in self.key_value_output_names
)

logits = torch.from_numpy(outputs[self.output_names["logits"]]).to(self.device)
if "loss" in self.output_names:
loss = torch.from_numpy(outputs[self.output_names["loss"]]).to(self.device)
past_key_values = tuple(model_outputs[self.output_names[key]] for key in self.key_value_output_names)

if self.use_cache and self.model_type != "gpt_bigcode":
# Tuple of tuple of length `n_layers`, with each tuple of length equal to the number of self-attention and
# per decoder layer
# Tuple of tuple of length `n_layers`, with each tuple of length equal to the number of self-attention and per decoder layer
past_key_values = tuple(
past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv)
)
Expand Down
Loading

0 comments on commit 949743e

Please sign in to comment.