diff --git a/test/sparsity/test_wanda.py b/test/sparsity/test_wanda.py index e02ea9822a..5347e61e03 100644 --- a/test/sparsity/test_wanda.py +++ b/test/sparsity/test_wanda.py @@ -113,6 +113,39 @@ def test_two_layer_mlp_unstructured(self): sparsifier.squash_mask() + def test_two_layer_mlp_unstructured_custom_config(self): + model = nn.Sequential( + nn.Linear(128, 200), nn.ReLU(), nn.Linear(200, 10) + ) # C_in by C_out + X1 = torch.randn(100, 128) # B1 by C_in + X2 = torch.randn(50, 128) # B2 by C_in + + # Define custom config to sparsify only the first Linear layer for testing + config = [{"tensor_fqn": "0.weight"}] + + sparsifier = WandaSparsifier(sparsity_level=0.5) + sparsifier.prepare(model, config=config) + + model(X1) + model(X2) + sparsifier.step() + + cnt = 0 + for m in model.modules(): + if isinstance(m, nn.Linear): + cnt += 1 + sparsity_level = (m.weight == 0).float().mean() + if cnt == 1: # First Linear layer should have 50% sparsity + assert ( + sparsity_level == 0.5 + ), f"sparsity for linear layer {cnt} should be 0.5" + else: # Other layers should not be sparsified + assert ( + sparsity_level != 0.5 + ), f"sparsity for linear layer {cnt} should not be 0.5" + + sparsifier.squash_mask() + if __name__ == "__main__": unittest.main() diff --git a/torchao/sparsity/wanda.py b/torchao/sparsity/wanda.py index e8aa97d310..50cf7835ce 100644 --- a/torchao/sparsity/wanda.py +++ b/torchao/sparsity/wanda.py @@ -3,7 +3,7 @@ import torch from torch import nn -from torch.ao.pruning import BaseSparsifier +from torch.ao.pruning import BaseSparsifier, get_arg_info_from_tensor_fqn from torch.ao.quantization import QConfig, default_placeholder_observer from torch.ao.quantization.quantize import _remove_qconfig @@ -47,9 +47,27 @@ def __init__( def prepare(self, model: nn.Module, config: List[Dict]) -> None: # activation: use PerChannelNormObserver # use no-op placeholder weight observer - model.qconfig = QConfig( - activation=PerChannelNormObserver, weight=default_placeholder_observer - ) # type: ignore[assignment] + if config is None: + # If no config is provided, apply the qconfig to the entire model + model.qconfig = QConfig( + activation=PerChannelNormObserver, weight=default_placeholder_observer + ) # type: ignore[assignment] + else: + for module_config in config: + tensor_fqn = module_config.get("tensor_fqn", None) + if tensor_fqn is None: + raise ValueError("Each config must contain a 'tensor_fqn'.") + + # Extract module information from tensor_fqn + info_from_tensor_fqn = get_arg_info_from_tensor_fqn(model, tensor_fqn) + module = info_from_tensor_fqn["module"] + + # Apply the qconfig directly to the module if it exists + if module is not None: + module.qconfig = QConfig( + activation=PerChannelNormObserver, + weight=default_placeholder_observer, + ) # type: ignore[assignment] torch.ao.quantization.prepare(model, inplace=True) # call superclass prepare