Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Windows and onnx dtype compatibility #1886

Merged
merged 12 commits into from
Jun 24, 2024
124 changes: 30 additions & 94 deletions optimum/onnxruntime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Defines the base classes that are used to perform inference with ONNX Runtime of Transformers models."""

from abc import abstractmethod
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple, Union
from typing import Dict, Optional, Set, Tuple, Union

import numpy as np
import torch
Expand All @@ -24,22 +24,22 @@

from ..utils import NormalizedConfigManager
from ..utils.logging import warn_once
from .modeling_ort import ORTModel
from .utils import get_ordered_input_names, logging


logger = logging.get_logger(__name__)


if TYPE_CHECKING:
from .modeling_ort import ORTModel


class ORTModelPart:
"""
For multi-file ONNX models, such as encoder-decoder models, represents a part of the model.
It has its own `onnxruntime.InferenceSession`, and can perform a forward pass.
"""

_prepare_onnx_inputs = ORTModel._prepare_onnx_inputs
_prepare_onnx_outputs = ORTModel._prepare_onnx_outputs

def __init__(
self,
session: InferenceSession,
Expand All @@ -53,6 +53,8 @@ def __init__(
self.main_input_name = self.parent_model.main_input_name
self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())}
self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())}
self.input_dtypes = {input_key.name: input_key.type for input_key in session.get_inputs()}
self.output_dtypes = {output_key.name: output_key.type for output_key in session.get_outputs()}

self._ordered_input_names = get_ordered_input_names(self.input_names.keys(), func=self.forward)

Expand Down Expand Up @@ -98,25 +100,13 @@ def forward(

last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"])
else:
if use_torch:
onnx_inputs = {"input_ids": input_ids.cpu().detach().numpy()}

# Add the attention_mask inputs when needed
if "attention_mask" in self.input_names:
onnx_inputs["attention_mask"] = attention_mask.cpu().detach().numpy()
else:
onnx_inputs = {"input_ids": input_ids}
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask}

# Add the attention_mask inputs when needed
if "attention_mask" in self.input_names:
onnx_inputs["attention_mask"] = attention_mask
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)

# Run inference
outputs = self.session.run(None, onnx_inputs)

last_hidden_state = outputs[self.output_names["last_hidden_state"]]
if use_torch:
last_hidden_state = torch.from_numpy(last_hidden_state).to(self.device)
last_hidden_state = model_outputs["last_hidden_state"]

return BaseModelOutput(last_hidden_state=last_hidden_state)

Expand Down Expand Up @@ -350,83 +340,29 @@ def forward(
else:
raise ValueError("Unsupported num_pkv")
else:
if use_torch:
onnx_inputs = {
"input_ids": input_ids.cpu().detach().numpy(),
}

# Add the encoder_hidden_states inputs when needed
if "encoder_hidden_states" in self.input_names:
onnx_inputs["encoder_hidden_states"] = encoder_hidden_states.cpu().detach().numpy()

# Add the decoder_attention_mask inputs when needed
if "decoder_attention_mask" in self.input_names:
onnx_inputs["decoder_attention_mask"] = decoder_attention_mask.cpu().detach().numpy()

# Add the encoder_attention_mask inputs when needed
if "encoder_attention_mask" in self.input_names:
onnx_inputs["encoder_attention_mask"] = encoder_attention_mask.cpu().detach().numpy()

if past_key_values is not None:
# Add the past_key_values to the decoder inputs
for input_name, past_key_value in zip(self.key_value_input_names, past_key_values):
onnx_inputs[input_name] = past_key_value.cpu().detach().numpy()

if "labels" in self.input_names:
# TODO: Any preprocessing like `self._shift_right(labels)`?
onnx_inputs["labels"] = labels.cpu().detach().numpy()

if self.parent_model.use_merged is True:
onnx_inputs["use_cache_branch"] = use_cache_branch_tensor.cpu().detach().numpy()
else:
onnx_inputs = {
"input_ids": input_ids,
}

# Add the encoder_hidden_states inputs when needed
if "encoder_hidden_states" in self.input_names:
onnx_inputs["encoder_hidden_states"] = encoder_hidden_states

# Add the decoder_attention_mask inputs when needed
if "decoder_attention_mask" in self.input_names:
onnx_inputs["decoder_attention_mask"] = decoder_attention_mask

# Add the encoder_attention_mask inputs when needed
if "encoder_attention_mask" in self.input_names:
onnx_inputs["encoder_attention_mask"] = encoder_attention_mask

if past_key_values is not None:
# Add the past_key_values to the decoder inputs
for input_name, past_key_value in zip(self.key_value_input_names, past_key_values):
onnx_inputs[input_name] = past_key_value

if "labels" in self.input_names:
# TODO: Any preprocessing like `self._shift_right(labels)`?
onnx_inputs["labels"] = labels

if self.parent_model.use_merged is True:
onnx_inputs["use_cache_branch"] = use_cache_branch_tensor
model_inputs = {
"input_ids": input_ids,
"encoder_hidden_states": encoder_hidden_states,
"decoder_attention_mask": decoder_attention_mask,
"encoder_attention_mask": encoder_attention_mask,
"use_cache_branch": use_cache_branch_tensor,
"labels": labels,
}
if past_key_values is not None:
model_inputs.update(zip(self.key_value_input_names, past_key_values))

# Run inference
outputs = self.session.run(None, onnx_inputs)
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)

# TODO: using two loops here is probably unefficient
# TODO: using a new variable out_past_key_values is memory inefficient,
# past_key_values is not used anymore at this point
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the
# self-attention layer and 2 to the cross-attention layer)
out_past_key_values = tuple(
torch.from_numpy(outputs[self.output_names[key]]).to(self.device)
for key in self.key_value_output_names
)

logits = outputs[self.output_names["logits"]]
if use_torch:
logits = torch.from_numpy(logits).to(self.device)
out_past_key_values = tuple(model_outputs[output_name] for output_name in self.key_value_output_names)

loss = None
if "loss" in self.output_names:
loss = outputs[self.output_names["loss"]]
if use_torch:
loss = torch.from_numpy(loss).to(self.device)
loss = model_outputs.get("loss", None)
logits = model_outputs["logits"]

# TODO: this is extremely ugly and unreadable. What if cross-attention k/v change?
# Tuple of tuple of length `n_layers`, with each tuple of length equal to:
Expand Down
73 changes: 30 additions & 43 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
if check_if_transformers_greater("4.25.0"):
from transformers.generation import GenerationMixin
else:
from transformers.generation_utils import GenerationMixin
from transformers.generation_utils import GenerationMixin # type: ignore # noqa: F401


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -139,15 +139,16 @@ def __init__(

self.num_pkv = 2
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
self.key_value_input_names = [key for key in self.inputs_names if (".key" in key) or (".value" in key)]
self.key_value_input_names = [key for key in self.input_names if (".key" in key) or (".value" in key)]
self.key_value_output_names = [key for key in self.output_names if (".key" in key) or (".value" in key)]
self.use_cache = len(self.key_value_input_names) > 0

if generation_config is None:
generation_config = GenerationConfig.from_model_config(config)

self.generation_config = generation_config
self.onnx_paths = [self.model_path]
self.use_merged = "use_cache_branch" in self.inputs_names
self.use_merged = "use_cache_branch" in self.input_names
self.model_type = self.config.model_type

self.use_fp16 = False
Expand All @@ -160,7 +161,7 @@ def __init__(

# Reference: https://github.com/huggingface/optimum/pull/1381
model_type = config.model_type.replace("_", "-")
if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and "position_ids" not in self.inputs_names:
if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and "position_ids" not in self.input_names:
logger.warning(
f"ORTModelForCausalLM loaded a legacy ONNX model with no position_ids input, although this input is required for batched generation for the architecture {model_type}. "
"We strongly encourage to re-export the model with optimum>=1.14 for position_ids and batched inference support."
Expand Down Expand Up @@ -202,7 +203,6 @@ def forward(
use_torch = isinstance(input_ids, torch.Tensor)
self.raise_on_numpy_input_io_binding(use_torch)

inputs = {}
known_output_shapes = {}
use_cache_branch = None
loss = None
Expand All @@ -226,10 +226,10 @@ def forward(
# I suspect the reason is the contiguous python list that messes something up?
model_inputs = [input_ids.contiguous()]

if "attention_mask" in self.inputs_names:
if "attention_mask" in self.input_names:
model_inputs.append(attention_mask)

if "position_ids" in self.inputs_names:
if "position_ids" in self.input_names:
if position_ids is None:
raise ValueError("position_ids was not passed but is a required input for this ONNX model.")
model_inputs.append(position_ids.contiguous())
Expand All @@ -240,12 +240,11 @@ def forward(
if use_cache_branch is not None:
model_inputs.append(use_cache_branch)

if "labels" in self.inputs_names:
if "labels" in self.input_names:
model_inputs.append(labels)
known_output_shapes.update({"loss": []})

io_binding, output_shapes, output_buffers = self._prepare_io_binding(
self.model,
io_binding, output_shapes, output_buffers = self.prepare_io_binding(
*model_inputs,
known_output_shapes=known_output_shapes,
ordered_input_names=self._ordered_input_names,
Expand All @@ -259,53 +258,41 @@ def forward(
io_binding.synchronize_outputs()

if self.use_cache:
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer(2)
past_key_values = ()
for name in self.key_value_output_names:
past_key_values += (output_buffers[name].view(output_shapes[name]),)
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer(2 for the self-attention)
past_key_values = tuple(
output_buffers[name].view(output_shapes[name]) for name in self.key_value_output_names
)

logits = output_buffers["logits"].view(output_shapes["logits"])

if "loss" in self.output_names:
loss = output_buffers["loss"].view(output_shapes["loss"])
else:
inputs["input_ids"] = input_ids.cpu().detach().numpy() if use_torch else input_ids

if "attention_mask" in self.inputs_names:
inputs["attention_mask"] = attention_mask.cpu().detach().numpy() if use_torch else attention_mask

if "labels" in self.inputs_names:
inputs["labels"] = labels.cpu().detach().numpy() if use_torch else labels

if "position_ids" in self.inputs_names:
if position_ids is None:
raise ValueError("position_ids was not passed but is a required input for this ONNX model.")
inputs["position_ids"] = position_ids.cpu().detach().numpy() if use_torch else position_ids

# Add the past_key_values to the decoder inputs
model_inputs = {
"input_ids": input_ids,
"position_ids": position_ids,
"attention_mask": attention_mask,
"use_cache_branch": use_cache_branch,
"labels": labels,
}
if past_key_values is not None:
for input_name, past_key_value in zip(self.key_value_input_names, past_key_values):
inputs[input_name] = past_key_value.cpu().detach().numpy() if use_torch else past_key_value
model_inputs.update(
zip(self.key_value_input_names, past_key_values),
)

if use_cache_branch is not None:
inputs["use_cache_branch"] = use_cache_branch.cpu().detach().numpy() if use_torch else use_cache_branch
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_outputs = self.model.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)

outputs = self.model.run(None, inputs)
loss = model_outputs.get("loss", None)
logits = model_outputs["logits"]

if self.use_cache:
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 for the self-attention)
past_key_values = tuple(
torch.from_numpy(outputs[self.output_names[key]]).to(self.device)
for key in self.key_value_output_names
)

logits = torch.from_numpy(outputs[self.output_names["logits"]]).to(self.device)
if "loss" in self.output_names:
loss = torch.from_numpy(outputs[self.output_names["loss"]]).to(self.device)
past_key_values = tuple(model_outputs[output_name] for output_name in self.key_value_output_names)

if self.use_cache and self.model_type != "gpt_bigcode":
# Tuple of tuple of length `n_layers`, with each tuple of length equal to the number of self-attention and
# per decoder layer
# Tuple of tuple of length `n_layers`, with each tuple of length equal to the number of self-attention and per decoder layer
past_key_values = tuple(
past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv)
)
Expand Down
Loading
Loading