From 318569d53b28cd191838c948d1065fcb571b43ca Mon Sep 17 00:00:00 2001 From: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> Date: Wed, 19 Jun 2024 16:28:20 +0200 Subject: [PATCH] Assert correct device when dequantizing (like we do for quantizing) (#90) * Update forward.py * Update forward.py --- src/compressed_tensors/quantization/lifecycle/forward.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 453c9df79..b0b952c03 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -93,6 +93,14 @@ def dequantize( :param args: quantization args used to quantize x_q :return: dequantized float tensor """ + # ensure all tensors are on the same device + # assumes that the target device is the input + # tensor's device + if x_q.device != scale.device: + scale = scale.to(x_q.device) + if x_q.device != zero_point.device: + zero_point = zero_point.to(x_q.device) + if args is None: if scale.ndim == 0 or scale.ndim == 1: args = QuantizationArgs(strategy=QuantizationStrategy.TENSOR)