diff --git a/ruff.toml b/ruff.toml index f61562c23e..25b412c386 100644 --- a/ruff.toml +++ b/ruff.toml @@ -4,12 +4,13 @@ # Example: To lint all files in every subfolder of 'test', add "test/**/*" include = [ "torchao/float8/**/*.py", - "test/dtypes/test_nf4.py", "torchao/quantization/**/*.py", - "test/quantization/test_observer.py", - "test/dtypes/test_affine_quantized_float.py", "torchao/dtypes/**/*.py", + "torchao/sparsity/**/*.py", "torchao/prototype/low_bit_optim/**.py", + "test/quantization/test_observer.py", + "test/dtypes/test_affine_quantized_float.py", + "test/dtypes/test_nf4.py", "test/prototype/low_bit_optim/**.py", ] diff --git a/torchao/sparsity/__init__.py b/torchao/sparsity/__init__.py index c3b10f949a..d8df4ef122 100644 --- a/torchao/sparsity/__init__.py +++ b/torchao/sparsity/__init__.py @@ -4,20 +4,22 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from .wanda import WandaSparsifier # noqa: F403 -from .utils import PerChannelNormObserver # noqa: F403 +from torchao.quantization.quant_api import ( + int8_dynamic_activation_int8_semi_sparse_weight, +) from .sparse_api import ( apply_fake_sparsity, - sparsify_, semi_sparse_weight, - int8_dynamic_activation_int8_semi_sparse_weight + sparsify_, ) +from .utils import PerChannelNormObserver # noqa: F403 +from .wanda import WandaSparsifier # noqa: F403 __all__ = [ "WandaSparsifier", "PerChannelNormObserver", "apply_fake_sparsity", - "sparsify_" + "sparsify_", "semi_sparse_weight", - "int8_dynamic_activation_int8_semi_sparse_weight" + "int8_dynamic_activation_int8_semi_sparse_weight", ] diff --git a/torchao/sparsity/marlin/__init__.py b/torchao/sparsity/marlin/__init__.py index 3cb45a271e..5d2f005e43 100644 --- a/torchao/sparsity/marlin/__init__.py +++ b/torchao/sparsity/marlin/__init__.py @@ -1,11 +1,11 @@ +from typing import Tuple + import torch -from typing import Tuple, Dict, List import torchao.sparsity.marlin.utils as utils from torchao.sparsity.marlin.utils import const from torchao.sparsity.utils import mask_creator - __all__ = [ "inject_24", "marlin_24_workspace", @@ -14,11 +14,13 @@ ] -def inject_24(w: torch.Tensor, size_k: int, size_n: int) -> Tuple[torch.Tensor, torch.Tensor]: +def inject_24( + w: torch.Tensor, size_k: int, size_n: int +) -> Tuple[torch.Tensor, torch.Tensor]: """Injects 2:4 sparsity into a weight tensor. The sparsity is applied in a 2:4 ratio, where for every group of 4 weights, 2 will be pruned based on their value. The mask will be created based on the ranked weight values. - + Args: w (torch.Tensor): The weight tensor to inject sparsity into. size_k (int): The number of input features. @@ -32,13 +34,13 @@ def inject_24(w: torch.Tensor, size_k: int, size_n: int) -> Tuple[torch.Tensor, def marlin_24_workspace( - out_features: int, - min_thread_n: int = const.MIN_THREAD_N, - max_parallel: int = const.MAX_PARALLEL - ) -> torch.Tensor: - """Creates a workspace for marlin 2:4 quantization. The workspace is used to coordinate the locks + out_features: int, + min_thread_n: int = const.MIN_THREAD_N, + max_parallel: int = const.MAX_PARALLEL, +) -> torch.Tensor: + """Creates a workspace for marlin 2:4 quantization. The workspace is used to coordinate the locks during the execution of the kernel. - + Args: out_features (int): The number of output features. min_thread_n (int, optional): The minimum number of threads per block. Defaults to `MARLIN_24_MIN_THREAD_N`. @@ -46,19 +48,21 @@ def marlin_24_workspace( Returns: torch.Tensor: The workspace tensor fully initialized with zeros. """ - assert (out_features % min_thread_n == 0), f"out_features = {out_features}, min_thread_n = {min_thread_n}" - max_workspace_size = ((out_features // min_thread_n) * max_parallel) + assert ( + out_features % min_thread_n == 0 + ), f"out_features = {out_features}, min_thread_n = {min_thread_n}" + max_workspace_size = (out_features // min_thread_n) * max_parallel return torch.zeros(max_workspace_size, dtype=torch.int, device="cuda") def pack_to_marlin_24( - q_w_24: torch.Tensor, - scales: torch.Tensor, - num_bits: int, - group_size: int, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q_w_24: torch.Tensor, + scales: torch.Tensor, + num_bits: int, + group_size: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Packs the quantized weights and scales into the marlin 2:4 format. - + Args: q_w_24 (torch.Tensor): The quantized weight tensor with 2:4 sparsity applied. scales (torch.Tensor): The scale tensor. @@ -89,13 +93,13 @@ def pack_to_marlin_24( def unpack_from_marlin_24( - q_w_24_comp: torch.Tensor, - scales: torch.Tensor, - meta: torch.Tensor, - original_shape: torch.Size, - group_size: int, - num_bits: int - ) -> Tuple[torch.Tensor, torch.Tensor]: + q_w_24_comp: torch.Tensor, + scales: torch.Tensor, + meta: torch.Tensor, + original_shape: torch.Size, + group_size: int, + num_bits: int, +) -> Tuple[torch.Tensor, torch.Tensor]: """Unpacks the quantized weights and scales from the marlin 2:4 format. Args: q_w_24_comp (torch.Tensor): The packed quantized weights. @@ -109,10 +113,8 @@ def unpack_from_marlin_24( """ in_features, out_features = original_shape - # Unpacks the scales - unpacked_scales = _from_marlin_scale( - scales, *original_shape, group_size, num_bits - ) + # Unpacks the scales + unpacked_scales = _from_marlin_scale(scales, *original_shape, group_size, num_bits) in_features_comp = in_features // 2 @@ -130,14 +132,11 @@ def unpack_from_marlin_24( def _compress_quantized_24_weight( - q_24: torch.Tensor, - size_k: int, - size_n: int, - num_bits: int - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Compresses the quantized weights to a 2:4 sparse format. Normalizes the weights over 0 + q_24: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> Tuple[torch.Tensor, torch.Tensor]: + """Compresses the quantized weights to a 2:4 sparse format. Normalizes the weights over 0 before compressing them. - + Args: q_24 (torch.Tensor): The quantized weight tensor. size_k (int): The number of input features. @@ -168,14 +167,10 @@ def _compress_quantized_24_weight( def _decompress_quantized_24_weight( - q_24_comp: torch.Tensor, - meta: torch.Tensor, - size_k: int, - size_n: int, - num_bits: int - ) -> torch.Tensor: + q_24_comp: torch.Tensor, meta: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: """Decompresses the quantized weights from a 2:4 sparse format and restores the original shape. - + Args: q_24_comp (torch.Tensor): The compressed quantized weight tensor in 2:4 sparse format. meta (torch.Tensor): The meta tensor. @@ -210,13 +205,13 @@ def _decompress_quantized_24_weight( def _to_marlin_weights( - q_w: torch.Tensor, - size_k: int, - size_n: int, - num_bits: int, - ) -> torch.Tensor: + q_w: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, +) -> torch.Tensor: """Converts a quantized and 2:4 sparse format weight tensor to the marlin 2:4 format. - + Args: q_w (torch.Tensor): The quantized weight tensor in 2:4 sparse format. size_k (int): The number of input features. @@ -236,7 +231,11 @@ def _to_marlin_weights( # Original implementation uses numpy + uint32 but we need to use int64 because torch.uint32 # does not support rshift_cpu. q_w = q_w.cpu().to(torch.int64) - q_packed = torch.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=torch.int64, device=q_w.device) + q_packed = torch.zeros( + (q_w.shape[0], q_w.shape[1] // pack_factor), + dtype=torch.int64, + device=q_w.device, + ) for i in range(pack_factor): q_packed |= q_w[:, i::pack_factor] << (num_bits * i) @@ -245,13 +244,10 @@ def _to_marlin_weights( def _from_marlin_weights( - q_packed: torch.Tensor, - size_k: int, - size_n: int, - num_bits: int - ) -> torch.Tensor: + q_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: """Converts a weight tensor in the marlin 2:4 format to a regular quantized 2:4 sparse format. - + Args: q_packed (torch.Tensor): The weight tensor in the marlin 2:4 format. size_k (int): The number of input features. @@ -269,23 +265,27 @@ def _from_marlin_weights( # Original implementation uses numpy + uint32 but we need to use int64 because torch.uint32 # does not support rshift_cpu. q_packed = q_packed.cpu().to(torch.int64) - q_w_unpacked = torch.zeros((q_packed.shape[0], q_packed.shape[1] * pack_factor), dtype=torch.int64, device=q_packed.device) + q_w_unpacked = torch.zeros( + (q_packed.shape[0], q_packed.shape[1] * pack_factor), + dtype=torch.int64, + device=q_packed.device, + ) for i in range(pack_factor): - q_w_unpacked[:, i::pack_factor] = (q_packed >> (num_bits * i)) & ((1 << num_bits) - 1) + q_w_unpacked[:, i::pack_factor] = (q_packed >> (num_bits * i)) & ( + (1 << num_bits) - 1 + ) q_w_unpacked = q_w_unpacked.to(orig_device, dtype=torch.int32) - q_w_comp = utils.reverse_marlin_permute_weights(q_w_unpacked, size_k, size_n, perm_24) + q_w_comp = utils.reverse_marlin_permute_weights( + q_w_unpacked, size_k, size_n, perm_24 + ) return q_w_comp def _to_marlin_scales( - scales: torch.Tensor, - size_k: int, - size_n: int, - group_size: int, - num_bits: int - ) -> torch.Tensor: + scales: torch.Tensor, size_k: int, size_n: int, group_size: int, num_bits: int +) -> torch.Tensor: """Converts a scale tensor to the format necessary for marlin. Args: scales (torch.Tensor): The scale tensor. @@ -293,7 +293,7 @@ def _to_marlin_scales( size_n (int): The number of output features. group_size (int): The group size that was applied during quantization. num_bits (int): The number of bits used for quantization. - + Returns: torch.Tensor: The scale tensor in the marlin format. """ @@ -301,20 +301,18 @@ def _to_marlin_scales( if group_size < size_k and group_size != -1: scales = scales.reshape((-1, len(scale_perm_24)))[:, scale_perm_24] else: - scales = scales.reshape((-1, len(scale_perm_single_24)))[:, scale_perm_single_24] + scales = scales.reshape((-1, len(scale_perm_single_24)))[ + :, scale_perm_single_24 + ] scales = scales.reshape((-1, size_n)).contiguous() return scales def _from_marlin_scale( - scales: torch.Tensor, - size_k: int, - size_n: int, - group_size: int, - num_bits: int - ) -> torch.Tensor: + scales: torch.Tensor, size_k: int, size_n: int, group_size: int, num_bits: int +) -> torch.Tensor: """Converts a scale tensor from the marlin format to their original format. - + Args: scales (torch.Tensor): The scale tensor in the marlin format. size_k (int): The number of input features. @@ -329,5 +327,7 @@ def _from_marlin_scale( scales = scales.reshape((-1, len(scale_perm_24)))[:, scale_perm_24] return scales.reshape((size_k // group_size, size_n)) else: - scales = scales.reshape((-1, len(scale_perm_single_24)))[:, scale_perm_single_24] - return scales.reshape((1, -1)) + scales = scales.reshape((-1, len(scale_perm_single_24)))[ + :, scale_perm_single_24 + ] + return scales.reshape((1, -1)) diff --git a/torchao/sparsity/marlin/utils.py b/torchao/sparsity/marlin/utils.py index 4c55725539..45837773dd 100644 --- a/torchao/sparsity/marlin/utils.py +++ b/torchao/sparsity/marlin/utils.py @@ -1,6 +1,7 @@ -import torch -from typing import List, Tuple from dataclasses import dataclass, field +from typing import List, Tuple + +import torch @dataclass(frozen=True) @@ -12,12 +13,14 @@ class Marlin24Constants: # NOTE: Cuda kernel supports fp8, but not implemented yet in SparseMarlinAQTTensorImpl SUPPORTED_NUM_BITS: List[int] = field(default_factory=lambda: [4, 8]) SUPPORTED_GROUP_SIZES: List[int] = field(default_factory=lambda: [-1, 32, 64, 128]) + + const = Marlin24Constants() def get_pack_factor(num_bits: int) -> int: """Compute the packing factor for a given number of bits. - + Args: num_bits (int): Number of bits to pack. Returns: @@ -29,14 +32,14 @@ def get_pack_factor(num_bits: int) -> int: def marlin_permute_weights( - q_w: torch.Tensor, - size_k: int, - size_n: int, - perm: torch.Tensor, - tile: int = const.TILE - ) -> torch.Tensor: + q_w: torch.Tensor, + size_k: int, + size_n: int, + perm: torch.Tensor, + tile: int = const.TILE, +) -> torch.Tensor: """Permute weights to 16x64 Marlin tiles. - + Args: q_w (torch.Tensor): Quantized weights. size_k (int): Number of input features. @@ -62,12 +65,12 @@ def marlin_permute_weights( def reverse_marlin_permute_weights( - q_w_unpacked: torch.Tensor, - size_k: int, - size_n: int, - reverse_perm: torch.Tensor, - tile: int = const.TILE, - ) -> torch.Tensor: + q_w_unpacked: torch.Tensor, + size_k: int, + size_n: int, + reverse_perm: torch.Tensor, + tile: int = const.TILE, +) -> torch.Tensor: """Reverse permute weights from 16x64 Marlin tiles. Args: q_w_unpacked (torch.Tensor): Unpacked quantized weights. @@ -79,12 +82,17 @@ def reverse_marlin_permute_weights( torch.Tensor: Weight tensor reverse permuted from Marlin tiles. """ - assert (q_w_unpacked.shape[0], size_n) == (size_k // tile, q_w_unpacked.shape[1] // tile) + assert (q_w_unpacked.shape[0], size_n) == ( + size_k // tile, + q_w_unpacked.shape[1] // tile, + ) assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" # Reverse permute weights to original shape - q_w_comp = q_w_unpacked.reshape((-1, reverse_perm.numel()))[:, reverse_perm].reshape(q_w_unpacked.shape) + q_w_comp = q_w_unpacked.reshape((-1, reverse_perm.numel()))[ + :, reverse_perm + ].reshape(q_w_unpacked.shape) q_w_comp = q_w_comp.reshape((size_k // tile, size_n // tile, tile, tile)) q_w_comp = q_w_comp.permute((0, 2, 1, 3)) q_w_comp = q_w_comp.reshape((size_k, size_n)) @@ -94,18 +102,18 @@ def reverse_marlin_permute_weights( def get_perms_24(num_bits: int) -> Tuple[torch.Tensor, List[int], List[int]]: """Precompute permutations for Marlin24 weight and scale shuffling - + Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible with the tensor-core format that is described here: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type - + As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core (without the need to use ldmatrix instructions) - + Args: num_bits (int): Number of bits to pack. Returns: - Tuple[torch.Tensor, List[int], List[int]]: The weight permutation tensor, scale permutation list, and + Tuple[torch.Tensor, List[int], List[int]]: The weight permutation tensor, scale permutation list, and scale permutation list for a single group. """ perm_list: List[int] = [] @@ -115,16 +123,15 @@ def get_perms_24(num_bits: int) -> Tuple[torch.Tensor, List[int], List[int]]: col_o = col // 2 for block in [0, 1]: for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, ]: - perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + - 4 * block) + perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block) for j in range(4): perm_list.extend([p + 1 * j for p in perm1]) - + # Convert to torch tensor perm = torch.tensor(perm_list, dtype=torch.int32) @@ -149,13 +156,15 @@ def get_perms_24(num_bits: int) -> Tuple[torch.Tensor, List[int], List[int]]: return perm, scale_perm, scale_perm_single -def get_reverse_perms_24(num_bits: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +def get_reverse_perms_24( + num_bits: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Reverse permutation for Marlin24 weight and scale shuffling from `get_perms_24`. - + Args: num_bits (int): Number of bits to pack. Returns: - Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The reversed weight permutation tensor, scale permutation list and + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The reversed weight permutation tensor, scale permutation list and scale permutation list for single group. """ perm_24, scale_perm_24, scale_perm_single_24 = get_perms_24(num_bits) @@ -180,8 +189,7 @@ def get_reverse_perms_24(num_bits: int) -> Tuple[torch.Tensor, torch.Tensor, tor # matrix elements into reordered metadata matrix elements (or, # equivalently, for gathering reordered metadata matrix element back # into metadata matrix elements). -def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, - device): +def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device): dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) @@ -189,9 +197,13 @@ def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, group_x = 64 group_y = 32 if meta_dtype.itemsize == 2 else 16 - dst_rows = (dst_rows // group_x * group_x + (dst_rows % 2) * 2 + - (dst_rows % 8) // 4 + ((dst_rows % group_y) % 4) // 2 * 32 + - ((dst_rows % group_x) // 8) * 4) + dst_rows = ( + dst_rows // group_x * group_x + + (dst_rows % 2) * 2 + + (dst_rows % 8) // 4 + + ((dst_rows % group_y) % 4) // 2 * 32 + + ((dst_rows % group_x) // 8) * 4 + ) topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) @@ -204,8 +216,7 @@ def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, interleave = 2 cols_maj = dst_cols // interleave cols_min = dst_cols % interleave - return (cols_maj * m * interleave + dst_rows * interleave + - cols_min).view(-1) + return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1) # This function converts dense matrix into sparse semi-structured @@ -229,17 +240,18 @@ def sparse_semi_structured_from_dense_cutlass(dense): raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 if quadbits_per_meta_elem not in (4, 8): - raise RuntimeError( - "Invalid number of elements per meta element calculated") + raise RuntimeError("Invalid number of elements per meta element calculated") if meta_dtype == torch.int32: if m % 16 != 0: raise RuntimeError( - f"Number of rows of dense matrix {m} must be divisible by 16") + f"Number of rows of dense matrix {m} must be divisible by 16" + ) else: if m % 32 != 0: raise RuntimeError( - f"Number of rows of dense matrix {m} must be divisible by 32") + f"Number of rows of dense matrix {m} must be divisible by 32" + ) if k % (4 * quadbits_per_meta_elem) != 0: raise RuntimeError( f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 @@ -300,40 +312,39 @@ def sparse_semi_structured_from_dense_cutlass(dense): idxs1 = bit2 | (bit3.to(torch.int64) << 1) if dense.dtype != torch.float: - sparse0 = dense_4.gather( - -1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] + sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) else: - sparse = dense_2.gather(-1, - idxs0.unsqueeze(-1) // 2).view( - m, - k // 2) # type: ignore[possibly-undefined] + sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2) # type: ignore[possibly-undefined] meta_4 = idxs0 | (idxs1 << 2) - meta_n = meta_4.view( - (-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) if quadbits_per_meta_elem == 4: - meta = (meta_n[:, :, 0] - | (meta_n[:, :, 1] << 4) - | (meta_n[:, :, 2] << 8) - | (meta_n[:, :, 3] << 12)) + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + ) elif quadbits_per_meta_elem == 8: - meta = (meta_n[:, :, 0] - | (meta_n[:, :, 1] << 4) - | (meta_n[:, :, 2] << 8) - | (meta_n[:, :, 3] << 12) - | (meta_n[:, :, 4] << 16) - | (meta_n[:, :, 5] << 20) - | (meta_n[:, :, 6] << 24) - | (meta_n[:, :, 7] << 28)) + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28) + ) # Reorder meta tensor elements. - meta_reordered = meta.new_empty( - (m * meta_ncols, )) # type: ignore[possibly-undefined] + meta_reordered = meta.new_empty((m * meta_ncols,)) # type: ignore[possibly-undefined] meta_offsets = _calculate_meta_reordering_scatter_offsets( - m, meta_ncols, meta_dtype, device) + m, meta_ncols, meta_dtype, device + ) meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) return (sparse, meta_reordered.view(m, meta_ncols)) @@ -376,13 +387,14 @@ def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: raise RuntimeError( f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 - "expected according to the number of columns of meta matrix") + "expected according to the number of columns of meta matrix" + ) # Undo meta tensor elements reordering. meta_offsets = _calculate_meta_reordering_scatter_offsets( - m, meta_ncols, meta_dtype, device) - meta = torch.gather(meta_reordered.view(-1), 0, - meta_offsets).view(m, meta_ncols) + m, meta_ncols, meta_dtype, device + ) + meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols) # Unpack sparse tensor back to original dense tensor, using # information provided by meta tensor. Note that torch.float @@ -424,15 +436,16 @@ def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): meta_2[:, :, 15] = (meta >> 30) & 0b11 dense_offsets = meta_2.view(-1) + ( - torch.arange(0, 2 * m * k // ksparse, device=device) * 4).view( - -1, 1).repeat(1, 2).view(-1) + torch.arange(0, 2 * m * k // ksparse, device=device) * 4 + ).view(-1, 1).repeat(1, 2).view(-1) - dense = torch.zeros((m * 2 * k, ), dtype=sparse.dtype, device=device) + dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device) if sparse.dtype != torch.float: # dense.scatter_(0, dense_offsets, sparse.view(-1)) dense.scatter_(0, dense_offsets, sparse.reshape(-1)) else: - dense.view(torch.half).scatter_(0, dense_offsets, - sparse.view(torch.half).view(-1)) + dense.view(torch.half).scatter_( + 0, dense_offsets, sparse.view(torch.half).view(-1) + ) return dense.view(m, 2 * k) diff --git a/torchao/sparsity/prototype/__init__.py b/torchao/sparsity/prototype/__init__.py index 924b7f409b..821e5049e0 100644 --- a/torchao/sparsity/prototype/__init__.py +++ b/torchao/sparsity/prototype/__init__.py @@ -18,3 +18,16 @@ from torchao.prototype.sparsity.sparsifier.weight_norm_sparsifier import ( WeightNormSparsifier, ) + +__all__ = [ + "BaseScheduler", + "CubicSL", + "LambdaSL", + "BaseSparsifier", + "NearlyDiagonalSparsifier", + "FakeSparsity", + "fqn_to_module", + "get_arg_info_from_tensor_fqn", + "module_to_fqn", + "WeightNormSparsifier", +] diff --git a/torchao/sparsity/prototype/pruner/FPGM_pruner.py b/torchao/sparsity/prototype/pruner/FPGM_pruner.py index 412c395108..3e091bfafb 100644 --- a/torchao/sparsity/prototype/pruner/FPGM_pruner.py +++ b/torchao/sparsity/prototype/pruner/FPGM_pruner.py @@ -1 +1,3 @@ from torchao.prototype.sparsity.pruner.FPGM_pruner import FPGMPruner + +__all__ = ["FPGMPruner"] diff --git a/torchao/sparsity/prototype/pruner/__init__.py b/torchao/sparsity/prototype/pruner/__init__.py index 6f017aa9e2..9d7f775389 100644 --- a/torchao/sparsity/prototype/pruner/__init__.py +++ b/torchao/sparsity/prototype/pruner/__init__.py @@ -1,8 +1,17 @@ from .base_structured_sparsifier import BaseStructuredSparsifier +from .FPGM_pruner import FPGMPruner +from .lstm_saliency_pruner import LSTMSaliencyPruner from .parametrization import ( - FakeStructuredSparsity, BiasHook, + FakeStructuredSparsity, ) from .saliency_pruner import SaliencyPruner -from .lstm_saliency_pruner import LSTMSaliencyPruner -from .FPGM_pruner import FPGMPruner + +__all__ = [ + "BaseStructuredSparsifier", + "FPGMPruner", + "LSTMSaliencyPruner", + "BiasHook", + "FakeStructuredSparsity", + "SaliencyPruner", +] diff --git a/torchao/sparsity/prototype/pruner/base_structured_sparsifier.py b/torchao/sparsity/prototype/pruner/base_structured_sparsifier.py index 257750d4df..9117662d82 100644 --- a/torchao/sparsity/prototype/pruner/base_structured_sparsifier.py +++ b/torchao/sparsity/prototype/pruner/base_structured_sparsifier.py @@ -1,3 +1,7 @@ from torchao.prototype.sparsity.pruner.base_structured_sparsifier import ( BaseStructuredSparsifier, ) + +__all__ = [ + "BaseStructuredSparsifier" +] diff --git a/torchao/sparsity/prototype/pruner/lstm_saliency_pruner.py b/torchao/sparsity/prototype/pruner/lstm_saliency_pruner.py index 9c1656bf47..bc5663303c 100644 --- a/torchao/sparsity/prototype/pruner/lstm_saliency_pruner.py +++ b/torchao/sparsity/prototype/pruner/lstm_saliency_pruner.py @@ -1 +1,3 @@ from torchao.prototype.sparsity.pruner.lstm_saliency_pruner import LSTMSaliencyPruner + +__all__ = ["LSTMSaliencyPruner"] diff --git a/torchao/sparsity/prototype/pruner/parametrization.py b/torchao/sparsity/prototype/pruner/parametrization.py index 8603639293..f8c6e6e418 100644 --- a/torchao/sparsity/prototype/pruner/parametrization.py +++ b/torchao/sparsity/prototype/pruner/parametrization.py @@ -2,3 +2,8 @@ BiasHook, FakeStructuredSparsity, ) + +__all__ = [ + "BiasHook", + "FakeStructuredSparsity", +] diff --git a/torchao/sparsity/prototype/pruner/saliency_pruner.py b/torchao/sparsity/prototype/pruner/saliency_pruner.py index 4f43ccf46e..7bd10064f5 100644 --- a/torchao/sparsity/prototype/pruner/saliency_pruner.py +++ b/torchao/sparsity/prototype/pruner/saliency_pruner.py @@ -1 +1,5 @@ from torchao.prototype.sparsity.pruner.saliency_pruner import SaliencyPruner + +__all__ = [ + "SaliencyPruner" +] diff --git a/torchao/sparsity/prototype/scheduler/base_scheduler.py b/torchao/sparsity/prototype/scheduler/base_scheduler.py index 877f419ac1..31169a28ef 100644 --- a/torchao/sparsity/prototype/scheduler/base_scheduler.py +++ b/torchao/sparsity/prototype/scheduler/base_scheduler.py @@ -1 +1,3 @@ from torchao.prototype.sparsity.scheduler.base_scheduler import BaseScheduler + +__all__ = ["BaseScheduler"] diff --git a/torchao/sparsity/prototype/scheduler/cubic_scheduler.py b/torchao/sparsity/prototype/scheduler/cubic_scheduler.py index 9b86f16be6..82847ffef7 100644 --- a/torchao/sparsity/prototype/scheduler/cubic_scheduler.py +++ b/torchao/sparsity/prototype/scheduler/cubic_scheduler.py @@ -1 +1,3 @@ from torchao.prototype.sparsity.scheduler.cubic_scheduler import CubicSL + +__all__ = ["CubicSL"] diff --git a/torchao/sparsity/prototype/scheduler/lambda_scheduler.py b/torchao/sparsity/prototype/scheduler/lambda_scheduler.py index 0730558195..47fce2a0b3 100644 --- a/torchao/sparsity/prototype/scheduler/lambda_scheduler.py +++ b/torchao/sparsity/prototype/scheduler/lambda_scheduler.py @@ -1 +1,3 @@ from torchao.prototype.sparsity.scheduler.lambda_scheduler import LambdaSL + +__all__ = ["LambdaSL"] diff --git a/torchao/sparsity/prototype/sparsifier/base_sparsifier.py b/torchao/sparsity/prototype/sparsifier/base_sparsifier.py index 954c06b74f..490c1872c7 100644 --- a/torchao/sparsity/prototype/sparsifier/base_sparsifier.py +++ b/torchao/sparsity/prototype/sparsifier/base_sparsifier.py @@ -1 +1,3 @@ from torchao.prototype.sparsity.sparsifier.base_sparsifier import BaseSparsifier + +__all__ = ["BaseSparsifier"] diff --git a/torchao/sparsity/prototype/sparsifier/nearly_diagonal_sparsifier.py b/torchao/sparsity/prototype/sparsifier/nearly_diagonal_sparsifier.py index 640ec667c2..d0ec384367 100644 --- a/torchao/sparsity/prototype/sparsifier/nearly_diagonal_sparsifier.py +++ b/torchao/sparsity/prototype/sparsifier/nearly_diagonal_sparsifier.py @@ -1,3 +1,5 @@ from torchao.prototype.sparsity.sparsifier.nearly_diagonal_sparsifier import ( NearlyDiagonalSparsifier, ) + +__all__ = ["NearlyDiagonalSparsifier"] diff --git a/torchao/sparsity/prototype/sparsifier/utils.py b/torchao/sparsity/prototype/sparsifier/utils.py index f410d1b325..a30fe5628b 100644 --- a/torchao/sparsity/prototype/sparsifier/utils.py +++ b/torchao/sparsity/prototype/sparsifier/utils.py @@ -1 +1,3 @@ from torchao.prototype.sparsity.sparsifier.utils import FakeSparsity + +__all__ = ["FakeSparsity"] diff --git a/torchao/sparsity/prototype/sparsifier/weight_norm_sparsifier.py b/torchao/sparsity/prototype/sparsifier/weight_norm_sparsifier.py index a490e5f65b..0a33ac5295 100644 --- a/torchao/sparsity/prototype/sparsifier/weight_norm_sparsifier.py +++ b/torchao/sparsity/prototype/sparsifier/weight_norm_sparsifier.py @@ -1,3 +1,5 @@ from torchao.prototype.sparsity.sparsifier.weight_norm_sparsifier import ( WeightNormSparsifier, ) + +__all__ = ["WeightNormSparsifier"] diff --git a/torchao/sparsity/prototype/superblock/blocksparse.py b/torchao/sparsity/prototype/superblock/blocksparse.py index a1696c02b7..54510e393a 100644 --- a/torchao/sparsity/prototype/superblock/blocksparse.py +++ b/torchao/sparsity/prototype/superblock/blocksparse.py @@ -1 +1,3 @@ from torchao.prototype.sparsity.superblock.blocksparse import BlockSparseTensor + +__all__ = ["BlockSparseTensor"] diff --git a/torchao/sparsity/prototype/superblock/supermask.py b/torchao/sparsity/prototype/superblock/supermask.py index 75d5b17651..f502d1f2ad 100644 --- a/torchao/sparsity/prototype/superblock/supermask.py +++ b/torchao/sparsity/prototype/superblock/supermask.py @@ -3,3 +3,9 @@ SupermaskConv2d, SupermaskLinear, ) + +__all__ = [ + "GetSubnet", + "SupermaskConv2d", + "SupermaskLinear", +] diff --git a/torchao/sparsity/prototype/superblock/utils.py b/torchao/sparsity/prototype/superblock/utils.py index e409abffc0..e60e244c80 100644 --- a/torchao/sparsity/prototype/superblock/utils.py +++ b/torchao/sparsity/prototype/superblock/utils.py @@ -8,3 +8,14 @@ RASampler, SmoothedValue, ) + +__all__ = [ + "ClassificationPresetEval", + "ClassificationPresetTrain", + "ExponentialMovingAverage", + "MetricLogger", + "RandomCutmix", + "RandomMixup", + "RASampler", + "SmoothedValue", +] diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index 98bfa3b30a..3dd7971525 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -3,11 +3,11 @@ import torch from torch.ao.pruning import WeightNormSparsifier from torch.sparse import to_sparse_semi_structured + from torchao.quantization.quant_api import ( _get_linear_subclass_inserter, _is_linear, _replace_with_custom_fn_if_matches_filter, - int8_dynamic_activation_int8_semi_sparse_weight, ) diff --git a/torchao/sparsity/training/__init__.py b/torchao/sparsity/training/__init__.py index 35d5e5436f..3c4212101b 100644 --- a/torchao/sparsity/training/__init__.py +++ b/torchao/sparsity/training/__init__.py @@ -6,13 +6,15 @@ from torchao.sparsity.training.autograd import semi_structured_sparsify from torchao.sparsity.training.pointwise_ops import CUTLASS_POINTWISE_OP_DISPATCH_TABLE - from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 # load pointwise op support, which exists only for CUTLASS if TORCH_VERSION_AT_LEAST_2_3: from torch.sparse import SparseSemiStructuredTensorCUTLASS - SparseSemiStructuredTensorCUTLASS._load_dispatch_table(CUTLASS_POINTWISE_OP_DISPATCH_TABLE) + + SparseSemiStructuredTensorCUTLASS._load_dispatch_table( + CUTLASS_POINTWISE_OP_DISPATCH_TABLE + ) __all__ = [ "SemiSparseLinear", @@ -21,6 +23,7 @@ "swap_semi_sparse_linear_with_linear", ] + class SemiSparseLinear(torch.nn.Linear): """ Replacement nn.Linear that supports runtime weight sparsity @@ -39,7 +42,9 @@ def from_dense(cls, linear): @classmethod def to_dense(cls, semi_sparse_linear): - mod = torch.nn.Linear(semi_sparse_linear.in_features, semi_sparse_linear.out_features) + mod = torch.nn.Linear( + semi_sparse_linear.in_features, semi_sparse_linear.out_features + ) mod.weight = semi_sparse_linear.weight mod.bias = semi_sparse_linear.bias return mod @@ -63,11 +68,14 @@ def from_dense(cls, linear): @classmethod def to_dense(cls, semi_sparse_linear): - mod = torch.nn.Linear(semi_sparse_linear.in_features, semi_sparse_linear.out_features) + mod = torch.nn.Linear( + semi_sparse_linear.in_features, semi_sparse_linear.out_features + ) mod.weight = semi_sparse_linear.weight mod.bias = semi_sparse_linear.bias return mod + def swap_linear_with_semi_sparse_linear(model, config, current=""): """ Public API for replacing nn.Linear with SemiSparseLinear @@ -82,6 +90,7 @@ def swap_linear_with_semi_sparse_linear(model, config, current=""): else: swap_linear_with_semi_sparse_linear(child, config, current=fqn) + def swap_semi_sparse_linear_with_linear(model, current=""): """ Public API for replacing instances of SemiSparseLinear/SemiSparseActivaitonLinear with nn.Linear diff --git a/torchao/sparsity/training/autograd.py b/torchao/sparsity/training/autograd.py index 33f069c5d9..fa61fa3b2c 100644 --- a/torchao/sparsity/training/autograd.py +++ b/torchao/sparsity/training/autograd.py @@ -1,28 +1,37 @@ from enum import Enum + import torch from torch.sparse import SparseSemiStructuredTensor from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 if TORCH_VERSION_AT_LEAST_2_3: - from torch.sparse import SparseSemiStructuredTensorCUTLASS, SparseSemiStructuredTensorCUSPARSELT + from torch.sparse import ( + SparseSemiStructuredTensorCUSPARSELT, + SparseSemiStructuredTensorCUTLASS, + ) + torch._dynamo.allow_in_graph(SparseSemiStructuredTensorCUSPARSELT) torch._dynamo.allow_in_graph(SparseSemiStructuredTensorCUTLASS) GRADIENT_TYPE = Enum("GRADIENT_TYPE", ["DENSE", "SPARSE", "STE"]) -class _SparsifyFunc(torch.autograd.Function): +class _SparsifyFunc(torch.autograd.Function): @staticmethod def forward(ctx, x: torch.Tensor, algo: str, backend: GRADIENT_TYPE): # type: ignore[override] - use_cutlass = (backend == "cutlass") + use_cutlass = backend == "cutlass" if not isinstance(x, SparseSemiStructuredTensor): - (packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile( - x, algorithm=algo, use_cutlass=use_cutlass + (packed, meta, packed_t, meta_t, bitmask) = ( + torch._sparse_semi_structured_tile( + x, algorithm=algo, use_cutlass=use_cutlass + ) ) cls = ( - SparseSemiStructuredTensorCUTLASS if use_cutlass else SparseSemiStructuredTensorCUSPARSELT + SparseSemiStructuredTensorCUTLASS + if use_cutlass + else SparseSemiStructuredTensorCUSPARSELT ) out = cls( x.shape, @@ -44,10 +53,15 @@ def backward(ctx, grad_out: torch.Tensor): # type: ignore[override] # We just return grad_out, since we just use STE - straight through estimation return grad_out, None, None -class _SparsifyLikeFunc(torch.autograd.Function): +class _SparsifyLikeFunc(torch.autograd.Function): @staticmethod - def forward(ctx, x: torch.Tensor, pattern: SparseSemiStructuredTensor, gradient=GRADIENT_TYPE.SPARSE): # type: ignore[override] + def forward( + ctx, + x: torch.Tensor, + pattern: SparseSemiStructuredTensor, + gradient=GRADIENT_TYPE.SPARSE, + ): # type: ignore[override] assert isinstance(pattern, SparseSemiStructuredTensor) if not isinstance(pattern, SparseSemiStructuredTensorCUTLASS): @@ -59,7 +73,9 @@ def forward(ctx, x: torch.Tensor, pattern: SparseSemiStructuredTensor, gradient= "`sparsify_like(x, pattern)` is not implemented when `bitmask` is transposed" ) - packed, packed_t = torch._sparse_semi_structured_apply(x, pattern.compressed_swizzled_bitmask) + packed, packed_t = torch._sparse_semi_structured_apply( + x, pattern.compressed_swizzled_bitmask + ) # save for backwards ctx.meta = pattern.meta @@ -79,7 +95,9 @@ def forward(ctx, x: torch.Tensor, pattern: SparseSemiStructuredTensor, gradient= @staticmethod def backward(ctx, grad_out: torch.Tensor): # type: ignore[override] - if ctx.gradient == GRADIENT_TYPE.STE or isinstance(grad_out, SparseSemiStructuredTensor): + if ctx.gradient == GRADIENT_TYPE.STE or isinstance( + grad_out, SparseSemiStructuredTensor + ): return grad_out, None, None, None assert not isinstance(grad_out, SparseSemiStructuredTensor) assert grad_out.dtype == ctx.dtype @@ -113,6 +131,7 @@ def backward(ctx, grad_out: torch.Tensor): # type: ignore[override] ) return grad_out, None + @torch._dynamo.allow_in_graph def semi_structured_sparsify( x: torch.Tensor, @@ -124,6 +143,7 @@ def semi_structured_sparsify( """ return _SparsifyFunc.apply(x, algo, backend) + @torch._dynamo.allow_in_graph def semi_structured_sparsify_like( x: torch.Tensor, diff --git a/torchao/sparsity/training/pointwise_ops.py b/torchao/sparsity/training/pointwise_ops.py index 25aa0dbc52..0c50cff43b 100644 --- a/torchao/sparsity/training/pointwise_ops.py +++ b/torchao/sparsity/training/pointwise_ops.py @@ -1,14 +1,19 @@ from functools import partial + import torch -from torch.sparse import SparseSemiStructuredTensor +from torch.sparse import SparseSemiStructuredTensor + from torchao.sparsity.training.autograd import semi_structured_sparsify_like -def _semi_sparse_pointwise_op(func, types, args=(), kwargs=None, sparsify_like_args_list=()): + +def _semi_sparse_pointwise_op( + func, types, args=(), kwargs=None, sparsify_like_args_list=() +): """ - adds pointwise op support for semi-structured tensors. + adds pointwise op support for semi-structured tensors. Assumes that at least one of the arguments in arg is a SparseSemiStructuredTensor. - The last instance of a SparseSemiStructuredTensor is used as the reference mask to sparsify the others tensors passed in args. + The last instance of a SparseSemiStructuredTensor is used as the reference mask to sparsify the others tensors passed in args. sparsify_like_args_list is used to specify which arguments to sparsify like the reference tensor. """ reference_sparse_tensor = None @@ -23,22 +28,26 @@ def handle_arg(i, tensor): # if they are specified in `sparsify_like_args_list`. if not isinstance(tensor, SparseSemiStructuredTensor): if i in sparsify_like_args_list: - tensor = semi_structured_sparsify_like(tensor, reference_sparse_tensor) + tensor = semi_structured_sparsify_like( + tensor, reference_sparse_tensor + ) else: raise ValueError( f"Operation {func.__module__}.{func.__name__} on {type(reference_sparse_tensor)} requires all operands to " f"be {type(reference_sparse_tensor)}, but operand {i} is a {type(tensor)}" ) - - # If the tensor is a SparseSemiStructuredTensor, we make sure that the sparsity pattern is the same as the reference tensor. + + # If the tensor is a SparseSemiStructuredTensor, we make sure that the sparsity pattern is the same as the reference tensor. # Pointwise ops on tensors containing two different sparsity patterns is not defined, as in the case of addition, where - # adding two semi-structured sparse tensors yields a result that is not semi-structured sparse. + # adding two semi-structured sparse tensors yields a result that is not semi-structured sparse. else: if ( tensor.compressed_swizzled_bitmask is None or reference_sparse_tensor.compressed_swizzled_bitmask is None - or tensor.compressed_swizzled_bitmask.data_ptr() != reference_sparse_tensor.compressed_swizzled_bitmask.data_ptr() - or tensor.compressed_swizzled_bitmask.stride() != reference_sparse_tensor.compressed_swizzled_bitmask.stride() + or tensor.compressed_swizzled_bitmask.data_ptr() + != reference_sparse_tensor.compressed_swizzled_bitmask.data_ptr() + or tensor.compressed_swizzled_bitmask.stride() + != reference_sparse_tensor.compressed_swizzled_bitmask.stride() ): raise ValueError( f"Operation {func.__module__}.{func.__name__} on {type(reference_sparse_tensor)} requires all operands to be " @@ -46,25 +55,28 @@ def handle_arg(i, tensor): ) return tensor - args_updated = [ handle_arg(i, tensor) for i, tensor in enumerate(args) ] + args_updated = [handle_arg(i, tensor) for i, tensor in enumerate(args)] return reference_sparse_tensor.__class__( reference_sparse_tensor.shape, - func(*[ + func( + *[ x.packed if isinstance(x, SparseSemiStructuredTensor) else x for x in args_updated - ]), + ] + ), reference_sparse_tensor.meta, func( *[ - x.packed_t if isinstance(x, SparseSemiStructuredTensor) - else x for x in args_updated + x.packed_t if isinstance(x, SparseSemiStructuredTensor) else x + for x in args_updated ] ), reference_sparse_tensor.meta_t, reference_sparse_tensor.compressed_swizzled_bitmask, ) + # Add pointwise ops to the dispatch table CUTLASS_POINTWISE_OP_DISPATCH_TABLE = { torch.ops.aten.relu: _semi_sparse_pointwise_op, @@ -79,12 +91,10 @@ def handle_arg(i, tensor): # Note: for these ops, we allow the gradient to come in as a `torch.Tensor` # and we will run the sparsification right before calling the BW aten func torch.ops.aten.gelu_backward: partial( - _semi_sparse_pointwise_op, - sparsify_like_args_list=(0,) + _semi_sparse_pointwise_op, sparsify_like_args_list=(0,) ), torch.ops.aten.silu_backward: partial( - _semi_sparse_pointwise_op, - sparsify_like_args_list=(0, 1) + _semi_sparse_pointwise_op, sparsify_like_args_list=(0, 1) ), torch.ops.aten.threshold_backward: partial( # relu BW _semi_sparse_pointwise_op, diff --git a/torchao/sparsity/utils.py b/torchao/sparsity/utils.py index 4b5164863f..1abd27a80f 100644 --- a/torchao/sparsity/utils.py +++ b/torchao/sparsity/utils.py @@ -1,4 +1,5 @@ import random + import torch from torch.ao.quantization.observer import UniformQuantizationObserverBase @@ -9,18 +10,20 @@ "mask_creator", ] + def create_block_sparse_tensor(M, N, blocksize, sparsity, dtype): - assert sparsity <= 1.0 and sparsity >= 0.0, \ - "sparsity should be a value between 0 and 1" - A = torch.bernoulli(torch.full((M//blocksize, N//blocksize), - 1 - sparsity, dtype=dtype)) + assert ( + sparsity <= 1.0 and sparsity >= 0.0 + ), "sparsity should be a value between 0 and 1" + A = torch.bernoulli( + torch.full((M // blocksize, N // blocksize), 1 - sparsity, dtype=dtype) + ) A = torch.repeat_interleave(A, blocksize, dim=0) A = torch.repeat_interleave(A, blocksize, dim=1) return A.to(dtype).contiguous().cuda() -def create_semi_structured_tensor( - r, c, dtype -): + +def create_semi_structured_tensor(r, c, dtype): """ This function returns a 1:2 sparse matrix of size (r, c). Note that this means this matrix will also be 2:4 and 4:8 sparse as well. @@ -30,13 +33,12 @@ def create_semi_structured_tensor( mask_entries = [random.choice(choices) for i in range(r * c // 2)] mask = ( - torch.tensor(mask_entries, dtype=torch.int32) - .reshape(r, c) - .contiguous() + torch.tensor(mask_entries, dtype=torch.int32).reshape(r, c).contiguous() ).cuda() sparse_weight = torch.rand(r, c).cuda() * mask return sparse_weight.to(dtype) + # Observers class PerChannelNormObserver(UniformQuantizationObserverBase): """ @@ -52,7 +54,7 @@ def __init__(self, **kwargs) -> None: quant_min=None, quant_max=None, eps=torch.finfo(torch.float32).eps, - **kwargs + **kwargs, ) # set averaging constant so quantization flow knows observer is memoryless. self.averaging_constant = 1.0 @@ -89,14 +91,14 @@ def calculate_qparams(self): def mask_creator( - tensor: torch.Tensor, - N: int = 2, - M: int = 4, - ) -> torch.Tensor: + tensor: torch.Tensor, + N: int = 2, + M: int = 4, +) -> torch.Tensor: """ Class for creating N:M sparsity masks. - Masks will be created using the N:M ratio, where for every block of - M weights, N will be pruned based on ranked weight value. Each mask + Masks will be created using the N:M ratio, where for every block of + M weights, N will be pruned based on ranked weight value. Each mask will correspond to the given tensor. :param tensor: The input tensor to create a mask for :param N: The number of weights in a group to keep @@ -107,14 +109,14 @@ def mask_creator( # for i, tensor in enumerate(tensors): if tensor.numel() % M != 0: raise ValueError( - f"Tensor of size {tensor.shape} can't be evenly divided into " - f"{M} groups") + f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups" + ) num_groups = tensor.numel() // M # N:M sparsity for linear layers tensor_temp = tensor.detach().abs().reshape(num_groups, M) - index = torch.argsort(tensor_temp, dim=1)[:, :int(M - N)] + index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)] w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) diff --git a/torchao/sparsity/wanda.py b/torchao/sparsity/wanda.py index 1dbcbebe57..e8aa97d310 100644 --- a/torchao/sparsity/wanda.py +++ b/torchao/sparsity/wanda.py @@ -1,11 +1,10 @@ import warnings - from typing import Dict, List, Optional, Tuple import torch from torch import nn from torch.ao.pruning import BaseSparsifier -from torch.ao.quantization import default_placeholder_observer, QConfig +from torch.ao.quantization import QConfig, default_placeholder_observer from torch.ao.quantization.quantize import _remove_qconfig from .utils import PerChannelNormObserver @@ -101,7 +100,6 @@ def squash_mask( # remove quantization config for config in self.groups: module = config["module"] - tensor_name = config["tensor_name"] _remove_qconfig(module) # remove parameterizations