Skip to content

Commit

Permalink
Partial GPTQ int4 models conversion support + Swap nibbles in u4/i4 e…
Browse files Browse the repository at this point in the history
…lement type (openvinotoolkit#20371)

* Reference implementation for u4 constant compression from pytorch model based on bitwise ops pattern

* Fixed order of 4-bit halfs in byte

* Switched PyTorch FE to dev mode: in case if model cannot be fully converted, give partially converted model with PTFrameworkNode's with a printed warning (normally would raise an exception in case).

* Moved u4 compression to utils_quantize. Implemented not-interleaved version of u4 compression

* Removed debug output

* Added aten::matmul to the list of exceptions in may_produce_alias as a workaround for gptq models

* Added patching for gptq models applied automatically in convert_model

* WA for an inssue with u4 with earlier convert to fp16

* U4 blocked repacking for gptq patched model layout

* Deleted obsolete u4 re-packing based on aten::cat. Fixed the resulting u4 constant shape. Removed debug output.

* Revert "Switched PyTorch FE to dev mode: in case if model cannot be fully converted, give partially converted model with PTFrameworkNode's with a printed warning (normally would raise an exception in case)."

This reverts commit 0ef1455.

* Update src/frontends/pytorch/src/op/cat.cpp

* Check mask and shift values in u4 pattern. deque -> OutputVector for u4_compression_stack

* Convert to a given floating type instead of half in gptq patching. Better structured code.

* Code style fix

* Removed deque include

* Code style fixes

* Trailing space removed

* Fixed patched_forward and ts_decoder after unvalidated commits.

* Swap nibbles in u4/i4

* Better exception handling around jit.trace and gptq.patch_model

* Update src/bindings/python/src/openvino/frontend/pytorch/gptq.py

Co-authored-by: Alexander Kozlov <[email protected]>

* Update src/bindings/python/src/openvino/frontend/pytorch/gptq.py

Co-authored-by: Alexander Kozlov <[email protected]>

* Code style

* Revers int4 byte order

* Fixed core tests

* Fixed unguarded dynamic_cast result

Co-authored-by: Evgenya Nugmanova <[email protected]>

* Fixed transformation tests

* Update src/bindings/python/src/openvino/frontend/pytorch/gptq.py

Co-authored-by: Maxim Vafin <[email protected]>

* Prevent patching of non-gptq models

* Removed extra calling of quantized weights decompression patterns

* Better detection of supported AutoGPTQ models + more diagnostics

* Accurate diagnostics in case when aten::stack has multiple axes

---------

Co-authored-by: Alexander Kozlov <[email protected]>
Co-authored-by: Ilya Churaev <[email protected]>
Co-authored-by: Evgenya Nugmanova <[email protected]>
Co-authored-by: Maxim Vafin <[email protected]>
  • Loading branch information
5 people authored Oct 18, 2023
1 parent cf9791e commit 46935e0
Show file tree
Hide file tree
Showing 19 changed files with 516 additions and 236 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def inlined_inputs(self, index):
return result

def may_produce_alias(self, in_index: int, out_index: int) -> bool:
if self.get_op_type() in ["aten::conv1d", "aten::conv2d", "aten::conv3d"]:
if self.get_op_type() in ["aten::conv1d", "aten::conv2d", "aten::conv3d", "aten::matmul"]:
# AliasDB::may_contain_alias sometimes return True for tensors produced by convnd, we have to workaround that
return False
try:
Expand Down
140 changes: 140 additions & 0 deletions src/bindings/python/src/openvino/frontend/pytorch/gptq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@

# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

# flake8: noqa
# mypy: ignore-errors

import torch
from functools import partial

# Wraps a single tensor to a module to prevent it from jit.freezing
# It depends on a tensor dtype whether it will be preserved from freezing. Refer to the decoder code to learn which types will be preserved.
class KeepWeight(torch.nn.Module):

def __init__(self, weight):
super().__init__()
self.weight = torch.nn.Parameter(weight, requires_grad=False)

def forward(self):
return self.weight


# Produces a pattern that can be captured later and represented as a single u4 constant node
def decompression_pattern(weights):
mask = torch.tensor(15, dtype=torch.uint8).to(weights.device)
return torch.stack((torch.bitwise_and(weights, mask), torch.bitwise_right_shift(weights, 4)), dim=-1)


def patched_forward(self, *args, **kwargs):
if hasattr(self, '_hf_hook'):
args, kwargs = self._hf_hook.pre_forward(self, *args, **kwargs)

x = args[0]
dtype = x.dtype
outshape = x.shape[:-1] + (self.width,)
x = x.view(-1, x.shape[-1])
groups = self.qzeros.shape[0]
height = self.qweight.shape[0]

unpacked_weights = decompression_pattern(
self._openvino_u4_compression_submodule_qweights()).contiguous().view(height, -1, 8)
unpacked_weights = torch.transpose(
unpacked_weights, 1, 2).contiguous().view(-1, self.group_size, self.width)
unpacked_zp = decompression_pattern(
self._openvino_u4_compression_submodule_qzeros()).contiguous().view(groups, 1, -1)

unpacked_zp = unpacked_zp.to(dtype) + 1

unpacked_weights = (unpacked_weights.to(dtype) - unpacked_zp) * self.scales
unpacked_weights = unpacked_weights.view(-1, self.width)

out = x @ unpacked_weights

out = out.view(outshape)
if self.bias is not None:
out.add_(self.bias)

if hasattr(self, '_hf_hook'):
out = self._hf_hook.post_forward(self, out)
return out


# All the following AutoGPTQ's quant types are supposed to have the same weights packing schema
supported_quant_types = ['triton', 'exllama', 'cuda', 'exllamav2', 'cuda-old']


def patch_model(model):
for name, m in model.named_modules():
if hasattr(m, '_openvino_patch_orig_forward'):
# already patched, skipping
continue
# TODO: Check module type
is_quantized = getattr(m, 'is_quantized', None)
if is_quantized is not None:
m.is_quantized = False
m.float() # enables tracing on CPU, applied for all modules
if hasattr(m, 'QUANT_TYPE'):
if m.QUANT_TYPE not in supported_quant_types:
raise ValueError(
f'Unsupported QUANT_TYPE == {m.QUANT_TYPE} is discovered for AutoGPTQ model, only the following types are supported: {supported_quant_types}')
if m.bits != 4:
raise ValueError(
f'Unsupported bits == {m.bits} is discovered in module {name} in AutoGPTQ model, only bits == 4 is supported.')

int4_in_int32 = 8
groups = m.qzeros.shape[0]
m.width = m.qweight.shape[1]
assert m.group_size == m.qweight.shape[0] * int4_in_int32 // groups

m._openvino_patch_orig_forward = m.forward
m.forward = partial(patched_forward, m)

# Keep original field properties to be used when model is returned back to its original state
m._openvino_patch_orig_qweights_type = m.qweight.dtype
m._openvino_patch_orig_qzeros_type = m.qzeros.dtype
m._openvino_patch_orig_scale_shape = m.scales.shape

m.qweight = m.qweight.view(dtype=torch.uint8)
m.qzeros = m.qzeros.view(dtype=torch.uint8)

# TODO: Redundant tensor copy? Try to remove m.qweigh and m.qzeros after keeping modified values as submodules
m.add_module(
'_openvino_u4_compression_submodule_qweights', KeepWeight(m.qweight))
m.add_module('_openvino_u4_compression_submodule_qzeros',
KeepWeight(m.qzeros))

m.scales = m.scales.view(-1, 1, m.width)


def unpatch_model(model):
for _, m in model.named_modules():
if hasattr(m, '_openvino_patch_orig_forward'):
try:
m.forward = m._openvino_patch_orig_forward
del m._openvino_patch_orig_forward

m.qweight = m.qweight.view(
dtype=m._openvino_patch_orig_qweights_type)
del m._openvino_patch_orig_qweights_type

m.qzeros = m.qzeros.view(
dtype=m._openvino_patch_orig_qzeros_type)
del m._openvino_patch_orig_qzeros_type

m.scales = m.scales.view(m._openvino_patch_orig_scale_shape)
del m._openvino_patch_orig_scale_shape

del m._openvino_u4_compression_submodule_qweights
del m._openvino_u4_compression_submodule_qzeros
except Exception as error:
print('[ WARNING ] Exception raised during GPTQ model unpatching. Depending on the exact issue it may lead to broken original model')
print(error)


def detect_gptq_model_raw(model):
return model and getattr(model, 'config', None) and getattr(model.config, 'quantization_config', None) and model.config.quantization_config.quant_method == 'gptq'


def detect_gptq_model(model):
return detect_gptq_model_raw(model) or getattr(model, 'model', None) and detect_gptq_model_raw(model.model)
26 changes: 23 additions & 3 deletions src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from openvino.runtime import op, PartialShape, Type as OVType, OVAny
from openvino.frontend.pytorch.utils import ivalue_to_constant, get_value_from_getattr, pt_to_ov_type_map, prepare_example_inputs_and_model, convert_quantized_tensor
from openvino.runtime import opset11 as ops
from openvino.frontend.pytorch import gptq

import typing
import torch
Expand Down Expand Up @@ -84,8 +85,27 @@ def _get_scripted_model(self, pt_module, example_inputs=None, skip_freeze=False)
if example_inputs is None:
scripted = torch.jit.script(pt_module)
else:
input_parameters, input_signature, pt_module, self._input_is_list = prepare_example_inputs_and_model(example_inputs, input_params, pt_module)
scripted = torch.jit.trace(pt_module, **input_parameters, strict=False)
input_parameters, input_signature, pt_module, self._input_is_list = prepare_example_inputs_and_model(
example_inputs, input_params, pt_module)
gptq_patched = False

if gptq.detect_gptq_model(pt_module):
try:
gptq.patch_model(pt_module)
gptq_patched = True
except Exception as error:
print('[ WARNING ] Failed patching of AutoGPTQ model. Error message:\n', error)
print('[ WARNING ] Tracing of the model will likely be unsuccesfull or incorrect')
gptq.unpatch_model(pt_module)
gptq_patched = False

try:
scripted = torch.jit.trace(
pt_module, **input_parameters, strict=False)
finally:
if gptq_patched:
gptq.unpatch_model(pt_module)

if not skip_freeze:
for n in scripted.inlined_graph.nodes():
# TODO: switch off freezing for all traced models
Expand Down Expand Up @@ -341,7 +361,7 @@ def input_is_none(self, index: int) -> bool:
return False

def may_produce_alias(self, in_index: int, out_index: int) -> bool:
if self.get_op_type() in ["aten::conv1d", "aten::conv2d", "aten::conv3d"]:
if self.get_op_type() in ["aten::conv1d", "aten::conv2d", "aten::conv3d", "aten::matmul"]:
# AliasDB::may_contain_alias sometimes return True for tensors produced by convnd, we have to workaround that
return False
try:
Expand Down
4 changes: 2 additions & 2 deletions src/common/transformations/tests/utils/convert_precision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1188,7 +1188,7 @@ void constant_convert_test(element::Type type_from,
}
ASSERT_TRUE(actual.size() >= expected.size());
for (size_t i = 0; i < expected.size(); i++) {
ASSERT_EQ(expected[i], actual[i]);
EXPECT_EQ(expected[i], actual[i]) << "Elements with index " << i << " are not equal.";
}
}

Expand Down Expand Up @@ -1378,7 +1378,7 @@ TEST(TransformationTests, ConvertPrecision_ConstantConversion_U1ToU4) {
constant_convert_test<uint8_t, uint8_t>(element::u1,
element::u4,
std::vector<uint8_t>{171},
{1, 0, 1, 0, 1, 0, 1, 1});
{0, 1, 0, 1, 0, 1, 1, 1});
}

TEST(TransformationTests, ConvertPrecision_keep_precission_sensitive_fp32_with_exp) {
Expand Down
32 changes: 7 additions & 25 deletions src/core/include/openvino/op/constant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ class OPENVINO_API Constant : public Op {
typename StorageDataType = fundamental_type_for<Type>,
typename std::enable_if<Type == element::Type_t::u4, bool>::type = true>
StorageDataType get_element_value(size_t index) const {
return (get_data_ptr<uint8_t>()[index / 2] >> (index % 2 ? 0 : 4)) & 0x0F;
return (get_data_ptr<uint8_t>()[index / 2] >> (index % 2 ? 4 : 0)) & 0x0F;
}

template <element::Type_t Type,
Expand All @@ -440,7 +440,7 @@ class OPENVINO_API Constant : public Op {
typename StorageDataType = fundamental_type_for<Type>,
typename std::enable_if<Type == element::Type_t::i4, bool>::type = true>
StorageDataType get_element_value(size_t index) const {
const uint8_t i4data = (get_data_ptr<uint8_t>()[index / 2] >> (index % 2 ? 0 : 4)) & 0x0F;
const uint8_t i4data = (get_data_ptr<uint8_t>()[index / 2] >> (index % 2 ? 4 : 0)) & 0x0F;
const bool is_negative_number = (i4data >> 3) & 0x01;
const int8_t data = is_negative_number ? i4data | 0xF0 : i4data;
return data;
Expand Down Expand Up @@ -530,7 +530,7 @@ class OPENVINO_API Constant : public Op {
const auto round_element_no = element_number % 2 ? element_number + 1 : element_number;
output.reserve(round_element_no); // adds 1 more elements here?
std::for_each(source_begin, source_end, [&](IN_T c) {
for (const auto i : {4, 0}) {
for (const auto i : {0, 4}) {
const uint8_t data = (c >> i) & 0x0F;
output.push_back(data);
}
Expand All @@ -548,7 +548,7 @@ class OPENVINO_API Constant : public Op {
const auto round_element_no = element_number % 2 ? element_number + 1 : element_number;
output.reserve(round_element_no); // adds 1 more elements here?
std::for_each(source_begin, source_end, [&](IN_T c) {
for (const auto i : {4, 0}) {
for (const auto i : {0, 4}) {
const uint8_t i4data = (c >> i) & 0x0F;
const bool is_negative_number = (i4data >> 3) & 0x01;
const int8_t data = is_negative_number ? i4data | 0xF0 : i4data;
Expand Down Expand Up @@ -663,27 +663,9 @@ class OPENVINO_API Constant : public Op {
template <element::Type_t Type,
typename T,
typename StorageDataType = fundamental_type_for<Type>,
typename std::enable_if<Type == element::Type_t::u4 || Type == element::Type_t::i4, bool>::type = true>
void write_buffer(const std::vector<T>& source) {
auto p = get_data_ptr_nc<Type>();
size_t i = 0;
for (; i < source.size() / 2; i++) {
const auto v1 = value_in_range<Type>(source[i * 2]) & 0x0F;
const auto v2 = value_in_range<Type>(source[i * 2 + 1]) & 0x0F;
const auto v = (v1 << 4) | v2;
p[i] = static_cast<StorageDataType>(v);
}
if (source.size() % 2) {
const auto v1 = value_in_range<Type>(source[i * 2]) & 0x0F;
const auto v = v1 << 4;
p[i] = static_cast<StorageDataType>(v);
}
}

template <element::Type_t Type,
typename T,
typename StorageDataType = fundamental_type_for<Type>,
typename std::enable_if<Type == element::Type_t::nf4 && std::is_integral<T>::value, bool>::type = true>
typename std::enable_if<Type == element::Type_t::u4 || Type == element::Type_t::i4 ||
(Type == element::Type_t::nf4 && std::is_integral<T>::value),
bool>::type = true>
void write_buffer(const std::vector<T>& source) {
auto p = get_data_ptr_nc<Type>();
size_t i = 0;
Expand Down
12 changes: 6 additions & 6 deletions src/core/reference/include/openvino/reference/convert.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace reference {
namespace detail {
inline void set_u1(uint8_t* buf, size_t idx, uint8_t val) {
const size_t byte_idx = idx / 8;
const uint8_t bit_idx = 7 - (idx % 8);
const uint8_t bit_idx = 7 - (idx % 8); // Reversed order of bits
if (val) {
buf[byte_idx] |= (1 << bit_idx);
} else {
Expand All @@ -24,33 +24,33 @@ inline void set_u1(uint8_t* buf, size_t idx, uint8_t val) {

inline uint8_t get_u1(const uint8_t* buf, size_t idx) {
const size_t byte_idx = idx / 8;
const uint8_t bit_idx = 7 - (idx % 8);
const uint8_t bit_idx = 7 - (idx % 8); // Reversed order of bits
return (buf[byte_idx] & (1 << bit_idx)) ? 1 : 0;
}

inline void set_u4(uint8_t* buf, size_t idx, uint8_t val) {
const size_t byte_idx = idx / 2;
const uint8_t bit_shift = 4 * (++idx % 2);
const uint8_t bit_shift = 4 * (idx % 2);
buf[byte_idx] &= ~(0xF << bit_shift); // half byte zeroed
buf[byte_idx] |= ((val & 0xF) << bit_shift); // set 1's
}

inline uint8_t get_u4(const uint8_t* buf, size_t idx) {
const size_t byte_idx = idx / 2;
const uint8_t bit_shift = 4 * (++idx % 2);
const uint8_t bit_shift = 4 * (idx % 2);
return (buf[byte_idx] >> bit_shift) & 0xF;
}

inline void set_i4(uint8_t* buf, size_t idx, int8_t val) {
const size_t byte_idx = idx / 2;
const uint8_t bit_shift = 4 * (++idx % 2);
const uint8_t bit_shift = 4 * (idx % 2);
buf[byte_idx] &= ~(0xF << bit_shift); // half byte zeroed
buf[byte_idx] |= ((val & 0xF) << bit_shift); // set 1's
}

inline int8_t get_i4(const uint8_t* buf, size_t idx) {
const size_t byte_idx = idx / 2;
const uint8_t bit_shift = 4 * (++idx % 2);
const uint8_t bit_shift = 4 * (idx % 2);
uint8_t val = (buf[byte_idx] >> bit_shift) & 0xF;
if (val & 0x08) { // negative number
val |= 0xF0;
Expand Down
20 changes: 10 additions & 10 deletions src/core/tests/constant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,8 @@ TEST(constant, int4_string) {
EXPECT_EQ(v[2], -1);

const auto p = c.get_data_ptr<uint8_t>();
EXPECT_EQ(0x10, p[0]);
EXPECT_EQ(0xF0, p[1] & 0xF0);
EXPECT_EQ(0x01, p[0]);
EXPECT_EQ(0x0F, p[1] & 0x0F);

EXPECT_EQ(input, c.get_value_strings());

Expand Down Expand Up @@ -318,8 +318,8 @@ TEST(constant, int4_vector_negative_number) {
EXPECT_EQ(v[2], int8_t(-1));

const auto p = c.get_data_ptr<uint8_t>();
EXPECT_EQ(0xFE, p[0]);
EXPECT_EQ(0xF0, p[1] & 0xF0);
EXPECT_EQ(0xEF, p[0]);
EXPECT_EQ(0x0F, p[1] & 0x0F);
}

TEST(constant, int4_vector_positive_number) {
Expand All @@ -332,8 +332,8 @@ TEST(constant, int4_vector_positive_number) {
EXPECT_EQ(v[2], int8_t(5));

const auto p = c.get_data_ptr<uint8_t>();
EXPECT_EQ(0x12, p[0]);
EXPECT_EQ(0x50, p[1] & 0xF0);
EXPECT_EQ(0x21, p[0]);
EXPECT_EQ(0x05, p[1] & 0x0F);
}

TEST(constant, int4_vector_broadcast_negative_number) {
Expand Down Expand Up @@ -795,8 +795,8 @@ TEST(constant, uint4_string) {
EXPECT_EQ(v[3], 0);

const auto p = c.get_data_ptr<uint8_t>();
EXPECT_EQ(p[0], 0x10);
EXPECT_EQ(p[1], 0x10);
EXPECT_EQ(p[0], 0x01);
EXPECT_EQ(p[1], 0x01);

EXPECT_EQ(input, c.get_value_strings());

Expand Down Expand Up @@ -831,8 +831,8 @@ TEST(constant, uint4_vector) {
EXPECT_EQ(v[3], 0);

const auto p = c.get_data_ptr<uint8_t>();
EXPECT_EQ(p[0], 0x10);
EXPECT_EQ(p[1], 0x10);
EXPECT_EQ(p[0], 0x01);
EXPECT_EQ(p[1], 0x01);
}

TEST(constant, uint4_vector_broadcast) {
Expand Down
4 changes: 2 additions & 2 deletions src/core/tests/int4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ TEST(int4, convert_i4_to_string) {
vector<uint8_t> values{171, 16};
auto constant = make_shared<ov::op::v0::Constant>(element::i4, Shape{3}, &values[0]);

vector<string> ref{"-6", "-5", "1"};
vector<string> ref{"-5", "-6", "0"};
for (size_t i = 0; i < 3; ++i) {
ASSERT_EQ(constant->convert_value_to_string(i), ref[i]);
EXPECT_EQ(constant->convert_value_to_string(i), ref[i]);
}
}

Expand Down
Loading

0 comments on commit 46935e0

Please sign in to comment.