Skip to content

Commit

Permalink
make sure to disable gradients for integer tensor (#32943)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian authored Nov 18, 2024
1 parent 1c471fc commit 36759f3
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,10 @@ def _load_state_dict_into_meta_model(
param_to = "cpu"
if is_fsdp_enabled() and not is_local_dist_rank_0():
param_to = "meta"
value = type(value)(value.data.to(param_to), **value.__dict__)
val_kwargs = {}
if hasattr(module, "weight") and module.weight.__class__.__name__ == "Int8Params":
val_kwargs["requires_grad"] = False
value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__)
setattr(module, tensor_name, value)
# TODO: consider removing used param_parts from state_dict before return

Expand Down

0 comments on commit 36759f3

Please sign in to comment.