Skip to content

Commit

Permalink
deprecate and create alternative
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jul 3, 2024
1 parent f755a58 commit 7d649c0
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 1 deletion.
5 changes: 5 additions & 0 deletions optimum/onnxruntime/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ def __init__(
label_names (`List[str]`, `optional`):
The list of keys in your dictionary of inputs that correspond to the labels.
"""

logger.warning(
"The class `optimum.onnxruntime.model.ORTModel` is deprecated and will be removed in the next release."
)

self.compute_metrics = compute_metrics
self.label_names = ["labels"] if label_names is None else label_names
self.session = InferenceSession(str(model_path), providers=[execution_provider])
Expand Down
62 changes: 61 additions & 1 deletion optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@
import re
from enum import Enum
from inspect import signature
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from packaging import version
from tqdm import tqdm
from transformers import EvalPrediction
from transformers.trainer_pt_utils import nested_concat
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import logging

import onnxruntime as ort
Expand All @@ -30,6 +34,12 @@
from ..utils.import_utils import _is_package_available


if TYPE_CHECKING:
from datasets import Dataset

from .modeling_ort import ORTModel


logger = logging.get_logger(__name__)

ONNX_WEIGHTS_NAME = "model.onnx"
Expand Down Expand Up @@ -341,3 +351,53 @@ class ORTQuantizableOperator(Enum):
Resize = "Resize"
AveragePool = "AveragePool"
Concat = "Concat"


def evaluation_loop(
model: "ORTModel",
dataset: "Dataset",
label_names: Optional[List[str]] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
):
"""
Run evaluation and returns metrics and predictions.
Args:
model (`ORTModel`):
The ONNXRuntime model to use for the evaluation step.
dataset (`datasets.Dataset`):
Dataset to use for the evaluation step.
label_names (`List[str]`, `optional`):
The list of keys in your dictionary of inputs that correspond to the labels.
compute_metrics (`Callable[[EvalPrediction], Dict]`, `optional`):
The function that will be used to compute metrics at evaluation. Must take an `EvalPrediction` and
return a dictionary string to metric values.
"""

all_preds = None
all_labels = None

for inputs in tqdm(dataset, desc="Evaluation"):
has_labels = all(inputs.get(k) is not None for k in label_names)
if has_labels:
labels = tuple(np.array([inputs.get(name)]) for name in label_names)
if len(labels) == 1:
labels = labels[0]
else:
labels = None

inputs = {key: np.array([inputs[key]]) for key in model.input_names if key in inputs}
preds = model(**inputs)

if len(preds) == 1:
preds = preds[0]

all_preds = preds if all_preds is None else nested_concat(all_preds, preds, padding_index=-100)
all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)

if compute_metrics is not None and all_preds is not None and all_labels is not None:
metrics = compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
else:
metrics = {}

return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=len(dataset))

0 comments on commit 7d649c0

Please sign in to comment.