Skip to content

Commit

Permalink
Add OpenVINO Tokenizers (#513)
Browse files Browse the repository at this point in the history
* Convert tokenizers with openvino_tokenizers

* Update optimum/exporters/openvino/__main__.py

* Refactor and Add Tests

* Fix t5 Test

* Add Warning

* Return Tests

* Move export_tokenizer to convert.py

Reuse existing preprocessors

* Avoid Double Tokenizer Save

* Fix Style

* Refactor After Review

* Skip Tokenizers Tests If No Package Installed

Check logs from tokneizers test

* Style Fix

* Fix OV Tokenizers Check

* Fix Tests

* Add Missing return

* Turn off tokenizer message if not installed

* Move tokenizers to OV dependencies

* Check OV Compatibility

* Bump OV Version

* Move OpenVINO Tokenizers To Optional Dependencies

* Add --convert-tokenizer Option to CLI

* Fix SD Tokenizer

---------

Co-authored-by: Sergey Lyalin <[email protected]>
  • Loading branch information
apaniukov and slyalin authored Feb 8, 2024
1 parent 1c14957 commit 2be2e75
Show file tree
Hide file tree
Showing 9 changed files with 182 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_openvino.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
python -m pip install --upgrade pip
# install PyTorch CPU version to avoid installing CUDA packages on GitHub runner without GPU
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install .[openvino,nncf,tests,diffusers]
pip install .[openvino,openvino-tokenizers,nncf,tests,diffusers]
- name: Test with Pytest
run: |
pytest tests/openvino/ --ignore test_modeling_basic
6 changes: 6 additions & 0 deletions optimum/commands/export/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,11 @@ def parse_args_openvino(parser: "ArgumentParser"):
"OpenVINO native inference code that expects kv-cache inputs and outputs in the model."
),
)
optional_group.add_argument(
"--convert-tokenizer",
action="store_true",
help="Add converted tokenizer and detokenizer with OpenVINO Tokenizers",
)


class OVExportCommand(BaseOptimumCLICommand):
Expand Down Expand Up @@ -151,5 +156,6 @@ def run(self):
compression_option=self.args.weight_format,
compression_ratio=self.args.ratio,
stateful=not self.args.disable_stateful,
convert_tokenizer=self.args.convert_tokenizer,
# **input_shapes,
)
36 changes: 29 additions & 7 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,21 @@
from typing import Any, Callable, Dict, Optional, Union

from requests.exceptions import ConnectionError as RequestsConnectionError
from transformers import AutoConfig, AutoTokenizer
from transformers import AutoConfig, PreTrainedTokenizerBase

from optimum.exporters import TasksManager
from optimum.exporters.onnx import __main__ as optimum_main
from optimum.exporters.onnx.base import OnnxConfig, OnnxConfigWithPast
from optimum.utils import DEFAULT_DUMMY_SHAPES
from optimum.utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors

from ...intel.utils.import_utils import is_nncf_available, is_optimum_version, is_transformers_version
from .convert import export_models
from ...intel.utils.import_utils import (
is_nncf_available,
is_openvino_tokenizers_available,
is_optimum_version,
is_transformers_version,
)
from .convert import export_models, export_tokenizer
from .stateful import ensure_export_task_support_stateful


Expand All @@ -41,7 +46,6 @@
]

OV_XML_FILE_NAME = "openvino_model.xml"

_MAX_UNCOMPRESSED_SIZE = 1e9

logger = logging.getLogger(__name__)
Expand All @@ -67,6 +71,7 @@ def main_export(
compression_option: Optional[str] = None,
compression_ratio: Optional[float] = None,
stateful: bool = True,
convert_tokenizer: bool = False,
**kwargs_shapes,
):
"""
Expand Down Expand Up @@ -318,13 +323,17 @@ class StoreAttr(object):
and getattr(model.config, "pad_token_id", None) is None
and task in ["text-classification"]
)

tokenizer = next(
(preprocessor for preprocessor in preprocessors if isinstance(preprocessor, PreTrainedTokenizerBase)), None
)

if needs_pad_token_id:
if pad_token_id is not None:
model.config.pad_token_id = pad_token_id
else:
elif tokenizer is not None:
try:
tok = AutoTokenizer.from_pretrained(model_name_or_path)
model.config.pad_token_id = tok.pad_token_id
model.config.pad_token_id = tokenizer.pad_token_id
except Exception:
raise ValueError(
"Could not infer the pad token id, which is needed in this case, please provide it with the --pad_token_id argument"
Expand All @@ -336,6 +345,15 @@ class StoreAttr(object):
generation_config.save_pretrained(output)
maybe_save_preprocessors(model_name_or_path, output)

if convert_tokenizer and tokenizer is not None and is_openvino_tokenizers_available():
try:
export_tokenizer(tokenizer, output)
except Exception as exception:
logger.warning(
"Could not load tokenizer using specified model ID or path. OpenVINO tokenizer/detokenizer "
f"models won't be generated. Exception: {exception}"
)

if model.config.is_encoder_decoder and task.startswith("text-generation"):
raise ValueError(
f"model.config.is_encoder_decoder is True and task is `{task}`, which are incompatible. If the task was auto-inferred, please fill a bug report"
Expand Down Expand Up @@ -365,10 +383,14 @@ class StoreAttr(object):
tokenizer = getattr(model, "tokenizer", None)
if tokenizer is not None:
tokenizer.save_pretrained(output.joinpath("tokenizer"))
if convert_tokenizer and is_openvino_tokenizers_available():
export_tokenizer(tokenizer, output)

tokenizer_2 = getattr(model, "tokenizer_2", None)
if tokenizer_2 is not None:
tokenizer_2.save_pretrained(output.joinpath("tokenizer_2"))
if convert_tokenizer and is_openvino_tokenizers_available():
export_tokenizer(tokenizer, output, suffix="_2")

model.save_config(output)

Expand Down
52 changes: 52 additions & 0 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

from transformers import T5Tokenizer, T5TokenizerFast
from transformers.utils import is_tf_available, is_torch_available

from openvino.runtime import PartialShape, save_model
from openvino.runtime.exceptions import OVTypeError
from openvino.runtime.utils.types import get_element_type
from openvino.tools.ovc import convert_model
from optimum.exporters.onnx.base import OnnxConfig
Expand Down Expand Up @@ -536,3 +538,53 @@ def export_models(

outputs = list(map(list, zip(*outputs)))
return outputs


UNSUPPORTED_TOKENIZER_CLASSES = (
T5Tokenizer,
T5TokenizerFast,
)


def export_tokenizer(
tokenizer,
output: Union[str, Path],
suffix: Optional[str] = "",
):
from optimum.intel.openvino import OV_DETOKENIZER_NAME, OV_TOKENIZER_NAME # avoid circular imports

if isinstance(tokenizer, UNSUPPORTED_TOKENIZER_CLASSES):
logger.info(f"OpenVINO Tokenizer export for {type(tokenizer).__name__} is not supported.")
return

try:
from openvino_tokenizers import convert_tokenizer
except ModuleNotFoundError:
# avoid this message before tokenizers are part of the openvino dependencies
# logger.info(
# "Run `pip install openvino-tokenizers[transformers]` to get OpenVINO tokenizer/detokenizer models."
# )
return

if not isinstance(output, Path):
output = Path(output)

try:
converted = convert_tokenizer(tokenizer, with_detokenizer=True)
except NotImplementedError:
logger.warning("Detokenizer is not supported, convert tokenizer only.")
converted = convert_tokenizer(tokenizer, with_detokenizer=False)
except OVTypeError:
logger.warning(f"OpenVINO Tokenizer export for {type(tokenizer).__name__} is not supported.")
return
except Exception as exception:
logger.warning(
f"OpenVINO Tokenizer export for {type(tokenizer).__name__} is not supported. Exception: {exception}"
)
return

if not isinstance(converted, tuple):
converted = (converted,)

for model, file_name in zip(converted, (OV_TOKENIZER_NAME, OV_DETOKENIZER_NAME)):
save_model(model, output / file_name.format(suffix))
9 changes: 8 additions & 1 deletion optimum/intel/openvino/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,14 @@
import logging

from ..utils.import_utils import is_diffusers_available, is_nncf_available
from .utils import OV_DECODER_NAME, OV_DECODER_WITH_PAST_NAME, OV_ENCODER_NAME, OV_XML_FILE_NAME
from .utils import (
OV_DECODER_NAME,
OV_DECODER_WITH_PAST_NAME,
OV_DETOKENIZER_NAME,
OV_ENCODER_NAME,
OV_TOKENIZER_NAME,
OV_XML_FILE_NAME,
)


if is_nncf_available():
Expand Down
3 changes: 3 additions & 0 deletions optimum/intel/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
OV_DECODER_NAME = "openvino_decoder_model.xml"
OV_DECODER_WITH_PAST_NAME = "openvino_decoder_with_past_model.xml"

OV_TOKENIZER_NAME = "openvino_tokenizer{}.xml"
OV_DETOKENIZER_NAME = "openvino_detokenizer{}.xml"

ONNX_WEIGHTS_NAME = "model.onnx"
ONNX_ENCODER_NAME = "encoder_model.onnx"
ONNX_DECODER_NAME = "decoder_model.onnx"
Expand Down
37 changes: 34 additions & 3 deletions optimum/intel/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import importlib.util
import logging
import operator as op
import sys
from collections import OrderedDict
Expand All @@ -27,6 +28,8 @@
import importlib.metadata as importlib_metadata


logger = logging.getLogger(__name__)

STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}

_optimum_version = importlib_metadata.version("optimum")
Expand Down Expand Up @@ -75,13 +78,38 @@
version = get_version()
# avoid invalid format
if "-" in version:
major_version, dev_info = version.split("-", 1)
ov_major_version, dev_info = version.split("-", 1)
commit_id = dev_info.split("-")[0]
version = f"{major_version}-{commit_id}"
version = f"{ov_major_version}-{commit_id}"
_openvino_version = version
except ImportError:
_openvino_available = False

_openvino_tokenizers_available = importlib.util.find_spec("openvino_tokenizers") is not None and _openvino_available
_openvino_tokenizers_version = "N/A"
if _openvino_tokenizers_available:
try:
_openvino_tokenizers_version = importlib_metadata.version("openvino_tokenizers")
except importlib_metadata.PackageNotFoundError:
_openvino_tokenizers_available = False

if _openvino_tokenizers_available and _openvino_tokenizers_version != "N/A":
_compatible_openvino_version = next(
(
requirement.split("==")[-1]
for requirement in importlib_metadata.requires("openvino-tokenizers")
if requirement.startswith("openvino==")
),
"",
)
_openvino_tokenizers_available = _compatible_openvino_version == ov_major_version
if not _openvino_tokenizers_available:
logger.warning(
"OpenVINO Tokenizer version is not compatible with OpenVINO version. "
f"Installed OpenVINO version: {ov_major_version},"
f"OpenVINO Tokenizers requires {_compatible_openvino_version}. "
f"OpenVINO Tokenizers models will not be added during export."
)

_nncf_available = importlib.util.find_spec("nncf") is not None
_nncf_version = "N/A"
Expand All @@ -91,7 +119,6 @@
except importlib_metadata.PackageNotFoundError:
_nncf_available = False


_diffusers_available = importlib.util.find_spec("diffusers") is not None
_diffusers_version = "N/A"
if _diffusers_available:
Expand Down Expand Up @@ -135,6 +162,10 @@ def is_openvino_available():
return _openvino_available


def is_openvino_tokenizers_available():
return _openvino_tokenizers_available


def is_nncf_available():
return _nncf_available

Expand Down
9 changes: 8 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,14 @@
"onnxruntime<1.15.0",
"transformers>=4.34.0",
],
"openvino": ["openvino>=2023.2", "onnx", "onnxruntime", "transformers>=4.36.0", "optimum>=1.16.1"],
"openvino": [
"openvino>=2023.3",
"onnx",
"onnxruntime",
"transformers>=4.36.0",
"optimum>=1.16.1",
],
"openvino-tokenizers": ["openvino-tokenizers[transformers]"],
"nncf": ["nncf @ git+https://github.com/openvinotoolkit/nncf.git"],
"ipex": ["intel-extension-for-pytorch", "onnx"],
"diffusers": ["diffusers"],
Expand Down
41 changes: 41 additions & 0 deletions tests/openvino/test_exporters_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import subprocess
import unittest
from pathlib import Path
from tempfile import TemporaryDirectory

from parameterized import parameterized
Expand All @@ -38,6 +39,7 @@
OVStableDiffusionXLPipeline,
)
from optimum.intel.openvino.utils import _HEAD_TO_AUTOMODELS
from optimum.intel.utils.import_utils import is_openvino_tokenizers_available


class OVCLIExportTestCase(unittest.TestCase):
Expand All @@ -61,6 +63,19 @@ class OVCLIExportTestCase(unittest.TestCase):
("stable-diffusion-xl", "stable-diffusion-xl"),
("stable-diffusion-xl", "stable-diffusion-xl-refiner"),
)
EXPECTED_NUMBER_OF_TOKENIZER_MODELS = {
"gpt2": 2,
"t5": 0, # failed internal sentencepiece check - no <s> token in the vocab
"albert": 0, # not supported yet
"distilbert": 1, # no detokenizer
"roberta": 2,
"vit": 0, # no tokenizer for image model
"wav2vec2": 0, # no tokenizer
"bert": 1, # no detokenizer
"blenderbot": 2,
"stable-diffusion": 0, # not supported
"stable-diffusion-xl": 0, # not supported
}

SUPPORTED_4BIT_ARCHITECTURES = (("text-generation-with-past", "opt125m"),)

Expand Down Expand Up @@ -98,6 +113,32 @@ def test_exporters_cli(self, task: str, model_type: str):
model_kwargs = {"use_cache": task.endswith("with-past")} if "generation" in task else {}
eval(_HEAD_TO_AUTOMODELS[task.replace("-with-past", "")]).from_pretrained(tmpdir, **model_kwargs)

@parameterized.expand(
arch
for arch in SUPPORTED_ARCHITECTURES
if not arch[0].endswith("-with-past") and not arch[1].endswith("-refiner")
)
@unittest.skipIf(not is_openvino_tokenizers_available(), reason="OpenVINO Tokenizers not available")
def test_exporters_cli_tokenizers(self, task: str, model_type: str):
with TemporaryDirectory() as tmpdir:
output = subprocess.check_output(
f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} --convert-tokenizer --task {task} {tmpdir}",
shell=True,
stderr=subprocess.STDOUT,
).decode()
save_dir = Path(tmpdir)
number_of_tokenizers = sum("tokenizer" in file for file in map(str, save_dir.rglob("*.xml")))
self.assertEqual(
self.EXPECTED_NUMBER_OF_TOKENIZER_MODELS[model_type],
number_of_tokenizers,
f"OVT: {is_openvino_tokenizers_available() }",
)

if number_of_tokenizers == 1:
self.assertTrue("Detokenizer is not supported, convert tokenizer only." in output, output)
elif number_of_tokenizers == 0 and task not in ("image-classification", "audio-classification"):
self.assertTrue(("OpenVINO Tokenizer export for" in output and "is not supported." in output), output)

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_exporters_cli_fp16(self, task: str, model_type: str):
with TemporaryDirectory() as tmpdir:
Expand Down

0 comments on commit 2be2e75

Please sign in to comment.