Skip to content

Commit

Permalink
FIX [quantization / ESM] Fix ESM 8bit / 4bit with bitsandbytes (#…
Browse files Browse the repository at this point in the history
…29329)

* fix ESM 8bit

* Apply suggestions from code review

Co-authored-by: Arthur <[email protected]>

* fixup

---------

Co-authored-by: Arthur <[email protected]>
  • Loading branch information
younesbelkada and ArthurZucker authored Mar 1, 2024
1 parent 2858d6c commit 50db7ca
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/esm/modeling_esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def forward(
if head_mask is not None:
attention_probs = attention_probs * head_mask

context_layer = torch.matmul(attention_probs, value_layer)
context_layer = torch.matmul(attention_probs.to(value_layer.dtype), value_layer)

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/quantizers/quantizer_bnb_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def check_quantized_param(
import bitsandbytes as bnb

module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module._parameters[tensor_name], bnb.nn.Params4bit):
if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Params4bit):
# Add here check for loaded components' dtypes once serialization is implemented
return True
elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias":
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/quantizers/quantizer_bnb_8bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def check_quantized_param(
import bitsandbytes as bnb

module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module._parameters[tensor_name], bnb.nn.Int8Params):
if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Int8Params):
if self.pre_quantized:
if param_name.replace("weight", "SCB") not in state_dict.keys():
raise ValueError("Missing quantization component `SCB`")
Expand Down
20 changes: 17 additions & 3 deletions tests/models/esm/test_modeling_esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import unittest

from transformers import EsmConfig, is_torch_available
from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device
from transformers.testing_utils import TestCasePlus, require_bitsandbytes, require_torch, slow, torch_device

from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
Expand Down Expand Up @@ -303,9 +303,9 @@ def test_resize_tokens_embeddings(self):
pass


@slow
@require_torch
class EsmModelIntegrationTest(TestCasePlus):
@slow
def test_inference_masked_lm(self):
with torch.no_grad():
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")
Expand All @@ -323,7 +323,6 @@ def test_inference_masked_lm(self):
)
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))

@slow
def test_inference_no_head(self):
with torch.no_grad():
model = EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D")
Expand All @@ -336,3 +335,18 @@ def test_inference_no_head(self):
[[[0.1444, 0.5413, 0.3248], [0.3034, 0.0053, 0.3108], [0.3228, -0.2499, 0.3415]]]
)
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))

@require_bitsandbytes
def test_inference_bitsandbytes(self):
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t36_3B_UR50D", load_in_8bit=True)

input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
# Just test if inference works
with torch.no_grad():
_ = model(input_ids)[0]

model = EsmForMaskedLM.from_pretrained("facebook/esm2_t36_3B_UR50D", load_in_4bit=True)

input_ids = torch.tensor([[0, 6, 4, 13, 5, 4, 16, 12, 11, 7, 2]])
# Just test if inference works
_ = model(input_ids)[0]

0 comments on commit 50db7ca

Please sign in to comment.