diff --git a/.github/workflows/dockerfile_sanity.yml b/.github/workflows/dockerfile_sanity.yml index 060b80ca45..738be80319 100644 --- a/.github/workflows/dockerfile_sanity.yml +++ b/.github/workflows/dockerfile_sanity.yml @@ -5,13 +5,13 @@ on: branches: - main paths: - - "docker/Dockerfile.intel" - + - 'Dockerfile.ipex' + pull_request: branches: - main paths: - - "docker/Dockerfile.intel" + - 'Dockerfile.ipex' jobs: build_and_run: @@ -27,7 +27,7 @@ jobs: - name: Build and Run Docker Image run: | IMAGE_NAME="intel_image:latest" - docker build -f docker/Dockerfile.intel -t $IMAGE_NAME . + docker build -f Dockerfile.ipex -t $IMAGE_NAME . if [ $? -ne 0 ]; then echo "Docker image build failed." exit 1 diff --git a/.github/workflows/test_inc.yml b/.github/workflows/test_inc.yml index b31f2055f3..7f622a2da9 100644 --- a/.github/workflows/test_inc.yml +++ b/.github/workflows/test_inc.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - torch-version: ["2.4.*", "2.5.0"] + torch-version: ["2.4.0", "2.5.*"] runs-on: ubuntu-22.04 @@ -35,7 +35,7 @@ jobs: run: | pip install --upgrade pip pip install torch==${{ matrix.torch-version }} torchaudio torchvision --index-url https://download.pytorch.org/whl/cpu - pip install .[neural-compressor,ipex,diffusers,peft,tests] transformers[testing] intel-extension-for-pytorch==${{ matrix.torch-version }} + pip install .[neural-compressor,diffusers,peft,tests] transformers[testing] intel-extension-for-pytorch==${{ matrix.torch-version }} - name: Assert versions run: | diff --git a/.github/workflows/test_ipex.yml b/.github/workflows/test_ipex.yml index 59a3161528..de933e3795 100644 --- a/.github/workflows/test_ipex.yml +++ b/.github/workflows/test_ipex.yml @@ -18,8 +18,8 @@ jobs: strategy: fail-fast: false matrix: - torch-version: ["2.2.0", "2.3.*", "2.4.*"] - transformers-version: ["4.39.0", "4.44.*"] + transformers-version: ["4.46.0", "4.46.3"] + torch-version: ["2.4.0", "2.5.*"] runs-on: ubuntu-22.04 @@ -38,10 +38,6 @@ jobs: pip install torch==${{ matrix.torch-version }} torchaudio torchvision --extra-index-url https://download.pytorch.org/whl/cpu pip install .[ipex,tests] transformers[testing]==${{ matrix.transformers-version }} intel_extension_for_pytorch==${{ matrix.torch-version }} - - if: ${{ matrix.torch-version == '2.2.0' }} - name: Downgrade Numpy - run: pip install numpy==1.* - - name: Assert versions run: | python -c "import torch; print(torch.__version__); assert torch.__version__.startswith('${{ matrix.torch-version }}'.replace('.*', ''))" @@ -50,4 +46,4 @@ jobs: - name: Test with Pytest run: | - pytest tests/ipex \ No newline at end of file + pytest tests/ipex diff --git a/.github/workflows/test_openvino.yml b/.github/workflows/test_openvino.yml index e2889cb4e0..db35324a9e 100644 --- a/.github/workflows/test_openvino.yml +++ b/.github/workflows/test_openvino.yml @@ -1,6 +1,7 @@ name: OpenVINO - Test on: + workflow_dispatch: push: branches: - main @@ -46,9 +47,9 @@ jobs: pip install .[openvino,openvino-tokenizers,diffusers,tests] transformers[testing] - if: ${{ matrix.transformers-version != 'latest' }} - name: Downgrade Transformers and Accelerate + name: Install specific dependencies and versions required for older transformers run: | - pip install transformers==${{ matrix.transformers-version }} accelerate==0.* + pip install transformers==${{ matrix.transformers-version }} accelerate==0.* peft==0.13.* diffusers==0.30.* transformers_stream_generator - if: ${{ matrix.test-pattern == '*modeling*' }} name: Uninstall NNCF diff --git a/.github/workflows/test_openvino_full.yml b/.github/workflows/test_openvino_full.yml new file mode 100644 index 0000000000..914035b750 --- /dev/null +++ b/.github/workflows/test_openvino_full.yml @@ -0,0 +1,88 @@ +name: OpenVINO - Full Test + +on: + workflow_dispatch: + schedule: + - cron: "41 3 * * *" # run every day at 3:41 + push: + branches: + - v*-release + pull_request: + types: [opened, synchronize, reopened, labeled] + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + build: + if: ${{ (github.event_name == 'workflow_dispatch') || (github.event_name == 'schedule') || (github.event_name == 'push') || contains( github.event.pull_request.labels.*.name, 'openvino-test') }} + strategy: + fail-fast: false + matrix: + include: + - python-version: "3.9" + os: "ubuntu-22.04" + transformers-version: "latest" + openvino: "ov-stable" + nncf: "nncf-stable" + - python-version: "3.9" + os: "ubuntu-22.04" + transformers-version: "latest" + openvino: "ov-nightly" + nncf: "nncf-stable" + - python-version: "3.9" + os: "ubuntu-22.04" + transformers-version: "latest" + openvino: "ov-stable" + nncf: "nncf-develop" + - python-version: "3.9" + os: "ubuntu-22.04" + transformers-version: "latest" + openvino: "ov-nightly" + nncf: "nncf-develop" + + runs-on: ${{ matrix.os }} + + steps: + - uses: actions/checkout@v4 + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + # Install PyTorch CPU to prevent unnecessary downloading/installing of CUDA packages + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + pip install .[tests] + + - name: Install openvino-nightly + if: ${{ matrix.openvino == 'ov-nightly' }} + run: pip install --pre -U openvino openvino-tokenizers --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly + + - name: Install openvino release + if: ${{ matrix.openvino == 'ov-stable' }} + run: pip install .[openvino] + + - name: Install nncf develop + if: ${{ matrix.nncf == 'nncf-develop' }} + run: pip install git+https://github.com/openvinotoolkit/nncf.git + + - name: Install nncf release + if: ${{ matrix.nncf == 'nncf-stable' }} + run: pip install .[nncf] + + - name: Install the lowest compatible transformers version + if: ${{ matrix.transformers-version != 'latest' }} + run: pip install transformers==${{ matrix.transformers-version }} + + - name: Pip freeze + run: pip freeze + + - name: OpenVINO tests + run: pytest tests/openvino --durations=0 + env: + RUN_SLOW: 1 + HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }} diff --git a/.github/workflows/test_openvino_slow.yml b/.github/workflows/test_openvino_slow.yml index bf52413a7d..8c3d9b2d3f 100644 --- a/.github/workflows/test_openvino_slow.yml +++ b/.github/workflows/test_openvino_slow.yml @@ -25,9 +25,7 @@ jobs: fail-fast: false matrix: os: ["ubuntu-22.04", "windows-2019"] - openvino-version: ["stable", "nightly"] transformers-version: ["4.36.0", "latest"] - nncf: ["nncf", "git+https://github.com/openvinotoolkit/nncf.git"] runs-on: ${{ matrix.os }} @@ -47,14 +45,9 @@ jobs: pip install .[openvino,tests] transformers[testing] pip uninstall -y nncf - - if: ${{ matrix.openvino-version == 'nightly' }} - name: Install nightly OpenVINO - run: | - pip install openvino openvino-tokenizers --pre --upgrade --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly - - if: ${{ matrix.transformers-version != 'latest' }} - name: Downgrade Transformers and Accelerate - run: pip install transformers==${{ matrix.transformers-version }} accelerate==0.* + name: Install specific dependencies and versions required for older transformers + run: pip install transformers==${{ matrix.transformers-version }} accelerate==0.* peft==0.13.*, diffusers==0.30.* transformers_stream_generator - name: Pip freeze run: pip freeze @@ -65,7 +58,11 @@ jobs: - name: Install dependencies (slow) run: | - pip install ${{ matrix.nncf }} + pip install .[nncf] + + - if: ${{ matrix.transformers-version != 'latest' }} + name: Downgrade Transformers and Accelerate + run: pip install transformers==${{ matrix.transformers-version }} accelerate==0.* - name: Test with Pytest (slow) run: | diff --git a/Dockerfile.ipex b/Dockerfile.ipex new file mode 100644 index 0000000000..a03b1d26a3 --- /dev/null +++ b/Dockerfile.ipex @@ -0,0 +1,73 @@ +ARG PLATFORM=cpu + +FROM ubuntu:22.04 as cpu +WORKDIR /usr/src/ +RUN --mount=type=cache,id=apt-dev,target=/var/cache/apt \ + sh -c "apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \ + ca-certificates \ + git \ + curl \ + vim \ + build-essential \ + ccache \ + libgoogle-perftools-dev \ + numactl \ + cmake \ + libjpeg-dev \ + pybind11-dev \ + libpng-dev \ + python3 \ + python3-pip \ + && rm -rf /var/lib/apt/lists/*" +RUN /usr/sbin/update-ccache-symlinks +RUN mkdir /opt/ccache && ccache --set-config=cache_dir=/opt/ccache + +ARG IPEX_VERSION=2.5.0 +ARG PYTORCH_VERSION=2.5.1 +ARG TORCHVISION_VERSION=0.20.1+cpu +ARG TORCHAUDIO_VERSION=2.5.1+cpu + +RUN python3 -m pip install --no-cache-dir \ + torch==${PYTORCH_VERSION}+cpu \ + torchvision==${TORCHVISION_VERSION} \ + torchaudio==${TORCHAUDIO_VERSION} \ + --index-url https://download.pytorch.org/whl/cpu && \ + python3 -m pip install intel-openmp -f https://download.pytorch.org/whl/torch_stable.html && \ + python3 -m pip install intel-extension-for-pytorch==$IPEX_VERSION && \ + python3 -m pip install oneccl_bind_pt --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/cn/ && \ + python3 -m pip install --no-cache-dir py-libnuma + +ARG KMP_BLOCKTIME=1 +ENV KMP_BLOCKTIME=${KMP_BLOCKTIME} +ARG KMP_HW_SUBSET=1T +ENV KMP_HW_SUBSET=${KMP_HW_SUBSET} +ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc.so" + +FROM intel/intel-extension-for-pytorch:2.3.110-xpu as xpu +WORKDIR /usr/src/ + +RUN --mount=type=cache,id=apt-dev,target=/var/cache/apt \ + sh -c "apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \ + ca-certificates \ + git \ + curl \ + vim \ + ccache \ + libgoogle-perftools-dev \ + numactl \ + libjpeg-dev \ + pybind11-dev \ + libpng-dev \ + && rm -rf /var/lib/apt/lists/*" +RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null + +RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \ +| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y intel-basekit xpu-smi cmake ninja-build pciutils + +FROM ${PLATFORM} + +COPY optimum optimum +COPY Makefile setup.cfg setup.py pyproject.toml README.md ./ +RUN pip install . diff --git a/README.md b/README.md index b3879ef380..28c5800684 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ 🤗 Optimum Intel is the interface between the 🤗 Transformers and Diffusers libraries and the different tools and libraries provided by Intel to accelerate end-to-end pipelines on Intel architectures. -[Intel Extension for PyTorch](https://intel.github.io/intel-extension-for-pytorch/#introduction) is an open-source library which provides optimizations for both eager mode and graph mode, however, compared to eager mode, graph mode in PyTorch* normally yields better performance from optimization techniques, such as operation fusion. +[Intel Extension for PyTorch](https://intel.github.io/intel-extension-for-pytorch/#introduction) is an open-source library which provides optimizations like faster attention and operators fusion. Intel [Neural Compressor](https://www.intel.com/content/www/us/en/developer/tools/oneapi/neural-compressor.html) is an open-source library enabling the usage of the most popular compression techniques such as quantization, pruning and knowledge distillation. It supports automatic accuracy-driven tuning strategies in order for users to easily generate quantized model. The users can easily apply static, dynamic and aware-training quantization approaches while giving an expected accuracy criteria. It also supports different weight pruning techniques enabling the creation of pruned model giving a predefined sparsity target. @@ -159,7 +159,7 @@ optimized_model = OVModelForSequenceClassification.from_pretrained(save_dir) ## IPEX -To load your IPEX model, you can just replace your `AutoModelForXxx` class with the corresponding `IPEXModelForXxx` class. You can set `export=True` to load a PyTorch checkpoint, export your model via TorchScript and apply IPEX optimizations : both operators optimization (replaced with customized IPEX operators) and graph-level optimization (like operators fusion) will be applied on your model. +To load your IPEX model, you can just replace your `AutoModelForXxx` class with the corresponding `IPEXModelForXxx` class. It will load a PyTorch checkpoint, and apply IPEX operators optimization (replaced with customized IPEX operators). ```diff from transformers import AutoTokenizer, pipeline - from transformers import AutoModelForCausalLM @@ -168,7 +168,7 @@ To load your IPEX model, you can just replace your `AutoModelForXxx` class with model_id = "gpt2" - model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) -+ model = IPEXModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, export=True) ++ model = IPEXModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained(model_id) pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) results = pipe("He's a dreadful magician and") diff --git a/docker/Dockerfile.intel b/docker/Dockerfile.intel deleted file mode 100644 index ad4ff63e8c..0000000000 --- a/docker/Dockerfile.intel +++ /dev/null @@ -1,53 +0,0 @@ -# syntax = docker/dockerfile:1 -# based onhttps://github.com/pytorch/pytorch/blob/master/Dockerfile -# -# NOTE: To build this you will need a docker version >= 19.03 and DOCKER_BUILDKIT=1 -# -# If you do not use buildkit you are not going to have a good time -# -# For reference: -# https://docs.docker.com/develop/develop-images/build_enhancements/ - -ARG BASE_IMAGE=ubuntu:22.04 -FROM ${BASE_IMAGE} - -RUN --mount=type=cache,id=apt-dev,target=/var/cache/apt \ - sh -c "apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \ - ca-certificates \ - git \ - curl \ - vim \ - build-essential \ - ccache \ - libgoogle-perftools-dev \ - numactl \ - cmake \ - libjpeg-dev \ - pybind11-dev \ - libpng-dev \ - python3 \ - python3-pip \ - && rm -rf /var/lib/apt/lists/*" -RUN /usr/sbin/update-ccache-symlinks -RUN mkdir /opt/ccache && ccache --set-config=cache_dir=/opt/ccache - -ARG IPEX_VERSION=2.3.100 -ARG PYTORCH_VERSION=2.3.1 -ARG TORCHVISION_VERSION=0.18.1+cpu -ARG TORCHAUDIO_VERSION=2.3.1+cpu - -RUN python3 -m pip install --no-cache-dir \ - intel-openmp \ - torch==${PYTORCH_VERSION}+cpu \ - torchvision==${TORCHVISION_VERSION} \ - torchaudio==${TORCHAUDIO_VERSION} \ - -f https://download.pytorch.org/whl/torch_stable.html && \ - python3 -m pip install intel-extension-for-pytorch==$IPEX_VERSION && \ - python3 -m pip install oneccl_bind_pt --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/ && \ - python3 -m pip install --no-cache-dir py-libnuma - -ARG KMP_BLOCKTIME=1 -ENV KMP_BLOCKTIME=${KMP_BLOCKTIME} -ARG KMP_HW_SUBSET=1T -ENV KMP_HW_SUBSET=${KMP_HW_SUBSET} -ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc.so" diff --git a/docs/source/ipex/inference.mdx b/docs/source/ipex/inference.mdx index c712275e42..72826da595 100644 --- a/docs/source/ipex/inference.mdx +++ b/docs/source/ipex/inference.mdx @@ -14,8 +14,8 @@ Optimum Intel can be used to load models from the [Hub](https://huggingface.co/m ## Loading -You can load your model and apply IPEX optimizations (including weight prepacking and graph mode). For supported architectures like LLaMA, BERT and ViT, further optimizations will be applied by patching the model to use custom operators. -For now, support is only enabled for CPUs and the original model will be exported via TorchScript. In the future `torch.compile` will be used and model exported via TorchScript will get deprecated. +You can load your model and apply IPEX optimizations (apply torch.compile except text-generation tasks). For supported architectures like LLaMA, BERT and ViT, further optimizations will be applied by patching the model to use custom operators. +For now, support is enabled for Intel CPU/GPU. Previous models converted to TorchScript will be deprecated in v1.22. ```diff import torch @@ -25,7 +25,7 @@ For now, support is only enabled for CPUs and the original model will be exporte model_id = "gpt2" - model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) -+ model = IPEXModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, export=True) ++ model = IPEXModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained(model_id) pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) results = pipe("He's a dreadful magician and") @@ -43,3 +43,4 @@ As shown in the table below, each task is associated with a class enabling to au | `IPEXModelForMaskedLM` | `fill-mask` | | `IPEXModelForAudioClassification` | `audio-classification` | | `IPEXModelForCausalLM` | `text-generation` | +| `IPEXModelForSeq2SeqLM` | `text2text-generation` | diff --git a/docs/source/ipex/models.mdx b/docs/source/ipex/models.mdx index 346ca26599..b8cd6c482f 100644 --- a/docs/source/ipex/models.mdx +++ b/docs/source/ipex/models.mdx @@ -40,6 +40,7 @@ Here is the list of the supported architectures : - Roberta - Roformer - SqueezeBert +- T5 - UniSpeech - Vit - Wav2Vec2 diff --git a/docs/source/openvino/export.mdx b/docs/source/openvino/export.mdx index dd542be735..3e7e458c02 100644 --- a/docs/source/openvino/export.mdx +++ b/docs/source/openvino/export.mdx @@ -78,14 +78,15 @@ Optional arguments: --ratio RATIO A parameter used when applying 4-bit quantization to control the ratio between 4-bit and 8-bit quantization. If set to 0.8, 80% of the layers will be quantized to int4 while 20% will be quantized to int8. This helps to achieve better accuracy at the sacrifice of the model size - and inference latency. Default value is 1.0. + and inference latency. Default value is 1.0. Note: If dataset is provided, and the ratio is + less than 1.0, then data-aware mixed precision assignment will be applied. --sym Whether to apply symmetric quantization --group-size GROUP_SIZE The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization. --backup-precision {none,int8_sym,int8_asym} - Defines a backup precision for mixed-precision weight compression. Only valid for int4 weight - format. If not provided, backup precision is int8_asym. 'none' stands for original floating- + Defines a backup precision for mixed-precision weight compression. Only valid for 4-bit weight + formats. If not provided, backup precision is int8_asym. 'none' stands for original floating- point precision of the model weights, in this case weights are retained in their original precision without any quantization. 'int8_sym' stands for 8-bit integer symmetric quantization without zero point. 'int8_asym' stands for 8-bit integer asymmetric quantization with zero @@ -94,7 +95,9 @@ Optional arguments: can use the one from the list ['auto','wikitext2','c4','c4-new']. With 'auto' the dataset will be collected from model's generations. For diffusion models it should be on of ['conceptual_captions','laion/220k-GPT4Vision-captions-from-LIVIS','laion/filtered-wit']. For - visual language models the dataset must be set to 'contextual'. + visual language models the dataset must be set to 'contextual'. Note: if none of the data-aware + compression algorithms are selected and ratio parameter is omitted or equals 1.0, the dataset + argument will not have an effect on the resulting model. --all-layers Whether embeddings and last MatMul layers should be compressed to INT4. If not provided an weight compression is applied, they are compressed to INT8. --awq Whether to apply AWQ algorithm. AWQ improves generation quality of INT4-compressed LLMs, but diff --git a/docs/source/openvino/optimization.mdx b/docs/source/openvino/optimization.mdx index 28de5ffa4b..147421dd4a 100644 --- a/docs/source/openvino/optimization.mdx +++ b/docs/source/openvino/optimization.mdx @@ -166,6 +166,25 @@ calibration_dataset = quantizer.get_calibration_dataset( The `quantize()` method applies post-training static quantization and export the resulting quantized model to the OpenVINO Intermediate Representation (IR). The resulting graph is represented with two files: an XML file describing the network topology and a binary file describing the weights. The resulting model can be run on any target Intel device. +#### Speech-to-text Models Quantization + +The speech-to-text Whisper model can be quantized without the need for preparing a custom calibration dataset. Please see example below. + +```python +model_id = "openai/whisper-tiny" +ov_model = OVModelForSpeechSeq2Seq.from_pretrained( + model_id, + quantization_config=OVQuantizationConfig( + num_samples=10, + dataset="librispeech", + processor=model_id, + matmul_sq_alpha=0.95, + ) +) +``` + +With this, encoder, decoder and decoder-with-past models of the Whisper pipeline will be fully quantized, including activations. + ### Hybrid quantization Traditional optimization methods like post-training 8-bit quantization do not work well for Stable Diffusion (SD) models and can lead to poor generation results. On the other hand, weight compression does not improve performance significantly when applied to Stable Diffusion models, as the size of activations is comparable to weights. diff --git a/examples/neural_compressor/language-modeling/run_clm.py b/examples/neural_compressor/language-modeling/run_clm.py index 7e81072194..55f79b2185 100644 --- a/examples/neural_compressor/language-modeling/run_clm.py +++ b/examples/neural_compressor/language-modeling/run_clm.py @@ -215,6 +215,10 @@ class OptimizationArguments: default="sym", metadata={"help": "Scheme for weight only quantization. Choose from 'sym' and 'asym'."}, ) + use_layer_wise: bool = field( + default=False, + metadata={"help": "Use layer wise to do quantization to save memory."}, + ) quantization_methodology: str = field( default="rtn", metadata={"help": "Quantization methodology for weight only quantization. Choose from 'rtn' and 'gptq'."}, @@ -659,6 +663,7 @@ def compute_metrics(eval_preds): "bits": optim_args.bits, "sym": optim_args.weight_only_scheme == "sym", "group_size": optim_args.group_size, + "use_layer_wise": optim_args.use_layer_wise, } if optim_args.quantization_methodology == "gptq": @@ -666,6 +671,7 @@ def compute_metrics(eval_preds): damp_percent=optim_args.damp_percent, nsamples=optim_args.num_calibration_samples, blocksize=optim_args.gptq_block_size, + tokenizer=tokenizer, **algorithm_args, ) else: diff --git a/notebooks/ipex/text_generation.ipynb b/notebooks/ipex/text_generation.ipynb index d1a62d9201..4c97d5b6b0 100644 --- a/notebooks/ipex/text_generation.ipynb +++ b/notebooks/ipex/text_generation.ipynb @@ -11,7 +11,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "To load your IPEX model, you can just replace your `AutoModelForXxx` class with the corresponding `IPEXModelForXxx` class. You can set `export=True` to load a PyTorch checkpoint, export your model via TorchScript and apply IPEX optimizations : both operators optimization (replaced with customized IPEX operators) and graph-level optimization (like operators fusion) will be applied on your model." + "To load your IPEX model, you can just replace your `AutoModelForXxx` class with the corresponding `IPEXModelForXxx` class. It could apply IPEX, providing optimizations like faster attention and operators fusion." ] }, { @@ -60,7 +60,7 @@ } ], "source": [ - "model = IPEXModelForCausalLM.from_pretrained(\"gpt2\", torch_dtype=torch.bfloat16, export=True)\n", + "model = IPEXModelForCausalLM.from_pretrained(\"gpt2\", torch_dtype=torch.bfloat16)\n", "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n", "input_sentence = [\"Answer the following yes/no question by reasoning step-by-step please. Can you write a whole Haiku in a single tweet?\"]\n", "model_inputs = tokenizer(input_sentence, return_tensors=\"pt\")\n", diff --git a/optimum/commands/export/openvino.py b/optimum/commands/export/openvino.py index 5e951aa438..6965efcb54 100644 --- a/optimum/commands/export/openvino.py +++ b/optimum/commands/export/openvino.py @@ -102,7 +102,8 @@ def parse_args_openvino(parser: "ArgumentParser"): default=None, help=( "A parameter used when applying 4-bit quantization to control the ratio between 4-bit and 8-bit quantization. If set to 0.8, 80%% of the layers will be quantized to int4 " - "while 20%% will be quantized to int8. This helps to achieve better accuracy at the sacrifice of the model size and inference latency. Default value is 1.0." + "while 20%% will be quantized to int8. This helps to achieve better accuracy at the sacrifice of the model size and inference latency. Default value is 1.0. " + "Note: If dataset is provided, and the ratio is less than 1.0, then data-aware mixed precision assignment will be applied." ), ) optional_group.add_argument( @@ -123,7 +124,7 @@ def parse_args_openvino(parser: "ArgumentParser"): choices=["none", "int8_sym", "int8_asym"], default=None, help=( - "Defines a backup precision for mixed-precision weight compression. Only valid for int4 weight format. " + "Defines a backup precision for mixed-precision weight compression. Only valid for 4-bit weight formats. " "If not provided, backup precision is int8_asym. 'none' stands for original floating-point precision of " "the model weights, in this case weights are retained in their original precision without any " "quantization. 'int8_sym' stands for 8-bit integer symmetric quantization without zero point. 'int8_asym' " @@ -140,7 +141,9 @@ def parse_args_openvino(parser: "ArgumentParser"): "dataset will be collected from model's generations. " "For diffusion models it should be on of ['conceptual_captions'," "'laion/220k-GPT4Vision-captions-from-LIVIS','laion/filtered-wit']. " - "For visual language models the dataset must be set to 'contextual'." + "For visual language models the dataset must be set to 'contextual'. " + "Note: if none of the data-aware compression algorithms are selected and ratio parameter is omitted or " + "equals 1.0, the dataset argument will not have an effect on the resulting model." ), ) optional_group.add_argument( @@ -354,6 +357,10 @@ def run(self): from optimum.intel import OVStableDiffusion3Pipeline model_cls = OVStableDiffusion3Pipeline + elif class_name == "FluxPipeline": + from optimum.intel import OVFluxPipeline + + model_cls = OVFluxPipeline else: raise NotImplementedError(f"Quantization in hybrid mode isn't supported for class {class_name}.") diff --git a/optimum/exporters/ipex/cache_utils.py b/optimum/exporters/ipex/cache_utils.py new file mode 100755 index 0000000000..dec1e81895 --- /dev/null +++ b/optimum/exporters/ipex/cache_utils.py @@ -0,0 +1,238 @@ +from typing import List, Optional, Tuple + +import torch +from intel_extension_for_pytorch.llm.modules import PagedAttention +from transformers import Cache, PretrainedConfig + + +class IPEXPagedCache(Cache): + """ + A PagedCache that grows dynamically as more tokens are generated. everytime it grows block-size memory, vendor could set the pageCache memory layout. + ipex-xpu: + ipex-cpu: + + Example: + + ```python + >>> from transformers import AutoTokenizer + >>> from optimum.intel import IPEXModelForCausalLM + >>> from optimum.exporters.ipex.cache_utils import IPEXPagedCache + + >>> model = IPEXModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", export=True) + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") + + >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = IPEXPagedCache() + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation + ``` + """ + + def __init__( + self, + config: PretrainedConfig, + batch_size: int, + max_cache_len: int, + device, + dtype=None, + layer_device_map=None, + **kwargs, + ) -> None: + super().__init__() + self.batch_size = batch_size + # Used in `generate` to keep tally of how many tokens the cache has seen + self._seen_tokens = torch.zeros([batch_size], dtype=torch.int32, device=device) + self.block_size = 16 + self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * batch_size + self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape( + batch_size, -1 + ) + self.free_blocks = torch.arange(self.num_blocks, device=device) + self.max_cache_len = max_cache_len + self.num_kv_heads = config.num_key_value_heads + self.num_hidden_layers = config.num_hidden_layers + if hasattr(config, "head_dim"): + head_size = config.head_dim + else: + head_size = config.hidden_size // config.num_attention_heads + self.head_size = head_size + self.max_seq_len = 0 + + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + + if device.type == "cpu": + key_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size) + value_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size) + elif device.type == "xpu": + key_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size, 1) + value_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size) + for i in range(config.num_hidden_layers): + new_layer_key_cache = torch.zeros(key_cache_shape, dtype=dtype, device=device) + new_layer_value_cache = torch.zeros(value_cache_shape, dtype=dtype, device=device) + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + + def update_for_prefill( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + batch_size: int, + input_lens: torch.Tensor, + ): + if layer_idx == 0: + all_block_indices = [] + all_slot_offsets = [] + num_blocks = (input_lens + self.block_size - 1) // self.block_size + for i in range(batch_size): + for b_idx in range(num_blocks[i]): + if self.block_tables[i][b_idx] == -1: + # need a free block + self.block_tables[i][b_idx] = self.free_blocks[0] + self.free_blocks = self.free_blocks[1:] + + slots_range = torch.arange(input_lens[i], device=key_states.device) + block_indices = slots_range // self.block_size + slot_offsets = slots_range % self.block_size + all_block_indices.append(self.block_tables[i][block_indices]) + all_slot_offsets.append(slot_offsets) + + all_block_indices = torch.cat(all_block_indices) + all_slot_offsets = torch.cat(all_slot_offsets) + self.slots = all_block_indices * self.block_size + all_slot_offsets + + # Update the cache + PagedAttention.reshape_and_cache( + key_states, + value_states, + self.key_cache[layer_idx], + self.value_cache[layer_idx], + self.slots, + ) + + # Update the number of seen tokens + if layer_idx == self.num_hidden_layers - 1: + self._seen_tokens = self._seen_tokens + input_lens + self.max_seq_len, _ = self._seen_tokens.max(dim=0) + + def update_for_decode( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + batch_size: int, + ): + if layer_idx == 0: + start_block_idx = self._seen_tokens // self.block_size + num_blocks = (self._seen_tokens + self.block_size) // self.block_size + slot_offset_in_block = (self._seen_tokens) % self.block_size + self.slots = torch.zeros([batch_size], device=key_states.device, dtype=torch.int32) + for i in range(batch_size): + for b_idx in range(start_block_idx[i], num_blocks[i]): + if self.block_tables[i][b_idx] == -1: + # need a free block + self.block_tables[i][b_idx] = self.free_blocks[0] + self.free_blocks = self.free_blocks[1:] + + self.slots[i] = self.block_tables[i][start_block_idx[i]] * self.block_size + slot_offset_in_block[i] + # Update the cache + PagedAttention.reshape_and_cache( + key_states, + value_states, + self.key_cache[layer_idx], + self.value_cache[layer_idx], + self.slots, + ) + + # Update the number of seen tokens + if layer_idx == self.num_hidden_layers - 1: + self._seen_tokens = self._seen_tokens + 1 + self.max_seq_len = self.max_seq_len + 1 + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + attention_mask: torch.Tensor, + input_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + Return: + A tuple containing the updated key and value states. + """ + + batch_size = input_lens.shape[-1] + if self.get_seq_length() == 0: + # prefill + self.update_for_prefill(key_states, value_states, layer_idx, batch_size, input_lens) + else: + # decode + self.update_for_decode(key_states, value_states, layer_idx, batch_size) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states that were seen by the model.""" + return self.max_seq_len + + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states.""" + return self.max_cache_len + + def reset(self): + """Resets the cache values while preserving the objects""" + self._seen_tokens = torch.zeros([self.batch_size], dtype=torch.int32, device=self.block_tables.device) + self.block_tables.fill_(-1) + self.free_blocks = torch.arange(self.num_blocks, device=self.block_tables.device) + self.max_seq_len = 0 + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + device = self.block_tables.device + origin_table = self.block_tables.clone() + updated_block_tables = self.block_tables.index_select(0, beam_idx.to(device)) + mask = self.block_tables.masked_fill(self.block_tables != -1, 1).masked_fill(self.block_tables == -1, 0) + num_blocks = mask.cumsum(-1)[:, -1] + updated_table = [] + for i in range(beam_idx.shape[0]): + self.block_tables[i, 0 : num_blocks[i] - 1] = updated_block_tables[i, 0 : num_blocks[i] - 1] + updated_table.append(self.block_tables[i : i + 1, num_blocks[i] - 1 : num_blocks[i]]) + updated_table = torch.cat(tuple(updated_table), dim=0) + for layer_idx in range(self.num_hidden_layers): + self.key_cache[layer_idx][updated_table] = self.key_cache[layer_idx][updated_table[beam_idx]] + self.value_cache[layer_idx][updated_table] = self.value_cache[layer_idx][updated_table[beam_idx]] + free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1)) + self.free_blocks = torch.cat((self.free_blocks, free_table)) + + def crop(self, maximum_length: int): + """Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be + negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search.""" + + max_seq_len = self.get_seq_length() + if maximum_length < 0: + maximum_length = max_seq_len - abs(maximum_length) + + if max_seq_len <= maximum_length: + return + origin_table = self.block_tables.clone() + for bs in range(self._seen_tokens.shape[0]): + new_tokens = self._seen_tokens[bs] + maximum_length - max_seq_len + num_blocks = (new_tokens + self.block_size - 1) // self.block_size + self.block_tables[bs, num_blocks:] = -1 + self._seen_tokens[bs] = new_tokens + self.max_seq_len, _ = self._seen_tokens.max(dim=0) + free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1)) + self.free_blocks = torch.cat((self.free_blocks, free_table)) diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index 484fd38077..03937754a6 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -13,11 +13,10 @@ # limitations under the License. from transformers.models.bert.modeling_bert import BertIntermediate -from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconForCausalLM -from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2LMHeadModel +from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel +from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model from transformers.models.llama.modeling_llama import ( LlamaDecoderLayer, - LlamaForCausalLM, LlamaModel, LlamaRMSNorm, ) @@ -28,7 +27,9 @@ from .modeling_utils import ( _IPEX_MINIMUM_VERSION_FOR_PATCHING, + _falcon_model_forward, _gpt2_block_forward, + _gpt2_model_forward, _ipex_rms_layer_norm_forward, _IPEXFalconDecoderLayer, _IPEXGPT2Attention, @@ -39,8 +40,8 @@ # Please also update in the setup.py and .github/workflows/test_ipex.yml if you change the transformers version -_TRANSFORMERS_MIN_VERSION = "4.39.0" -_TRANSFORMERS_MAX_VERSION = "4.44.99" +_TRANSFORMERS_MIN_VERSION = "4.46.0" +_TRANSFORMERS_MAX_VERSION = "4.46.99" _IPEX_EXPORTED_GENERATION_TASKS = ("text-generation",) @@ -75,7 +76,7 @@ def patch_op(m, target_m, new_op_name, new_op): def _patch_llama_model(model): """ Patch llama model: - 1. Use IPEX Rope and IAKV cache + 1. Use IPEX rope and paged cache 2. Linear fusion with (2 Linears + Silu + Mul) and (Linear + Add) """ convert_functions(model, LlamaModel, "forward", _llama_model_forward) @@ -87,11 +88,14 @@ def _patch_llama_model(model): def _patch_falcon_model(model): """ Patch falcon model: - 1. Disable SDPA so the attention mask will be compatible to ipex attention. - 2. Use IPEX Rope and IAKV cache - 3. Linear fusion with (Linear + Gelu) and (Linear + Add + Add) + 1. Use IPEX rope and paged cache + 2. Linear fusion with (Linear + Gelu) and (Linear + Add + Add) """ - model.transformer._use_sdpa = False + num_key_value_heads = ( + model.config.num_kv_heads if (model.config.new_decoder_architecture or not model.config.multi_query) else 1 + ) + setattr(model.config, "num_key_value_heads", num_key_value_heads) + convert_functions(model, FalconModel, "forward", _falcon_model_forward) replace_customized_linear_with_linear(model) convert_class(model, FalconDecoderLayer, _IPEXFalconDecoderLayer, model.config) return model @@ -100,12 +104,13 @@ def _patch_falcon_model(model): def _patch_gpt2_model(model): """ Patch gpt2 model: - 1. Disable SDPA so the attention mask will be compatible to ipex attention. - 2. Use IAKV cache + 1. Use IPEX paged attention """ - model.transformer._attn_implementation = "eager" - convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.config) + num_key_value_heads = model.config.num_attention_heads + setattr(model.config, "num_key_value_heads", num_key_value_heads) + convert_functions(model, GPT2Model, "forward", _gpt2_model_forward) convert_functions(model, GPT2Block, "forward", _gpt2_block_forward) + convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.config) return model @@ -136,11 +141,11 @@ def _patch_model(model): raise ImportError( f"Only transformers versions {_TRANSFORMERS_MIN_VERSION} ~ {_TRANSFORMERS_MAX_VERSION} are verified." ) - if isinstance(model, LlamaForCausalLM): + if model.config.model_type == "llama": model = _patch_llama_model(model) - elif isinstance(model, FalconForCausalLM): + elif model.config.model_type == "falcon": model = _patch_falcon_model(model) - elif isinstance(model, GPT2LMHeadModel): + elif model.config.model_type == "gpt2": model = _patch_gpt2_model(model) elif model.config.model_type == "bert": model = _patch_bert_model(model) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py old mode 100644 new mode 100755 index 3d28350b86..ca51c47fb4 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -18,19 +18,18 @@ import torch from torch import nn -from torch.nn import functional as F -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask -from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.gpt2.modeling_gpt2 import GPT2Block -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv +from transformers.cache_utils import Cache +from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions from optimum.intel.utils.import_utils import is_ipex_version from optimum.intel.utils.modeling_utils import _setattr_from_module +from .cache_utils import IPEXPagedCache + logger = logging.getLogger(__name__) -_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.3.0" +_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.4.0" if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING): @@ -38,28 +37,114 @@ f"Please upgrade the IPEX version to at least {_IPEX_MINIMUM_VERSION_FOR_PATCHING} if you want to patch the model." ) else: + from intel_extension_for_pytorch.llm.functional import rms_norm, rotary_embedding, varlen_attention from intel_extension_for_pytorch.llm.modules import ( - IndirectAccessKVCacheAttention, Linear2SiluMul, LinearAdd, LinearAddAdd, LinearGelu, - RotaryEmbedding, + PagedAttention, ) +# TODO: Following XPULinearXXX op classes will be put into ipex after 2.6.0 version +class XPULinear2SiluMul(torch.nn.Module): + def __init__( + self, + gate_proj: torch.nn.Module, + up_proj: torch.nn.Module, + ): + super().__init__() + self.gate_proj_weight = gate_proj.weight.transpose(0, 1).contiguous() + self.up_proj_weight = up_proj.weight.transpose(0, 1).contiguous() + self.gate_proj_bias = gate_proj.bias + self.up_proj_bias = up_proj.bias + + def forward( + self, + hidden_states, + ): + up = torch.ops.torch_ipex.mm_silu(hidden_states, self.gate_proj_weight) + if self.gate_proj_bias is not None: + up += self.gate_proj_bias + hidden_states = torch.ops.torch_ipex.mm_resmul(hidden_states, self.up_proj_weight, up) + if self.up_proj_bias is not None: + hidden_states += self.up_proj_bias + return hidden_states + + +class XPULinearGelu(torch.nn.Module): + def __init__(self, module: torch.nn.Module): + super().__init__() + self.weight = module.weight.transpose(0, 1).contiguous() + self.bias = module.bias + + def forward(self, x): + return torch.ops.torch_ipex.matmul_gelu(x, self.weight, self.bias, 1.0, "tanh") + + +class XPULinearAdd(torch.nn.Module): + def __init__( + self, + module: torch.nn.Module, + ): + super().__init__() + self.weight = module.weight.transpose(0, 1).contiguous() + self.bias = module.bias + + def forward( + self, + hidden_states, + residual, + ): + token_len, _ = hidden_states.size() + if residual is None: + hidden_states = torch.matmul(hidden_states, self.weight) + if self.bias is not None: + hidden_states += self.bias + else: + if self.bias is not None: + hidden_states = torch.ops.torch_ipex.mm_bias_resadd( + hidden_states, self.weight, self.bias, 1.0, residual, 1.0 + ) + else: + hidden_states = torch.addmm( + residual.flatten(0, -2), + hidden_states.flatten(0, -2), + self.weight, + beta=1.0, + ) + hidden_states = hidden_states.view(token_len, -1) + return hidden_states + + +class XPUlinearAddAdd(torch.nn.Module): + def __init__(self, module: torch.nn.Module): + super().__init__() + self.weight = module.weight.transpose(0, 1).contiguous() + self.bias = module.bias + + def forward(self, x, y, z): + if self.bias is not None: + x = torch.ops.torch_ipex.mm_bias_resadd(x, self.weight, self.bias, 1.0, y, 1.0) + x += z + else: + x = torch.ops.torch_ipex.mm_bias_resadd(x, self.weight, z, 1.0, y, 1.0) + return x + + # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L83 def _ipex_rms_layer_norm_forward(self, hidden_states): - return torch.ops.torch_ipex.rmsnorm(hidden_states, self.weight, self.variance_epsilon) + return rms_norm(hidden_states, self.weight, self.variance_epsilon) -# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L1130 +# Adapted from https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/llama/modeling_llama.py#L918 def _llama_model_forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -85,29 +170,21 @@ def _llama_model_forward( else: raise ValueError("You have to specify either input_ids or inputs_embeds") - past_key_values_length = 0 - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] + if past_key_values is not None and not isinstance(past_key_values, IPEXPagedCache): + raise ValueError("only support IPEXPagedCache input now") + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) - position_ids = position_ids.unsqueeze(0) + position_ids = position_ids.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if getattr(self.config, "_flash_attn_2_enabled", False): - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - # embed positions hidden_states = inputs_embeds @@ -116,25 +193,41 @@ def _llama_model_forward( all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None + position_embeddings = self.rotary_emb(hidden_states, position_ids) + if past_key_values_length == 0: + # first token, remove the padding from hidden_states, varlen do not accept attention mask + hidden_states_copy = hidden_states + index = attention_mask.view(-1) != 0 + hidden_states = (hidden_states.view(-1, hidden_states.shape[-1]))[index] + cos = position_embeddings[0] + sin = position_embeddings[1] + cos = (cos.reshape(-1, cos.shape[-1]))[index] + sin = (sin.reshape(-1, sin.shape[-1]))[index] + position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1)) + else: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = past_key_values[idx] if past_key_values is not None else None - layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + position_embeddings=position_embeddings, + input_lens=input_lens, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -146,6 +239,10 @@ def _llama_model_forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if hidden_states.shape[0] != batch_size * seq_length: + (hidden_states_copy.view(-1, hidden_states.shape[-1]))[attention_mask.view(-1) != 0] = hidden_states + hidden_states = hidden_states_copy + hidden_states = hidden_states.view(batch_size, -1, hidden_states.shape[-1]) if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( @@ -156,17 +253,318 @@ def _llama_model_forward( ) -def _gpt2_block_forward(self, hidden_states, *args, **kwargs): - attention_mask = kwargs.get("attention_mask", None) - if attention_mask is not None: - bsz, seq_len, _ = hidden_states.size() - layer_past = kwargs.get("layer_past", None) - past_len = layer_past[0].size(-2) if layer_past is not None else 0 - attention_mask = (1 - attention_mask / torch.finfo(attention_mask.dtype).min).squeeze(1, 2) - attention_mask = _prepare_4d_causal_attention_mask(attention_mask, (bsz, seq_len), hidden_states, past_len) - kwargs["attention_mask"] = attention_mask +# Adapted from https://github.com/huggingface/transformers/blob/v4.46.1/src/transformers/models/falcon/modeling_falcon.py#L945 +def _falcon_model_forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, +) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + batch_size, seq_length, _ = inputs_embeds.shape + + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape batch_size x num_heads x N x N + # head_mask has shape n_layer x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + if past_key_values_length == 0: + # first token, remove the padding from hidden_states, varlen do not accept attention mask + hidden_states_copy = hidden_states + index = attention_mask.view(-1) != 0 + hidden_states = (hidden_states.view(-1, hidden_states.shape[-1]))[index] + cos = position_embeddings[0] + sin = position_embeddings[1] + cos = (cos.reshape(-1, cos.shape[-1]))[index] + sin = (sin.reshape(-1, sin.shape[-1]))[index] + position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1)) + else: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + + next_decoder_cache = None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + for i, block in enumerate(self.h): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = block( + hidden_states, + layer_past=past_key_values, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=None, + cache_position=cache_position, + position_embeddings=position_embeddings, + input_lens=input_lens, + ) + + hidden_states = outputs[0] + if use_cache is True: + next_decoder_cache = outputs[1] + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if hidden_states.shape[0] != batch_size * seq_length: + (hidden_states_copy.view(-1, hidden_states.shape[-1]))[attention_mask.view(-1) != 0] = hidden_states + hidden_states = hidden_states_copy + + hidden_states = hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +def _gpt2_model_forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + batch_size, seq_length, _ = inputs_embeds.shape + position_embeddings = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeddings + + encoder_attention_mask = None + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + if past_length == 0: + # first token, remove the padding from hidden_states, varlen do not accept attention mask + hidden_states_copy = hidden_states + index = attention_mask.view(-1) != 0 + hidden_states = (hidden_states.view(-1, hidden_states.shape[-1]))[index] + else: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + + presents = None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + for i, block in enumerate(self.h): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = block( + hidden_states, + layer_past=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + input_lens=input_lens, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = outputs[1] + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + hidden_states = self.ln_f(hidden_states) + if hidden_states.shape[0] != batch_size * seq_length: + (hidden_states_copy.view(-1, hidden_states.shape[-1]))[attention_mask.view(-1) != 0] = hidden_states + hidden_states = hidden_states_copy + + hidden_states = hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# To pass input_lens, adapted from https://github.com/huggingface/transformers/blob/v4.46.3/src/transformers/models/gpt2/modeling_gpt2.py#L602 +def _gpt2_block_forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs, +) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + # residual connection + hidden_states = attn_output + residual + + if encoder_hidden_states is not None: + # add one self-attention block for cross-attention + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " + "cross-attention layers by setting `config.add_cross_attention=True`" + ) + residual = hidden_states + hidden_states = self.ln_cross_attn(hidden_states) + cross_attn_outputs = self.crossattention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + **kwargs, + ) + attn_output = cross_attn_outputs[0] + # residual connection + hidden_states = residual + attn_output + outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] - return GPT2Block.forward(self, hidden_states, *args, **kwargs) + return outputs # hidden_states, present, (attentions, cross_attentions) class _IPEXAttention(nn.Module): @@ -174,14 +572,11 @@ def __init__(self, module, config) -> None: super().__init__() _setattr_from_module(self, module) self.config = config - self.ipex_scale_dot_product = IndirectAccessKVCacheAttention(text_max_length=config.max_position_embeddings) - if hasattr(config, "rope_theta"): - self.ipex_rope = RotaryEmbedding( - config.max_position_embeddings, - config.hidden_size // config.num_attention_heads, - config.rope_theta, - config.architectures[0], - ) + self.module_device = next(module.parameters()).device + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=self.module_device + ).repeat_interleave(self.num_groups) def qkv_gemm(self, hidden_states): raise NotImplementedError("Need to implement in specific model class") @@ -189,29 +584,8 @@ def qkv_gemm(self, hidden_states): def rope(self, *args, **kwargs): raise NotImplementedError("Need to implement in specific model class") - def sdpa_with_cache(self, query, key, value, past_key_value, attention_mask, **kwargs): - # This ipex op pre-allocates buffers for past_key_values and use beam index history - # which to decide which beam should be used to make attention scale dot more efficient. - (attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product( - query, - key, - value, - math.sqrt(self.head_dim), - past_key_value, - kwargs.get("head_mask", None), - attention_mask, - kwargs.get("alibi", None), - ) - return attn_output, past_key_value, attn_weights - - def sdpa_without_cache(self, query, key, value, past_key_value, attention_mask, **kwargs): - raise NotImplementedError("Need to implement in specific model class") - - def prepare_attention_mask_float(self, attention_mask, *args): - return attention_mask - - def postprocess_attention_output(self, attn_output, bsz, seq_len): - attn_output = attn_output.transpose(1, 2).reshape(bsz, seq_len, self.hidden_size) + def postprocess_attention_output(self, attn_output): + attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1]) return attn_output def forward( @@ -219,40 +593,60 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[IPEXPagedCache] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # For llama inputs: https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/llama/modeling_llama.py#L308 - # For falcon inputs: https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/falcon/modeling_falcon.py#L370 if past_key_value is None and kwargs.get("layer_past", None) is not None: past_key_value = kwargs.pop("layer_past", None) - bsz, seq_len, _ = hidden_states.size() - past_len = past_key_value[0].size(-2) if past_key_value is not None else 0 - kv_seq_len = seq_len + past_len - - qkv_out = self.qkv_gemm(hidden_states) - if isinstance(qkv_out, tuple) and len(qkv_out) == 3: - query, key, value = self.qkv_gemm(hidden_states) - query, key = self.rope(query, key, kv_seq_len, use_cache, position_ids=position_ids) + input_lens = kwargs.pop("input_lens", None) + past_len = 0 + if past_key_value is not None: + past_len = past_key_value.get_seq_length() + query, key, value = self.qkv_gemm(hidden_states) + query, key = self.rope(query, key, **kwargs) + + if past_key_value is not None: + key_cache, value_cache = past_key_value.update(key, value, self.layer_idx, attention_mask, input_lens) + + attn_output = torch.empty_like(query) + if past_len == 0: + # prefill, remove padding + seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) + varlen_attention( + query.contiguous() if query.device.type == "xpu" else query, + key.contiguous() if key.device.type == "xpu" else key, + value.contiguous() if value.device.type == "xpu" else value, + attn_output, + seq_len_tensor, + seq_len_tensor, + input_lens.max(), + input_lens.max(), + 0.0, + 1.0 / math.sqrt(self.head_dim), + False, + True, + False, + None, + ) else: - query, key, value = self.rope(qkv_out, kv_seq_len, use_cache, past_len=past_len) - - attention_mask = self.prepare_attention_mask_float(attention_mask, query.dtype) - sdpa = self.sdpa_with_cache if use_cache else self.sdpa_without_cache - attn_output, past_key_value, attn_weights = sdpa( - query, - key, - value, - past_key_value, - attention_mask, - position_ids=position_ids, - head_mask=kwargs.get("head_mask", None), - alibi=kwargs.get("alibi", None), - ) - attn_output = self.postprocess_attention_output(attn_output, bsz, seq_len) + # decode + PagedAttention.single_query_cached_kv_attention( + attn_output, + query, + key_cache, + value_cache, + self.kv_head_mapping, + 1.0 / math.sqrt(self.head_dim), + past_key_value.block_tables, + input_lens, + past_key_value.block_size, + input_lens.max(), + None, + ) + attn_output = self.postprocess_attention_output(attn_output) if not output_attentions: attn_weights = None @@ -262,105 +656,83 @@ def forward( class _IPEXLlamaAttention(_IPEXAttention): def __init__(self, module, config) -> None: super().__init__(module, config) - if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]: - self.mha_linear_add = LinearAdd(module.o_proj) - del self.__dict__["_modules"]["o_proj"] + concat_weight = torch.concat([self.q_proj.weight, self.k_proj.weight, self.v_proj.weight]).contiguous() + bias_list = [bias for bias in [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] if bias] + use_bias = bias_list != [] + self.concat_qkv = nn.Linear(concat_weight.shape[1], concat_weight.shape[0], bias=use_bias) + self.concat_qkv.weight = nn.Parameter(concat_weight) + if use_bias: + concat_bias = torch.concat(bias_list, 0).contiguous() + self.concat_linear.bias = nn.Parameter(concat_bias) + self.q_slice = self.q_proj.weight.shape[0] + self.k_slice = self.q_slice + self.k_proj.weight.shape[0] + self.v_slice = self.k_slice + self.v_proj.weight.shape[0] + if self.module_device.type == "cpu": + if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]: + self.mha_linear_add = LinearAdd(module.o_proj) + + elif self.module_device.type == "xpu": + if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]: + self.mha_linear_add = XPULinearAdd(module.o_proj) def qkv_gemm(self, hidden_states): - bsz, seq_len, _ = hidden_states.size() - query = self.q_proj(hidden_states).view(bsz, seq_len, self.num_heads, self.head_dim) - key = self.k_proj(hidden_states).view(bsz, seq_len, self.num_key_value_heads, self.head_dim) - value = self.v_proj(hidden_states).view(bsz, seq_len, self.num_key_value_heads, self.head_dim) + qkv_out = self.concat_qkv(hidden_states) + query = qkv_out[:, : self.q_slice].view(-1, self.num_heads, self.head_dim) + key = qkv_out[:, self.q_slice : self.k_slice].view(-1, self.num_key_value_heads, self.head_dim) + value = qkv_out[:, self.k_slice :].view(-1, self.num_key_value_heads, self.head_dim) return query, key, value - def rope(self, query, key, kv_seq_len, use_cache, position_ids): - if use_cache: - args = (self.head_dim, self.head_dim // 2, self.head_dim, kv_seq_len) - key = self.ipex_rope(key, position_ids, self.num_key_value_heads, *args) - query = self.ipex_rope(query, position_ids, self.num_heads, *args) + def rope(self, query, key, **kwargs): + position_embeddings = kwargs.pop("position_embeddings", None) + rotary_embedding(query, key, position_embeddings[1], position_embeddings[0], query.size(-1), True) return query, key - # Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L341 - def sdpa_without_cache(self, query, key, value, past_key_value, attention_mask, position_ids, **kwargs): - query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) - cos, sin = self.rotary_emb(value, position_ids) - query, key = apply_rotary_pos_emb(query, key, cos, sin) - # repeat k/v heads if n_kv_heads < n_heads - key = repeat_kv(key, self.num_key_value_groups) - value = repeat_kv(value, self.num_key_value_groups) - attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: - attn_weights = torch.tensor(attn_weights) + torch.tensor(attention_mask) - attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_output = torch.matmul(attn_weights, value) - - return attn_output, None, attn_weights - class _IPEXFalconAttention(_IPEXAttention): - def qkv_gemm(self, hidden_states): - return self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + def __init__(self, module, config): + self.num_key_value_heads = config.num_key_value_heads + super().__init__(module, config) + self.q_slice = self.head_dim * config.num_kv_heads + self.k_slice = self.q_slice + self.head_dim + self.v_slice = self.k_slice + self.head_dim - def rope(self, fused_qkv, seq_len, use_cache, past_len): - if use_cache: - query, key, value = self.ipex_rope( - fused_qkv, - torch.tensor(past_len), - self.num_heads, - self.head_dim, - self.head_dim // 2, - self.head_dim, - seq_len, - 3, - ) + def qkv_gemm(self, hidden_states): + qkv_out = self.query_key_value(hidden_states) + if self.new_decoder_architecture: + qkv_out = qkv_out.view(qkv_out.shape[0], -1, self.num_heads // self.num_kv_heads + 2, self.head_dim) + query = qkv_out[:, :, :-2, :].flatten(1, 2) + key = qkv_out[:, :, [-2], :].flatten(1, 2) + value = qkv_out[:, :, [-1], :].flatten(1, 2) else: - (query, key, value) = self._split_heads(fused_qkv) + query = qkv_out[:, : self.q_slice].view(-1, self.num_heads, self.head_dim) + key = qkv_out[:, self.q_slice : self.k_slice].view(-1, self.num_key_value_heads, self.head_dim) + value = qkv_out[:, self.k_slice :].view(-1, self.num_key_value_heads, self.head_dim) return query, key, value - def prepare_attention_mask_float(self, attention_mask, dtype): - attention_mask_float = ( - (attention_mask * 1.0).masked_fill(attention_mask.to(torch.bool), float("-1e9")).to(dtype) - ) - return attention_mask_float - - def sdpa_without_cache(self, query, key, value, past_key_value, attention_mask, **kwargs): - bs, q_len = query.shape[0], query.shape[1] - query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) - attn_output = F.scaled_dot_product_attention(query, key, value, attention_mask, 0.0, is_causal=False) - attn_output = attn_output.view(bs, self.num_heads, q_len, self.head_dim) - - return attn_output, None, None + def rope(self, query, key, **kwargs): + position_embeddings = kwargs.pop("position_embeddings", None) + rotary_embedding(query, key, position_embeddings[1], position_embeddings[0], query.size(-1), True) + return query, key class _IPEXGPT2Attention(_IPEXAttention): def __init__(self, module, config) -> None: + self.num_key_value_heads = config.num_key_value_heads super().__init__(module, config) - def _split_heads_ipex(self, tensor, num_heads, attn_head_size): - new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) - return tensor.view(new_shape) # (batch, seq_length, head, head_features) - def qkv_gemm(self, hidden_states): - query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) - query = self._split_heads_ipex(query, self.num_heads, self.head_dim) - key = self._split_heads_ipex(key, self.num_heads, self.head_dim) - value = self._split_heads_ipex(value, self.num_heads, self.head_dim) + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=-1) + query = query.view(-1, self.num_heads, self.head_dim) + key = key.view(-1, self.num_heads, self.head_dim) + value = value.view(-1, self.num_heads, self.head_dim) return query, key, value def rope(self, query, key, *args, **kwargs): return query, key - def sdpa_without_cache(self, query, key, value, past_key_value, attention_mask, **kwargs): - query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2) - attn_output = F.scaled_dot_product_attention(query, key, value, attention_mask, 0.0, is_causal=True) - - return attn_output, None, None - - def postprocess_attention_output(self, attn_output, bsz, seq_len): - attn_output = attn_output.transpose(1, 2).reshape(bsz, seq_len, self.embed_dim) + def postprocess_attention_output(self, attn_output): + attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1]) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) return attn_output @@ -372,13 +744,17 @@ def __init__(self, module, config) -> None: super().__init__() _setattr_from_module(self, module) self.config = config - # LinearAllreduce and LinearLayer cannot use fused op LinearAdd - if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]: - self.mlp_linear_add = LinearAdd(module.down_proj) - del self.__dict__["_modules"]["down_proj"] - self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj) - del self.__dict__["_modules"]["gate_proj"] - del self.__dict__["_modules"]["up_proj"] + self.module_device = next(module.parameters()).device + if self.module_device.type == "cpu": + # LinearAllreduce and LinearLayer cannot use fused op LinearAdd + if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]: + self.mlp_linear_add = LinearAdd(module.down_proj) + self.linear_silu_mul = Linear2SiluMul(module.gate_proj, module.up_proj) + elif self.module_device.type == "xpu": + # LinearAllreduce and LinearLayer cannot use fused op LinearAdd + if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]: + self.mlp_linear_add = XPULinearAdd(module.down_proj) + self.linear_silu_mul = XPULinear2SiluMul(module.gate_proj, module.up_proj) def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor = None, **kwargs): if hasattr(self, "linear_silu_mul"): @@ -401,11 +777,16 @@ def __init__(self, module, config) -> None: _setattr_from_module(self, module) self.config = config # LinearAllreduce and LinearLayer cannot use fused op LinearAdd - self.linear_gelu = LinearGelu(module.dense_h_to_4h) - del self.__dict__["_modules"]["dense_h_to_4h"] + self.module_device = next(module.parameters()).device + if self.module_device.type == "cpu": + self.linear_gelu = LinearGelu(module.dense_h_to_4h) + elif self.module_device.type == "xpu": + self.linear_gelu = XPULinearGelu(module.dense_h_to_4h) if module.dense_4h_to_h.__class__.__name__ not in ["LinearAllreduce"]: - self.linear_add_add = LinearAddAdd(module.dense_4h_to_h) - del self.__dict__["_modules"]["dense_4h_to_h"] + if self.module_device.type == "cpu": + self.linear_add_add = LinearAddAdd(module.dense_4h_to_h) + elif self.module_device.type == "xpu": + self.linear_add_add = XPUlinearAddAdd(module.dense_4h_to_h) def forward( self, @@ -489,8 +870,11 @@ class _IPEXIntermediate(nn.Module): def __init__(self, module, config): super().__init__() _setattr_from_module(self, module) - self.linear_gelu = LinearGelu(module.dense) - del self.__dict__["_modules"]["dense"] + self.module_device = next(module.parameters()).device + if self.module_device.type == "cpu": + self.linear_gelu = LinearGelu(module.dense) + elif self.module_device.type == "xpu": + self.linear_gelu = XPULinearGelu(module.dense) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.linear_gelu(hidden_states) diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index 3ac8314889..592cd85a4b 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -29,7 +29,6 @@ from optimum.exporters import TasksManager from optimum.exporters.onnx.base import OnnxConfig from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED -from optimum.exporters.openvino.convert import export_from_model from optimum.intel.utils.import_utils import ( is_nncf_available, is_openvino_tokenizers_available, @@ -42,7 +41,12 @@ ) from optimum.utils.save_utils import maybe_load_preprocessors -from .utils import _MAX_UNCOMPRESSED_SIZE, MULTI_MODAL_TEXT_GENERATION_MODELS, clear_class_registry +from .utils import ( + _MAX_UNCOMPRESSED_SIZE, + MULTI_MODAL_TEXT_GENERATION_MODELS, + clear_class_registry, + deduce_diffusers_dtype, +) FORCE_ATTN_MODEL_CLASSES = {"phi3-v": "eager"} @@ -185,6 +189,7 @@ def main_export( >>> main_export("gpt2", output="gpt2_ov/") ``` """ + from optimum.exporters.openvino.convert import export_from_model if use_auth_token is not None: warnings.warn( @@ -332,6 +337,19 @@ class StoreAttr(object): return model GPTQQuantizer.post_init_model = post_init_model + elif library_name == "diffusers" and is_openvino_version(">=", "2024.6"): + dtype = deduce_diffusers_dtype( + model_name_or_path, + revision=revision, + cache_dir=cache_dir, + token=token, + local_files_only=local_files_only, + force_download=force_download, + trust_remote_code=trust_remote_code, + ) + if dtype in [torch.float16, torch.bfloat16]: + loading_kwargs["torch_dtype"] = dtype + patch_16bit = True if library_name == "open_clip": model = _OpenClipForZeroShotImageClassification.from_pretrained(model_name_or_path, cache_dir=cache_dir) @@ -456,7 +474,6 @@ class StoreAttr(object): from optimum.intel.openvino.quantization import _weight_only_quantization _weight_only_quantization(submodel, quantization_config) - compressed_submodel_path = submodel_path.parent / f"{submodel_path.stem}_compressed.xml" save_model(submodel, compressed_submodel_path, compress_to_fp16=False) del submodel diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index e4ece9801b..c9e18cff6a 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -20,7 +20,6 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union -import onnx from transformers.generation import GenerationMixin from transformers.utils import is_tf_available, is_torch_available @@ -28,10 +27,6 @@ from openvino.runtime.exceptions import OVTypeError from openvino.tools.ovc import convert_model from optimum.exporters import TasksManager -from optimum.exporters.onnx.base import OnnxConfig -from optimum.exporters.onnx.convert import check_dummy_inputs_are_allowed -from optimum.exporters.onnx.convert import export_pytorch as export_pytorch_to_onnx -from optimum.exporters.onnx.convert import export_tensorflow as export_tensorflow_onnx from optimum.exporters.utils import ( _get_submodels_and_export_configs as _default_get_submodels_and_export_configs, ) @@ -89,6 +84,7 @@ if TYPE_CHECKING: + from optimum.exporters.onnx.base import OnnxConfig from optimum.intel.openvino.configuration import OVConfig @@ -99,11 +95,15 @@ def _set_runtime_options( ], task: str, library_name: str, + quantized_model: bool, ): for model_name in models_and_export_configs.keys(): _, sub_export_config = models_and_export_configs[model_name] + sub_export_config.runtime_options = {} if "diffusers" in library_name or "text-generation" in task: - sub_export_config.runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"} + sub_export_config.runtime_options["ACTIVATIONS_SCALE_FACTOR"] = "8.0" + if not quantized_model and "text-generation" in task: + sub_export_config.runtime_options["KV_CACHE_PRECISION"] = "f16" def _save_model( @@ -111,13 +111,13 @@ def _save_model( path: str, ov_config: Optional["OVConfig"] = None, library_name: Optional[str] = None, - config: OnnxConfig = None, + config: "OnnxConfig" = None, ): compress_to_fp16 = ov_config is not None and ov_config.dtype == "fp16" model = _add_version_info_to_model(model, library_name) - if hasattr(config, "runtime_options"): - model = _add_runtime_options_to_rt_info(model, config.runtime_options) + runtime_options = config.runtime_options if hasattr(config, "runtime_options") else {} + model = _add_runtime_options_to_rt_info(model, runtime_options) save_model(model, path, compress_to_fp16) del model gc.collect() @@ -125,7 +125,7 @@ def _save_model( def export( model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin", "DiffusionPipeline"], - config: OnnxConfig, + config: "OnnxConfig", output: Path, opset: Optional[int] = None, device: str = "cpu", @@ -208,7 +208,7 @@ def export( def export_tensorflow( model: Union["PreTrainedModel", "ModelMixin"], - config: OnnxConfig, + config: "OnnxConfig", opset: int, output: Path, ov_config: Optional["OVConfig"] = None, @@ -228,6 +228,8 @@ def export_tensorflow( output_names: list of output names from ONNX configuration bool: True if the model was exported successfully. """ + from optimum.exporters.onnx.convert import export_tensorflow as export_tensorflow_onnx + onnx_path = Path(output).with_suffix(".onnx") input_names, output_names = export_tensorflow_onnx(model, config, opset, onnx_path) ov_model = convert_model(str(onnx_path)) @@ -248,7 +250,7 @@ def export_tensorflow( def export_pytorch_via_onnx( model: Union["PreTrainedModel", "ModelMixin"], - config: OnnxConfig, + config: "OnnxConfig", opset: int, output: Path, device: str = "cpu", @@ -285,6 +287,8 @@ def export_pytorch_via_onnx( """ import torch + from optimum.exporters.onnx.convert import export_pytorch as export_pytorch_to_onnx + output = Path(output) orig_torch_onnx_export = torch.onnx.export torch.onnx.export = functools.partial(orig_torch_onnx_export, do_constant_folding=False) @@ -313,7 +317,7 @@ def export_pytorch_via_onnx( def export_pytorch( model: Union["PreTrainedModel", "ModelMixin"], - config: OnnxConfig, + config: "OnnxConfig", opset: int, output: Path, device: str = "cpu", @@ -355,6 +359,8 @@ def export_pytorch( import torch from torch.utils._pytree import tree_map + from optimum.exporters.onnx.convert import check_dummy_inputs_are_allowed + logger.info(f"Using framework PyTorch: {torch.__version__}") output = Path(output) @@ -755,7 +761,12 @@ def export_from_model( model.save_config(output) - _set_runtime_options(models_and_export_configs, task, library_name) + _set_runtime_options( + models_and_export_configs, + task, + library_name, + hasattr(ov_config, "quantization_config") and ov_config.quantization_config, + ) export_models( models_and_export_configs=models_and_export_configs, @@ -869,6 +880,8 @@ def _add_version_info_to_model(model: Model, library_name: Optional[str] = None) model.set_rt_info(_nncf_version, ["optimum", "nncf_version"]) input_model = rt_info["conversion_parameters"].get("input_model", None) if input_model is not None and "onnx" in input_model.value: + import onnx + model.set_rt_info(onnx.__version__, ["optimum", "onnx_version"]) except Exception: diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index c07e1544f2..802cd02418 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -1833,8 +1833,9 @@ def __init__( normalized_config: NormalizedVisionConfig, batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"], - width: int = DEFAULT_DUMMY_SHAPES["width"], - height: int = DEFAULT_DUMMY_SHAPES["height"], + width: int = DEFAULT_DUMMY_SHAPES["width"] // 4, + height: int = DEFAULT_DUMMY_SHAPES["height"] // 4, + # Reduce img shape by 4 for FLUX to reduce memory usage on conversion **kwargs, ): super().__init__(task, normalized_config, batch_size, num_channels, width, height, **kwargs) diff --git a/optimum/exporters/openvino/stateful.py b/optimum/exporters/openvino/stateful.py index 4b4374ab51..39d64c2aec 100644 --- a/optimum/exporters/openvino/stateful.py +++ b/optimum/exporters/openvino/stateful.py @@ -20,7 +20,6 @@ import openvino as ov from openvino.runtime import opset13 -from optimum.exporters import TasksManager from optimum.intel.utils.import_utils import _openvino_version, is_openvino_version, is_transformers_version from .utils import MULTI_MODAL_TEXT_GENERATION_MODELS @@ -192,6 +191,8 @@ def ensure_stateful_is_available(warn=True): def ensure_export_task_support_stateful(task: str): + from optimum.exporters import TasksManager + task = TasksManager.map_from_synonym(task) return task in ["text-generation-with-past"] diff --git a/optimum/exporters/openvino/utils.py b/optimum/exporters/openvino/utils.py index 9891395a38..46b151e7de 100644 --- a/optimum/exporters/openvino/utils.py +++ b/optimum/exporters/openvino/utils.py @@ -26,6 +26,7 @@ from optimum.exporters import TasksManager from optimum.exporters.onnx.base import OnnxConfig from optimum.intel.utils import is_transformers_version +from optimum.intel.utils.import_utils import is_safetensors_available from optimum.utils import is_diffusers_available from optimum.utils.save_utils import maybe_save_preprocessors @@ -240,6 +241,41 @@ def save_config(config, save_dir): config.to_json_file(output_config_file, use_diff=True) +def deduce_diffusers_dtype(model_name_or_path, **loading_kwargs): + dtype = None + if is_safetensors_available(): + if Path(model_name_or_path).is_dir(): + path = Path(model_name_or_path) + else: + from diffusers import DiffusionPipeline + + path = Path(DiffusionPipeline.download(model_name_or_path, **loading_kwargs)) + model_part_name = None + if (path / "transformer").is_dir(): + model_part_name = "transformer" + elif (path / "unet").is_dir(): + model_part_name = "unet" + if model_part_name: + directory = path / model_part_name + safetensors_files = [ + filename for filename in directory.glob("*.safetensors") if len(filename.suffixes) == 1 + ] + safetensors_file = None + if len(safetensors_files) > 0: + safetensors_file = safetensors_files.pop(0) + if safetensors_file: + from safetensors import safe_open + + with safe_open(safetensors_file, framework="pt", device="cpu") as f: + if len(f.keys()) > 0: + for key in f.keys(): + tensor = f.get_tensor(key) + if tensor.dtype.is_floating_point: + dtype = tensor.dtype + break + return dtype + + def save_preprocessors( preprocessors: List, config: PretrainedConfig, output: Union[str, Path], trust_remote_code: bool ): diff --git a/optimum/intel/__init__.py b/optimum/intel/__init__.py index 0230394d29..ad9fdca078 100644 --- a/optimum/intel/__init__.py +++ b/optimum/intel/__init__.py @@ -41,11 +41,20 @@ from .utils import dummy_ipex_objects _import_structure["utils.dummy_ipex_objects"] = [ - name for name in dir(dummy_ipex_objects) if not name.startswith("_") + "IPEXModelForCausalLM", + "IPEXModelForSequenceClassification", + "IPEXModelForMaskedLM", + "IPEXModelForTokenClassification", + "IPEXModelForQuestionAnswering", + "IPEXModelForImageClassification", + "IPEXModelForAudioClassification", + "IPEXModel", ] else: + _import_structure["utils.dummy_ipex_objects"] = [] _import_structure["ipex"] = [ "IPEXModelForCausalLM", + "IPEXModelForSeq2SeqLM", "IPEXModelForSequenceClassification", "IPEXModelForMaskedLM", "IPEXModelForTokenClassification", @@ -55,6 +64,15 @@ "IPEXModel", ] +try: + if not (is_ipex_available() and is_sentence_transformers_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + _import_structure["utils.dummy_ipex_objects"].extend(["IPEXSentenceTransformer"]) +else: + _import_structure["ipex"].extend(["IPEXSentenceTransformer"]) + + try: if not (is_openvino_available() and is_nncf_available()): raise OptionalDependencyNotAvailable() @@ -212,15 +230,9 @@ if not (is_openvino_available() and is_sentence_transformers_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - _import_structure["utils.dummy_openvino_and_sentence_transformers_objects"] = [ - "OVSentenceTransformer", - ] + _import_structure["utils.dummy_openvino_and_sentence_transformers_objects"] = ["OVSentenceTransformer"] else: - _import_structure["openvino"].extend( - [ - "OVSentenceTransformer", - ] - ) + _import_structure["openvino"].extend(["OVSentenceTransformer"]) if TYPE_CHECKING: @@ -237,10 +249,19 @@ IPEXModelForImageClassification, IPEXModelForMaskedLM, IPEXModelForQuestionAnswering, + IPEXModelForSeq2SeqLM, IPEXModelForSequenceClassification, IPEXModelForTokenClassification, ) + try: + if not (is_ipex_available() and is_sentence_transformers_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_ipex_objects import IPEXSentenceTransformer + else: + from .ipex import IPEXSentenceTransformer + try: if not (is_openvino_available() and is_nncf_available()): raise OptionalDependencyNotAvailable() @@ -372,13 +393,9 @@ if not (is_openvino_available() and is_sentence_transformers_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils.dummy_openvino_and_sentence_transformers_objects import ( - OVSentenceTransformer, - ) + from .utils.dummy_openvino_and_sentence_transformers_objects import OVSentenceTransformer else: - from .openvino import ( - OVSentenceTransformer, - ) + from .openvino import OVSentenceTransformer else: import sys diff --git a/optimum/intel/generation/modeling.py b/optimum/intel/generation/modeling.py index 22a4745f0c..a6e8a76f4f 100644 --- a/optimum/intel/generation/modeling.py +++ b/optimum/intel/generation/modeling.py @@ -373,6 +373,7 @@ def _from_pretrained( file_name: Optional[str] = WEIGHTS_NAME, local_files_only: bool = False, use_cache: bool = True, + subfolder: str = None, **kwargs, ): if use_auth_token is not None: @@ -402,6 +403,7 @@ def _from_pretrained( cache_dir=cache_dir, force_download=force_download, local_files_only=local_files_only, + subfolder=subfolder, ) model_save_dir = Path(model_cache_path).parent model = cls.load_model(model_cache_path) diff --git a/optimum/intel/ipex/__init__.py b/optimum/intel/ipex/__init__.py index c1f711acfc..9aae96b08a 100644 --- a/optimum/intel/ipex/__init__.py +++ b/optimum/intel/ipex/__init__.py @@ -12,13 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from optimum.intel.ipex.modeling_base import ( +from ..utils.import_utils import is_sentence_transformers_available +from .modeling_base import ( IPEXModel, IPEXModelForAudioClassification, IPEXModelForCausalLM, IPEXModelForImageClassification, IPEXModelForMaskedLM, IPEXModelForQuestionAnswering, + IPEXModelForSeq2SeqLM, IPEXModelForSequenceClassification, IPEXModelForTokenClassification, ) + + +if is_sentence_transformers_available(): + from .modeling_sentence_transformers import IPEXSentenceTransformer diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 739a2f2b44..af36d06f4d 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -16,17 +16,12 @@ import inspect import logging import os -import warnings from pathlib import Path from tempfile import TemporaryDirectory from typing import Dict, Optional, Tuple, Union -import intel_extension_for_pytorch as ipex import torch import transformers -from huggingface_hub import hf_hub_download -from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE -from intel_extension_for_pytorch.cpu._auto_kernel_selection import _enable_tpp from transformers import ( AutoConfig, AutoModel, @@ -35,34 +30,29 @@ AutoModelForImageClassification, AutoModelForMaskedLM, AutoModelForQuestionAnswering, + AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoModelForTokenClassification, GenerationConfig, GenerationMixin, PretrainedConfig, - is_torch_xpu_available, ) from transformers.dynamic_module_utils import get_class_from_dynamic_module from transformers.generation.candidate_generator import _crop_past_key_values -from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput +from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.auto.auto_factory import _get_model_class as get_model_class -from transformers.utils import WEIGHTS_NAME -from optimum.exporters import TasksManager -from optimum.exporters.tasks import make_backend_config_constructor_for_task from optimum.modeling_base import OptimizedModel from optimum.utils import NormalizedConfigManager -from ...exporters.ipex.model_config import ipex_onnx_config +from ...exporters.ipex.cache_utils import IPEXPagedCache from ...exporters.ipex.model_patcher import ( _IPEX_EXPORTED_GENERATION_TASKS, _IPEX_MINIMUM_VERSION_FOR_PATCHING, _patch_model, ) -from ..generation.modeling import get_float_type -from ..utils.constant import _TASK_ALIASES -from ..utils.import_utils import is_ipex_version, is_torch_version, is_transformers_version -from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, recursive_to_device +from ..generation.modeling import prepare_jit_inputs +from ..utils.import_utils import is_ipex_version, is_transformers_version logger = logging.getLogger(__name__) @@ -70,91 +60,19 @@ _IPEX_SUPPORT_MODEL_TYPES = ("llama", "bert", "vit", "falcon", "gpt2") _IPEX_EXPORTED_GENERATION_METHODS = ("sample", "greedy_search", "beam_sample", "beam_search", "assisted_generation") +_IPEX_MINIMUM_VERSION_FOR_COMPILE = "2.5.0" +# TODO: Some models are already fixed in torch 2.6, will enable them when torch upgrading to 2.6 +_COMPILE_NOT_READY_MODEL_TYPES = ("electra", "roformer", "gpt_neox", "beit", "llama", "falcon", "gpt2") -def _is_patched_with_ipex(model, task): +def _is_patched_with_ipex(model, task, use_cache: bool = True): if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING): return False - - if isinstance(model, torch.jit.ScriptModule): - for node in model.graph.nodes(): - # Only patched model enabled fusion linear. - if "/fusions/" in node.__str__(): - return True - return False - elif task in _IPEX_EXPORTED_GENERATION_TASKS and model.config.hidden_size < 64: - # The ipex IAKV op in patched model requires the hidden size at least 64 + if not use_cache and task in _IPEX_EXPORTED_GENERATION_TASKS: return False - return model.config.model_type in _IPEX_SUPPORT_MODEL_TYPES -def _prepare_inputs_for_ipex_model(model, task, use_cache): - task = _TASK_ALIASES.get(task, task) - signature = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.__call__) - if _is_patched_with_ipex(model, task) and model.config.model_type in ipex_onnx_config: - onnx_config_class = make_backend_config_constructor_for_task( - ipex_onnx_config[model.config.model_type], task=task - ) - else: - onnx_config_class = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task) - float_dtype = get_float_type(model.dtype) - if "text-generation" in task: - onnx_config = onnx_config_class( - model.config, use_past=use_cache, use_past_in_inputs=use_cache, float_dtype=float_dtype - ) - else: - onnx_config = onnx_config_class(model.config) - - dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt") - - # Check attention_mask shape - if _is_patched_with_ipex(model, task) and model.config.model_type in ipex_onnx_config and use_cache: - past_len = dummy_inputs["past_key_values"][0][0].shape[-2] - input_len = dummy_inputs["input_ids"].shape[-1] - attention_len = dummy_inputs["attention_mask"].shape[-1] - if attention_len != input_len + past_len: - dummy_inputs["attention_mask"] = torch.ones([dummy_inputs["input_ids"].shape[0], input_len + past_len]).to( - dummy_inputs["input_ids"].dtype - ) - - return {key: dummy_inputs[key] for key in signature.parameters if dummy_inputs.get(key, None) is not None} - - -def ipex_jit_trace(model, task, use_cache): - # Only support torch version >= 2.1.0 to support example_kwarg_inputs in jit.trace - if is_torch_version("<", "2.1.0"): - raise ImportError("`torch>=2.1.0` is needed to trace your model") - - if _is_patched_with_ipex(model, task): - model = _patch_model(model) - - sample_inputs = _prepare_inputs_for_ipex_model(model, task, use_cache) - - model.config.return_dict = False - model.config.use_cache = use_cache - - # Use Tensor Processing Primitives to accelerate linear, see https://arxiv.org/abs/2104.05755. - # Only ipex >= 2.3.0 supports tpp. The tpp is only verified for llm in generation tasks. - if is_ipex_version(">=", "2.3.0") and task in _IPEX_EXPORTED_GENERATION_TASKS: - _enable_tpp() - model = ipex.optimize(model.eval(), dtype=model.dtype, inplace=True) - # Disable repack while jit tracing to reduce the memory - ipex._C.disable_jit_linear_repack() - with torch.no_grad(): - trace_model = torch.jit.trace( - model, - example_kwarg_inputs=sample_inputs, - strict=False, - check_trace=False, - ) - trace_model = torch.jit.freeze(trace_model) - trace_model(**sample_inputs) - trace_model(**sample_inputs) - - return trace_model - - class IPEXModel(OptimizedModel): auto_model_class = AutoModel export_feature = "feature-extraction" @@ -166,47 +84,35 @@ def __init__( self, model, config: PretrainedConfig = None, - export: bool = False, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, - warmup: bool = True, + warmup: Optional[bool] = True, **kwargs, ): - if is_torch_xpu_available(check_device=True): - self._device = torch.device("xpu:0") - elif torch.cuda.is_available(): - self._device = torch.device("cuda:0") - else: - self._device = torch.device("cpu") - - # CPU only support jit model for now. - if export: - if isinstance(model, torch.jit.RecursiveScriptModule): - logger.warning("The model has been exported already.") - else: - config = model.config if config is None else config - use_cache = kwargs.get("use_cache", True) - model = ipex_jit_trace(model, self.export_feature, use_cache) - config.torchscript = True - + config = config or model.config OptimizedModel.__init__(self, model=model, config=config) - self.model.to(self._device) - self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32 + self._supports_cache_class = getattr(model, "_supports_cache_class", None) + self._supports_sdpa = getattr(model, "_supports_sdpa", None) + self._supports_quantized_cache = getattr(model, "_supports_quantized_cache", None) + self._supports_static_cache = getattr(model, "_supports_static_cache", None) + self._dtype = self.model.dtype if self.model.dtype is not None else torch.float32 + self.use_cache = kwargs.get("use_cache", False) self.model_save_dir = model_save_dir - self._is_ipex_exported = _is_patched_with_ipex(model, self.export_feature) + self._add_patch = _is_patched_with_ipex(model, self.export_feature, self.use_cache) + self.compiled = False - if isinstance(model, torch.jit.RecursiveScriptModule): - self.input_names = { - inputs.debugName().split(".")[0] for inputs in model.graph.inputs() if inputs.debugName() != "self" - } - else: - self.input_names = set(inspect.signature(model.forward).parameters) + self.input_names = set(inspect.signature(model.forward).parameters) + if self._add_patch: + model = _patch_model(model) # Registers the IPEXModelForXXX classes into the transformers AutoModel classes to avoid warnings when creating # a pipeline https://github.com/huggingface/transformers/blob/cad61b68396a1a387287a8e2e2fef78a25b79383/src/transformers/pipelines/base.py#L863 AutoConfig.register(self.base_model_prefix, AutoConfig) if hasattr(self.auto_model_class, "register"): self.auto_model_class.register(AutoConfig, self.__class__) + + self.maybe_apply_torch_compile() + if warmup: self._init_warmup() @@ -219,16 +125,6 @@ def _from_pretrained( cls, model_id: Union[str, Path], config: PretrainedConfig, - use_auth_token: Optional[Union[bool, str]] = None, - token: Optional[Union[bool, str]] = None, - revision: Optional[str] = None, - force_download: bool = False, - cache_dir: Union[str, Path] = HUGGINGFACE_HUB_CACHE, - subfolder: str = "", - local_files_only: bool = False, - torch_dtype: Optional[Union[str, "torch.dtype"]] = None, - trust_remote_code: bool = False, - file_name: Optional[str] = WEIGHTS_NAME, **kwargs, ): """ @@ -240,121 +136,23 @@ def _from_pretrained( Can be either: - The model id of a pretrained model hosted inside a model repo on huggingface.co. - The path to a directory containing the model weights. - use_auth_token (Optional[Union[bool, str]], defaults to `None`): - Deprecated. Please use `token` instead. - token (Optional[Union[bool, str]], defaults to `None`): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `huggingface-cli login` (stored in `~/.huggingface`). - revision (`str`, *optional*): - The specific model version to use. It can be a branch name, a tag name, or a commit id. - force_download (`bool`, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - cache_dir (`Union[str, Path]`, *optional*): - The path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. - subfolder (`str`, *optional*) - In case the relevant files are located inside a subfolder of the model repo either locally or on huggingface.co, you can specify the folder name here. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (i.e., do not try to download the model). - torch_dtype (`Optional[Union[str, "torch.dtype"]]`, *optional*) - float16 or bfloat16 or float32: load in a specified dtype, ignoring the model config.torch_dtype if one exists. If not specified, the model will get loaded in float32. - trust_remote_code (`bool`, *optional*) - Allows to use custom code for the modeling hosted in the model repository. This option should only be set for repositories you trust and in which you have read the code, as it will execute on your local machine arbitrary code present in the model repository. - file_name (`str`, *optional*): - The file name of the model to load. Overwrites the default file name and allows one to load the model - with a different name. """ - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", - FutureWarning, - ) - if token is not None: - raise ValueError( - "Both the arguments `use_auth_token` and `token` were specified, which is not supported. Please specify only `token`." - ) - token = use_auth_token - - commit_hash = kwargs.pop("_commit_hash", None) - - model_kwargs = { - "revision": revision, - "token": token, - "cache_dir": cache_dir, - "subfolder": subfolder, - "local_files_only": local_files_only, - "force_download": force_download, - } - - if not getattr(config, "torchscript", False): - logger.warning("Detect torchscript is false. Convert to torchscript model!") - - if is_torch_version("<", "2.1.0"): - raise ImportError("`torch>=2.0.0` is needed to trace your model") - - task = cls.export_feature - config.torch_dtype = torch_dtype - model = TasksManager.get_model_from_task( - task, - model_id, - library_name="transformers", - trust_remote_code=trust_remote_code, - torch_dtype=torch_dtype, - _commit_hash=commit_hash, - **model_kwargs, - ) - - return cls(model, config=config, export=True, **kwargs) - - # Load the model from local directory - if os.path.isdir(model_id): - model_cache_path = os.path.join(model_id, file_name) - model_save_dir = model_id - # Download the model from the hub - else: - model_cache_path = hf_hub_download(repo_id=model_id, filename=file_name, **model_kwargs) - model_save_dir = Path(model_cache_path).parent + if getattr(config, "torchscript", False): + raise ValueError("IPEXModel is no longer support torchscript models.") - model = torch.jit.load(model_cache_path) - torch.jit.freeze(model.eval()) - - return cls(model, config=config, model_save_dir=model_save_dir, **kwargs) + model = cls.auto_model_class.from_pretrained(model_id, **kwargs) + return cls(model, config=model.config, **kwargs) def _save_pretrained(self, save_directory: Union[str, Path]): - output_path = os.path.join(save_directory, WEIGHTS_NAME) - if getattr(self.config, "torchscript", None): - torch.jit.save(self.model, output_path) - else: - logger.warning("The module is not a torchscript model, will be treated as a transformers model.") - self.model.save_pretrained(output_path) - - def forward( - self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - token_type_ids: torch.Tensor = None, - position_ids: torch.Tensor = None, - **kwargs, - ): - inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - - if "token_type_ids" in self.input_names: - inputs["token_type_ids"] = token_type_ids + self.model.save_pretrained(save_directory, safe_serialization=False) - if "position_ids" in self.input_names: - inputs["position_ids"] = position_ids + def push_to_hub(self, *args, **kwargs): + kwargs["safe_serialization"] = False + return self.model.push_to_hub(*args, **kwargs) - outputs = self._call_model(**inputs) - if isinstance(outputs, dict): - model_output = ModelOutput(**outputs) - else: - model_output = ModelOutput() - model_output[self.output_name] = outputs[0] - return model_output + @torch.no_grad() + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) def eval(self): self.model.eval() @@ -362,7 +160,7 @@ def eval(self): @property def device(self) -> torch.device: - return self._device + return self.model.device @property def dtype(self) -> torch.dtype: @@ -375,33 +173,41 @@ def model_dtype(self): ) return self._dtype + @property + def add_patch(self) -> bool: + return self._add_patch + def to(self, device: Union[torch.device, str]): - self._device = device if isinstance(device, torch.device) else torch.device(device) - self.model.to(self._device) + self.model.to(device) return self def can_generate(self): return isinstance(self, GenerationMixin) - def _call_model(self, *args, **kwargs): - try: - with torch.autocast(self.device.type, self.dtype), torch.no_grad(): - out = self.model(*args, **kwargs) - except RuntimeError: - out = self.model(*args, **kwargs) - return out + def maybe_apply_torch_compile(self): + if ( + self.model.device.type != "cpu" + or self.config.model_type in _COMPILE_NOT_READY_MODEL_TYPES + or is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_COMPILE) + ): + return + if self.use_cache and not self._supports_static_cache: + return + from torch._inductor import config as inductor_config + + # System level optimization + inductor_config.cpp_wrapper = True + os.environ["TORCHINDUCTOR_FREEZING"] = "1" + logger.info("Enable torch.compile optimization") + self.model.forward = torch.compile(self.model.forward) + self.compiled = True def _init_warmup(self): - # warmup, the first 2 forwards of an IPEX model include some preprocessing steps and - # the results of the compute are unpredictable - # TODO : add warmup for IPEX exported model - if not self._is_ipex_exported: - use_cache = "past_key_values" in self.input_names - dummy_inputs = _prepare_inputs_for_ipex_model(self, self.export_feature, use_cache) - if self._device.type != "cpu": - dummy_inputs = recursive_to_device(value=dummy_inputs, device=self._device) - for _ in range(2): - self(**dummy_inputs) + inputs = prepare_jit_inputs(self.model, self.export_feature, False) + with torch.no_grad(): + self.model(**inputs) + self.model(**inputs) + logger.info("Warm up end") class IPEXModelForSequenceClassification(IPEXModel): @@ -426,98 +232,38 @@ class IPEXModelForImageClassification(IPEXModel): auto_model_class = AutoModelForImageClassification export_feature = "image-classification" - def forward( - self, - pixel_values: torch.Tensor, - **kwargs, - ): - inputs = { - "pixel_values": pixel_values, - } - - outputs = self._call_model(**inputs) - return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0]) - class IPEXModelForAudioClassification(IPEXModel): auto_model_class = AutoModelForAudioClassification export_feature = "audio-classification" - def forward( - self, - input_values: torch.Tensor, - attention_mask: torch.Tensor = None, - **kwargs, - ): - inputs = { - "input_values": input_values, - } - - if "attention_mask" in self.input_names: - inputs["attention_mask"] = attention_mask - - outputs = self._call_model(**inputs) - return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0]) - class IPEXModelForQuestionAnswering(IPEXModel): auto_model_class = AutoModelForQuestionAnswering export_feature = "question-answering" - def forward( - self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - token_type_ids: torch.Tensor = None, - **kwargs, - ): - inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - - if "token_type_ids" in self.input_names: - inputs["token_type_ids"] = token_type_ids - - outputs = self._call_model(**inputs) - start_logits = outputs["start_logits"] if isinstance(outputs, dict) else outputs[0] - end_logits = outputs["end_logits"] if isinstance(outputs, dict) else outputs[1] - return ModelOutput(start_logits=start_logits, end_logits=end_logits) - class IPEXModelForCausalLM(IPEXModel, GenerationMixin): auto_model_class = AutoModelForCausalLM export_feature = "text-generation" - _supports_cache_class = False - _is_stateful = False def __init__( self, model, config: PretrainedConfig = None, - export: bool = False, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, use_cache: bool = True, - warmup: bool = True, + warmup: Optional[bool] = True, **kwargs, ): - # Perform the initial warmup at the end of __init__ - super().__init__( - model, config, export=export, model_save_dir=model_save_dir, warmup=False, use_cache=use_cache - ) + super().__init__(model, config, model_save_dir=model_save_dir, warmup=False, use_cache=use_cache) + if self._add_patch: + self._supports_cache_class = True GenerationMixin.__init__(self) model_type = self.config.model_type.replace("_", "-") self.normalized_config = NormalizedConfigManager.get_normalized_config_class(model_type)(self.config) - self.use_cache = "past_key_values" in self.input_names - if isinstance(model, torch.jit.RecursiveScriptModule) and use_cache ^ self.use_cache: - raise ValueError( - f"`use_cache` was set to `{use_cache}` but the loaded model only supports `use_cache={self.use_cache}`. " - f"Please load your current model with `use_cache={self.use_cache}` or export the original model " - f"once again with `use_cache={use_cache}` when calling the `from_pretrained` method. " - "To export your model, simply set `export=True`." - ) self.config.is_decoder = True self.config.is_encoder_decoder = False @@ -529,146 +275,33 @@ def __init__( except AttributeError: self.model_cls = get_model_class(self.config, AutoModelForCausalLM._model_mapping) - if self._is_ipex_exported: - self._reorder_cache = _ipex_reorder_cache - else: - # Check if _reorder_cache is a static method - if "_reorder_cache" in self.model_cls.__dict__ and isinstance( - self.model_cls.__dict__["_reorder_cache"], staticmethod - ): - self._reorder_cache = self.model_cls._reorder_cache - elif "_reorder_cache" in self.model_cls.__dict__: - self._reorder_cache = self.model_cls._reorder_cache.__get__(self) - - if is_transformers_version(">=", "4.38.0") and model_type in { - "llama", - "phi", - "persimmon", - "mistral", - "falcon", - "gpt2", - }: - self.prepare_inputs_for_generation = _ipex_prepare_inputs_for_generation - else: - self.prepare_inputs_for_generation = self.model_cls.prepare_inputs_for_generation.__get__(self) - if hasattr(self.model_cls, "_convert_to_standard_cache"): self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache if hasattr(self.model_cls, "_convert_to_bloom_cache"): self._convert_to_bloom_cache = self.model_cls._convert_to_bloom_cache + if warmup: self._init_warmup() - def _prepare_past_key_values(self, input_ids): - model_type = self.config.model_type.replace("_", "-") - nb_pkv = 2 - num_layers = self.normalized_config.num_layers - d_k = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads - batch_size = input_ids.shape[0] - - if model_type in {"mistral", "llama", "falcon"}: - num_attention_heads = getattr(self.normalized_config, "num_key_value_heads", 1) - else: - num_attention_heads = self.normalized_config.num_attention_heads - - if self._is_ipex_exported: - # Indirect access kv cache has a different data layout compared with most transformers model, - # see https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/llm.html#indirect-access-kv-cache - beam_idx_tmp = torch.zeros( - (self.config.max_position_embeddings, input_ids.shape[0]), dtype=torch.long - ).contiguous() - past_key_values = tuple( - [ - ( - torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(), - torch.zeros([1, 1, 1, 1]).contiguous(), - torch.zeros([1, 1, 1, 1]).contiguous(), - beam_idx_tmp, - ) - for i in range(num_layers) - ] - ) - return past_key_values - elif model_type == "bloom" and is_transformers_version("<", "4.44"): - shape_key = (batch_size * num_attention_heads, d_k, 0) - shape_value = (batch_size * num_attention_heads, 0, d_k) - key = torch.empty(size=shape_key, dtype=self.model_dtype, device=self._device) - value = torch.empty(size=shape_value, dtype=self.model_dtype, device=self._device) - past_key_values = tuple( - tuple(key if idx % 2 == 0 else value for idx in range(nb_pkv)) for _ in range(num_layers) - ) - elif model_type.replace("-", "_") in MULTI_QUERY_ATTN_MODELS: - shape = (batch_size, 0, d_k * 2) - pkv = torch.empty(size=shape, dtype=self.model_dtype, device=self._device) - past_key_values = tuple(pkv for _ in range(num_layers)) - else: - shape = (batch_size, num_attention_heads, 0, d_k) - pkv = torch.empty(size=shape, dtype=self.model_dtype, device=self._device) - past_key_values = tuple(tuple(pkv for _ in range(nb_pkv)) for _ in range(num_layers)) - - return past_key_values - - # Temporary fix, will delete when https://github.com/huggingface/transformers/pull/31226 release. - def _get_initial_cache_position(self, input_ids, model_kwargs): - """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length""" - if not model_kwargs.get("use_cache", True): - model_kwargs["cache_position"] = None - return model_kwargs - - past_length = 0 - if "past_key_values" in model_kwargs: - past_length = model_kwargs["past_key_values"][0][0].shape[-2] - if "inputs_embeds" in model_kwargs: - cur_len = model_kwargs["inputs_embeds"].shape[1] - else: - cur_len = input_ids.shape[-1] - model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device) - return model_kwargs - + @torch.no_grad() def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.FloatTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - position_ids: Optional[torch.FloatTensor] = None, **kwargs, ) -> CausalLMOutputWithPast: - # 1. Prepare model inputs - if attention_mask is None: + if self.add_patch and input_ids is not None and attention_mask is None: attention_mask = torch.ones_like(input_ids) - - inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - - if "position_ids" in self.input_names or not self.input_names: - inputs["position_ids"] = position_ids - - if self.use_cache: - if past_key_values is None: - past_key_values = self._prepare_past_key_values(input_ids) - - inputs["past_key_values"] = past_key_values - - # 2. Model forward - outputs = self._call_model(**inputs) - - # 3. Process model outputs - if isinstance(outputs, (list, tuple)): - logits = outputs[0] - past_key_values = outputs[1] if self.use_cache else None - else: - logits = outputs["logits"] - past_key_values = outputs["past_key_values"] if self.use_cache else None - - return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values) + return self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs) def _prepare_generation_config( self, generation_config: Optional[GenerationConfig], **kwargs: Dict ) -> Tuple[GenerationConfig, Dict]: generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs) generation_method = generation_config.get_generation_mode().value + if self.compiled and generation_config.cache_implementation != "ipex_paged" and self._supports_static_cache: + # Use static cache for torch compile + generation_config.cache_implementation = "static" if generation_method not in _IPEX_EXPORTED_GENERATION_METHODS: raise ValueError( f"The generation method {generation_method} is not supported for IPEXModelForCausalLM for now, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}" @@ -676,15 +309,32 @@ def _prepare_generation_config( return generation_config, model_kwargs + def _reorder_cache(self, *args, **kwargs): + return self.model._reorder_cache(*args, **kwargs) + + def prepare_inputs_for_generation(self, *args, **kwargs): + return self.model.prepare_inputs_for_generation(*args, **kwargs) + def generate(self, *args, **kwargs): - if is_ipex_version("<", "2.4.0") and self._is_ipex_exported and kwargs.get("assistant_model", None): + if is_ipex_version("<", "2.4.0") and self._add_patch and kwargs.get("assistant_model", None): raise ValueError( f"Assisted decoding is not supported for patched models if ipex < 2.4, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}" ) - # Patch functions to support IAKV cache - if self._is_ipex_exported and kwargs.get("assistant_model", None): + # Patch functions to support ipex_paged cache + if self._add_patch: + transformers.generation.utils.NEED_SETUP_CACHE_CLASSES_MAPPING["ipex_paged"] = IPEXPagedCache + self.generation_config.cache_implementation = "ipex_paged" + if is_transformers_version(">=", "4.45.0"): + if "ipex_paged" not in transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS: + transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS.append("ipex_paged") + if kwargs.get("generation_config", None): + # Change cache implementation temporarily + orig_cache_implementation = kwargs["generation_config"].cache_implementation + kwargs["generation_config"].cache_implementation = "ipex_paged" + + if self._add_patch and kwargs.get("assistant_model", None): transformers.generation.utils._crop_past_key_values = _ipex_crop_past_key_values - elif self._is_ipex_exported: + elif self._add_patch: transformers.generation.candidate_generator._crop_past_key_values = _ipex_crop_past_key_values try: @@ -694,100 +344,100 @@ def generate(self, *args, **kwargs): transformers.generation.candidate_generator._crop_past_key_values = _crop_past_key_values raise e - if self._is_ipex_exported and kwargs.get("assistant_model", None): + if self._add_patch and kwargs.get("assistant_model", None): transformers.generation.utils._crop_past_key_values = _crop_past_key_values transformers.generation.candidate_generator._crop_past_key_values = _crop_past_key_values + # change back cache_implementation + if self._add_patch and kwargs.get("generation_config", None): + kwargs["generation_config"].cache_implementation = orig_cache_implementation + return result + def _init_warmup(self): + inputs = prepare_jit_inputs(self.model, self.export_feature, False) + self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4) + self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4) + logger.info("Warm up end") -def _ipex_prepare_inputs_for_generation( - input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs -): - from transformers.cache_utils import Cache - if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - -def _ipex_reorder_cache( - past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor -) -> Tuple[Tuple[torch.Tensor]]: - # Ipex patched model uses indirect access kv cache which has a different shape with other transformers models - if len(past_key_values[0]) == 4 and past_key_values[0][0].shape[-1] == 1: - for layer_past in past_key_values: - layer_past[3][layer_past[0].size(-2) - 1] = beam_idx - return past_key_values - elif len(past_key_values[0]) == 8: - for layer_past in past_key_values: - layer_past[3][layer_past[0].size(-2) - 1] = beam_idx - layer_past[7][layer_past[0].size(-2) - 1] = beam_idx - return past_key_values - else: - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past_key_values - ) +class IPEXModelForSeq2SeqLM(IPEXModel, GenerationMixin): + auto_model_class = AutoModelForSeq2SeqLM + export_feature = "text2text-generation" + + def __init__( + self, + model, + config: PretrainedConfig = None, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + use_cache: bool = True, + warmup: Optional[bool] = True, + **kwargs, + ): + super().__init__(model, config, model_save_dir=model_save_dir, warmup=False, use_cache=use_cache) + GenerationMixin.__init__(self) + + model_type = self.config.model_type.replace("_", "-") + self.normalized_config = NormalizedConfigManager.get_normalized_config_class(model_type)(self.config) + + self.config.is_decoder = False + self.config.is_encoder_decoder = True + + self.generation_config = GenerationConfig.from_model_config(self.config) + try: + self.model_cls = get_class_from_dynamic_module( + self.config.auto_map["AutoModelForSeq2SeqLM"], model_save_dir + ) + except AttributeError: + self.model_cls = get_model_class(self.config, AutoModelForSeq2SeqLM._model_mapping) + + if hasattr(self.model_cls, "_convert_to_standard_cache"): + self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache + + if warmup: + self._init_warmup() + + @torch.no_grad() + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + **kwargs, + ) -> CausalLMOutputWithPast: + return self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs) + + def _prepare_generation_config( + self, generation_config: Optional[GenerationConfig], **kwargs: Dict + ) -> Tuple[GenerationConfig, Dict]: + generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs) + # Use static cache for torch.compile + if self.compiled: + generation_config.cache_implementation = "static" + + return generation_config, model_kwargs + + def _reorder_cache(self, *args, **kwargs): + return self.model._reorder_cache(*args, **kwargs) + + def prepare_inputs_for_generation(self, *args, **kwargs): + return self.model.prepare_inputs_for_generation(*args, **kwargs) + + def get_encoder(self, *args, **kwargs): + return self.model.get_encoder(*args, **kwargs) + + def _init_warmup(self): + inputs = prepare_jit_inputs(self.model, self.export_feature, False) + self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4) + self.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=4) + logger.info("Warm up end") def _ipex_crop_past_key_values(model, past_key_values, max_length): if isinstance(model, IPEXModel) and _is_patched_with_ipex(model, "text-generation"): - new_past_key_values = [] - for i in range(len(past_key_values)): - pkv = [] - pkv.append(past_key_values[i][0][:, :max_length, :max_length, :]) - pkv += [past_key_values[i][_] for _ in range(1, 4)] - new_past_key_values.append(tuple(pkv)) - new_past_key_values = tuple(new_past_key_values) - return new_past_key_values + if isinstance(past_key_values, IPEXPagedCache): + # .crop is an inplace op, returns None + past_key_values = past_key_values.crop(max_length) + return past_key_values + else: + raise ValueError("only support IPEXPagedCache input now") return _crop_past_key_values(model, past_key_values, max_length) diff --git a/optimum/intel/ipex/modeling_sentence_transformers.py b/optimum/intel/ipex/modeling_sentence_transformers.py new file mode 100644 index 0000000000..8a4f3704c3 --- /dev/null +++ b/optimum/intel/ipex/modeling_sentence_transformers.py @@ -0,0 +1,98 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from sentence_transformers import SentenceTransformer +from sentence_transformers.models import Transformer +from sentence_transformers.models.Transformer import _save_pretrained_wrapper +from sentence_transformers.util import import_from_string +from transformers import MT5Config, T5Config +from transformers.dynamic_module_utils import get_class_from_dynamic_module + +from .modeling_base import IPEXModel + + +class IPEXTransformer(Transformer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.backend = "ipex" + + def _load_model(self, model_name_or_path, config, cache_dir, backend, **model_args) -> None: + self._load_ipex_model(model_name_or_path, config, cache_dir, **model_args) + + def _load_ipex_model(self, model_name_or_path, config, cache_dir, **model_args) -> None: + if isinstance(config, T5Config) or isinstance(config, MT5Config): + raise ValueError("T5 models are not yet supported by the IPEX backend.") + + export = model_args.pop("export", None) + + if export is None: + export = not getattr(config, "torchscript", False) + + load_path = Path(model_name_or_path) + is_local = load_path.exists() + + self.auto_model = IPEXModel.from_pretrained( + model_name_or_path, + config=config, + cache_dir=cache_dir, + export=export, + **model_args, + ) + + # Wrap the save_pretrained method to save the model in the correct subfolder + self.auto_model._save_pretrained = _save_pretrained_wrapper(self.auto_model._save_pretrained, "ipex") + + # Warn the user to save the model if they haven't already + if export: + self._backend_warn_to_save(model_name_or_path, is_local, "IPEX") + + +class IPEXSentenceTransformer(SentenceTransformer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.backend = "ipex" + + def _load_module_class_from_ref( + self, + class_ref: str, + model_name_or_path: str, + trust_remote_code: bool, + revision: Optional[str] = None, + model_kwargs: Optional[Dict[str, Any]] = None, + ) -> torch.nn.Module: + if class_ref.startswith("sentence_transformers."): + if class_ref == "sentence_transformers.models.Transformer": + class_ref = "optimum.intel.ipex.modeling_sentence_transformers.IPEXTransformer" + return import_from_string(class_ref) + + if trust_remote_code: + code_revision = model_kwargs.pop("code_revision", None) if model_kwargs else None + try: + return get_class_from_dynamic_module( + class_ref, + model_name_or_path, + revision=revision, + code_revision=code_revision, + ) + except OSError: + # Ignore the error if the file does not exist, and fall back to the default import + pass + + return import_from_string(class_ref) diff --git a/optimum/intel/ipex/utils.py b/optimum/intel/ipex/utils.py index 3d3feb3db2..23126bcd4c 100644 --- a/optimum/intel/ipex/utils.py +++ b/optimum/intel/ipex/utils.py @@ -16,6 +16,7 @@ _HEAD_TO_AUTOMODELS = { "feature-extraction": "IPEXModel", "text-generation": "IPEXModelForCausalLM", + "text2text-generation": "IPEXModelForSeq2SeqLM", "text-classification": "IPEXModelForSequenceClassification", "token-classification": "IPEXModelForTokenClassification", "question-answering": "IPEXModelForQuestionAnswering", diff --git a/optimum/intel/neural_compressor/quantization.py b/optimum/intel/neural_compressor/quantization.py index 6ca9fd661d..92e7fc57b9 100644 --- a/optimum/intel/neural_compressor/quantization.py +++ b/optimum/intel/neural_compressor/quantization.py @@ -374,22 +374,21 @@ def _weight_only_quantization( } low_cpu_mem_usage = True - if use_xpu: - try: - # TODO: if low_cpu_mem_uasge is True, gptj will have accuracy issue on CPU device. - model = model_class.from_pretrained( - model_id, low_cpu_mem_usage=low_cpu_mem_usage, device_map="cpu", **loading_kwargs - ) - except NotImplementedError: - logger.info( - "Failed to load models with `low_cpu_mem_usage=True`, will fall to traditional load method resulting in higher memory consumption." - ) - low_cpu_mem_usage = False - model = model_class.from_pretrained(model_id, low_cpu_mem_usage=low_cpu_mem_usage, **loading_kwargs) - quantization_config.update(**{"device": "xpu"}) - quantization_config.post_init_xpu() + + if getattr(quantization_config, "use_layer_wise", False): + if is_neural_compressor_version(">=", "3.2"): + from neural_compressor.torch import load_empty_model + + model = load_empty_model(model_id, cls=model_class, **loading_kwargs) + else: + raise ValueError("INC version must be >= 3.2 when use_layer_wise is set to True in quantization_config.") else: model = model_class.from_pretrained(model_id, low_cpu_mem_usage=low_cpu_mem_usage, **loading_kwargs) + + if use_xpu: + quantization_config.update(**{"device": "xpu"}) + quantization_config.post_init_xpu() + else: quantization_config.post_init_cpu() model.config.update({"low_cpu_mem_usage": low_cpu_mem_usage}) diff --git a/optimum/intel/neural_compressor/utils.py b/optimum/intel/neural_compressor/utils.py index c7a7ceda72..80c5f78528 100644 --- a/optimum/intel/neural_compressor/utils.py +++ b/optimum/intel/neural_compressor/utils.py @@ -30,10 +30,9 @@ CONFIG_NAME = "best_configure.yaml" QUANTIZATION_CONFIG_NAME = "quantize_config.json" +IPEX_MINIMUM_VERSION = "2.4.0" NEURAL_COMPRESSOR_MINIMUM_VERSION = "2.1.0" NEURAL_COMPRESSOR_WEIGHT_ONLY_MINIMUM_VERSION = "2.3.0" -IPEX_MINIMUM_VERSION = "2.3.1" - _HEAD_TO_AUTOMODELS = { "fill-mask": "INCModelForMaskedLM", diff --git a/optimum/intel/openvino/configuration.py b/optimum/intel/openvino/configuration.py index b34cd84cd0..4fdfe368a2 100644 --- a/optimum/intel/openvino/configuration.py +++ b/optimum/intel/openvino/configuration.py @@ -26,7 +26,7 @@ from optimum.configuration_utils import BaseConfig from ..utils.import_utils import is_nncf_available -from .utils import PREDEFINED_SD_DATASETS, PREDEFINED_VISUAL_LM_DATASETS +from .utils import PREDEFINED_SD_DATASETS, PREDEFINED_SPEECH_TO_TEXT_DATASETS, PREDEFINED_VISUAL_LM_DATASETS if is_nncf_available(): @@ -123,11 +123,18 @@ class OVQuantizationMethod(str, Enum): "mistralai/Mistral-7B-v0.1": {"bits": 4, "sym": True, "group_size": 128, "ratio": 0.9}, "baichuan-inc/Baichuan2-7B-Chat": { "bits": 4, - "sym": True, + "sym": False, "group_size": 128, "ratio": 0.8, + }, + "baichuan-inc/Baichuan2-13B-Chat": { + "bits": 4, + "sym": False, + "group_size": 128, + "ratio": 1.0, "dataset": "wikitext2", "quant_method": OVQuantizationMethod.AWQ, + "scale_estimation": True, }, "lmsys/longchat-7b-16k": { "bits": 4, @@ -255,6 +262,10 @@ def __init__( sym: bool = False, ignored_scope: Optional[dict] = None, num_samples: Optional[int] = None, + dataset: Optional[Optional[Union[str, List[str]]]] = None, + tokenizer: Optional[str] = None, + processor: Optional[str] = None, + trust_remote_code: bool = False, **kwargs, ): """ @@ -272,6 +283,10 @@ def __init__( self.bits = bits self.sym = sym self.num_samples = num_samples + self.dataset = dataset + self.tokenizer = tokenizer + self.processor = processor + self.trust_remote_code = trust_remote_code if isinstance(ignored_scope, nncf.IgnoredScope): ignored_scope = ignored_scope.__dict__ @@ -313,6 +328,10 @@ class OVWeightQuantizationConfig(OVQuantizationConfigBase): user or organization name, like `dbmdz/bert-base-german-cased`. - A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + trust_remote_code (`bool`, defaults to `False`): + Allows to use custom code for the modeling hosted in the model repository. This option should only be set + for repositories you trust and in which you have read the code, as it will execute on your local machine + arbitrary code present in the model repository. dataset (`str or List[str]`, *optional*): The dataset used for data-aware compression with NNCF. - For language models you can provide your own dataset in a list of strings or just use one from the list @@ -325,6 +344,8 @@ class OVWeightQuantizationConfig(OVQuantizationConfigBase): ratio (`float`, defaults to 1.0): The ratio between baseline and backup precisions (e.g. 0.9 means 90% of layers quantized to INT4_ASYM and the rest to INT8_ASYM). + Note: If dataset is provided, and the ratio is less than 1.0, then data-aware mixed precision assignment + will be applied. all_layers (`bool`, *optional*): Defines how many layers are compressed to 4-bits while the rest are kept in 8-bit precision. sensitivity_metric (`str`, *optional*): @@ -395,10 +416,16 @@ def __init__( backup_precision: Optional[str] = None, **kwargs, ): - super().__init__(bits=bits, sym=sym, ignored_scope=ignored_scope, num_samples=num_samples) - self.tokenizer = tokenizer - self.trust_remote_code = trust_remote_code - self.dataset = dataset + super().__init__( + bits=bits, + sym=sym, + ignored_scope=ignored_scope, + num_samples=num_samples, + dataset=dataset, + tokenizer=tokenizer, + processor=processor, + trust_remote_code=trust_remote_code, + ) self.group_size = group_size or (-1 if bits == 8 else 128) self.ratio = ratio self.all_layers = all_layers @@ -407,7 +434,6 @@ def __init__( self.scale_estimation = scale_estimation self.weight_format = weight_format self.gptq = gptq - self.processor = processor self.lora_correction = lora_correction self.backup_precision = backup_precision self.post_init() @@ -417,7 +443,7 @@ def post_init(self): Safety checker that arguments are correct """ super().post_init() - if self.ratio is not None and not (0 <= self.ratio <= 1): + if not (0 <= self.ratio <= 1): raise ValueError("`ratio` must between 0 and 1.") if self.group_size is not None and self.group_size != -1 and self.group_size <= 0: raise ValueError("`group_size` must be greater than 0 or equal to -1") @@ -437,6 +463,18 @@ def post_init(self): or {stable_diffusion_datasets} for diffusion models, but we found {self.dataset}""" ) + if self.dataset is not None and not ( + self.quant_method == OVQuantizationMethod.AWQ + or self.scale_estimation + or self.gptq + or self.lora_correction + or (self.ratio < 1.0 and self.sensitivity_metric != nncf.SensitivityMetric.WEIGHT_QUANTIZATION_ERROR) + ): + logger.warning( + "The provided dataset won't have any effect on the resulting compressed model because no data-aware " + "quantization algorithm is selected and compression ratio is 1.0." + ) + if self.bits not in [4, 8]: raise ValueError(f"Only support quantization to [4,8] bits but found {self.bits}") @@ -535,6 +573,11 @@ def __init__( model_type: str = "transformer", fast_bias_correction: bool = True, overflow_fix: str = "disable", + dataset: Optional[str] = None, + tokenizer: Optional[str] = None, + processor: Optional[str] = None, + trust_remote_code: bool = False, + smooth_quant_alpha: Optional[float] = None, **kwargs, ): """ @@ -557,11 +600,42 @@ def __init__( Whether to apply fast or full bias correction algorithm. overflow_fix (`str`, default to "disable"): Parameter for controlling overflow fix setting. + dataset (`str`, *optional*): + The dataset used for quantization. For text-to-speech model quantization the allowed value is 'librispeech'. + tokenizer (`str`, *optional*): + The tokenizer used to process the dataset. You can pass either: + - A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co. + Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a + user or organization name, like `dbmdz/bert-base-german-cased`. + - A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved + using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + processor (`str`, *optional*): + A transformers processor used to process inputs for multi-modal models. You can pass either: + - A string, the *model id* of a predefined processor hosted inside a model repo on huggingface.co. + - A path to a *directory* containing files required by the processor, for instance saved + using the [`~AutoProcessor.save_pretrained`] method, e.g., `./my_model_directory/`. + trust_remote_code (`bool`, defaults to `False`): + Allows to use custom code for the modeling hosted in the model repository. This option should only be set + for repositories you trust and in which you have read the code, as it will execute on your local machine + arbitrary code present in the model repository. + smooth_quant_alpha (`float`, *optional*): + SmoothQuant alpha parameter that improves the distribution of activations before MatMul layers and + reduces quantization error. """ - super().__init__(bits=bits, sym=sym, ignored_scope=ignored_scope, num_samples=num_samples) + super().__init__( + bits=bits, + sym=sym, + ignored_scope=ignored_scope, + num_samples=num_samples, + dataset=dataset, + tokenizer=tokenizer, + processor=processor, + trust_remote_code=trust_remote_code, + ) self.model_type = model_type self.fast_bias_correction = fast_bias_correction self.overflow_fix = overflow_fix + self.smooth_quant_alpha = smooth_quant_alpha self.post_init() def post_init(self): @@ -573,6 +647,18 @@ def post_init(self): if self.bits != 8: raise ValueError(f"Only support 8-bit for static quantization but found {self.bits}") + if self.dataset is not None: + if self.dataset not in PREDEFINED_SPEECH_TO_TEXT_DATASETS: + raise ValueError( + f"You have entered the following string value for dataset: {self.dataset}. But it is not supported." + f" Currently you can only choose {list(PREDEFINED_SPEECH_TO_TEXT_DATASETS.keys())}." + ) + + if self.smooth_quant_alpha is not None and not (0 <= self.smooth_quant_alpha <= 1): + raise ValueError( + f"SmoothQuant alpha parameter must be in range [0, 1], but found {self.smooth_quant_alpha}" + ) + class OVConfig(BaseConfig): CONFIG_NAME = "openvino_config.json" diff --git a/optimum/intel/openvino/modeling.py b/optimum/intel/openvino/modeling.py index b765d2f9f4..9ee35f3687 100644 --- a/optimum/intel/openvino/modeling.py +++ b/optimum/intel/openvino/modeling.py @@ -50,8 +50,6 @@ XVectorOutput, ) -from optimum.exporters import TasksManager - from ..utils.import_utils import is_timm_available, is_timm_version from .modeling_base import OVBaseModel from .utils import _is_timm_ov_dir @@ -695,7 +693,7 @@ class OVModelForCTC(OVModel): """ auto_model_class = AutoModelForCTC - export_feature = TasksManager.infer_task_from_model(auto_model_class) + export_feature = "automatic-speech-recognition" @add_start_docstrings_to_model_forward( AUDIO_INPUTS_DOCSTRING.format("batch_size, sequence_length") @@ -775,7 +773,7 @@ class OVModelForAudioXVector(OVModel): """ auto_model_class = AutoModelForAudioXVector - export_feature = TasksManager.infer_task_from_model(auto_model_class) + export_feature = "audio-xvector" @add_start_docstrings_to_model_forward( AUDIO_INPUTS_DOCSTRING.format("batch_size, sequence_length") @@ -851,7 +849,7 @@ class OVModelForAudioFrameClassification(OVModel): """ auto_model_class = AutoModelForAudioFrameClassification - export_feature = TasksManager.infer_task_from_model(auto_model_class) + export_feature = "audio-frame-classification" @add_start_docstrings_to_model_forward( AUDIO_INPUTS_DOCSTRING.format("batch_size, sequence_length") diff --git a/optimum/intel/openvino/modeling_base.py b/optimum/intel/openvino/modeling_base.py index d5d5666891..99422f1a54 100644 --- a/optimum/intel/openvino/modeling_base.py +++ b/optimum/intel/openvino/modeling_base.py @@ -30,7 +30,7 @@ from transformers.generation import GenerationMixin from transformers.utils import is_offline_mode -from optimum.exporters.onnx import OnnxConfig +from optimum.exporters.base import ExportConfig from optimum.modeling_base import FROM_PRETRAINED_START_DOCSTRING, OptimizedModel from ...exporters.openvino import export, main_export @@ -279,7 +279,6 @@ def _compile_model( compiled_model = core.compile_model(model, device.upper() if device is not None else device, config=ov_config) if "OPENVINO_LOG_LEVEL" in os.environ and int(os.environ["OPENVINO_LOG_LEVEL"]) > 2: - logger.info(f"{device if device is not None else 'AUTO'} SUPPORTED_PROPERTIES:") _print_compiled_model_properties(compiled_model) return compiled_model @@ -624,7 +623,7 @@ def _to_load( cls, model, config: PretrainedConfig, - onnx_config: OnnxConfig, + onnx_config: ExportConfig, token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, @@ -802,7 +801,6 @@ def _compile(self): self.request = core.compile_model(self.model, self._device, self.ov_config) # OPENVINO_LOG_LEVEL can be found in https://docs.openvino.ai/2023.2/openvino_docs_OV_UG_supported_plugins_AUTO_debugging.html if "OPENVINO_LOG_LEVEL" in os.environ and int(os.environ["OPENVINO_LOG_LEVEL"]) > 2: - logger.info(f"{self._device} SUPPORTED_PROPERTIES:") _print_compiled_model_properties(self.request) @property diff --git a/optimum/intel/openvino/modeling_diffusion.py b/optimum/intel/openvino/modeling_diffusion.py index 9c53994b8c..e86c5a8f02 100644 --- a/optimum/intel/openvino/modeling_diffusion.py +++ b/optimum/intel/openvino/modeling_diffusion.py @@ -979,7 +979,6 @@ def _compile(self): self.request = core.compile_model(self.model, self._device, self.ov_config) # OPENVINO_LOG_LEVEL can be found in https://docs.openvino.ai/2023.2/openvino_docs_OV_UG_supported_plugins_AUTO_debugging.html if "OPENVINO_LOG_LEVEL" in os.environ and int(os.environ["OPENVINO_LOG_LEVEL"]) > 2: - logger.info(f"{self._device} SUPPORTED_PROPERTIES:") _print_compiled_model_properties(self.request) def to(self, *args, device: Optional[str] = None, dtype: Optional[torch.dtype] = None): diff --git a/optimum/intel/openvino/modeling_seq2seq.py b/optimum/intel/openvino/modeling_seq2seq.py index 0ccf78a361..fa48430a77 100644 --- a/optimum/intel/openvino/modeling_seq2seq.py +++ b/optimum/intel/openvino/modeling_seq2seq.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import copy import logging import os from pathlib import Path @@ -35,7 +35,9 @@ from transformers.generation import GenerationMixin from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput +from .. import OVConfig, OVQuantizer from ..utils import is_transformers_version +from .configuration import OVQuantizationConfig, OVQuantizationConfigBase from .modeling_base_seq2seq import OVBaseModelForSeq2SeqLM from .utils import OV_TO_PT_TYPE, _print_compiled_model_properties @@ -550,7 +552,6 @@ def _compile(self): self.request = core.compile_model(self.model, self._device, ov_config) # OPENVINO_LOG_LEVEL can be found in https://docs.openvino.ai/2023.2/openvino_docs_OV_UG_supported_plugins_AUTO_debugging.html if "OPENVINO_LOG_LEVEL" in os.environ and int(os.environ["OPENVINO_LOG_LEVEL"]) > 2: - logger.info(f"{self._device} SUPPORTED_PROPERTIES:") _print_compiled_model_properties(self.request) @@ -691,7 +692,6 @@ def _compile(self): self.request = compiled_model.create_infer_request() # OPENVINO_LOG_LEVEL can be found in https://docs.openvino.ai/2023.2/openvino_docs_OV_UG_supported_plugins_AUTO_debugging.html if "OPENVINO_LOG_LEVEL" in os.environ and int(os.environ["OPENVINO_LOG_LEVEL"]) > 2: - logger.info(f"{self._device} SUPPORTED_PROPERTIES:") _print_compiled_model_properties(compiled_model) @@ -975,9 +975,25 @@ def _from_pretrained( cls, model_id: Union[str, Path], config: "PretrainedConfig", + load_in_8bit: bool = False, + quantization_config: Union[dict, OVQuantizationConfigBase] = None, **kwargs, ): - return super(OVModelForSpeechSeq2Seq, cls)._from_pretrained(model_id, config, **kwargs) + compile_only = kwargs.get("compile_only", False) + + if not compile_only and isinstance(quantization_config, OVQuantizationConfig): + model = super(OVModelForSpeechSeq2Seq, cls)._from_pretrained( + model_id, config, load_in_8bit=False, **kwargs + ) + quantization_config_copy = copy.deepcopy(quantization_config) + quantization_config_copy.processor = quantization_config.processor or model_id + OVQuantizer(model).quantize(ov_config=OVConfig(quantization_config=quantization_config_copy)) + else: + model = super(OVModelForSpeechSeq2Seq, cls)._from_pretrained( + model_id, config, load_in_8bit=load_in_8bit, quantization_config=quantization_config, **kwargs + ) + + return model class DummyWhisperModel: def __init__(self): diff --git a/optimum/intel/openvino/modeling_visual_language.py b/optimum/intel/openvino/modeling_visual_language.py index 41c879e481..77781f61c0 100644 --- a/optimum/intel/openvino/modeling_visual_language.py +++ b/optimum/intel/openvino/modeling_visual_language.py @@ -28,7 +28,6 @@ from ...exporters.openvino import main_export from ...exporters.openvino.stateful import ensure_stateful_is_available, model_has_input_output_name from ...exporters.openvino.utils import save_config -from .. import OVQuantizer from ..utils.import_utils import is_transformers_version from .configuration import OVConfig, OVWeightQuantizationConfig from .modeling_base import OVBaseModel, OVModelPart @@ -561,6 +560,8 @@ def _from_pretrained( ) if to_quantize: + from optimum.intel.openvino.quantization import OVQuantizer + quantization_config_copy = copy.deepcopy(quantization_config) quantization_config_copy.tokenizer = quantization_config.tokenizer or model_id potential_processor_id = config.mm_vision_tower if isinstance(model, _OVNanoLlavaForCausalLM) else model_id diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index c6a625bd7b..6f739e2543 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -59,7 +59,13 @@ is_diffusers_available, ) from ..utils.modeling_utils import get_model_device -from .configuration import OVConfig, OVQuantizationConfig, OVQuantizationMethod, OVWeightQuantizationConfig +from .configuration import ( + OVConfig, + OVQuantizationConfig, + OVQuantizationConfigBase, + OVQuantizationMethod, + OVWeightQuantizationConfig, +) from .modeling_base import OVBaseModel from .utils import ( MAX_ONNX_OPSET, @@ -67,6 +73,7 @@ ONNX_WEIGHTS_NAME, OV_XML_FILE_NAME, PREDEFINED_SD_DATASETS, + PREDEFINED_SPEECH_TO_TEXT_DATASETS, PREDEFINED_VISUAL_LM_DATASETS, ) @@ -319,6 +326,7 @@ def _quantize_ovbasemodel( remove_unused_columns: bool = True, **kwargs, ): + from optimum.intel.openvino.modeling_seq2seq import _OVModelForWhisper from optimum.intel.openvino.modeling_visual_language import OVModelForVisualCausalLM if is_diffusers_available(): @@ -344,7 +352,7 @@ def _quantize_ovbasemodel( data_collator=data_collator, ) if self.model.export_feature == "text-generation" and self.model.use_cache: - calibration_dataset = self._prepare_text_generation_dataset( + calibration_dataset = self._prepare_text_generation_calibration_data( quantization_config, calibration_dataloader ) else: @@ -357,31 +365,31 @@ def _quantize_ovbasemodel( f"`nncf.Dataset` or `datasets.Dataset`. Found: {type(calibration_dataset)}." ) - if isinstance(quantization_config, OVWeightQuantizationConfig): - if quantization_config.dataset is not None and calibration_dataset is not None: - logger.info( - "Both `quantization_config.dataset` and `calibration_dataset` were provided for weight only " - "quantization. Will rely on `calibration_dataset`." - ) - - if calibration_dataset is None and quantization_config.dataset is not None: - from optimum.intel import OVModelForCausalLM + if quantization_config.dataset is not None and calibration_dataset is not None: + logger.info( + "Both `quantization_config.dataset` and `calibration_dataset` were provided for weight only " + "quantization. Will rely on `calibration_dataset`." + ) - if isinstance(self.model, OVModelForCausalLM): - calibration_dataset = self._prepare_causal_lm_dataset(quantization_config) - elif isinstance(self.model, OVModelForVisualCausalLM): - calibration_dataset = self._prepare_visual_causal_lm_dataset(quantization_config) - elif is_diffusers_available() and isinstance(self.model, OVDiffusionPipeline): - if not isinstance(quantization_config.dataset, str): - raise ValueError("Please provide dataset as one of the accepted dataset labels.") - calibration_dataset = self._prepare_unet_dataset( - quantization_config.num_samples, dataset_name=quantization_config.dataset - ) - else: - raise ValueError( - f"Can't create weight compression calibration dataset from string for {type(self.model)}" - ) + if calibration_dataset is None and quantization_config.dataset is not None: + from optimum.intel import OVModelForCausalLM + + if isinstance(self.model, OVModelForCausalLM): + calibration_dataset = self._prepare_causal_lm_calibration_data(quantization_config) + elif isinstance(self.model, OVModelForVisualCausalLM): + calibration_dataset = self._prepare_visual_causal_lm_calibration_data(quantization_config) + elif isinstance(self.model, _OVModelForWhisper): + calibration_dataset = self._prepare_speech_to_text_calibration_data(quantization_config) + elif is_diffusers_available() and isinstance(self.model, OVDiffusionPipeline): + if not isinstance(quantization_config.dataset, str): + raise ValueError("Please provide dataset as one of the accepted dataset labels.") + calibration_dataset = self._prepare_unet_dataset( + quantization_config.num_samples, dataset_name=quantization_config.dataset + ) + else: + raise ValueError(f"Can't create quantization calibration dataset from string for {type(self.model)}") + if isinstance(quantization_config, OVWeightQuantizationConfig): if quantization_config.quant_method == OVQuantizationMethod.HYBRID: if calibration_dataset is None: raise ValueError("Calibration dataset is required to run hybrid quantization.") @@ -399,22 +407,24 @@ def _quantize_ovbasemodel( ] sub_models = filter(lambda x: x, (getattr(self.model, name) for name in sub_model_names)) for sub_model in sub_models: - _weight_only_quantization(sub_model.model, quantization_config_copy) + _weight_only_quantization(sub_model.model, quantization_config_copy, **kwargs) if self.model.unet is not None: # Apply hybrid quantization to UNet self.model.unet.model = _hybrid_quantization( - self.model.unet.model, quantization_config, calibration_dataset + self.model.unet.model, quantization_config, calibration_dataset, **kwargs ) else: self.model.transformer.model = _hybrid_quantization( - self.model.transformer.model, quantization_config, calibration_dataset + self.model.transformer.model, quantization_config, calibration_dataset, **kwargs ) self.model.clear_requests() else: # The model may be for example OVModelForImageClassification, OVModelForAudioClassification, etc. - self.model.model = _hybrid_quantization(self.model.model, quantization_config, calibration_dataset) + self.model.model = _hybrid_quantization( + self.model.model, quantization_config, calibration_dataset, **kwargs + ) self.model.request = None else: if is_diffusers_available() and isinstance(self.model, OVDiffusionPipeline): @@ -429,47 +439,36 @@ def _quantize_ovbasemodel( ] sub_models = filter(lambda x: x, (getattr(self.model, name) for name in sub_model_names)) for sub_model in sub_models: - _weight_only_quantization(sub_model.model, quantization_config) + _weight_only_quantization(sub_model.model, quantization_config, **kwargs) self.model.clear_requests() elif isinstance(self.model, OVModelForVisualCausalLM): language_model = self.model.language_model - _weight_only_quantization(language_model.model, quantization_config, calibration_dataset) + _weight_only_quantization(language_model.model, quantization_config, calibration_dataset, **kwargs) sub_model_names = ["vision_embeddings", "text_embeddings"] + self.model.additional_parts sub_models = [getattr(self.model, f"{name}_model") for name in sub_model_names] for sub_model in sub_models: - _weight_only_quantization(sub_model, OVWeightQuantizationConfig(bits=8, sym=True)) + _weight_only_quantization(sub_model, OVWeightQuantizationConfig(bits=8, sym=True), **kwargs) self.model.clear_requests() else: - _weight_only_quantization(self.model.model, quantization_config, calibration_dataset) + _weight_only_quantization(self.model.model, quantization_config, calibration_dataset, **kwargs) self.model.request = None - if save_directory is not None: - self.model.save_pretrained(save_directory) - ov_config.save_pretrained(save_directory) - return + else: + if not isinstance(quantization_config, OVQuantizationConfig): + raise ValueError(f"Unsupported type of quantization config: {type(quantization_config)}") - if not isinstance(quantization_config, OVQuantizationConfig): - raise ValueError(f"Unsupported type of quantization config: {type(quantization_config)}") - - if calibration_dataset is None: - raise ValueError("Calibration dataset is required to run quantization.") - - # Actual model quantization - quantized_model = nncf.quantize( - self.model.model, - calibration_dataset, - subset_size=quantization_config.num_samples, - ignored_scope=quantization_config.get_ignored_scope_instance(), - model_type=nncf.ModelType(quantization_config.model_type), - preset=nncf.QuantizationPreset.PERFORMANCE if quantization_config.sym else nncf.QuantizationPreset.MIXED, - fast_bias_correction=quantization_config.fast_bias_correction, - advanced_parameters=nncf.AdvancedQuantizationParameters( - overflow_fix=OverflowFix(quantization_config.overflow_fix) - ), - **kwargs, - ) + if calibration_dataset is None: + raise ValueError("Calibration dataset is required to run quantization.") + + # Quantize model(s) + if isinstance(self.model, _OVModelForWhisper): + self._quantize_whisper_model(quantization_config, calibration_dataset, **kwargs) + else: + quantized_model = _full_quantization( + self.model.model, quantization_config, calibration_dataset, **kwargs + ) + self.model.model = quantized_model + self.model.request = None - self.model.model = quantized_model - self.model.request = None if save_directory is not None: self.model.save_pretrained(save_directory) ov_config.save_pretrained(save_directory) @@ -725,7 +724,7 @@ def _remove_unused_columns(self, dataset: "Dataset"): ignored_columns = list(set(dataset.column_names) - set(self._signature_columns)) return dataset.remove_columns(ignored_columns) - def _prepare_causal_lm_dataset(self, quantization_config: OVWeightQuantizationConfig): + def _prepare_causal_lm_calibration_data(self, quantization_config: OVQuantizationConfigBase): from optimum.gptq.data import get_dataset, prepare_dataset tokenizer = AutoTokenizer.from_pretrained( @@ -748,7 +747,7 @@ def _prepare_causal_lm_dataset(self, quantization_config: OVWeightQuantizationCo return calibration_dataset - def _prepare_visual_causal_lm_dataset(self, config: OVWeightQuantizationConfig): + def _prepare_visual_causal_lm_calibration_data(self, config: OVQuantizationConfigBase): dataset_name = config.dataset if dataset_name not in PREDEFINED_VISUAL_LM_DATASETS: raise ValueError( @@ -770,8 +769,8 @@ def _prepare_visual_causal_lm_dataset(self, config: OVWeightQuantizationConfig): tokenizer = None dataset_metadata = PREDEFINED_VISUAL_LM_DATASETS[dataset_name] - dataset = datasets.load_dataset(dataset_metadata["name"], split=dataset_metadata["split"]).shuffle(seed=0) - num_samples = min(config.num_samples or 128, len(dataset)) + dataset = datasets.load_dataset(dataset_metadata["id"], split=dataset_metadata["split"]).shuffle(seed=0) + num_samples = min(config.num_samples or 32, len(dataset)) dataset = islice(dataset, num_samples) calibration_dataset = [] @@ -809,8 +808,75 @@ def _prepare_visual_causal_lm_dataset(self, config: OVWeightQuantizationConfig): calibration_dataset = nncf.Dataset(calibration_dataset) return calibration_dataset - def _prepare_text_generation_dataset( - self, quantization_config: OVQuantizationConfig, calibration_dataloader: OVDataLoader + def _prepare_speech_to_text_calibration_data(self, config: OVQuantizationConfigBase): + if not is_datasets_available(): + raise ValueError(DATASETS_IMPORT_ERROR.format("OVQuantizer._prepare_whisper_calibration_data")) + + from datasets import load_dataset + + encoder_calibration_data = [] + encoder_model = self.model.encoder + encoder_model._compile() + encoder_model.request = InferRequestWrapper( + encoder_model.request, encoder_calibration_data, apply_caching=True + ) + + decoder_calibration_data = [] + decoder_model = self.model.decoder + decoder_model._compile() + decoder_model.request = InferRequestWrapper( + decoder_model.request, decoder_calibration_data, apply_caching=True + ) + + decoder_w_p_calibration_data = [] + decoder_w_p_model = self.model.decoder_with_past + decoder_w_p_model._compile() + decoder_w_p_model.request = InferRequestWrapper( + decoder_w_p_model.request, decoder_w_p_calibration_data, apply_caching=True + ) + + dataset_metadata = PREDEFINED_SPEECH_TO_TEXT_DATASETS[config.dataset] + + processor = AutoProcessor.from_pretrained(config.processor) + + try: + dataset = load_dataset( + dataset_metadata["id"], + dataset_metadata["name"], + split=dataset_metadata["split"], + streaming=True, + trust_remote_code=config.trust_remote_code, + ) + num_samples = config.num_samples or 128 + + audio_inputs = [] + # Download audio inputs beforehand to avoid possible connection issues + for item in tqdm(islice(dataset, num_samples), desc="Downloading audio inputs", total=num_samples): + audio = item + for key_name in dataset_metadata["inputs"]["audio"]: + audio = audio[key_name] + + sampling_rate = item + for key_name in dataset_metadata["inputs"]["sampling_rate"]: + sampling_rate = sampling_rate[key_name] + audio_inputs.append((audio, sampling_rate)) + + for audio, sampling_rate in tqdm(audio_inputs, desc="Collecting calibration data"): + input_features = processor(audio, sampling_rate=sampling_rate, return_tensors="pt").input_features + self.model.generate(input_features) + finally: + encoder_model.request = encoder_model.request.request + decoder_model.request = decoder_model.request.request + decoder_w_p_model.request = decoder_w_p_model.request.request + + return ( + nncf.Dataset(encoder_calibration_data), + nncf.Dataset(decoder_calibration_data), + nncf.Dataset(decoder_w_p_calibration_data), + ) + + def _prepare_text_generation_calibration_data( + self, quantization_config: OVQuantizationConfigBase, calibration_dataloader: OVDataLoader ) -> nncf.Dataset: # Prefetch past_key_values self.model.update_pkv_precision(True) @@ -898,11 +964,44 @@ def transform_fn(data_item): calibration_dataset = nncf.Dataset(calibration_data[:num_samples]) return calibration_dataset + def _quantize_whisper_model(self, quantization_config, calibration_dataset, **kwargs): + # Quantize encoder model + # quantization_config.num_samples of audio samples result in more actual model inputs + config = copy.deepcopy(quantization_config) + config.num_samples = calibration_dataset[0].get_length() + quantized_encoder_model = _full_quantization( + self.model.encoder_model, config, calibration_dataset[0], **kwargs + ) + self.model.encoder_model = quantized_encoder_model + self.model.encoder.model = quantized_encoder_model + self.model.encoder.request = None + + # Quantize decoder model + config = copy.deepcopy(quantization_config) + config.num_samples = calibration_dataset[1].get_length() + quantized_decoder_model = _full_quantization( + self.model.decoder_model, config, calibration_dataset[1], **kwargs + ) + self.model.decoder_model = quantized_decoder_model + self.model.decoder.model = quantized_decoder_model + self.model.decoder.request = None + + # Quantize decoder with past model + config = copy.deepcopy(quantization_config) + config.num_samples = calibration_dataset[2].get_length() + quantized_decoder_w_p_model = _full_quantization( + self.model.decoder_with_past_model, config, calibration_dataset[2], **kwargs + ) + self.model.decoder_with_past_model = quantized_decoder_w_p_model + self.model.decoder_with_past.model = quantized_decoder_w_p_model + self.model.decoder_with_past.request = None + def _weight_only_quantization( model: openvino.runtime.Model, quantization_config: Union[OVWeightQuantizationConfig, Dict], calibration_dataset: Optional[Union[nncf.Dataset, Iterable]] = None, + **kwargs, ) -> openvino.runtime.Model: config = quantization_config if isinstance(config, dict): @@ -950,9 +1049,40 @@ def _weight_only_quantization( gptq=config.gptq, lora_correction=config.lora_correction, backup_mode=None if config.backup_precision is None else nncf.BackupMode(config.backup_precision), + **kwargs, ) +def _full_quantization( + model: openvino.runtime.Model, + quantization_config: OVQuantizationConfig, + calibration_dataset: nncf.Dataset, + **kwargs, +): + advanced_parameters_kwargs = {} + if quantization_config.smooth_quant_alpha is not None: + advanced_parameters_kwargs["smooth_quant_alphas"] = AdvancedSmoothQuantParameters( + matmul=quantization_config.smooth_quant_alpha + ) + + quantized_model = nncf.quantize( + model, + calibration_dataset, + subset_size=quantization_config.num_samples, + ignored_scope=quantization_config.get_ignored_scope_instance(), + model_type=nncf.ModelType(quantization_config.model_type), + preset=nncf.QuantizationPreset.PERFORMANCE if quantization_config.sym else nncf.QuantizationPreset.MIXED, + fast_bias_correction=quantization_config.fast_bias_correction, + advanced_parameters=nncf.AdvancedQuantizationParameters( + overflow_fix=OverflowFix(quantization_config.overflow_fix), + **advanced_parameters_kwargs, + ), + **kwargs, + ) + + return quantized_model + + def _get_operation_const_op(operation, const_port_id: int): node = operation.input_value(const_port_id).get_node() queue = deque([node]) @@ -999,7 +1129,7 @@ def _collect_ops_with_weights(model): def _hybrid_quantization( - model: openvino.runtime.Model, quantization_config: OVWeightQuantizationConfig, dataset: nncf.Dataset + model: openvino.runtime.Model, quantization_config: OVWeightQuantizationConfig, dataset: nncf.Dataset, **kwargs ) -> openvino.runtime.Model: """ Quantize a model in hybrid mode with NNCF which means that we quantize: @@ -1020,8 +1150,10 @@ def _hybrid_quantization( wc_config = copy.deepcopy(quantization_config) wc_config.ignored_scope = wc_config.ignored_scope or {} - wc_config.ignored_scope["types"] = wc_config.ignored_scope.get("types", []) + ["Convolution"] - compressed_model = _weight_only_quantization(model, wc_config) + + wc_ignored_types = ["Convolution"] if any(op.get_type_name() == "Convolution" for op in model.get_ops()) else [] + wc_config.ignored_scope["types"] = wc_config.ignored_scope.get("types", []) + wc_ignored_types + compressed_model = _weight_only_quantization(model, wc_config, **kwargs) ptq_ignored_scope = quantization_config.get_ignored_scope_instance() ptq_ignored_scope.names += ops_to_compress @@ -1037,5 +1169,6 @@ def _hybrid_quantization( smooth_quant_alphas=AdvancedSmoothQuantParameters(matmul=-1) ), subset_size=subset_size, + **kwargs, ) return quantized_model diff --git a/optimum/intel/openvino/utils.py b/optimum/intel/openvino/utils.py index 1ba740c3ba..e54503e83d 100644 --- a/optimum/intel/openvino/utils.py +++ b/optimum/intel/openvino/utils.py @@ -142,12 +142,21 @@ PREDEFINED_VISUAL_LM_DATASETS = { "contextual": { - "name": "ucla-contextual/contextual_test", + "id": "ucla-contextual/contextual_test", "split": "test", "inputs": {"image_url": "image_url", "instruction": "instruction"}, } } +PREDEFINED_SPEECH_TO_TEXT_DATASETS = { + "librispeech": { + "id": "openslr/librispeech_asr", + "name": "clean", + "split": "validation", + "inputs": {"audio": ("audio", "array"), "sampling_rate": ("audio", "sampling_rate")}, + } +} + NEED_CONVERT_TO_FAST_TOKENIZER: Tuple[Type[PreTrainedTokenizer]] = (CLIPTokenizer,) @@ -200,6 +209,8 @@ def _is_timm_ov_dir(model_dir): def _print_compiled_model_properties(compiled_model): + cur_log_level = logger.getEffectiveLevel() + logger.setLevel(logging.INFO) supported_properties = properties.supported_properties() skip_keys = {"SUPPORTED_METRICS", "SUPPORTED_CONFIG_KEYS", supported_properties} keys = set(compiled_model.get_property(supported_properties)) - skip_keys @@ -222,6 +233,7 @@ def _print_compiled_model_properties(compiled_model): logger.info(f" {device}: {Core().get_property(device, 'FULL_DEVICE_NAME')}") except Exception: logger.error("[error] Get FULL_DEVICE_NAME failed") + logger.setLevel(cur_log_level) def np_to_pt_generators(np_object, device): diff --git a/optimum/intel/pipelines/pipeline_base.py b/optimum/intel/pipelines/pipeline_base.py index d26d8c42b6..04390ba3b1 100644 --- a/optimum/intel/pipelines/pipeline_base.py +++ b/optimum/intel/pipelines/pipeline_base.py @@ -58,6 +58,7 @@ IPEXModelForImageClassification, IPEXModelForMaskedLM, IPEXModelForQuestionAnswering, + IPEXModelForSeq2SeqLM, IPEXModelForSequenceClassification, IPEXModelForTokenClassification, ) @@ -69,6 +70,24 @@ "default": "gpt2", "type": "text", }, + "summarization": { + "impl": SummarizationPipeline, + "class": (IPEXModelForSeq2SeqLM,), + "default": "t5-base", + "type": "text", + }, + "translation": { + "impl": TranslationPipeline, + "class": (IPEXModelForSeq2SeqLM,), + "default": "t5-small", + "type": "text", + }, + "text2text-generation": { + "impl": Text2TextGenerationPipeline, + "class": (IPEXModelForSeq2SeqLM,), + "default": "t5-small", + "type": "text", + }, "fill-mask": { "impl": FillMaskPipeline, "class": (IPEXModelForMaskedLM,), @@ -246,6 +265,7 @@ def load_ipex_model( SUPPORTED_TASKS, hub_kwargs: Optional[Dict[str, Any]] = None, model_kwargs: Optional[Dict[str, Any]] = None, + device_map: Optional[torch.device] = None, ): hub_kwargs = hub_kwargs or {} model_kwargs = model_kwargs or {} @@ -253,7 +273,9 @@ def load_ipex_model( if model is None: model_id = SUPPORTED_TASKS[targeted_task]["default"] - model = ipex_model_class.from_pretrained(model_id, export=True, **hub_kwargs, **model_kwargs) + model = ipex_model_class.from_pretrained( + model_id, export=True, **hub_kwargs, **model_kwargs, device_map=device_map + ) elif isinstance(model, str): model_id = model try: @@ -262,7 +284,9 @@ def load_ipex_model( except RuntimeError: logger.warning("We will use IPEXModel with export=True to export the model") export = True - model = ipex_model_class.from_pretrained(model, export=export, **hub_kwargs, **model_kwargs) + model = ipex_model_class.from_pretrained( + model, export=export, **hub_kwargs, **model_kwargs, device_map=device_map + ) elif isinstance(model, IPEXModel): model_id = getattr(model.config, "name_or_path", None) else: diff --git a/optimum/intel/utils/dummy_ipex_objects.py b/optimum/intel/utils/dummy_ipex_objects.py index 4bd7eee630..7c1922305b 100644 --- a/optimum/intel/utils/dummy_ipex_objects.py +++ b/optimum/intel/utils/dummy_ipex_objects.py @@ -70,6 +70,17 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["ipex"]) +class IPEXModelForSeq2SeqLM(metaclass=DummyObject): + _backends = ["ipex"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["ipex"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["ipex"]) + + class IPEXModelForQuestionAnswering(metaclass=DummyObject): _backends = ["ipex"] @@ -101,3 +112,14 @@ def __init__(self, *args, **kwargs): @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["ipex"]) + + +class IPEXSentenceTransformer(metaclass=DummyObject): + _backends = ["ipex", "sentence_transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["ipex", "sentence_transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["ipex", "sentence_transformers"]) diff --git a/optimum/intel/version.py b/optimum/intel/version.py index 16bf124e0e..5677397928 100644 --- a/optimum/intel/version.py +++ b/optimum/intel/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "1.21.0.dev0" +__version__ = "1.22.0.dev0" diff --git a/setup.py b/setup.py index fde9150143..d9b3b8642b 100644 --- a/setup.py +++ b/setup.py @@ -56,6 +56,8 @@ "sentence-transformers", "open_clip_torch>=2.26.1", "peft", + "datasets[audio]>=1.4.0", + "tbb", ] QUALITY_REQUIRE = ["black~=23.1", "ruff==0.4.4"] @@ -64,7 +66,7 @@ "nncf": ["nncf>=2.14.0"], "openvino": ["nncf>=2.14.0", "openvino>=2024.5.0", "openvino-tokenizers>=2024.5.0"], "neural-compressor": ["neural-compressor[pt]>3.0", "accelerate", "transformers<4.46"], - "ipex": ["intel-extension-for-pytorch", "transformers>=4.39,<4.45"], + "ipex": ["intel-extension-for-pytorch>=2.4", "transformers>4.45,<4.47", "accelerate"], "diffusers": ["diffusers"], "quality": QUALITY_REQUIRE, "tests": TESTS_REQUIRE, diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index 53c733c4f5..419e1bb42a 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -17,7 +17,7 @@ import tempfile import time import unittest - +import os import numpy as np import requests import torch @@ -26,6 +26,7 @@ from transformers import ( AutoFeatureExtractor, AutoModelForCausalLM, + AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering, AutoTokenizer, GenerationConfig, @@ -33,23 +34,29 @@ pipeline, set_seed, ) - from optimum.intel import ( IPEXModel, IPEXModelForAudioClassification, IPEXModelForCausalLM, + IPEXModelForSeq2SeqLM, IPEXModelForImageClassification, IPEXModelForMaskedLM, IPEXModelForQuestionAnswering, IPEXModelForSequenceClassification, IPEXModelForTokenClassification, + IPEXSentenceTransformer, ) -from optimum.intel.utils.import_utils import is_ipex_version -from optimum.utils.testing_utils import grid_parameters -from utils_tests import MODEL_NAMES +from optimum.utils.testing_utils import grid_parameters, require_sentence_transformers +from optimum.intel.utils.import_utils import is_sentence_transformers_available, is_torch_version + +if is_sentence_transformers_available(): + from sentence_transformers import SentenceTransformer +from utils_tests import MODEL_NAMES, IS_XPU_AVAILABLE SEED = 42 +torch.use_deterministic_algorithms(True) +DEVICE = "xpu:0" if IS_XPU_AVAILABLE else "cpu" class Timer(object): @@ -74,17 +81,20 @@ class IPEXModelTest(unittest.TestCase): "squeezebert", "xlm", ) + IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("bert",) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_transformers(self, model_arch): model_id = MODEL_NAMES[model_arch] set_seed(SEED) - ipex_model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, export=True) + ipex_model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, device_map=DEVICE) + if model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES: + self.assertTrue(ipex_model.add_patch) self.assertIsInstance(ipex_model.config, PretrainedConfig) - transformers_model = self.IPEX_MODEL_CLASS.auto_model_class.from_pretrained(model_id) + transformers_model = self.IPEX_MODEL_CLASS.auto_model_class.from_pretrained(model_id, device_map=DEVICE) tokenizer = AutoTokenizer.from_pretrained(model_id) inputs = "This is a sample input" - tokens = tokenizer(inputs, return_tensors="pt") + tokens = tokenizer(inputs, return_tensors="pt").to(DEVICE) with torch.no_grad(): transformers_outputs = transformers_model(**tokens) outputs = ipex_model(**tokens) @@ -92,24 +102,23 @@ def test_compare_to_transformers(self, model_arch): # Test re-load model with tempfile.TemporaryDirectory() as tmpdirname: ipex_model.save_pretrained(tmpdirname) - loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname) + loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname, device_map=DEVICE) loaded_model_outputs = loaded_model(**tokens) # Test init method - init_model = self.IPEX_MODEL_CLASS(transformers_model, export=True) + init_model = self.IPEX_MODEL_CLASS(transformers_model) init_model_outputs = init_model(**tokens) - self.assertIsInstance(init_model.model, torch.jit.RecursiveScriptModule) # Compare tensor outputs for output_name in {"logits", "last_hidden_state"}: if output_name in transformers_outputs: - self.assertTrue(torch.allclose(outputs[output_name], transformers_outputs[output_name], atol=1e-4)) + self.assertTrue(torch.allclose(outputs[output_name], transformers_outputs[output_name], atol=1e-3)) self.assertTrue(torch.allclose(outputs[output_name], loaded_model_outputs[output_name])) self.assertTrue(torch.allclose(outputs[output_name], init_model_outputs[output_name])) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): model_id = MODEL_NAMES[model_arch] - model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, export=True) + model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, device_map=DEVICE) tokenizer = AutoTokenizer.from_pretrained(model_id) pipe = pipeline(self.IPEX_MODEL_CLASS.export_feature, model=model, tokenizer=tokenizer) text = "This restaurant is awesome" @@ -144,12 +153,12 @@ class IPEXModelForQuestionAnsweringTest(unittest.TestCase): def test_compare_to_transformers(self, model_arch): model_id = MODEL_NAMES[model_arch] set_seed(SEED) - ipex_model = IPEXModelForQuestionAnswering.from_pretrained(model_id, export=True) + ipex_model = IPEXModelForQuestionAnswering.from_pretrained(model_id, device_map=DEVICE) self.assertIsInstance(ipex_model.config, PretrainedConfig) - transformers_model = AutoModelForQuestionAnswering.from_pretrained(model_id) + transformers_model = AutoModelForQuestionAnswering.from_pretrained(model_id, device_map=DEVICE) tokenizer = AutoTokenizer.from_pretrained(model_id) inputs = "This is a sample input" - tokens = tokenizer(inputs, return_tensors="pt") + tokens = tokenizer(inputs, return_tensors="pt").to(DEVICE) with torch.no_grad(): transformers_outputs = transformers_model(**tokens) outputs = ipex_model(**tokens) @@ -157,13 +166,12 @@ def test_compare_to_transformers(self, model_arch): # Test re-load model with tempfile.TemporaryDirectory() as tmpdirname: ipex_model.save_pretrained(tmpdirname) - loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname) + loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname, device_map=DEVICE) loaded_model_outputs = loaded_model(**tokens) # Test init method - init_model = self.IPEX_MODEL_CLASS(transformers_model, export=True) + init_model = self.IPEX_MODEL_CLASS(transformers_model) init_model_outputs = init_model(**tokens) - self.assertIsInstance(init_model.model, torch.jit.RecursiveScriptModule) self.assertIn("start_logits", outputs) self.assertIn("end_logits", outputs) @@ -178,7 +186,7 @@ def test_compare_to_transformers(self, model_arch): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): model_id = MODEL_NAMES[model_arch] - model = IPEXModelForQuestionAnswering.from_pretrained(model_id, export=True) + model = IPEXModelForQuestionAnswering.from_pretrained(model_id, device_map=DEVICE) tokenizer = AutoTokenizer.from_pretrained(model_id) pipe = pipeline("question-answering", model=model, tokenizer=tokenizer) question = "What's my name?" @@ -188,15 +196,16 @@ def test_pipeline(self, model_arch): self.assertGreaterEqual(outputs["score"], 0.0) self.assertIsInstance(outputs["answer"], str) - @unittest.skipIf(is_ipex_version("<", "2.3.0"), reason="Only ipex version > 2.3.0 supports ipex model patching") def test_patched_model(self): ipex_model = IPEXModelForQuestionAnswering.from_pretrained( - "Jiqing/patched_tiny_random_bert_for_question_answering" + "Intel/tiny-random-bert_ipex_model", device_map=DEVICE + ) + transformers_model = AutoModelForQuestionAnswering.from_pretrained( + "hf-internal-testing/tiny-random-bert", device_map=DEVICE ) - transformers_model = AutoModelForQuestionAnswering.from_pretrained("hf-internal-testing/tiny-random-bert") tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert") inputs = "This is a sample input" - tokens = tokenizer(inputs, return_tensors="pt") + tokens = tokenizer(inputs, return_tensors="pt").to(DEVICE) with torch.no_grad(): transformers_outputs = transformers_model(**tokens) outputs = ipex_model(**tokens) @@ -225,7 +234,7 @@ class IPEXModelForCausalLMTest(unittest.TestCase): "mpt", "opt", ) - IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("llama2", "distilgpt2", "falcon") + IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("llama2", "falcon", "gpt2") GENERATION_LENGTH = 100 SPEEDUP_CACHE = 1.0 @@ -233,35 +242,65 @@ class IPEXModelForCausalLMTest(unittest.TestCase): def test_compare_to_transformers(self, model_arch): model_id = MODEL_NAMES[model_arch] set_seed(SEED) - ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, export=True) + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 + ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=DEVICE) self.assertIsInstance(ipex_model.config, PretrainedConfig) - self.assertTrue(ipex_model.use_cache) tokenizer = AutoTokenizer.from_pretrained(model_id) tokens = tokenizer( "This is a sample", return_tensors="pt", return_token_type_ids=False if model_arch in ("llama", "llama2") else None, - ) + ).to(DEVICE) inputs = ipex_model.prepare_inputs_for_generation(**tokens) outputs = ipex_model(**inputs) self.assertIsInstance(outputs.logits, torch.Tensor) - self.assertIsInstance(outputs.past_key_values, (tuple, list)) - transformers_model = AutoModelForCausalLM.from_pretrained(model_id) + transformers_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=DEVICE) with torch.no_grad(): transformers_outputs = transformers_model(**tokens) # Test re-load model with tempfile.TemporaryDirectory() as tmpdirname: ipex_model.save_pretrained(tmpdirname) - loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname) + loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname, torch_dtype=dtype, device_map=DEVICE) loaded_model_outputs = loaded_model(**inputs) # Test init method - init_model = self.IPEX_MODEL_CLASS(transformers_model, export=True) + init_model = self.IPEX_MODEL_CLASS(transformers_model) init_model_outputs = init_model(**inputs) - self.assertIsInstance(init_model.model, torch.jit.RecursiveScriptModule) + + # Compare tensor outputs + self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4)) + # To avoid float pointing error + self.assertTrue(torch.allclose(outputs.logits, loaded_model_outputs.logits, atol=1e-7)) + self.assertTrue(torch.allclose(outputs.logits, init_model_outputs.logits, atol=1e-7)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_forward(self, model_arch): + model_id = MODEL_NAMES[model_arch] + set_seed(SEED) + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 + ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=DEVICE) + self.assertIsInstance(ipex_model.config, PretrainedConfig) + input_ids = torch.Tensor([[1, 2, 3], [4, 5, 6]]).to(torch.long) + outputs = ipex_model(input_ids) + + self.assertIsInstance(outputs.logits, torch.Tensor) + + transformers_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=DEVICE) + with torch.no_grad(): + transformers_outputs = transformers_model(input_ids) + + # Test re-load model + with tempfile.TemporaryDirectory() as tmpdirname: + ipex_model.save_pretrained(tmpdirname) + loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname, torch_dtype=dtype, device_map=DEVICE) + loaded_model_outputs = loaded_model(input_ids) + + # Test init method + init_model = self.IPEX_MODEL_CLASS(transformers_model) + init_model_outputs = init_model(input_ids) # Compare tensor outputs self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4)) @@ -271,26 +310,28 @@ def test_compare_to_transformers(self, model_arch): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 model_id = MODEL_NAMES[model_arch] tokenizer = AutoTokenizer.from_pretrained(model_id) - model = IPEXModelForCausalLM.from_pretrained(model_id, export=True) + model = IPEXModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=DEVICE) model.config.encoder_no_repeat_ngram_size = 0 - model.to("cpu") pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) outputs = pipe("This is a sample", max_new_tokens=10) self.assertEqual(pipe.device, model.device) self.assertTrue(all("This is a sample" in item["generated_text"] for item in outputs)) @parameterized.expand(SUPPORTED_ARCHITECTURES) + @unittest.skip(reason="Paged attention do not support assisted decoding for now") def test_assisted_decoding(self, model_arch): - # Patched models are not support assisted decoding if ipex < 2.5. - if model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES and is_ipex_version("<", "2.4.0"): + # assist decoding does not support static cache now + if model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES: return model_id = MODEL_NAMES[model_arch] + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 tokenizer = AutoTokenizer.from_pretrained(model_id) - ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, export=True) - transformers_model = AutoModelForCausalLM.from_pretrained(model_id) - tokens = tokenizer("This is a sample input", return_tensors="pt") + ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=DEVICE) + transformers_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=DEVICE) + tokens = tokenizer("This is a sample input", return_tensors="pt").to(DEVICE) ipex_output = ipex_model.generate(**tokens, do_sample=False, max_new_tokens=4) ipex_output_assisted = ipex_model.generate( **tokens, do_sample=False, assistant_model=transformers_model, max_new_tokens=4 @@ -309,17 +350,24 @@ def test_assisted_decoding(self, model_arch): @parameterized.expand( grid_parameters( { - "model_arch": IPEX_PATCHED_SUPPORTED_ARCHITECTURES, + "model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True, False], } ) ) - @unittest.skipIf(is_ipex_version("<", "2.3.0"), reason="Only ipex version > 2.3.0 supports ipex model patching") - def test_ipex_patching_beam_search(self, test_name, model_arch, use_cache): + def test_ipex_beam_search(self, test_name, model_arch, use_cache): model_id = MODEL_NAMES[model_arch] set_seed(SEED) - model = IPEXModelForCausalLM.from_pretrained(model_id, export=True, use_cache=use_cache) - trasnformers_model = AutoModelForCausalLM.from_pretrained(model_id) + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 + model = IPEXModelForCausalLM.from_pretrained( + model_id, use_cache=use_cache, torch_dtype=dtype, device_map=DEVICE + ) + # It will be removed when torch 2.6 released + if model_arch == "opt" and not use_cache and model.compiled and is_torch_version("<", "2.6.0"): + return + if use_cache and model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES: + self.assertTrue(model.add_patch) + transformers_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=DEVICE) self.assertEqual(model.use_cache, use_cache) tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.pad_token = tokenizer.eos_token @@ -335,37 +383,21 @@ def test_ipex_patching_beam_search(self, test_name, model_arch, use_cache): ), ) for text in texts: - tokens = tokenizer(text, padding=True, return_tensors="pt") + tokens = tokenizer(text, padding=True, return_tensors="pt").to(DEVICE) for generation_config in generation_configs: outputs = model.generate(**tokens, generation_config=generation_config) - transformers_outputs = trasnformers_model.generate(**tokens, generation_config=generation_config) + transformers_outputs = transformers_model.generate(**tokens, generation_config=generation_config) self.assertIsInstance(outputs, torch.Tensor) self.assertTrue(torch.equal(outputs, transformers_outputs)) - @parameterized.expand(IPEX_PATCHED_SUPPORTED_ARCHITECTURES) - @unittest.skipIf(is_ipex_version("<", "2.3.0"), reason="Only ipex version > 2.3.0 supports ipex model patching") - def test_patched_model(self, model_arch): - model_id = MODEL_NAMES[model_arch] - patched_model_id = MODEL_NAMES["patched_" + model_arch] - ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, export=True) - exported_model = IPEXModelForCausalLM.from_pretrained(patched_model_id) - tokenizer = AutoTokenizer.from_pretrained(model_id) - tokens = tokenizer( - "This is a sample", - return_tensors="pt", - return_token_type_ids=False if model_arch in ("llama", "llama2") else None, - ) - inputs = ipex_model.prepare_inputs_for_generation(**tokens) - ipex_outputs = ipex_model(**inputs) - exported_outputs = exported_model(**inputs) - self.assertTrue(torch.allclose(ipex_outputs.logits, exported_outputs.logits, atol=1e-7)) - def test_compare_with_and_without_past_key_values(self): - model_id = "echarlaix/tiny-random-gpt2-torchscript" + model_id = "Intel/tiny_random_llama2_ipex_model" + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 + model_with_pkv = IPEXModelForCausalLM.from_pretrained( + model_id, use_cache=True, torch_dtype=dtype, device_map=DEVICE + ) tokenizer = AutoTokenizer.from_pretrained(model_id) - tokens = tokenizer("This is a sample input", return_tensors="pt") - - model_with_pkv = IPEXModelForCausalLM.from_pretrained(model_id, use_cache=True, subfolder="model_with_pkv") + tokens = tokenizer("This is a sample input", return_tensors="pt").to(DEVICE) # Warmup model_with_pkv.generate(**tokens) with Timer() as with_pkv_timer: @@ -373,7 +405,7 @@ def test_compare_with_and_without_past_key_values(self): **tokens, min_new_tokens=self.GENERATION_LENGTH, max_new_tokens=self.GENERATION_LENGTH, num_beams=1 ) model_without_pkv = IPEXModelForCausalLM.from_pretrained( - model_id, use_cache=False, subfolder="model_without_pkv" + model_id, use_cache=False, torch_dtype=dtype, device_map=DEVICE ) # Warmup model_without_pkv.generate(**tokens) @@ -385,6 +417,22 @@ def test_compare_with_and_without_past_key_values(self): self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH + tokens.input_ids.shape[1]) self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH + tokens.input_ids.shape[1]) + @parameterized.expand(IPEX_PATCHED_SUPPORTED_ARCHITECTURES) + def test_patched_model(self, model_arch): + model_id = MODEL_NAMES[model_arch] + patched_model_id = MODEL_NAMES["patched_" + model_arch] + ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, export=True, device_map=DEVICE) + exported_model = IPEXModelForCausalLM.from_pretrained(patched_model_id, device_map=DEVICE) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer("This is a sample", return_tensors="pt").to(DEVICE) + ipex_outputs = ipex_model.generate( + **tokens, max_new_tokens=1, return_dict_in_generate=True, output_logits=True + ) + exported_outputs = exported_model.generate( + **tokens, max_new_tokens=1, return_dict_in_generate=True, output_logits=True + ) + self.assertTrue(torch.allclose(ipex_outputs.logits[0], exported_outputs.logits[0], atol=1e-7)) + class IPEXModelForAudioClassificationTest(unittest.TestCase): IPEX_MODEL_CLASS = IPEXModelForAudioClassification @@ -403,11 +451,11 @@ def _generate_random_audio_data(self): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_transformers(self, model_arch): model_id = MODEL_NAMES[model_arch] - ipex_model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, export=True) + ipex_model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, device_map=DEVICE) self.assertIsInstance(ipex_model.config, PretrainedConfig) - transformers_model = self.IPEX_MODEL_CLASS.auto_model_class.from_pretrained(model_id) + transformers_model = self.IPEX_MODEL_CLASS.auto_model_class.from_pretrained(model_id, device_map=DEVICE) preprocessor = AutoFeatureExtractor.from_pretrained(model_id) - inputs = preprocessor(self._generate_random_audio_data(), return_tensors="pt") + inputs = preprocessor(self._generate_random_audio_data(), return_tensors="pt").to(DEVICE) with torch.no_grad(): transformers_outputs = transformers_model(**inputs) outputs = ipex_model(**inputs) @@ -415,13 +463,12 @@ def test_compare_to_transformers(self, model_arch): # Test re-load model with tempfile.TemporaryDirectory() as tmpdirname: ipex_model.save_pretrained(tmpdirname) - loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname) + loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname, device_map=DEVICE) loaded_model_outputs = loaded_model(**inputs) # Test init method - init_model = self.IPEX_MODEL_CLASS(transformers_model, export=True) + init_model = self.IPEX_MODEL_CLASS(transformers_model) init_model_outputs = init_model(**inputs) - self.assertIsInstance(init_model.model, torch.jit.RecursiveScriptModule) # Compare tensor outputs self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-3)) @@ -431,7 +478,7 @@ def test_compare_to_transformers(self, model_arch): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): model_id = MODEL_NAMES[model_arch] - model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, export=True) + model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, device_map=DEVICE) preprocessor = AutoFeatureExtractor.from_pretrained(model_id) pipe = pipeline("audio-classification", model=model, feature_extractor=preprocessor) outputs = pipe([np.random.random(16000)]) @@ -443,25 +490,27 @@ class IPEXModelForImageClassificationIntegrationTest(unittest.TestCase): IPEX_MODEL_CLASS = IPEXModelForImageClassification SUPPORTED_ARCHITECTURES = ( "beit", - # "levit", "mobilenet_v1", "mobilenet_v2", "mobilevit", "resnet", "vit", ) + IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("vit",) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_transformers(self, model_arch): model_id = MODEL_NAMES[model_arch] set_seed(SEED) - ipex_model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, export=True) + ipex_model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, device_map=DEVICE) + if model_arch in self.IPEX_PATCHED_SUPPORTED_ARCHITECTURES: + self.assertTrue(ipex_model.add_patch) self.assertIsInstance(ipex_model.config, PretrainedConfig) - transformers_model = self.IPEX_MODEL_CLASS.auto_model_class.from_pretrained(model_id) + transformers_model = self.IPEX_MODEL_CLASS.auto_model_class.from_pretrained(model_id, device_map=DEVICE) preprocessor = AutoFeatureExtractor.from_pretrained(model_id) url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) - inputs = preprocessor(images=image, return_tensors="pt") + inputs = preprocessor(images=image, return_tensors="pt").to(DEVICE) with torch.no_grad(): transformers_outputs = transformers_model(**inputs) outputs = ipex_model(**inputs) @@ -469,24 +518,23 @@ def test_compare_to_transformers(self, model_arch): # Test re-load model with tempfile.TemporaryDirectory() as tmpdirname: ipex_model.save_pretrained(tmpdirname) - loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname) + loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname, device_map=DEVICE) loaded_model_outputs = loaded_model(**inputs) # Test init method - init_model = self.IPEX_MODEL_CLASS(transformers_model, export=True) + init_model = self.IPEX_MODEL_CLASS(transformers_model) init_model_outputs = init_model(**inputs) - self.assertIsInstance(init_model.model, torch.jit.RecursiveScriptModule) self.assertIn("logits", outputs) # Compare tensor outputs self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4)) - self.assertTrue(torch.equal(outputs.logits, loaded_model_outputs.logits)) + self.assertTrue(torch.allclose(outputs.logits, loaded_model_outputs.logits, atol=1e-4)) self.assertTrue(torch.allclose(init_model_outputs.logits, transformers_outputs.logits, atol=1e-4)) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): model_id = MODEL_NAMES[model_arch] - model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, export=True) + model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, device_map=DEVICE) preprocessor = AutoFeatureExtractor.from_pretrained(model_id) pipe = pipeline("image-classification", model=model, feature_extractor=preprocessor) outputs = pipe("http://images.cocodataset.org/val2017/000000039769.jpg") @@ -494,12 +542,13 @@ def test_pipeline(self, model_arch): self.assertGreaterEqual(outputs[0]["score"], 0.0) self.assertTrue(isinstance(outputs[0]["label"], str)) - @unittest.skipIf(is_ipex_version("<", "2.3.0"), reason="Only ipex version > 2.3.0 supports ipex model patching") def test_patched_model(self): ipex_model = IPEXModelForImageClassification.from_pretrained( - "Jiqing/patched_tiny_random_vit_for_image_classification" + "Intel/tiny-random-vit_ipex_model", device_map=DEVICE + ) + transformers_model = self.IPEX_MODEL_CLASS.from_pretrained( + "hf-internal-testing/tiny-random-vit", device_map=DEVICE ) - transformers_model = self.IPEX_MODEL_CLASS.from_pretrained("hf-internal-testing/tiny-random-vit") preprocessor = AutoFeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-vit") url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) @@ -508,3 +557,153 @@ def test_patched_model(self): transformers_outputs = transformers_model(**inputs) outputs = ipex_model(**inputs) self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4)) + + +class IPEXModelForSeq2SeqLMTest(unittest.TestCase): + IPEX_MODEL_CLASS = IPEXModelForSeq2SeqLM + SUPPORTED_ARCHITECTURES = ("t5",) + GENERATION_LENGTH = 2 + SPEEDUP_CACHE = 1.0 + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_compare_to_transformers(self, model_arch): + model_id = MODEL_NAMES[model_arch] + set_seed(SEED) + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 + # Test model forward do not need cache. + ipex_model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, torch_dtype=dtype) + transformers_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, torch_dtype=dtype) + self.assertIsInstance(ipex_model.config, PretrainedConfig) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer( + "This is a sample", + return_tensors="pt", + return_token_type_ids=False if model_arch in ("llama", "llama2") else None, + ) + decoder_start_token_id = transformers_model.config.decoder_start_token_id if model_arch != "mbart" else 2 + decoder_inputs = {"decoder_input_ids": torch.ones((1, 1), dtype=torch.long) * decoder_start_token_id} + outputs = ipex_model(**tokens, **decoder_inputs) + + self.assertIsInstance(outputs.logits, torch.Tensor) + + with torch.no_grad(): + transformers_outputs = transformers_model(**tokens, **decoder_inputs) + + # Test re-load model + with tempfile.TemporaryDirectory() as tmpdirname: + ipex_model.save_pretrained(tmpdirname) + loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname, torch_dtype=dtype) + loaded_model_outputs = loaded_model(**tokens, **decoder_inputs) + + # Test init method + init_model = self.IPEX_MODEL_CLASS(transformers_model) + init_model_outputs = init_model(**tokens, **decoder_inputs) + + # Compare tensor outputs + self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-4)) + # To avoid float pointing error + self.assertTrue(torch.allclose(outputs.logits, loaded_model_outputs.logits, atol=1e-7)) + self.assertTrue(torch.allclose(outputs.logits, init_model_outputs.logits, atol=1e-7)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_pipeline(self, model_arch): + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 + model_id = MODEL_NAMES[model_arch] + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, torch_dtype=dtype) + model.config.encoder_no_repeat_ngram_size = 0 + # model.to("cpu") + pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer) + outputs = pipe("This is a sample", max_new_tokens=10, do_sample=False) + self.assertEqual(pipe.device, model.device) + + def test_compare_with_and_without_past_key_values(self): + model_id = "hf-internal-testing/tiny-random-t5" + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 + model_with_pkv = self.IPEX_MODEL_CLASS.from_pretrained(model_id, use_cache=True, torch_dtype=dtype) + device = model_with_pkv.device + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer("This is a sample input", return_tensors="pt").to(device) + # Warmup + model_with_pkv.generate(**tokens) + with Timer() as with_pkv_timer: + outputs_model_with_pkv = model_with_pkv.generate( + **tokens, min_new_tokens=self.GENERATION_LENGTH, max_new_tokens=self.GENERATION_LENGTH, num_beams=1 + ) + model_without_pkv = self.IPEX_MODEL_CLASS.from_pretrained(model_id, use_cache=False, torch_dtype=dtype) + # Warmup + model_without_pkv.generate(**tokens) + with Timer() as without_pkv_timer: + outputs_model_without_pkv = model_without_pkv.generate( + **tokens, min_new_tokens=self.GENERATION_LENGTH, max_new_tokens=self.GENERATION_LENGTH, num_beams=1 + ) + self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv)) + self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH + 1) + self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH + 1) + + @parameterized.expand( + grid_parameters( + { + "model_arch": SUPPORTED_ARCHITECTURES, + "use_cache": [True, False], + } + ) + ) + def test_ipex_beam_search(self, test_name, model_arch, use_cache): + model_id = MODEL_NAMES[model_arch] + set_seed(SEED) + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 + model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, use_cache=use_cache, torch_dtype=dtype) + device = model.device + transformers_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, torch_dtype=dtype).to(device) + self.assertEqual(model.use_cache, use_cache) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = tokenizer.eos_token + # Test with batch_size is 1 and 2. + texts = ["This is a sample", ["This is the first input", "This is the second input"]] + generation_configs = ( + GenerationConfig(max_new_tokens=4, num_beams=2, do_sample=False), + GenerationConfig(max_new_tokens=4, num_beams=4, do_sample=False), + GenerationConfig(max_new_tokens=4, num_beams=8, do_sample=False), + GenerationConfig(max_new_tokens=4, num_beams=32, do_sample=False), + GenerationConfig( + max_new_tokens=4, do_sample=False, top_p=0.9, top_k=0, pad_token_id=tokenizer.eos_token_id + ), + ) + for text in texts: + tokens = tokenizer(text, padding=True, return_tensors="pt").to(device) + for generation_config in generation_configs: + outputs = model.generate(**tokens, generation_config=generation_config) + transformers_outputs = transformers_model.generate(**tokens, generation_config=generation_config) + self.assertIsInstance(outputs, torch.Tensor) + self.assertTrue(torch.equal(outputs, transformers_outputs)) + + +class IPEXSTModel(unittest.TestCase): + SUPPORTED_ARCHITECTURES = ( + "st-bert", + "st-mpnet", + ) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_sentence_transformers + def test_compare_to_original_model(self, model_arch): + model_id = MODEL_NAMES[model_arch] + set_seed(SEED) + ipex_model = IPEXSentenceTransformer(model_id) + st_model = SentenceTransformer(model_id) + sentences = ["This is an example sentence", "Each sentence is converted"] + st_embeddings = st_model.encode(sentences) + ov_embeddings = ipex_model.encode(sentences) + self.assertTrue(np.allclose(ov_embeddings, st_embeddings, atol=1e-4)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_sentence_transformers + def test_sentence_transformers_save_and_infer(self, model_arch): + model_id = MODEL_NAMES[model_arch] + ipex_model = IPEXSentenceTransformer(model_id) + with tempfile.TemporaryDirectory() as tmpdirname: + ipex_model.save_pretrained(tmpdirname) + model = IPEXSentenceTransformer(tmpdirname, model_kwargs={"subfolder": "ipex"}) + sentences = ["This is an example sentence", "Each sentence is converted"] + model.encode(sentences) diff --git a/tests/ipex/test_pipelines.py b/tests/ipex/test_pipelines.py index 767097a5dd..f376c6050a 100644 --- a/tests/ipex/test_pipelines.py +++ b/tests/ipex/test_pipelines.py @@ -20,7 +20,7 @@ from parameterized import parameterized from transformers import AutoTokenizer from transformers.pipelines import pipeline as transformers_pipeline -from utils_tests import MODEL_NAMES +from utils_tests import IS_XPU_AVAILABLE, MODEL_NAMES from optimum.intel.ipex.modeling_base import ( IPEXModelForAudioClassification, @@ -28,12 +28,17 @@ IPEXModelForImageClassification, IPEXModelForMaskedLM, IPEXModelForQuestionAnswering, + IPEXModelForSeq2SeqLM, IPEXModelForSequenceClassification, IPEXModelForTokenClassification, ) from optimum.intel.pipelines import pipeline as ipex_pipeline +torch.use_deterministic_algorithms(True) +DEVICE = "xpu:0" if IS_XPU_AVAILABLE else "cpu" + + class PipelinesIntegrationTest(unittest.TestCase): COMMON_SUPPORTED_ARCHITECTURES = ( "albert", @@ -79,12 +84,13 @@ class PipelinesIntegrationTest(unittest.TestCase): "resnet", "vit", ) + TEXT2TEXT_GENERATION_SUPPORTED_ARCHITECTURES = ("t5",) @parameterized.expand(COMMON_SUPPORTED_ARCHITECTURES) def test_token_classification_pipeline_inference(self, model_arch): model_id = MODEL_NAMES[model_arch] - transformers_generator = transformers_pipeline("token-classification", model_id) - ipex_generator = ipex_pipeline("token-classification", model_id, accelerator="ipex") + transformers_generator = transformers_pipeline("token-classification", model_id, device_map=DEVICE) + ipex_generator = ipex_pipeline("token-classification", model_id, accelerator="ipex", device_map=DEVICE) inputs = "Hello I'm Omar and I live in Zürich." with torch.inference_mode(): transformers_output = transformers_generator(inputs) @@ -92,22 +98,20 @@ def test_token_classification_pipeline_inference(self, model_arch): ipex_output = ipex_generator(inputs) self.assertEqual(len(transformers_output), len(ipex_output)) self.assertTrue(isinstance(ipex_generator.model, IPEXModelForTokenClassification)) - self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) for i in range(len(transformers_output)): self.assertAlmostEqual(transformers_output[i]["score"], ipex_output[i]["score"], delta=1e-4) @parameterized.expand(COMMON_SUPPORTED_ARCHITECTURES) def test_sequence_classification_pipeline_inference(self, model_arch): model_id = MODEL_NAMES[model_arch] - transformers_generator = transformers_pipeline("text-classification", model_id) - ipex_generator = ipex_pipeline("text-classification", model_id, accelerator="ipex") + transformers_generator = transformers_pipeline("text-classification", model_id, device_map=DEVICE) + ipex_generator = ipex_pipeline("text-classification", model_id, accelerator="ipex", device_map=DEVICE) inputs = "This restaurant is awesome" with torch.inference_mode(): transformers_output = transformers_generator(inputs) with torch.inference_mode(): ipex_output = ipex_generator(inputs) self.assertTrue(isinstance(ipex_generator.model, IPEXModelForSequenceClassification)) - self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) self.assertEqual(transformers_output[0]["label"], ipex_output[0]["label"]) self.assertAlmostEqual(transformers_output[0]["score"], ipex_output[0]["score"], delta=1e-4) @@ -115,8 +119,8 @@ def test_sequence_classification_pipeline_inference(self, model_arch): def test_fill_mask_pipeline_inference(self, model_arch): model_id = MODEL_NAMES[model_arch] inputs = "The Milky Way is a galaxy." - transformers_generator = transformers_pipeline("fill-mask", model_id) - ipex_generator = ipex_pipeline("fill-mask", model_id, accelerator="ipex") + transformers_generator = transformers_pipeline("fill-mask", model_id, device_map=DEVICE) + ipex_generator = ipex_pipeline("fill-mask", model_id, accelerator="ipex", device_map=DEVICE) mask_token = transformers_generator.tokenizer.mask_token inputs = inputs.replace("", mask_token) with torch.inference_mode(): @@ -125,7 +129,6 @@ def test_fill_mask_pipeline_inference(self, model_arch): ipex_output = ipex_generator(inputs) self.assertEqual(len(transformers_output), len(ipex_output)) self.assertTrue(isinstance(ipex_generator.model, IPEXModelForMaskedLM)) - self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) for i in range(len(transformers_output)): self.assertEqual(transformers_output[i]["token"], ipex_output[i]["token"]) self.assertAlmostEqual(transformers_output[i]["score"], ipex_output[i]["score"], delta=1e-4) @@ -133,22 +136,26 @@ def test_fill_mask_pipeline_inference(self, model_arch): @parameterized.expand(TEXT_GENERATION_SUPPORTED_ARCHITECTURES) def test_text_generation_pipeline_inference(self, model_arch): model_id = MODEL_NAMES[model_arch] - transformers_generator = transformers_pipeline("text-generation", model_id) - ipex_generator = ipex_pipeline("text-generation", model_id, accelerator="ipex") + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 + transformers_generator = transformers_pipeline( + "text-generation", model_id, torch_dtype=dtype, device_map=DEVICE + ) + ipex_generator = ipex_pipeline( + "text-generation", model_id, accelerator="ipex", torch_dtype=dtype, device_map=DEVICE + ) inputs = "Describe a real-world application of AI." with torch.inference_mode(): - transformers_output = transformers_generator(inputs, max_new_tokens=10) + transformers_output = transformers_generator(inputs, do_sample=False, max_new_tokens=10) with torch.inference_mode(): - ipex_output = ipex_generator(inputs, max_new_tokens=10) + ipex_output = ipex_generator(inputs, do_sample=False, max_new_tokens=10) self.assertTrue(isinstance(ipex_generator.model, IPEXModelForCausalLM)) - self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) self.assertEqual(transformers_output[0]["generated_text"], ipex_output[0]["generated_text"]) @parameterized.expand(QUESTION_ANSWERING_SUPPORTED_ARCHITECTURES) def test_question_answering_pipeline_inference(self, model_arch): model_id = MODEL_NAMES[model_arch] - transformers_generator = transformers_pipeline("question-answering", model_id) - ipex_generator = ipex_pipeline("question-answering", model_id, accelerator="ipex") + transformers_generator = transformers_pipeline("question-answering", model_id, device_map=DEVICE) + ipex_generator = ipex_pipeline("question-answering", model_id, accelerator="ipex", device_map=DEVICE) question = "How many programming languages does BLOOM support?" context = "BLOOM has 176 billion parameters and can generate text in 46 languages natural languages and 13 programming languages." with torch.inference_mode(): @@ -156,7 +163,6 @@ def test_question_answering_pipeline_inference(self, model_arch): with torch.inference_mode(): ipex_output = ipex_generator(question=question, context=context) self.assertTrue(isinstance(ipex_generator.model, IPEXModelForQuestionAnswering)) - self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) self.assertAlmostEqual(transformers_output["score"], ipex_output["score"], delta=1e-4) self.assertEqual(transformers_output["start"], ipex_output["start"]) self.assertEqual(transformers_output["end"], ipex_output["end"]) @@ -164,23 +170,22 @@ def test_question_answering_pipeline_inference(self, model_arch): @parameterized.expand(AUDIO_CLASSIFICATION_SUPPORTED_ARCHITECTURES) def test_audio_classification_pipeline_inference(self, model_arch): model_id = MODEL_NAMES[model_arch] - transformers_generator = transformers_pipeline("audio-classification", model_id) - ipex_generator = ipex_pipeline("audio-classification", model_id, accelerator="ipex") + transformers_generator = transformers_pipeline("audio-classification", model_id, device_map=DEVICE) + ipex_generator = ipex_pipeline("audio-classification", model_id, accelerator="ipex", device_map=DEVICE) inputs = [np.random.random(16000)] with torch.inference_mode(): transformers_output = transformers_generator(inputs) with torch.inference_mode(): ipex_output = ipex_generator(inputs) self.assertTrue(isinstance(ipex_generator.model, IPEXModelForAudioClassification)) - self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) self.assertAlmostEqual(transformers_output[0][0]["score"], ipex_output[0][0]["score"], delta=1e-2) self.assertAlmostEqual(transformers_output[0][1]["score"], ipex_output[0][1]["score"], delta=1e-2) @parameterized.expand(IMAGE_CLASSIFICATION_SUPPORTED_ARCHITECTURES) def test_image_classification_pipeline_inference(self, model_arch): model_id = MODEL_NAMES[model_arch] - transformers_generator = transformers_pipeline("image-classification", model_id) - ipex_generator = ipex_pipeline("image-classification", model_id, accelerator="ipex") + transformers_generator = transformers_pipeline("image-classification", model_id, device_map=DEVICE) + ipex_generator = ipex_pipeline("image-classification", model_id, accelerator="ipex", device_map=DEVICE) inputs = "http://images.cocodataset.org/val2017/000000039769.jpg" with torch.inference_mode(): transformers_output = transformers_generator(inputs) @@ -188,7 +193,6 @@ def test_image_classification_pipeline_inference(self, model_arch): ipex_output = ipex_generator(inputs) self.assertEqual(len(transformers_output), len(ipex_output)) self.assertTrue(isinstance(ipex_generator.model, IPEXModelForImageClassification)) - self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) for i in range(len(transformers_output)): self.assertEqual(transformers_output[i]["label"], ipex_output[i]["label"]) self.assertAlmostEqual(transformers_output[i]["score"], ipex_output[i]["score"], delta=1e-4) @@ -196,27 +200,71 @@ def test_image_classification_pipeline_inference(self, model_arch): @parameterized.expand(COMMON_SUPPORTED_ARCHITECTURES) def test_pipeline_load_from_ipex_model(self, model_arch): model_id = MODEL_NAMES[model_arch] - model = IPEXModelForSequenceClassification.from_pretrained(model_id, export=True) + model = IPEXModelForSequenceClassification.from_pretrained(model_id, device_map=DEVICE) tokenizer = AutoTokenizer.from_pretrained(model_id) - ipex_generator = ipex_pipeline("text-classification", model, tokenizer=tokenizer, accelerator="ipex") + ipex_generator = ipex_pipeline( + "text-classification", model, tokenizer=tokenizer, accelerator="ipex", device_map=DEVICE + ) inputs = "This restaurant is awesome" with torch.inference_mode(): ipex_output = ipex_generator(inputs) self.assertTrue(isinstance(ipex_generator.model, IPEXModelForSequenceClassification)) - self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) self.assertGreaterEqual(ipex_output[0]["score"], 0.0) @parameterized.expand(COMMON_SUPPORTED_ARCHITECTURES) def test_pipeline_load_from_jit_model(self, model_arch): model_id = MODEL_NAMES[model_arch] - model = IPEXModelForSequenceClassification.from_pretrained(model_id, export=True) + model = IPEXModelForSequenceClassification.from_pretrained(model_id, device_map=DEVICE) save_dir = TemporaryDirectory().name model.save_pretrained(save_dir) tokenizer = AutoTokenizer.from_pretrained(model_id) - ipex_generator = ipex_pipeline("text-classification", save_dir, tokenizer=tokenizer, accelerator="ipex") + ipex_generator = ipex_pipeline( + "text-classification", save_dir, tokenizer=tokenizer, accelerator="ipex", device_map=DEVICE + ) inputs = "This restaurant is awesome" with torch.inference_mode(): ipex_output = ipex_generator(inputs) self.assertTrue(isinstance(ipex_generator.model, IPEXModelForSequenceClassification)) - self.assertTrue(isinstance(ipex_generator.model.model, torch.jit.RecursiveScriptModule)) self.assertGreaterEqual(ipex_output[0]["score"], 0.0) + + @parameterized.expand(TEXT2TEXT_GENERATION_SUPPORTED_ARCHITECTURES) + def test_text2text_generation_pipeline_inference(self, model_arch): + model_id = MODEL_NAMES[model_arch] + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 + transformers_generator = transformers_pipeline("text2text-generation", model_id, torch_dtype=dtype) + ipex_generator = ipex_pipeline("text2text-generation", model_id, accelerator="ipex", torch_dtype=dtype) + inputs = "Describe a real-world application of AI." + with torch.inference_mode(): + transformers_output = transformers_generator(inputs, do_sample=False, max_new_tokens=10) + with torch.inference_mode(): + ipex_output = ipex_generator(inputs, do_sample=False, max_new_tokens=10) + self.assertTrue(isinstance(ipex_generator.model, IPEXModelForSeq2SeqLM)) + self.assertEqual(transformers_output[0]["generated_text"], ipex_output[0]["generated_text"]) + + @parameterized.expand(TEXT2TEXT_GENERATION_SUPPORTED_ARCHITECTURES) + def test_summarization_generation_pipeline_inference(self, model_arch): + model_id = MODEL_NAMES[model_arch] + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 + transformers_generator = transformers_pipeline("summarization", model_id, torch_dtype=dtype) + ipex_generator = ipex_pipeline("summarization", model_id, accelerator="ipex", torch_dtype=dtype) + inputs = "Describe a real-world application of AI." + with torch.inference_mode(): + transformers_output = transformers_generator(inputs, do_sample=False, max_new_tokens=10) + with torch.inference_mode(): + ipex_output = ipex_generator(inputs, do_sample=False, max_new_tokens=10) + self.assertTrue(isinstance(ipex_generator.model, IPEXModelForSeq2SeqLM)) + self.assertEqual(transformers_output[0]["summary_text"], ipex_output[0]["summary_text"]) + + @parameterized.expand(TEXT2TEXT_GENERATION_SUPPORTED_ARCHITECTURES) + def test_translation_generation_pipeline_inference(self, model_arch): + model_id = MODEL_NAMES[model_arch] + dtype = torch.float16 if IS_XPU_AVAILABLE else torch.float32 + transformers_generator = transformers_pipeline("translation", model_id, torch_dtype=dtype) + ipex_generator = ipex_pipeline("translation", model_id, accelerator="ipex", torch_dtype=dtype) + inputs = "Describe a real-world application of AI." + with torch.inference_mode(): + transformers_output = transformers_generator(inputs, do_sample=False, max_new_tokens=10) + with torch.inference_mode(): + ipex_output = ipex_generator(inputs, do_sample=False, max_new_tokens=10) + self.assertTrue(isinstance(ipex_generator.model, IPEXModelForSeq2SeqLM)) + self.assertEqual(transformers_output[0]["translation_text"], ipex_output[0]["translation_text"]) diff --git a/tests/ipex/utils_tests.py b/tests/ipex/utils_tests.py index 595bc0246f..8cd93516da 100644 --- a/tests/ipex/utils_tests.py +++ b/tests/ipex/utils_tests.py @@ -11,8 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from transformers import is_torch_xpu_available +IS_XPU_AVAILABLE = is_torch_xpu_available(check_device=True) + MODEL_NAMES = { "albert": "hf-internal-testing/tiny-random-albert", "beit": "hf-internal-testing/tiny-random-BeitForImageClassification", @@ -25,18 +28,18 @@ "codegen": "hf-internal-testing/tiny-random-CodeGenForCausalLM", "convnext": "hf-internal-testing/tiny-random-convnext", "distilbert": "hf-internal-testing/tiny-random-distilbert", - "distilgpt2": "Jiqing/tiny_random_distilgpt2", + "distilgpt2": "Intel/tiny-random-distilgpt2", "electra": "hf-internal-testing/tiny-random-electra", "flaubert": "hf-internal-testing/tiny-random-flaubert", - "falcon": "Jiqing/tiny_random_falcon", + "falcon": "Intel/tiny-random-falcon", "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", - "gpt2": "hf-internal-testing/tiny-random-gpt2", + "gpt2": "Intel/tiny-random-gpt2", "gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel", "gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM", "gptj": "hf-internal-testing/tiny-random-GPTJModel", "levit": "hf-internal-testing/tiny-random-LevitModel", "llama": "fxmarty/tiny-llama-fast-tokenizer", - "llama2": "Jiqing/tiny_random_llama2", + "llama2": "Intel/tiny-random-llama2", "marian": "sshleifer/tiny-marian-en-de", "mbart": "hf-internal-testing/tiny-random-mbart", "mistral": "echarlaix/tiny-random-mistral", @@ -50,13 +53,15 @@ "resnet": "hf-internal-testing/tiny-random-resnet", "roberta": "hf-internal-testing/tiny-random-roberta", "roformer": "hf-internal-testing/tiny-random-roformer", + "st-bert": "sentence-transformers-testing/stsb-bert-tiny-safetensors", + "st-mpnet": "sentence-transformers/all-mpnet-base-v2", "squeezebert": "hf-internal-testing/tiny-random-squeezebert", "t5": "hf-internal-testing/tiny-random-t5", "unispeech": "hf-internal-testing/tiny-random-unispeech", "vit": "hf-internal-testing/tiny-random-vit", "wav2vec2": "anton-l/wav2vec2-random-tiny-classifier", "xlm": "hf-internal-testing/tiny-random-xlm", - "patched_falcon": "Jiqing/patched_tiny_random_falcon_for_causal_lm", - "patched_distilgpt2": "Jiqing/patched_tiny_random_distilgpt2_for_causal_lm", - "patched_llama2": "Jiqing/patched_tiny_random_llama2_for_causal_lm", + "patched_falcon": "Intel/tiny-random-falcon_ipex_model", + "patched_gpt2": "Intel/tiny-random-gpt2_ipex_model", + "patched_llama2": "Intel/tiny-random-llama2_ipex_model", } diff --git a/tests/neural_compressor/test_ipex.py b/tests/neural_compressor/test_ipex.py index ef1f19812e..2a230f23dd 100644 --- a/tests/neural_compressor/test_ipex.py +++ b/tests/neural_compressor/test_ipex.py @@ -52,7 +52,7 @@ class IPEXQuantizationTest(INCTestMixin): SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = (("text-classification", "bert", 21),) @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS) - def test_ipex_static_quantization_with_smoothquant(self, task, model_arch, expected_quantized_matmuls): + def test_static_quantization_with_smoothquant(self, task, model_arch, expected_quantized_matmuls): recipes = {"smooth_quant": True, "smooth_quant_args": {"alpha": 0.5}} num_samples = 10 model_name = MODEL_NAMES[model_arch] @@ -79,5 +79,5 @@ def test_ipex_static_quantization_with_smoothquant(self, task, model_arch, expec is_static=True, num_samples=num_samples, load_inc_model=False, - load_ipex_model=True, + load_ipex_model=False, ) diff --git a/tests/neural_compressor/test_optimization.py b/tests/neural_compressor/test_optimization.py index 75f2845c78..6b01baf705 100644 --- a/tests/neural_compressor/test_optimization.py +++ b/tests/neural_compressor/test_optimization.py @@ -45,7 +45,7 @@ set_seed, ) from utils_tests import MODEL_NAMES, SEED, INCTestMixin, _generate_dataset -from optimum.intel.utils.import_utils import is_torch_version +from optimum.intel.utils.import_utils import is_neural_compressor_version from optimum.intel import ( INCConfig, @@ -467,12 +467,16 @@ def _compute_metrics(pred): class WeightOnlyQuantizationTest(INCTestMixin): WEIGHT_ONLY_CONFIG = ( - ("rtn", 4), - ("gptq", 4), + ("rtn", 4, False), + ("rtn", 4, True), + ("gptq", 4, False), + ("gptq", 4, True), ) @parameterized.expand(WEIGHT_ONLY_CONFIG) - def test_weight_only_quantization(self, methodology, bits): + def test_weight_only_quantization(self, methodology, bits, use_layer_wise): + if use_layer_wise and is_neural_compressor_version("<", "3.2"): + self.skipTest("INC version < 3.2 doesn't support layer-wise feature.") from neural_compressor.transformers import GPTQConfig, RtnConfig model_name = "hf-internal-testing/tiny-random-GPTNeoForCausalLM" @@ -489,9 +493,10 @@ def test_weight_only_quantization(self, methodology, bits): batch_size=5, seq_len=32, block_size=16, + use_layer_wise=use_layer_wise, ) else: - quantization_config = RtnConfig(bits=bits, group_size=8) + quantization_config = RtnConfig(bits=bits, group_size=8, use_layer_wise=use_layer_wise) tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.add_special_tokens({"pad_token": "[PAD]"}) @@ -503,6 +508,7 @@ def test_weight_only_quantization(self, methodology, bits): with torch.no_grad(): quantizer_outputs = quantized_model(**tokens) quantized_model.save_pretrained(tmp_dir) + loaded_model = INCModelForCausalLM.from_pretrained(tmp_dir) with torch.no_grad(): loaded_outputs = loaded_model(**tokens) diff --git a/tests/openvino/test_export.py b/tests/openvino/test_export.py index 80a45cab6e..2d57f92d0e 100644 --- a/tests/openvino/test_export.py +++ b/tests/openvino/test_export.py @@ -132,6 +132,7 @@ def _openvino_export( ov_model.model.get_rt_info()["optimum"]["transformers_version"], _transformers_version ) self.assertTrue(ov_model.model.has_rt_info(["runtime_options", "ACTIVATIONS_SCALE_FACTOR"])) + self.assertTrue(ov_model.model.has_rt_info(["runtime_options", "KV_CACHE_PRECISION"])) if library_name == "diffusers": self.assertTrue( diff --git a/tests/openvino/test_exporters_cli.py b/tests/openvino/test_exporters_cli.py index 702eb19c04..7a3c824eee 100644 --- a/tests/openvino/test_exporters_cli.py +++ b/tests/openvino/test_exporters_cli.py @@ -105,6 +105,7 @@ class OVCLIExportTestCase(unittest.TestCase): if is_transformers_version(">=", "4.45"): SUPPORTED_SD_HYBRID_ARCHITECTURES.append(("stable-diffusion-3", 9, 65)) + SUPPORTED_SD_HYBRID_ARCHITECTURES.append(("flux", 7, 56)) TEST_4BIT_CONFIGURATIONS = [ ("text-generation-with-past", "opt125m", "int4 --sym --group-size 128", {"int8": 4, "int4": 72}), diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index 986cda0c47..54261b88f4 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -43,6 +43,7 @@ from optimum.intel import ( OVConfig, + OVFluxPipeline, OVLatentConsistencyModelPipeline, OVModelForAudioClassification, OVModelForCausalLM, @@ -93,6 +94,22 @@ class OVQuantizerTest(unittest.TestCase): (OVModelForSequenceClassification, "bert", 32, 35), (OVModelForCausalLM, "gpt2", 31, 22), ) + SUPPORTED_ARCHITECTURES_OV_MODEL_WITH_AUTO_DATASET = [ + ( + OVModelForSpeechSeq2Seq, + "whisper", + OVQuantizationConfig( + dataset="librispeech", + num_samples=1, + processor=MODEL_NAMES["whisper"], + trust_remote_code=True, + weight_only=False, + smooth_quant_alpha=0.95, + ), + (14, 22, 21) if is_transformers_version("<=", "4.42.4") else (14, 22, 25), + (14, 21, 17) if is_transformers_version("<=", "4.42.4") else (14, 22, 18), + ), + ] @parameterized.expand(SUPPORTED_ARCHITECTURES_TORCH_MODEL) def test_automodel_static_quantization(self, model_cls, model_name, expected_fake_quantize, expected_int8): @@ -180,6 +197,31 @@ def preprocess_function(examples, tokenizer): loaded_config = OVConfig.from_pretrained(tmp_dir) self.assertEqual(ov_config.quantization_config.to_dict(), loaded_config.quantization_config.to_dict()) + @parameterized.expand(SUPPORTED_ARCHITECTURES_OV_MODEL_WITH_AUTO_DATASET) + def test_ov_model_static_quantization_with_auto_dataset( + self, model_cls, model_name, quantization_config, expected_fake_quantize, expected_int8 + ): + model_id = MODEL_NAMES[model_name] + + with TemporaryDirectory() as tmp_dir: + ov_model = model_cls.from_pretrained(model_id, quantization_config=quantization_config) + ov_model.save_pretrained(tmp_dir) + + if model_cls == OVModelForSpeechSeq2Seq: + for model, expected_fq, expected_i8 in zip( + (ov_model.encoder.model, ov_model.decoder.model, ov_model.decoder_with_past.model), + expected_fake_quantize, + expected_int8, + ): + num_fake_quantize, num_weight_nodes = get_num_quantized_nodes(model) + self.assertEqual(expected_fq, num_fake_quantize) + self.assertEqual(expected_i8, num_weight_nodes["int8"]) + + input_features = torch.randn((1, 128, 3000), dtype=torch.float32) + ov_model.generate(input_features) + else: + raise Exception("Unexpected model class.") + class OVWeightCompressionTest(unittest.TestCase): SUPPORTED_ARCHITECTURES_WITH_EXPECTED_8BIT_COMPRESSED_MATMULS = ( @@ -465,6 +507,7 @@ class OVWeightCompressionTest(unittest.TestCase): SUPPORTED_ARCHITECTURES_WITH_HYBRID_QUANTIZATION.extend( [ (OVStableDiffusion3Pipeline, "stable-diffusion-3", 9, 65), + (OVFluxPipeline, "flux", 7, 56), ] ) @@ -1069,9 +1112,9 @@ class OVQuantizationConfigTest(unittest.TestCase): (dict(num_samples=100), OVWeightQuantizationConfig, "Can't determine type of OV quantization config"), (dict(abc="def"), OVWeightQuantizationConfig, "Can't determine type of OV quantization config"), ( - dict(bits=4, fast_bias_correction=True, dataset="wikitext2"), - OVWeightQuantizationConfig, - "Can't determine type of OV quantization config", + dict(bits=8, fast_bias_correction=True, dataset="librispeech"), + OVQuantizationConfig, + None, ), (dict(model_type="transformer"), OVQuantizationConfig, None), ( @@ -1091,7 +1134,12 @@ class OVQuantizationConfigTest(unittest.TestCase): (dict(abc="def", weight_only=False), OVQuantizationConfig, None), (dict(abc="def", weight_only=True), OVWeightQuantizationConfig, None), ( - dict(bits=4, fast_bias_correction=True, dataset="wikitext2", weight_only=True), + dict(bits=8, fast_bias_correction=True, dataset="librispeech", weight_only=True), + OVQuantizationConfig, + None, + ), + ( + dict(bits=4, dataset="wikitext2", weight_only=True), OVWeightQuantizationConfig, None, ), @@ -1151,7 +1199,7 @@ def test_for_no_short_id_duplicates(self): class InferRequestWrapperTest(unittest.TestCase): - MODEL_ID = ("openai/whisper-tiny.en",) + MODEL_NAME = ("whisper",) APPLY_CACHING = (False, True) @staticmethod @@ -1165,8 +1213,9 @@ def _generate_random_audio_data(processor): ).input_features return input_features - @parameterized.expand(itertools.product(MODEL_ID, APPLY_CACHING)) - def test_calibration_data_uniqueness(self, model_id, apply_caching): + @parameterized.expand(itertools.product(MODEL_NAME, APPLY_CACHING)) + def test_calibration_data_uniqueness(self, model_name, apply_caching): + model_id = MODEL_NAMES[model_name] ov_model = OVModelForSpeechSeq2Seq.from_pretrained(model_id, export=True, compile=True) processor = AutoProcessor.from_pretrained(model_id) diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index e37ad5baeb..dc87760828 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -148,7 +148,7 @@ "wav2vec2": "anton-l/wav2vec2-random-tiny-classifier", "wav2vec2-hf": "hf-internal-testing/tiny-random-Wav2Vec2Model", "wav2vec2-conformer": "hf-internal-testing/tiny-random-wav2vec2-conformer", - "whisper": "openai/whisper-tiny.en", + "whisper": "yujiepan/whisper-v3-tiny-random", "xlm": "hf-internal-testing/tiny-random-xlm", "xlm_roberta": "hf-internal-testing/tiny-xlm-roberta", "xglm": "hf-internal-testing/tiny-random-XGLMForCausalLM",