Skip to content

Commit

Permalink
Fix test for deprecated spmm_coo
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas committed Dec 3, 2024
1 parent bbb7063 commit d25ebb4
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2647,7 +2647,7 @@ def double_quant(
"""

coo_tensor = None
quant_row, quant_col, row_stats, col_stats, _ = int8_double_quant(
quant_row, quant_col, row_stats, col_stats, outlier_cols = int8_double_quant(
A,
col_stats,
row_stats,
Expand All @@ -2657,16 +2657,15 @@ def double_quant(
)

if threshold > 0.0:
# Build COO tensor for any outliers.
outlier_mask = A.abs() >= threshold
outlier_locations = outlier_mask.nonzero()
outliers = A[outlier_mask]
# Build a COO tensor including all of the outlier columns.
outlier_rows = torch.arange(0, A.shape[0], device=A.device, dtype=torch.int32)
outliers = A[:, outlier_cols]
coo_tensor = COOSparseTensor(
A.shape[0],
A.shape[1],
outliers.numel(),
outlier_locations[:, 0].int(),
outlier_locations[:, 1].int(),
outlier_rows.repeat_interleave(outliers.size(1)),
outlier_cols.repeat(outliers.size(0)).int(),
outliers,
)

Expand Down

0 comments on commit d25ebb4

Please sign in to comment.