From 9b3fb47c72ac8fae9c5549e68026155a7aaab762 Mon Sep 17 00:00:00 2001 From: Sofya Balandina Date: Wed, 30 Aug 2023 15:25:55 +0100 Subject: [PATCH] Move model conversion from MO to openVINO API (#1270) * Move model conversion from MO to openVINO API notebooks\211-speech-to-text\211-speech-to-text.ipynb notebooks\215-image-inpainting\215-image-inpainting.ipynb notebooks\216-attention-center\216-attention-center.ipynb notebooks\217-vision-deblur\217-vision-deblur.ipynb notebooks\223-text-prediction\223-text-prediction.ipynb notebooks\226-yolov7-optimization\226-yolov7-optimization.ipynb notebooks\228-clip-zero-shot-image-classification\228-clip-zero-shot-convert.ipynb notebooks\229-distilbert-sequence-classification\229-distilbert-sequence-classification.ipynb notebooks\231-instruct-pix2pix-image-editing\231-instruct-pix2pix-image-editing.ipynb * fix comments --- .../211-speech-to-text.ipynb | 28 ++++++------ .../215-image-inpainting.ipynb | 11 +++-- .../216-attention-center.ipynb | 33 +++++++++++--- .../217-vision-deblur/217-vision-deblur.ipynb | 20 ++++----- .../223-text-prediction.ipynb | 21 ++++----- .../226-yolov7-optimization.ipynb | 20 ++++----- .../228-clip-zero-shot-convert.ipynb | 16 +++---- ...9-distilbert-sequence-classification.ipynb | 44 ++++++++++++------- .../231-instruct-pix2pix-image-editing.ipynb | 38 ++++++++-------- 9 files changed, 123 insertions(+), 108 deletions(-) diff --git a/notebooks/211-speech-to-text/211-speech-to-text.ipynb b/notebooks/211-speech-to-text/211-speech-to-text.ipynb index cfbc8688e4b..9bf65701843 100644 --- a/notebooks/211-speech-to-text/211-speech-to-text.ipynb +++ b/notebooks/211-speech-to-text/211-speech-to-text.ipynb @@ -49,7 +49,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install -q \"librosa>=0.8.1\"" + "!pip install -q \"librosa>=0.8.1\" \"openvino-dev==2023.1.0.dev20230811\" \"onnx\"" ] }, { @@ -76,8 +76,7 @@ "import librosa.display\n", "import numpy as np\n", "import scipy\n", - "from openvino.runtime import Core, serialize, Tensor\n", - "from openvino.tools import mo" + "import openvino as ov" ] }, { @@ -255,9 +254,9 @@ " dynamic_axes={\"audio_signal\": {0: \"batch_size\", 2: \"wave_len\"}, \"output\": {0: \"batch_size\", 2: \"wave_len\"}}\n", " )\n", " # convert model to OpenVINO Model using model conversion API\n", - " ov_model = mo.convert_model(str(onnx_model_path))\n", - " # serialize model to IR for next usage\n", - " serialize(ov_model, str(converted_model_path))" + " ov_model = ov.convert_model(str(onnx_model_path))\n", + " # save model in IR format for next usage\n", + " ov.save_model(ov_model, str(converted_model_path))" ] }, { @@ -636,17 +635,18 @@ }, "outputs": [], "source": [ - "ie = Core()" + "core = ov.Core()" ] }, { + "attachments": {}, "cell_type": "markdown", "id": "ce3fc33e", "metadata": {}, "source": [ "You may run the model on multiple devices. By default, it will load the model on CPU (you can choose manually CPU, GPU etc.) or let the engine choose the best available device (AUTO).\n", "\n", - "To list all available devices that can be used, run `print(ie.available_devices)` command." + "To list all available devices that can be used, run `print(core.available_devices)` command." ] }, { @@ -666,7 +666,7 @@ } ], "source": [ - "print(ie.available_devices)" + "print(core.available_devices)" ] }, { @@ -686,8 +686,6 @@ "source": [ "import ipywidgets as widgets\n", "\n", - "core = Core()\n", - "\n", "device = widgets.Dropdown(\n", " options=core.available_devices + [\"AUTO\"],\n", " value='AUTO',\n", @@ -707,14 +705,14 @@ }, "outputs": [], "source": [ - "model = ie.read_model(\n", + "model = core.read_model(\n", " model=f\"{model_folder}/public/{model_name}/{precision}/{model_name}.xml\"\n", ")\n", "model_input_layer = model.input(0)\n", "shape = model_input_layer.partial_shape\n", "shape[2] = -1\n", "model.reshape({model_input_layer: shape})\n", - "compiled_model = ie.compile_model(model=model, device_name=device.value)" + "compiled_model = core.compile_model(model=model, device_name=device.value)" ] }, { @@ -738,7 +736,7 @@ "source": [ "output_layer_ir = compiled_model.output(0)\n", "\n", - "character_probabilities = compiled_model([Tensor(audio)])[output_layer_ir]" + "character_probabilities = compiled_model([ov.Tensor(audio)])[output_layer_ir]" ] }, { @@ -854,4 +852,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/notebooks/215-image-inpainting/215-image-inpainting.ipynb b/notebooks/215-image-inpainting/215-image-inpainting.ipynb index 39eed7059db..13e368b6a9b 100644 --- a/notebooks/215-image-inpainting/215-image-inpainting.ipynb +++ b/notebooks/215-image-inpainting/215-image-inpainting.ipynb @@ -37,8 +37,7 @@ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from zipfile import ZipFile\n", - "from openvino.tools import mo\n", - "from openvino.runtime import Core, Tensor, serialize\n", + "import openvino as ov\n", "\n", "sys.path.append(\"../utils\")\n", "import notebook_utils as utils" @@ -114,8 +113,8 @@ "\n", "# Run model conversion API to convert model to OpenVINO IR FP32 format, if the IR file does not exist.\n", "if not ir_path.exists():\n", - " ov_model = mo.convert_model(model_path, input_shape=[[1,512,680,3],[1,512,680,1]])\n", - " serialize(ov_model, str(ir_path))\n", + " ov_model = ov.convert_model(model_path, input=[[1,512,680,3],[1,512,680,1]])\n", + " ov.save_model(ov_model, str(ir_path))\n", "else:\n", " print(f\"{ir_path} already exists.\")" ] @@ -144,7 +143,7 @@ "metadata": {}, "outputs": [], "source": [ - "core = Core()\n", + "core = ov.Core()\n", "\n", "# Read the model.xml and weights file\n", "model = core.read_model(model=ir_path)" @@ -415,7 +414,7 @@ } ], "source": [ - "result = compiled_model([Tensor(masked_image.astype(np.float32)), Tensor(mask.astype(np.float32))])[output_layer]\n", + "result = compiled_model([ov.Tensor(masked_image.astype(np.float32)), ov.Tensor(mask.astype(np.float32))])[output_layer]\n", "result = result.squeeze().astype(np.uint8)\n", "plt.figure(figsize=(16, 12))\n", "plt.imshow(cv2.cvtColor(result, cv2.COLOR_BGR2RGB));" diff --git a/notebooks/216-attention-center/216-attention-center.ipynb b/notebooks/216-attention-center/216-attention-center.ipynb index a667a66f819..1c529d330a1 100644 --- a/notebooks/216-attention-center/216-attention-center.ipynb +++ b/notebooks/216-attention-center/216-attention-center.ipynb @@ -28,6 +28,15 @@ "- [Get result with OpenVINO IR model](#Get-result-with-OpenVINO-IR-model-Uparrow)\n" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "! pip install \"openvino==2023.1.0.dev20230811\"" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -59,8 +68,7 @@ "from pathlib import Path\n", "import matplotlib.pyplot as plt\n", "\n", - "from openvino.tools import mo\n", - "from openvino.runtime import serialize, Core" + "import openvino as ov" ] }, { @@ -85,13 +93,14 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Convert Tensorflow Lite model to OpenVINO IR format [$\\Uparrow$](#Table-of-content:)\n", "\n", "The attention-center model is pre-trained model in TensorFlow Lite format. In this Notebook the model will be converted to \n", - "OpenVINO IR format with Model Optimizer. This step will be skipped if the model have already been converted. For more information about Model Optimizer, please, see the [Model Optimizer Developer Guide]( https://docs.openvino.ai/2023.0/openvino_docs_MO_DG_Deep_Learning_Model_Optimizer_DevGuide.html). \n", + "OpenVINO IR format with model conversion API. For more information about model conversion, see this [page](https://docs.openvino.ai/2023.0/openvino_docs_model_processing_introduction.html). This step is also skipped if the model is already converted.\n", "\n", "Also TFLite models format is supported in OpenVINO by TFLite frontend, so the model can be passed directly to `core.read_model()`. You can find example in [002-openvino-api](https://github.com/openvinotoolkit/openvino_notebooks/tree/main/notebooks/002-openvino-api)." ] @@ -114,11 +123,11 @@ "\n", "ir_model_path = Path(\"./model/ir_center_model.xml\")\n", "\n", - "core = Core()\n", + "core = ov.Core()\n", "\n", "if not ir_model_path.exists():\n", - " model = mo.convert_model(tflite_model_path)\n", - " serialize(model, ir_model_path.as_posix())\n", + " model = ov.convert_model(tflite_model_path, input=[('image:0', [1,480,640,3], ov.Type.f32)])\n", + " ov.save_model(model, ir_model_path)\n", " print(\"IR model saved to {}\".format(ir_model_path))\n", "else:\n", " print(\"Read IR model from {}\".format(ir_model_path))\n", @@ -126,6 +135,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -307,8 +317,17 @@ "source": [ "import io\n", "import PIL\n", + "from urllib.request import urlretrieve\n", + "\n", + "img_path = Path(\"data/coco.jpg\")\n", + "img_path.parent.mkdir(parents=True, exist_ok=True)\n", + "urlretrieve(\n", + " \"https://storage.openvinotoolkit.org/repositories/openvino_notebooks/data/data/image/coco.jpg\",\n", + " img_path,\n", + ")\n", + "\n", "# read uploaded image\n", - "image = PIL.Image.open(io.BytesIO(load_file_widget.value[-1]['content'])) if load_file_widget.value else PIL.Image.open(\"../data/image/coco.jpg\")\n", + "image = PIL.Image.open(io.BytesIO(list(load_file_widget.value.values())[-1]['content'])) if load_file_widget.value else PIL.Image.open(img_path)\n", "image.convert(\"RGB\")\n", "\n", "input_image = Image((480, 640), image=(np.ascontiguousarray(image)[:, :, ::-1]).astype(np.uint8))\n", diff --git a/notebooks/217-vision-deblur/217-vision-deblur.ipynb b/notebooks/217-vision-deblur/217-vision-deblur.ipynb index e97f5ffdf84..107e71cbafa 100644 --- a/notebooks/217-vision-deblur/217-vision-deblur.ipynb +++ b/notebooks/217-vision-deblur/217-vision-deblur.ipynb @@ -79,7 +79,7 @@ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from IPython.display import Markdown, display\n", - "from openvino.runtime import Core\n", + "import openvino as ov\n", "\n", "sys.path.append(\"../utils\")\n", "from notebook_utils import load_image" @@ -149,7 +149,7 @@ "source": [ "import ipywidgets as widgets\n", "\n", - "core = Core()\n", + "core = ov.Core()\n", "\n", "device = widgets.Dropdown(\n", " options=core.available_devices + [\"AUTO\"],\n", @@ -286,7 +286,7 @@ "source": [ "### Convert DeblurGAN-v2 Model to OpenVINO IR format [$\\Uparrow$](#Table-of-content:)\n", "\n", - "For best results with OpenVINO, it is recommended to convert the model to OpenVINO IR format. To convert the PyTorch model, we will use model conversion Python API. The `mo.convert_model` Python function returns an OpenVINO model ready to load on a device and start making predictions. We can save it on a disk for next usage with `openvino.runtime.serialize`. For more information about model conversion Python API, see this [page](https://docs.openvino.ai/2023.0/openvino_docs_model_processing_introduction.html).\n", + "For best results with OpenVINO, it is recommended to convert the model to OpenVINO IR format. To convert the PyTorch model, we will use model conversion Python API. The `ov.convert_model` Python function returns an OpenVINO model ready to load on a device and start making predictions. We can save the model on the disk for next usage with `ov.save_model`. For more information about model conversion Python API, see this [page](https://docs.openvino.ai/2023.0/openvino_docs_model_processing_introduction.html).\n", "\n", "Model conversion may take a while." ] @@ -298,15 +298,12 @@ "metadata": {}, "outputs": [], "source": [ - "from openvino.tools import mo\n", - "from openvino.runtime import serialize\n", - "\n", "deblur_gan_model = DeblurV2(\"model/public/deblurgan-v2/ckpt/fpn_mobilenet.h5\", \"fpn_mobilenet\")\n", "\n", "with torch.no_grad():\n", " deblur_gan_model.eval()\n", - " ov_model = mo.convert_model(deblur_gan_model, input_shape=[[1,3,736,1312]], compress_to_fp16=(precision == \"FP16\"))\n", - " serialize(ov_model, model_xml_path)" + " ov_model = ov.convert_model(deblur_gan_model, example_input=torch.ones((1,3,736,1312), dtype=torch.float32), input=[[1,3,736,1312]])\n", + " ov.save_model(ov_model, model_xml_path, compress_to_fp16=(precision == \"FP16\"))" ] }, { @@ -317,7 +314,7 @@ "source": [ "## Load the Model [$\\Uparrow$](#Table-of-content:)\n", "\n", - "Load and compile the DeblurGAN-v2 model in the OpenVINO Runtime with `ie.read_model` and compile it for the specified device with `ie.compile_model`. Get input and output keys and the expected input shape for the model." + "Load and compile the DeblurGAN-v2 model in the OpenVINO Runtime with `core.read_model` and compile it for the specified device with `core.compile_model`. Get input and output keys and the expected input shape for the model." ] }, { @@ -327,9 +324,8 @@ "metadata": {}, "outputs": [], "source": [ - "ie = Core()\n", - "model = ie.read_model(model=model_xml_path)\n", - "compiled_model = ie.compile_model(model=model, device_name=device.value)" + "model = core.read_model(model=model_xml_path)\n", + "compiled_model = core.compile_model(model=model, device_name=device.value)" ] }, { diff --git a/notebooks/223-text-prediction/223-text-prediction.ipynb b/notebooks/223-text-prediction/223-text-prediction.ipynb index e560426b0e9..d7b8bbb160e 100644 --- a/notebooks/223-text-prediction/223-text-prediction.ipynb +++ b/notebooks/223-text-prediction/223-text-prediction.ipynb @@ -102,7 +102,7 @@ ], "source": [ "# Install Gradio for Interactive Inference and other requirements\n", - "!pip install -q \"openvino-dev>=2023.0.0\"\n", + "!pip install -q \"openvino==2023.1.0.dev20230811\"\n", "!pip install -q gradio\n", "!pip install -q transformers[torch] onnx" ] @@ -171,6 +171,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -181,7 +182,7 @@ "For starting work with GPT-Neo model using OpenVINO, a model should be converted to OpenVINO Intermediate Representation (IR) format. HuggingFace provides a GPT-Neo model in PyTorch format, which is supported in OpenVINO via conversion to ONNX. We use the HuggingFace transformers library's onnx module to export the model to ONNX. `transformers.onnx.export` accepts the preprocessing function for input sample generation (the tokenizer in our case), an instance of the model, ONNX export configuration, the ONNX opset version for export and output path. More information about transformers export to ONNX can be found in HuggingFace [documentation](https://huggingface.co/docs/transformers/serialization).\n", "\n", "While ONNX models are directly supported by OpenVINO runtime, it can be useful to convert them to IR format to take advantage of OpenVINO optimization tools and features.\n", - "The `mo.convert_model` Python function of [model conversion API](https://docs.openvino.ai/2023.0/openvino_docs_model_processing_introduction.html) can be used for converting the model. The function returns instance of OpenVINO Model class, which is ready to use in Python interface but can also be serialized to OpenVINO IR format for future execution using `openvino.runtime.serialize`. In our case, the `compress_to_fp16` parameter is enabled for compression model weights to FP16 precision and also specified dynamic input shapes with a possible shape range (from 1 token to a maximum length defined in our processing function) for optimization of memory consumption." + "The `ov.convert_model` Python function of [model conversion API](https://docs.openvino.ai/2023.0/openvino_docs_model_processing_introduction.html) can be used for converting the model. The function returns instance of OpenVINO Model class, which is ready to use in Python interface. The Model can also be save on device in OpenVINO IR format for future execution using `ov.save_model`. In our case dynamic input shapes with a possible shape range (from 1 token to a maximum length defined in our processing function) are specified for optimization of memory consumption." ] }, { @@ -200,10 +201,9 @@ ], "source": [ "from pathlib import Path\n", - "from openvino.runtime import serialize\n", - "from openvino.tools import mo\n", "from transformers.onnx import export, FeaturesManager\n", "\n", + "import openvino as ov\n", "\n", "# define path for saving onnx model\n", "onnx_path = Path(\"model/text_generator.onnx\")\n", @@ -223,12 +223,12 @@ "\n", "# convert model to openvino\n", "if model_name.value == \"PersonaGPT (Converastional)\":\n", - " ov_model = mo.convert_model(onnx_path, compress_to_fp16=True, input=\"input_ids[1,-1],attention_mask[1,-1]\")\n", + " ov_model = ov.convert_model(onnx_path, input=[('input_ids', [1, -1], ov.Type.i64), ('attention_mask', [1,-1], ov.Type.i64)])\n", "else:\n", - " ov_model = mo.convert_model(onnx_path, compress_to_fp16=True, input=\"input_ids[1,1..128],attention_mask[1,1..128]\")\n", + " ov_model = ov.convert_model(onnx_path, input=[('input_ids', [1, ov.Dimension(1,128)], ov.Type.i64), ('attention_mask', [1, ov.Dimension(1,128)], ov.Type.i64)])\n", "\n", "# serialize openvino model\n", - "serialize(ov_model, str(model_path))" + "ov.save_model(ov_model, str(model_path))" ] }, { @@ -271,10 +271,10 @@ } ], "source": [ - "from openvino.runtime import Core\n", "import ipywidgets as widgets\n", "\n", - "core = Core()\n", + "# initialize openvino core\n", + "core = ov.Core()\n", "\n", "device = widgets.Dropdown(\n", " options=core.available_devices + [\"AUTO\"],\n", @@ -292,9 +292,6 @@ "metadata": {}, "outputs": [], "source": [ - "# initialize openvino core\n", - "core = Core()\n", - "\n", "# read the model and corresponding weights from file\n", "model = core.read_model(model_path)" ] diff --git a/notebooks/226-yolov7-optimization/226-yolov7-optimization.ipynb b/notebooks/226-yolov7-optimization/226-yolov7-optimization.ipynb index 95b9bf5f142..7eb25292d94 100644 --- a/notebooks/226-yolov7-optimization/226-yolov7-optimization.ipynb +++ b/notebooks/226-yolov7-optimization/226-yolov7-optimization.ipynb @@ -342,13 +342,14 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Convert ONNX Model to OpenVINO Intermediate Representation (IR) [$\\Uparrow$](#Table-of-content:)\n", - "While ONNX models are directly supported by OpenVINO runtime, it can be useful to convert them to IR format to take the advantage of OpenVINO optimization tools and features.\n", - "The `mo.convert_model` python function in OpenVINO Model Optimizer can be used for converting the model.\n", - "The function returns instance of OpenVINO Model class, which is ready to use in Python interface. However, it can also be serialized to OpenVINO IR format for future execution." + "While ONNX models are directly supported by OpenVINO runtime, it can be useful to convert them to IR format to take the advantage of OpenVINO model conversion API features.\n", + "The `ov.convert_model` python function of [model conversion API](https://docs.openvino.ai/2023.0/openvino_docs_model_processing_introduction.html) can be used for converting the model.\n", + "The function returns instance of OpenVINO Model class, which is ready to use in Python interface. However, it can also be save on device in OpenVINO IR format using `ov.save_model` for future execution." ] }, { @@ -357,12 +358,11 @@ "metadata": {}, "outputs": [], "source": [ - "from openvino.tools import mo\n", - "from openvino.runtime import serialize\n", + "import openvino as ov\n", "\n", - "model = mo.convert_model('model/yolov7-tiny.onnx')\n", + "model = ov.convert_model('model/yolov7-tiny.onnx')\n", "# serialize model for saving IR\n", - "serialize(model, 'model/yolov7-tiny.xml')" + "ov.save_model(model, 'model/yolov7-tiny.xml')" ] }, { @@ -498,10 +498,9 @@ "source": [ "from typing import List, Tuple, Dict\n", "from utils.general import scale_coords, non_max_suppression\n", - "from openvino.runtime import Model\n", "\n", "\n", - "def detect(model: Model, image_path: Path, conf_thres: float = 0.25, iou_thres: float = 0.45, classes: List[int] = None, agnostic_nms: bool = False):\n", + "def detect(model: ov.Model, image_path: Path, conf_thres: float = 0.25, iou_thres: float = 0.45, classes: List[int] = None, agnostic_nms: bool = False):\n", " \"\"\"\n", " OpenVINO YOLOv7 model inference function. Reads image, preprocess it, runs model inference and postprocess results using NMS.\n", " Parameters:\n", @@ -554,8 +553,7 @@ "metadata": {}, "outputs": [], "source": [ - "from openvino.runtime import Core\n", - "core = Core()\n", + "core = ov.Core()\n", "# read converted model\n", "model = core.read_model('model/yolov7-tiny.xml')" ] diff --git a/notebooks/228-clip-zero-shot-image-classification/228-clip-zero-shot-convert.ipynb b/notebooks/228-clip-zero-shot-image-classification/228-clip-zero-shot-convert.ipynb index 86ae3ef1baa..c7313f6c966 100644 --- a/notebooks/228-clip-zero-shot-image-classification/228-clip-zero-shot-convert.ipynb +++ b/notebooks/228-clip-zero-shot-image-classification/228-clip-zero-shot-convert.ipynb @@ -251,6 +251,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -258,7 +259,8 @@ "\n", "![conversion_path](https://user-images.githubusercontent.com/29454499/208048580-8264e54c-151c-43ef-9e25-1302cd0dd7a2.png)\n", "\n", - "For best results with OpenVINO, it is recommended to convert the model to OpenVINO IR format. OpenVINO supports PyTorch via ONNX conversion. The `torch.onnx.export` function enables conversion of PyTorch models to ONNX format. It requires to provide initialized model object, example of inputs for tracing and path for saving result. The model contains operations which supported for ONNX tracing starting with opset 14, it is recommended to use it as `opset_version` parameter. Besides that, we need to have opportunity to provide descriptions various of length and images with different sizes, for preserving this capability after ONNX conversion, `dynamic_axes` parameter can be used. More information about PyTorch to ONNX exporting can be found in this [tutorial](https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html) and [PyTorch documentation](https://pytorch.org/docs/stable/onnx.html). We will use `mo.convert_model` functionality to convert the ONNX model. The `mo.convert_model` Python function returns an OpenVINO model ready to load on the device and start making predictions. We can save it on disk for the next usage with `openvino.runtime.serialize`.\n" + "For best results with OpenVINO, it is recommended to convert the model to OpenVINO IR format. OpenVINO supports PyTorch via ONNX conversion. The `torch.onnx.export` function enables conversion of PyTorch models to ONNX format. It requires to provide initialized model object, example of inputs for tracing and path for saving result. The model contains operations which supported for ONNX tracing starting with opset 14, it is recommended to use it as `opset_version` parameter. Besides that, we need to have opportunity to provide descriptions various of length and images with different sizes, for preserving this capability after ONNX conversion, `dynamic_axes` parameter can be used. More information about PyTorch to ONNX exporting can be found in this [tutorial](https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html) and [PyTorch documentation](https://pytorch.org/docs/stable/onnx.html). \n", + "To convert the ONNX model to OpenVINO IR format we will use `ov.convert_model` of [model conversion API](https://docs.openvino.ai/2023.0/openvino_docs_model_processing_introduction.html). The `ov.convert_model` Python function returns an OpenVINO Model object ready to load on the device and start making predictions. We can save it on disk for the next usage with `ov.save_model`.\n" ] }, { @@ -324,11 +326,10 @@ }, "outputs": [], "source": [ - "from openvino.runtime import serialize\n", - "from openvino.tools import mo\n", + "import openvino as ov\n", "\n", - "ov_model = mo.convert_model('clip-vit-base-patch16.onnx', compress_to_fp16=True)\n", - "serialize(ov_model, 'clip-vit-base-patch16.xml')" + "ov_model = ov.convert_model('clip-vit-base-patch16.onnx')\n", + "ov.save_model(ov_model, 'clip-vit-base-patch16.xml')" ] }, { @@ -352,10 +353,9 @@ "outputs": [], "source": [ "from scipy.special import softmax\n", - "from openvino.runtime import Core\n", "\n", "# create OpenVINO core object instance\n", - "core = Core()" + "core = ov.Core()" ] }, { @@ -656,4 +656,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} \ No newline at end of file +} diff --git a/notebooks/229-distilbert-sequence-classification/229-distilbert-sequence-classification.ipynb b/notebooks/229-distilbert-sequence-classification/229-distilbert-sequence-classification.ipynb index 8a25aa1709b..d89b55e98ed 100644 --- a/notebooks/229-distilbert-sequence-classification/229-distilbert-sequence-classification.ipynb +++ b/notebooks/229-distilbert-sequence-classification/229-distilbert-sequence-classification.ipynb @@ -30,7 +30,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 1, "id": "fe80a355", "metadata": { "tags": [] @@ -42,8 +42,7 @@ "import time\n", "from transformers import AutoModelForSequenceClassification, AutoTokenizer\n", "import numpy as np\n", - "from openvino.tools import mo\n", - "from openvino.runtime import PartialShape, Type, serialize, Core" + "import openvino as ov" ] }, { @@ -57,7 +56,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 2, "id": "5db803ea", "metadata": { "tags": [] @@ -82,7 +81,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 3, "id": "782bbebf", "metadata": { "tags": [] @@ -106,18 +105,29 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 4, "id": "4794f066", "metadata": { "tags": [] }, "outputs": [], "source": [ + "import torch\n", + "\n", "ir_xml_name = checkpoint + \".xml\"\n", "MODEL_DIR = \"model/\"\n", "ir_xml_path = Path(MODEL_DIR) / ir_xml_name\n", - "ov_model = mo.convert_model(model, input=[mo.InputCutInfo(shape=PartialShape([1, -1]), type=Type.i64), mo.InputCutInfo(shape=PartialShape([1, -1]), type=Type.i64)])\n", - "serialize(ov_model, ir_xml_path)" + "\n", + "MAX_SEQ_LENGTH = 128\n", + "input_info = [(ov.PartialShape([1, -1]), ov.Type.i64), (ov.PartialShape([1, -1]), ov.Type.i64)]\n", + "default_input = torch.ones(1, MAX_SEQ_LENGTH, dtype=torch.int64)\n", + "inputs = {\n", + " \"input_ids\": default_input,\n", + " \"attention_mask\": default_input,\n", + "}\n", + "\n", + "ov_model = ov.convert_model(model, input=input_info, example_input=inputs)\n", + "ov.save_model(ov_model, ir_xml_path)" ] }, { @@ -130,12 +140,12 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 5, "id": "39248a56-11b3-42cc-bf5f-de05e1732c77", "metadata": {}, "outputs": [], "source": [ - "core = Core()" + "core = ov.Core()" ] }, { @@ -150,7 +160,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 6, "id": "1e27ef1d-e91e-4cbe-8a86-457ddeb0a1c7", "metadata": {}, "outputs": [ @@ -165,7 +175,7 @@ "Dropdown(description='Device:', index=2, options=('CPU', 'GPU', 'AUTO'), value='AUTO')" ] }, - "execution_count": 17, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -185,7 +195,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 8, "id": "e31a2644", "metadata": { "tags": [] @@ -199,7 +209,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 9, "id": "de01fccc", "metadata": { "tags": [] @@ -228,7 +238,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 10, "id": "cc0c91a6", "metadata": { "tags": [] @@ -267,7 +277,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 11, "id": "cf976f71", "metadata": { "tags": [] @@ -302,7 +312,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 12, "id": "63f57d28", "metadata": { "tags": [] diff --git a/notebooks/231-instruct-pix2pix-image-editing/231-instruct-pix2pix-image-editing.ipynb b/notebooks/231-instruct-pix2pix-image-editing/231-instruct-pix2pix-image-editing.ipynb index abe86e5ca86..aafb9cac607 100644 --- a/notebooks/231-instruct-pix2pix-image-editing/231-instruct-pix2pix-image-editing.ipynb +++ b/notebooks/231-instruct-pix2pix-image-editing/231-instruct-pix2pix-image-editing.ipynb @@ -123,6 +123,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "6d8d5bdc-7ced-411d-b385-c4b1331e8888", "metadata": {}, @@ -131,7 +132,7 @@ "\n", "OpenVINO supports PyTorch through export to the ONNX format. We will use `torch.onnx.export` function for obtaining an ONNX model. For more information, refer to the [PyTorch documentation](https://pytorch.org/docs/stable/onnx.html). We need to provide a model object, input data for model tracing and a path for saving the model. Optionally, we can provide target onnx opset for conversion and other parameters specified in the documentation (for example, input and output names or dynamic shapes).\n", "\n", - "While ONNX models are directly supported by OpenVINO™ runtime, it can be useful to convert them to OpenVINO Intermediate Representation (IR) format to take the advantage of advanced OpenVINO optimization tools and features. We will use OpenVINO Model Optimizer to convert the model to IR format and compress weights to the `FP16` format.\n", + "While ONNX models are directly supported by OpenVINO™ runtime, it can be useful to convert them to OpenVINO Intermediate Representation (IR) format to take the advantage of advanced OpenVINO optimization features. We will use OpenVINO [model conversion API](https://docs.openvino.ai/2023.0/openvino_docs_model_processing_introduction.html) to convert the model to IR format.\n", "\n", "The InstructPix2Pix model is based on Stable Diffusion, a large-scale text-to-image latent diffusion model. You can find more details about how to run Stable Diffusion for text-to-image generation with OpenVINO in a separate [tutorial](../225-stable-diffusion-text-to-image/225-stable-diffusion-text-to-image.ipynb).\n", "\n", @@ -189,10 +190,9 @@ ], "source": [ "from pathlib import Path\n", - "from openvino.tools import mo\n", - "from openvino.runtime import serialize, Core\n", + "import openvino as ov\n", "\n", - "core = Core()\n", + "core = ov.Core()\n", "\n", "TEXT_ENCODER_ONNX_PATH = Path('text_encoder.onnx')\n", "TEXT_ENCODER_OV_PATH = TEXT_ENCODER_ONNX_PATH.with_suffix('.xml')\n", @@ -233,9 +233,8 @@ "\n", "if not TEXT_ENCODER_OV_PATH.exists():\n", " convert_encoder_onnx(text_encoder, TEXT_ENCODER_ONNX_PATH)\n", - " text_encoder = mo.convert_model(\n", - " TEXT_ENCODER_ONNX_PATH, compress_to_fp16=True)\n", - " serialize(text_encoder, str(TEXT_ENCODER_OV_PATH))\n", + " text_encoder = ov.convert_model(TEXT_ENCODER_ONNX_PATH)\n", + " ov.save_model(text_encoder, str(TEXT_ENCODER_OV_PATH))\n", " print('Text Encoder successfully converted to IR')\n", "else:\n", " print(f\"Text encoder will be loaded from {TEXT_ENCODER_OV_PATH}\")\n", @@ -324,8 +323,8 @@ "\n", "if not VAE_ENCODER_OV_PATH.exists():\n", " convert_vae_encoder_onnx(vae, VAE_ENCODER_ONNX_PATH)\n", - " vae_encoder = mo.convert_model(VAE_ENCODER_ONNX_PATH, compress_to_fp16=True)\n", - " serialize(vae_encoder, str(VAE_ENCODER_OV_PATH))\n", + " vae_encoder = ov.convert_model(VAE_ENCODER_ONNX_PATH)\n", + " ov.save_model(vae_encoder, str(VAE_ENCODER_OV_PATH))\n", " print('VAE encoder successfully converted to IR')\n", " del vae_encoder\n", "else:\n", @@ -400,9 +399,9 @@ "\n", "if not VAE_DECODER_OV_PATH.exists():\n", " convert_vae_decoder_onnx(vae, VAE_DECODER_ONNX_PATH)\n", - " vae_decoder = mo.convert_model(VAE_DECODER_ONNX_PATH, compress_to_fp16=True)\n", + " vae_decoder = ov.convert_model(VAE_DECODER_ONNX_PATH)\n", " print('VAE decoder successfully converted to IR')\n", - " serialize(vae_decoder, str(VAE_DECODER_OV_PATH))\n", + " ov.save_model(vae_decoder, str(VAE_DECODER_OV_PATH))\n", " del vae_decoder\n", "else:\n", " print(f\"VAE decoder will be loaded from {VAE_DECODER_OV_PATH}\")\n", @@ -489,8 +488,8 @@ "\n", "if not UNET_OV_PATH.exists():\n", " convert_unet_onnx(unet, UNET_ONNX_PATH)\n", - " unet = mo.convert_model(UNET_ONNX_PATH, compress_to_fp16=True)\n", - " serialize(unet, str(UNET_OV_PATH)) \n", + " unet = ov.convert_model(UNET_ONNX_PATH)\n", + " ov.save_model(unet, str(UNET_OV_PATH)) \n", " print('Unet successfully converted to IR')\n", "else:\n", " print(f\"Unet successfully loaded from {UNET_OV_PATH}\")\n", @@ -533,7 +532,6 @@ ], "source": [ "from diffusers.pipeline_utils import DiffusionPipeline\n", - "from openvino.runtime import Model, Core\n", "from transformers import CLIPTokenizer\n", "from typing import Union, List, Optional, Tuple\n", "import PIL\n", @@ -612,11 +610,11 @@ " self,\n", " tokenizer: CLIPTokenizer,\n", " scheduler: EulerAncestralDiscreteScheduler,\n", - " core: Core,\n", - " text_encoder: Model,\n", - " vae_encoder: Model,\n", - " unet: Model,\n", - " vae_decoder: Model,\n", + " core: ov.Core,\n", + " text_encoder: ov.Model,\n", + " vae_encoder: ov.Model,\n", + " unet: ov.Model,\n", + " vae_decoder: ov.Model,\n", " device:str = \"AUTO\"\n", " ):\n", " super().__init__()\n", @@ -626,7 +624,7 @@ " self.load_models(core, device, text_encoder,\n", " vae_encoder, unet, vae_decoder)\n", "\n", - " def load_models(self, core: Core, device: str, text_encoder: Model, vae_encoder: Model, unet: Model, vae_decoder: Model):\n", + " def load_models(self, core: ov.Core, device: str, text_encoder: ov.Model, vae_encoder: ov.Model, unet: ov.Model, vae_decoder: ov.Model):\n", " \"\"\"\n", " Function for loading models on device using OpenVINO\n", " \n",