Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] Prevent to from being ignored #2351

Merged
merged 4 commits into from
Dec 12, 2023
Merged

Conversation

tomaarsen
Copy link
Collaborator

Hello!

Pull Request overview

  • Replace ._target_device with already existing .device.
  • Add _target_device property pointing to device to preserve backwards compatibility.
  • Add test to show that this works now.

Details

The Bug

This was originally reported for SetFit, and then hackishly fixed here. In short, if you load a model in device X, then move it to device Y using model.to(Y), and then perform any training or inference, the model and the data will be moved to the original device X.

This is unexpected. See for example:

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors", device="cpu")
model.cuda()
print(model.device)
model.encode("Hello!")
print(model.device)

Results in:

cuda:0
cpu

The fix

After applying the fix, the above script returns:

cuda:0
cuda:0

In essence, there was never a need to track a _target_device. All torch Module instances already have a device, and we can just move the correct device on __init__.py, and then move all training/inference data to model.device. Whenever the model is then moved via to(...), cuda(), cpu(), etc, this automatically works.

Backwards compatibility

Although _target_device was a private property, it was somewhat commonly used, e.g. by SetFit, because of this bug. Applying the fix itself (i.e. replacing all self._target_device with self.device) would cause those third parties to break. As a result, I turned _target_device into a property with a setter pointing to self.to and a getter pointing to self.device. The getter has a logger.warn. In a later stage we can fully deprecate this, e.g. for a v3.0.0.

I added a test to ensure that this all works as expected now, including the BC preservation.

  • Tom Aarsen

@tomaarsen tomaarsen merged commit 135753b into UKPLab:master Dec 12, 2023
8 checks passed
@tomaarsen tomaarsen deleted the hotfix/to branch December 12, 2023 12:48
@winstxnhdw
Copy link

winstxnhdw commented Jan 30, 2024

Hey, this change seemed to have broke my application somehow. I am getting the following error.

  File "/home/winston/projects/Examplify/.venv/lib/python3.11/site-packages/sentence_transformers/SentenceTransformer.py", line 1123, in device
    first_tuple = next(gen)
                  ^^^^^^^^^
StopIteration

This is frustrating because all I am doing is calling the encode method.

@tomaarsen
Copy link
Collaborator Author

Hello!

That is indeed frustrating, my apologies. Is your Sentence Transformer model based on torch? It seems like it cannot find any tensors in the entire model. Feel free to respond here or open a new issue & I'll respond to you there.

  • Tom Aarsen

@winstxnhdw
Copy link

Hey. I am using bge-base-en-v1.5, and I am pretty sure it's based on torch.

@tomaarsen
Copy link
Collaborator Author

That certainly is. That model should definitely work. This is how I normally use that model:

from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim

model = SentenceTransformer("BAAI/bge-base-en-v1.5")

queries = ["How to bake a chocolate cake", "Symptoms of the flu"]
queries = [
    "Represent this sentence for searching relevant passages: " + query
    for query in queries
]
passages = [
    "To bake a delicious chocolate cake, you'll need the following ingredients: ...",
    "The flu, or influenza, is an illness caused by influenza viruses. Common ...",
]

embeddings = model.encode(queries + passages)
scores = cos_sim(embeddings[:2], embeddings[2:])
print(scores.tolist())
# [[0.8629187345504761, 0.2619859576225281], [0.33079883456230164, 0.8224020004272461]]

Do you think you could try and produce a small snippet that fails for you?

  • Tom Aarsen

@winstxnhdw
Copy link

winstxnhdw commented Jan 30, 2024

Sure.

from typing import Any, Iterator, TypedDict

from ctranslate2 import Encoder, StorageView
from numpy import array
from torch import as_tensor, float32, int32
from torch.nn import Module, Sequential

from server.types import ComputeTypes

class FlagEmbedding(Module):
    def __init__(
        self,
        transformer: Sequential,
        model_path: str,
        compute_type: ComputeTypes
    ):
        super().__init__()

        self.compute_type: ComputeTypes = compute_type
        self.encoder: Encoder | None = None
        self.tokenize = transformer.tokenize
        self.model_path = model_path


    def children(self) -> Iterator[Module]:
        return iter([])


    def forward(self, features: Features) -> Features:
        device = features["input_ids"].device

        if not self.encoder:
            self.encoder = Encoder(
                self.model_path,
                device=device.type,
                device_index=device.index or 0,
                compute_type=self.compute_type,
            )

        input_indices = features['input_ids'].to(int32)

        length = features['attention_mask'].sum(1, dtype=int32)

        if device.type == 'cpu':
            input_indices = input_indices.numpy()
            length = length.numpy()

        input_indices = StorageView.from_array(input_indices)

        length = StorageView.from_array(length)

        outputs = self.encoder.forward_batch(input_indices, length)

        last_hidden_state = outputs.last_hidden_state

        if device.type == 'cpu':
            last_hidden_state = array(last_hidden_state)

        features['token_embeddings'] = as_tensor(last_hidden_state, device=device).to(float32)

        return features

class Embedding(SentenceTransformer):
    def __init__(self, *, force_download: bool = False):

        model_name = 'bge-base-en-v1.5'

        super().__init__(f'BAAI/{model_name}')

        model_path = snapshot_download(f'winstxnhdw/{model_name}-ct2', local_files_only=not force_download)
        self[0] = FlagEmbedding(self[0], model_path, 'auto')

Embedding(force_download=true).encode("hello")

I am still using the torch model for bge-base-en-v1.5, as you can see I am passing it via this line super().__init__(f'BAAI/{model_name}'). I am just overriding one of the modules, which should not have been affected by what you are trying to do in this PR.

@tomaarsen
Copy link
Collaborator Author

Thanks for the snippet! That helped a bunch. The long and short is this:

...
model = Embedding(force_download=True)
print(len(list(model.parameters())))
# => 0

Sentence Transformers v2.2.2 already used the parameters to infer the device, e.g. this is a snippet from v2.2.2:

@property
def device(self) -> device:
"""
Get torch.device from module, assuming that the whole module has one device.
"""
try:
return next(self.parameters()).device
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = self._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].device

With other words, running model.device on your code would already cause a failure. The main difference in this PR is that the model now actually listens to its own device rather than only the initially defined _target_device. I think the error is actually fairly reasonable: it shouldn't be possible for the Embedding class to infer the device based on its modules once the Transformer instance has been removed.

Perhaps the best solution is to store the initial device & override the device of Embedding?

e.g.

class Embedding(SentenceTransformer):
    def __init__(self, *, force_download: bool = False, device: Optional[str] = None):

        model_name = 'bge-base-en-v1.5'

        super().__init__(f'BAAI/{model_name}', device=device)
        self._device = self.device

        model_path = snapshot_download(f'winstxnhdw/{model_name}-ct2', local_files_only=not force_download)
        self[0] = FlagEmbedding(self[0], model_path, 'auto')

    @property
    def device(self) -> torch.device:
        return self._device
  • Tom Aarsen

@winstxnhdw
Copy link

winstxnhdw commented Jan 30, 2024

I see. So my code was depending on a broken device field and the reason why my code was working was because _target_device is never updated. Thanks for the help. I believe you meant to write this instead.

self._device = super().device

@tomaarsen
Copy link
Collaborator Author

Indeed! And I think both should work, self.device should return the device like expected, as it's still before you override the first module.

  • Tom Aarsen

@winstxnhdw
Copy link

winstxnhdw commented Jan 31, 2024

Just for fun, I tried both and only super() works. I think this is because of how the MRO in Python works.

@tomaarsen
Copy link
Collaborator Author

Oh, of course - I totally forgot that I had overridden the device property. Oops!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants