Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP utils cleanup #854

Open
wants to merge 38 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
8a9e89d
use existing constant, clarify comment
kylesayrs Oct 19, 2024
202ac48
fix fwd func call (#845)
dsikka Oct 19, 2024
12e2436
always import from accelerate since it is a dependency
kylesayrs Nov 4, 2024
04df6ef
[Bugfix] Use weight parameter of linear layer (#836)
kylesayrs Oct 24, 2024
9107190
[Bugfix] Rename files to remove colons (#846)
kylesayrs Oct 24, 2024
4b8cfcb
cover all 3.9-3.12 in commit testing (#864)
dhuangnm Oct 24, 2024
58cba1b
Add marlin-24 recipe/configs for e2e testing (#866)
dsikka Oct 25, 2024
885fd74
[Bugfix] onload during sparsity calculation (#862)
kylesayrs Oct 28, 2024
f967197
Fix HFTrainer overloads (#869)
kylesayrs Oct 28, 2024
d9ad3c2
Support Model Offloading Tied Tensors Patch (#872)
kylesayrs Oct 29, 2024
5d9b7eb
add advice about dealing with non-invertable hessians (#875)
kylesayrs Oct 30, 2024
4a9d3dd
seed commit workflow (#877)
andy-neuma Oct 31, 2024
ef9472c
[Observer Restructure]: Add Observers; Add `calibration` and `frozen`…
dsikka Oct 31, 2024
28eff47
Bugfix get observer from name (#883)
rahul-tuli Oct 31, 2024
59779f5
BugFix: Fix Sparsity Reload Testing (#882)
dsikka Nov 1, 2024
44e196f
Use custom unique test names for e2e tests (#892)
dbarbuzzi Nov 4, 2024
dba61ec
Revert "Use custom unique test names for e2e tests (#892)" (#893)
dsikka Nov 4, 2024
4e126b3
Move config["testconfig_path"] assignment (#895)
dbarbuzzi Nov 4, 2024
e962e33
cap accelerate version to avoid bug (#897)
kylesayrs Nov 4, 2024
f8777c7
Fix observing offloaded weight (#896)
kylesayrs Nov 5, 2024
9c168f7
Update image in README.md (#861)
mgoin Nov 5, 2024
5f6f568
update accelerate version (#899)
kylesayrs Nov 7, 2024
fa61cf6
[GPTQ] Iterative Parameter Updating (#863)
kylesayrs Nov 7, 2024
7142da0
Small fixes for release (#901)
dsikka Nov 8, 2024
c82b22c
use smaller portion of dataset (#902)
dsikka Nov 9, 2024
f45c29e
Update example to not fail hessian inversion (#904)
dsikka Nov 9, 2024
8ef26cc
bump version (#907)
dsikka Nov 12, 2024
65d9db2
add default mappings (#906)
kylesayrs Nov 15, 2024
4b4c52e
[SparseAutoModelForCausalLM Deprecation] Feature change (#881)
horheynm Nov 18, 2024
4c75089
correct typo (#888)
kylesayrs Nov 18, 2024
f47b5f5
Explicit defaults for QuantizationModifier targets (#889)
kylesayrs Nov 19, 2024
372957e
[SparseAutoModelForCausalLM Deprecation] Update examples (#880)
horheynm Nov 20, 2024
cdb6231
Support pack_quantized format for nonuniform mixed-precision (#913)
mgoin Nov 20, 2024
50e881f
actually make the test useful (#920)
dsikka Nov 21, 2024
10dc0fe
revert summon_full_params_context
kylesayrs Nov 21, 2024
5416d6f
Merge remote-tracking branch 'origin' into kylesayrs/fsdp-wrapper-name
kylesayrs Nov 21, 2024
99dbb65
Merge branch 'main' into kylesayrs/fsdp-wrapper-name
kylesayrs Nov 24, 2024
e3739c2
Merge branch 'main' into kylesayrs/fsdp-wrapper-name
kylesayrs Nov 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions src/llmcompressor/utils/fsdp/context.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
try:
from accelerate import Accelerator
except ImportError:
Accelerator = None

try:
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.distributed.fsdp._common_utils import TrainingState
from torch.distributed.fsdp._common_utils import FSDP_WRAPPED_MODULE, TrainingState
except ImportError:
FullyShardedDataParallel = None
Accelerator = None

from contextlib import nullcontext

Expand All @@ -14,8 +17,6 @@
"fix_fsdp_module_name",
]

FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"


def summon_full_params_context(model, offload_to_cpu: bool = False):
if FullyShardedDataParallel is not None:
Expand Down Expand Up @@ -46,12 +47,15 @@ def main_process_first_context():
def fix_fsdp_module_name(name: str) -> str:
"""
Remove FSDP wrapper prefixes from a module name.
Accounts for scenario where FSDP_WRAPPER_NAME is
Accounts for scenario where FSDP_WRAPPED_MODULE is
at the end of the name, as well as in the middle.

:param name: name to strip
:return: stripped name
"""
return name.replace(FSDP_WRAPPER_NAME + ".", "").replace(
"." + FSDP_WRAPPER_NAME, ""
if FullyShardedDataParallel is None:
return name

return name.replace(FSDP_WRAPPED_MODULE + ".", "").replace(
"." + FSDP_WRAPPED_MODULE, ""
)
Loading