Skip to content

Commit

Permalink
Fixes observer attachment to model based on config for wanda sparsifi…
Browse files Browse the repository at this point in the history
…er (pytorch#1265)

* Fixes observer attachment to model based on config for wanda sparsifier

* handles case when no config is specified

* Added test case in test_wanda.py for custom config

* lint fix
  • Loading branch information
agrawal-aka authored Dec 18, 2024
1 parent a03ca99 commit 33d57af
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 4 deletions.
33 changes: 33 additions & 0 deletions test/sparsity/test_wanda.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,39 @@ def test_two_layer_mlp_unstructured(self):

sparsifier.squash_mask()

def test_two_layer_mlp_unstructured_custom_config(self):
model = nn.Sequential(
nn.Linear(128, 200), nn.ReLU(), nn.Linear(200, 10)
) # C_in by C_out
X1 = torch.randn(100, 128) # B1 by C_in
X2 = torch.randn(50, 128) # B2 by C_in

# Define custom config to sparsify only the first Linear layer for testing
config = [{"tensor_fqn": "0.weight"}]

sparsifier = WandaSparsifier(sparsity_level=0.5)
sparsifier.prepare(model, config=config)

model(X1)
model(X2)
sparsifier.step()

cnt = 0
for m in model.modules():
if isinstance(m, nn.Linear):
cnt += 1
sparsity_level = (m.weight == 0).float().mean()
if cnt == 1: # First Linear layer should have 50% sparsity
assert (
sparsity_level == 0.5
), f"sparsity for linear layer {cnt} should be 0.5"
else: # Other layers should not be sparsified
assert (
sparsity_level != 0.5
), f"sparsity for linear layer {cnt} should not be 0.5"

sparsifier.squash_mask()


if __name__ == "__main__":
unittest.main()
26 changes: 22 additions & 4 deletions torchao/sparsity/wanda.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
from torch import nn
from torch.ao.pruning import BaseSparsifier
from torch.ao.pruning import BaseSparsifier, get_arg_info_from_tensor_fqn
from torch.ao.quantization import QConfig, default_placeholder_observer
from torch.ao.quantization.quantize import _remove_qconfig

Expand Down Expand Up @@ -47,9 +47,27 @@ def __init__(
def prepare(self, model: nn.Module, config: List[Dict]) -> None:
# activation: use PerChannelNormObserver
# use no-op placeholder weight observer
model.qconfig = QConfig(
activation=PerChannelNormObserver, weight=default_placeholder_observer
) # type: ignore[assignment]
if config is None:
# If no config is provided, apply the qconfig to the entire model
model.qconfig = QConfig(
activation=PerChannelNormObserver, weight=default_placeholder_observer
) # type: ignore[assignment]
else:
for module_config in config:
tensor_fqn = module_config.get("tensor_fqn", None)
if tensor_fqn is None:
raise ValueError("Each config must contain a 'tensor_fqn'.")

# Extract module information from tensor_fqn
info_from_tensor_fqn = get_arg_info_from_tensor_fqn(model, tensor_fqn)
module = info_from_tensor_fqn["module"]

# Apply the qconfig directly to the module if it exists
if module is not None:
module.qconfig = QConfig(
activation=PerChannelNormObserver,
weight=default_placeholder_observer,
) # type: ignore[assignment]
torch.ao.quantization.prepare(model, inplace=True)

# call superclass prepare
Expand Down

0 comments on commit 33d57af

Please sign in to comment.