Skip to content

Commit

Permalink
Assert correct device when dequantizing (like we do for quantizing) (v…
Browse files Browse the repository at this point in the history
…llm-project#90)

* Update forward.py

* Update forward.py
  • Loading branch information
dbogunowicz authored Jun 19, 2024
1 parent d611303 commit 318569d
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ def dequantize(
:param args: quantization args used to quantize x_q
:return: dequantized float tensor
"""
# ensure all tensors are on the same device
# assumes that the target device is the input
# tensor's device
if x_q.device != scale.device:
scale = scale.to(x_q.device)
if x_q.device != zero_point.device:
zero_point = zero_point.to(x_q.device)

if args is None:
if scale.ndim == 0 or scale.ndim == 1:
args = QuantizationArgs(strategy=QuantizationStrategy.TENSOR)
Expand Down

0 comments on commit 318569d

Please sign in to comment.