Skip to content

Commit

Permalink
Refactor quantization tensor data representation (tracel-ai#2479)
Browse files Browse the repository at this point in the history
* Remove quantization strategy from QFloat to use scheme instead (qparams unknown at compile-time)

Instead, the qparams are stored in the TensorData bytes so we can pack/unpack them based on the scheme

* Change quantization tensor data representation to pack quantized data type into u32

* Fix clippy

* Remove comment

* Add alloc vec import

* Remove print

* Rename into_bytes
  • Loading branch information
laggui authored Nov 13, 2024
1 parent 94db460 commit 34b303e
Show file tree
Hide file tree
Showing 19 changed files with 485 additions and 260 deletions.
16 changes: 8 additions & 8 deletions crates/burn-fusion/src/ops/qtensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{marker::PhantomData, ops::Range};

use burn_tensor::{
ops::{FloatElem, FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
quantization::{QuantizationParametersPrimitive, QuantizationScheme, QuantizationStrategy},
quantization::{QuantizationParametersPrimitive, QuantizationScheme, QuantizationType},
repr::{
DequantizeOperationDescription, FloatOperationDescription, HandleContainer,
OperationDescription, QuantizationParametersDescription, QuantizeOperationDescription,
Expand All @@ -21,14 +21,14 @@ use crate::{
impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
fn q_from_data(data: TensorData, device: &Device<Self>) -> QuantizedTensor<Self> {
match data.dtype {
DType::QFloat(strategy) => {
DType::QFloat(scheme) => {
let client = get_client::<B>(device);
let tensor = B::q_from_data(data, device);
let shape = B::q_shape(&tensor);

let handles = B::quantized_tensor_handle(tensor);
let qparams = match strategy {
QuantizationStrategy::PerTensorAffineInt8(_) => {
let qparams = match scheme {
QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) => {
let offset = if let Some(offset) = handles.offset {
offset
} else {
Expand All @@ -49,7 +49,7 @@ impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
)),
}
}
QuantizationStrategy::PerTensorSymmetricInt8(_) => {
QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => {
assert!(
handles.offset.is_none(),
"Offset should not be provided for symmetric quantization."
Expand All @@ -74,7 +74,7 @@ impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
QFusionTensor {
qtensor,
qparams,
scheme: strategy.scheme(),
scheme,
}
}
_ => panic!(
Expand Down Expand Up @@ -142,7 +142,7 @@ impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
scale: qparams.scale.clone().into_description(),
offset: qparams.offset.clone().map(|x| x.into_description()),
},
scheme: scheme.clone(),
scheme: *scheme,
out: out.to_description_out(),
};

Expand All @@ -157,7 +157,7 @@ impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {

QFusionTensor {
qtensor: out,
scheme: scheme.clone(),
scheme: *scheme,
qparams: qparams.into(),
}
}
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-fusion/src/stream/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ impl RelativeOpsScalar<f32> for FloatOperationDescription {
.as_ref()
.map(|x| x.to_relative(converter)),
},
scheme: desc.scheme.clone(),
scheme: desc.scheme,
out: desc.out.to_relative(converter),
})
}
Expand All @@ -561,7 +561,7 @@ impl RelativeOpsScalar<f32> for FloatOperationDescription {
.as_ref()
.map(|x| x.to_relative(converter)),
},
scheme: desc.qtensor.scheme.clone(),
scheme: desc.qtensor.scheme,
},
out: desc.out.to_relative(converter),
})
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-fusion/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ impl<R: FusionRuntime> Clone for QFusionTensor<R> {
fn clone(&self) -> Self {
Self {
qtensor: self.qtensor.clone(),
scheme: self.scheme.clone(),
scheme: self.scheme,
qparams: self.qparams.clone(),
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/quantization/quantize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ where

QJitTensor {
qtensor,
scheme: scheme.clone(),
scheme: *scheme,
qparams,
}
}
73 changes: 17 additions & 56 deletions crates/burn-jit/src/ops/qtensor.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
use std::ops::Range;

use alloc::vec::Vec;
use burn_tensor::{
ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
quantization::{
QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme,
QuantizationStrategy, QuantizationType,
QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme, QuantizationType,
},
DType, Device, ElementConversion, Shape, TensorData,
DType, Device, Shape, TensorData,
};

use crate::{
Expand All @@ -17,28 +15,14 @@ use crate::{
};
use cubecl::CubeElement;

fn pack_i8s_to_u32s(data: &TensorData) -> Vec<u32> {
// Shift and combine groups of four 8-bit values into a u32.
// Same as doing this:
// let result = (a_u8 & 0xFF) << 24 | (b_u8 & 0xFF) << 16 | (c_u8 & 0xFF) << 8 | (d_u8 & 0xFF);
data.as_bytes()
.chunks(4)
.map(|x| {
x.iter().enumerate().fold(0u32, |acc, (i, x)| {
acc | (*x as i8 as u32 & 0xFF) << ((3 - i) * 8)
})
})
.collect()
}

/// Create a quantized tensor with packed values (u32).
fn packed_tensor<R: JitRuntime, S: Into<Shape>>(
data: Vec<u32>,
data: &[u8],
shape: S,
device: &R::Device,
) -> JitTensor<R, u32> {
let client = R::client(device);
let buffer = client.create(u32::as_bytes(&data));
let buffer = client.create(data);

JitTensor::new_contiguous(client, device.clone(), shape.into(), buffer)
}
Expand All @@ -51,27 +35,21 @@ where
{
fn q_from_data(data: TensorData, device: &Device<Self>) -> QuantizedTensor<Self> {
match data.dtype {
DType::QFloat(strategy) => match strategy {
QuantizationStrategy::PerTensorAffineInt8(q) => {
DType::QFloat(scheme) => match scheme {
QuantizationScheme::PerTensorAffine(QuantizationType::QInt8)
| QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => {
// Convert quantized values to packed u32s
let qparams = data.get_q_params().unwrap();
QJitTensor {
qtensor: packed_tensor(pack_i8s_to_u32s(&data), data.shape, device),
scheme: strategy.scheme(),
qtensor: packed_tensor(data.values_as_bytes(), data.shape.clone(), device),
scheme,
qparams: JitQuantizationParameters::new(
q.scale.elem(),
Some(q.offset.elem()),
qparams.scale,
qparams.offset,
device,
),
}
}
QuantizationStrategy::PerTensorSymmetricInt8(q) => {
// Convert quantized values to packed u32s
QJitTensor {
qtensor: packed_tensor(pack_i8s_to_u32s(&data), data.shape, device),
scheme: strategy.scheme(),
qparams: JitQuantizationParameters::new(q.scale.elem(), None, device),
}
}
},
_ => panic!(
"Invalid dtype (expected DType::QFloat, got {:?})",
Expand Down Expand Up @@ -119,35 +97,18 @@ where

async fn q_into_data(tensor: QuantizedTensor<Self>) -> TensorData {
let strategy = tensor.strategy();
let numel = tensor.qtensor.shape.num_elements();
let qtensor = kernel::into_contiguous(tensor.qtensor);

let bytes = qtensor.client.read_async(qtensor.handle.binding()).await;

// Convert packed bytes to quantized dtype (TensorData can be used with other backends,
// which don't have the prior knowledge of this packed representation)
// TensorData keeps quantized values packed into 32-bit unsigned integers so we can
// keep the current representation, just cast the bytes as u32.
match &tensor.scheme {
QuantizationScheme::PerTensorAffine(dtype)
| QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
QuantizationType::QInt8 => TensorData::quantized(
u32::from_bytes(&bytes)
.iter()
.enumerate()
.flat_map(|(i, packed)| {
// A single u32 could contain less than four 8-bit values...
let n = core::cmp::min(4, numel - i * 4);
// Extract each 8-bit segment from u32 and cast back to i8
// Same as doing this (when 4 values are fully packed):
// let a = ((packed >> 24) & 0xFF) as i8;
// let b = ((packed >> 16) & 0xFF) as i8;
// let c = ((packed >> 8) & 0xFF) as i8;
// let d = (packed & 0xFF) as i8;
(0..n).map(move |i| (packed >> ((3 - i) * 8) & 0xFF) as i8)
})
.collect(),
qtensor.shape,
strategy,
),
QuantizationType::QInt8 => {
TensorData::quantized(u32::from_bytes(&bytes).to_vec(), qtensor.shape, strategy)
}
},
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/tensor/qtensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ impl<R: JitRuntime, F: FloatElement, I: IntElement> Clone for QJitTensor<R, F, I
fn clone(&self) -> Self {
Self {
qtensor: self.qtensor.clone(),
scheme: self.scheme.clone(),
scheme: self.scheme,
qparams: self.qparams.clone(),
}
}
Expand Down
Loading

0 comments on commit 34b303e

Please sign in to comment.