Skip to content

Commit

Permalink
[enh] Add Support for multiple adapters on Transformers-based models (
Browse files Browse the repository at this point in the history
#3046)

* Support adapters on SentenceTransformer

* Add transformer model check + add testing

* Update ValueError text slightly; use self[0]; slight format changes

* Upload stsb-bert-tiny-lora and extend tests

---------

Co-authored-by: Carles Onielfa <[email protected]>
  • Loading branch information
tomaarsen and carlesonielfa authored Nov 8, 2024
1 parent 6baee57 commit 7ede83b
Show file tree
Hide file tree
Showing 3 changed files with 214 additions and 4 deletions.
3 changes: 2 additions & 1 deletion sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .evaluation import SentenceEvaluator
from .fit_mixin import FitMixin
from .models import Normalize, Pooling, Transformer
from .peft_mixin import PeftAdapterMixin
from .quantization import quantize_embeddings
from .util import (
batch_to_device,
Expand All @@ -52,7 +53,7 @@
logger = logging.getLogger(__name__)


class SentenceTransformer(nn.Sequential, FitMixin):
class SentenceTransformer(nn.Sequential, FitMixin, PeftAdapterMixin):
"""
Loads or creates a SentenceTransformer model that can be used to map sentences / text to embeddings.
Expand Down
143 changes: 143 additions & 0 deletions sentence_transformers/peft_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from __future__ import annotations

from functools import wraps

from transformers.integrations.peft import PeftAdapterMixin as PeftAdapterMixinTransformers

from .models import Transformer


def peft_wrapper(func):
"""Wrapper to call the method on the auto_model with a check for PEFT compatibility."""

@wraps(func)
def wrapper(self, *args, **kwargs):
self.check_peft_compatible_model()
method = getattr(self[0].auto_model, func.__name__)
return method(*args, **kwargs)

return wrapper


class PeftAdapterMixin:
"""
Wrapper Mixin that adds the functionality to easily load and use adapters on the model. For
more details about adapters check out the documentation of PEFT
library: https://huggingface.co/docs/peft/index
Currently supported PEFT methods follow those supported by transformers library,
you can find more information on:
https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin
"""

def has_peft_compatible_model(self) -> bool:
return isinstance(self[0], Transformer) and isinstance(self[0].auto_model, PeftAdapterMixinTransformers)

def check_peft_compatible_model(self) -> None:
if not self.has_peft_compatible_model():
raise ValueError(
"PEFT methods are only supported for Sentence Transformer models that use the Transformer module."
)

@peft_wrapper
def load_adapter(self, *args, **kwargs) -> None:
"""
Load adapter weights from file or remote Hub folder." If you are not familiar with adapters and PEFT methods, we
invite you to read more about them on PEFT official documentation: https://huggingface.co/docs/peft
Requires peft as a backend to load the adapter weights and the underlying model to be compatible with PEFT.
Args:
*args:
Positional arguments to pass to the underlying AutoModel `load_adapter` function. More information can be found in the transformers documentation
https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.load_adapter
**kwargs:
Keyword arguments to pass to the underlying AutoModel `load_adapter` function. More information can be found in the transformers documentation
https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.load_adapter
"""
... # Implementation handled by the wrapper

@peft_wrapper
def add_adapter(self, *args, **kwargs) -> None:
"""
Adds a fresh new adapter to the current model for training purposes. If no adapter name is passed, a default
name is assigned to the adapter to follow the convention of PEFT library (in PEFT we use "default" as the
default adapter name).
Requires peft as a backend to load the adapter weights and the underlying model to be compatible with PEFT.
Args:
*args:
Positional arguments to pass to the underlying AutoModel `add_adapter` function. More information can be found in the transformers documentation
https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.add_adapter
**kwargs:
Keyword arguments to pass to the underlying AutoModel `add_adapter` function. More information can be found in the transformers documentation
https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.add_adapter
"""
... # Implementation handled by the wrapper

@peft_wrapper
def set_adapter(self, *args, **kwargs) -> None:
"""
Sets a specific adapter by forcing the model to use a that adapter and disable the other adapters.
Args:
*args:
Positional arguments to pass to the underlying AutoModel `set_adapter` function. More information can be found in the transformers documentation
https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.set_adapter
**kwargs:
Keyword arguments to pass to the underlying AutoModel `set_adapter` function. More information can be found in the transformers documentation
https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.set_adapter
"""
... # Implementation handled by the wrapper

@peft_wrapper
def disable_adapters(self) -> None:
"""
Disable all adapters that are attached to the model. This leads to inferring with the base model only.
"""
... # Implementation handled by the wrapper

@peft_wrapper
def enable_adapters(self) -> None:
"""
Enable adapters that are attached to the model. The model will use `self.active_adapter()`
"""
... # Implementation handled by the wrapper

@peft_wrapper
def active_adapters(self) -> list[str]:
"""
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft
Gets the current active adapters of the model. In case of multi-adapter inference (combining multiple adapters
for inference) returns the list of all active adapters so that users can deal with them accordingly.
For previous PEFT versions (that does not support multi-adapter inference), `module.active_adapter` will return
a single string.
"""
... # Implementation handled by the wrapper

@peft_wrapper
def active_adapter(self) -> str: ... # Implementation handled by the wrapper

@peft_wrapper
def get_adapter_state_dict(self, *args, **kwargs) -> dict:
"""
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft
Gets the adapter state dict that should only contain the weights tensors of the specified adapter_name adapter.
If no adapter_name is passed, the active adapter is used.
Args:
*args:
Positional arguments to pass to the underlying AutoModel `get_adapter_state_dict` function. More information can be found in the transformers documentation
https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.get_adapter_state_dict
**kwargs:
Keyword arguments to pass to the underlying AutoModel `get_adapter_state_dict` function. More information can be found in the transformers documentation
https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.get_adapter_state_dict
"""
... # Implementation handled by the wrapper
72 changes: 69 additions & 3 deletions tests/test_sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
import pytest
import torch
from huggingface_hub import CommitInfo, HfApi, RepoUrl
from peft import PeftModel
from torch import nn
from transformers.utils import is_peft_available

from sentence_transformers import SentenceTransformer, util
from sentence_transformers.models import (
Expand Down Expand Up @@ -417,8 +417,9 @@ def transformers_init(*args, **kwargs):
assert transformer_kwargs["model_args"]["attn_implementation"] == "eager"


@pytest.mark.skipif(not is_peft_available(), reason="PEFT must be available to test PEFT support.")
def test_load_checkpoint_with_peft_and_lora() -> None:
from peft import LoraConfig, TaskType
from peft import LoraConfig, PeftModel, TaskType

peft_config = LoraConfig(
target_modules=["query", "key", "value"],
Expand All @@ -431,7 +432,7 @@ def test_load_checkpoint_with_peft_and_lora() -> None:

with SafeTemporaryDirectory() as tmp_folder:
model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors")
model._modules["0"].auto_model.add_adapter(peft_config)
model.add_adapter(peft_config)
model.save(tmp_folder)
expecteds = model.encode(["Hello there!", "How are you?"], convert_to_tensor=True)

Expand Down Expand Up @@ -715,3 +716,68 @@ def test_empty_encode(stsb_bert_tiny_model: SentenceTransformer) -> None:
model = stsb_bert_tiny_model
embeddings = model.encode([])
assert embeddings.shape == (0,)


@pytest.mark.skipif(not is_peft_available(), reason="PEFT must be available to test adapter methods.")
def test_multiple_adapters() -> None:
text = "Hello, World!"
model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors")
vec_initial = model.encode(text)
from peft import LoraConfig, TaskType, get_model_status

# Adding a fresh adapter
peft_config = LoraConfig(
target_modules=["query", "key", "value"],
task_type=TaskType.FEATURE_EXTRACTION,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1,
init_lora_weights=False, # Random initialization to test the adapter
)
model.add_adapter(peft_config)

# Load an adapter from the hub
model.load_adapter("sentence-transformers-testing/stsb-bert-tiny-lora", "hub_adapter")

# Adding another one with a different name
peft_config = LoraConfig(
target_modules=["value"],
task_type=TaskType.FEATURE_EXTRACTION,
inference_mode=False,
r=2,
lora_alpha=16,
lora_dropout=0.1,
init_lora_weights=False, # Random initialization to test the adapter
)
model.add_adapter(peft_config, "my_adapter")

# Check that peft recognizes the adapters while we compute vectors for later comparison
status = get_model_status(model)
assert status.available_adapters == ["default", "hub_adapter", "my_adapter"]
assert status.enabled
assert status.active_adapters == ["my_adapter"]
assert status.active_adapters == model.active_adapters()
vec_my_adapter = model.encode(text)

model.set_adapter("default")
status = get_model_status(model)
assert status.active_adapters == ["default"]
vec_default_adapter = model.encode(text)

model.disable_adapters()
status = get_model_status(model)
assert not status.enabled
vec_no_adapter = model.encode(text)

# Check that each vector is different
assert not np.allclose(vec_my_adapter, vec_default_adapter)
assert not np.allclose(vec_my_adapter, vec_no_adapter)
assert not np.allclose(vec_default_adapter, vec_no_adapter)
# Check that the vectors from the original model match
assert np.allclose(vec_initial, vec_no_adapter)

# Check that for non Transformer-based models we have an error
model = SentenceTransformer("sentence-transformers/average_word_embeddings_levy_dependency")
with pytest.raises(ValueError, match="PEFT methods are only supported"):
model.add_adapter(peft_config)

0 comments on commit 7ede83b

Please sign in to comment.