-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[bug
] Prevent to
from being ignored
#2351
Conversation
But try to preserve backwards compatibility
Hey, this change seemed to have broke my application somehow. I am getting the following error. File "/home/winston/projects/Examplify/.venv/lib/python3.11/site-packages/sentence_transformers/SentenceTransformer.py", line 1123, in device
first_tuple = next(gen)
^^^^^^^^^
StopIteration This is frustrating because all I am doing is calling the |
Hello! That is indeed frustrating, my apologies. Is your Sentence Transformer model based on
|
Hey. I am using |
That certainly is. That model should definitely work. This is how I normally use that model: from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
model = SentenceTransformer("BAAI/bge-base-en-v1.5")
queries = ["How to bake a chocolate cake", "Symptoms of the flu"]
queries = [
"Represent this sentence for searching relevant passages: " + query
for query in queries
]
passages = [
"To bake a delicious chocolate cake, you'll need the following ingredients: ...",
"The flu, or influenza, is an illness caused by influenza viruses. Common ...",
]
embeddings = model.encode(queries + passages)
scores = cos_sim(embeddings[:2], embeddings[2:])
print(scores.tolist())
# [[0.8629187345504761, 0.2619859576225281], [0.33079883456230164, 0.8224020004272461]] Do you think you could try and produce a small snippet that fails for you?
|
Sure. from typing import Any, Iterator, TypedDict
from ctranslate2 import Encoder, StorageView
from numpy import array
from torch import as_tensor, float32, int32
from torch.nn import Module, Sequential
from server.types import ComputeTypes
class FlagEmbedding(Module):
def __init__(
self,
transformer: Sequential,
model_path: str,
compute_type: ComputeTypes
):
super().__init__()
self.compute_type: ComputeTypes = compute_type
self.encoder: Encoder | None = None
self.tokenize = transformer.tokenize
self.model_path = model_path
def children(self) -> Iterator[Module]:
return iter([])
def forward(self, features: Features) -> Features:
device = features["input_ids"].device
if not self.encoder:
self.encoder = Encoder(
self.model_path,
device=device.type,
device_index=device.index or 0,
compute_type=self.compute_type,
)
input_indices = features['input_ids'].to(int32)
length = features['attention_mask'].sum(1, dtype=int32)
if device.type == 'cpu':
input_indices = input_indices.numpy()
length = length.numpy()
input_indices = StorageView.from_array(input_indices)
length = StorageView.from_array(length)
outputs = self.encoder.forward_batch(input_indices, length)
last_hidden_state = outputs.last_hidden_state
if device.type == 'cpu':
last_hidden_state = array(last_hidden_state)
features['token_embeddings'] = as_tensor(last_hidden_state, device=device).to(float32)
return features
class Embedding(SentenceTransformer):
def __init__(self, *, force_download: bool = False):
model_name = 'bge-base-en-v1.5'
super().__init__(f'BAAI/{model_name}')
model_path = snapshot_download(f'winstxnhdw/{model_name}-ct2', local_files_only=not force_download)
self[0] = FlagEmbedding(self[0], model_path, 'auto')
Embedding(force_download=true).encode("hello") I am still using the |
Thanks for the snippet! That helped a bunch. The long and short is this: ...
model = Embedding(force_download=True)
print(len(list(model.parameters())))
# => 0 Sentence Transformers v2.2.2 already used the parameters to infer the device, e.g. this is a snippet from v2.2.2: sentence-transformers/sentence_transformers/SentenceTransformer.py Lines 868 to 884 in f38e91e
With other words, running Perhaps the best solution is to store the initial device & override the e.g. class Embedding(SentenceTransformer):
def __init__(self, *, force_download: bool = False, device: Optional[str] = None):
model_name = 'bge-base-en-v1.5'
super().__init__(f'BAAI/{model_name}', device=device)
self._device = self.device
model_path = snapshot_download(f'winstxnhdw/{model_name}-ct2', local_files_only=not force_download)
self[0] = FlagEmbedding(self[0], model_path, 'auto')
@property
def device(self) -> torch.device:
return self._device
|
I see. So my code was depending on a broken self._device = super().device |
Indeed! And I think both should work,
|
Just for fun, I tried both and only |
Oh, of course - I totally forgot that I had overridden the |
Hello!
Pull Request overview
._target_device
with already existing.device
._target_device
property pointing todevice
to preserve backwards compatibility.Details
The Bug
This was originally reported for SetFit, and then hackishly fixed here. In short, if you load a model in device
X
, then move it to deviceY
usingmodel.to(Y)
, and then perform any training or inference, the model and the data will be moved to the original deviceX
.This is unexpected. See for example:
Results in:
The fix
After applying the fix, the above script returns:
In essence, there was never a need to track a
_target_device
. Alltorch
Module
instances already have adevice
, and we can just move the correctdevice
on__init__.py
, and then move all training/inference data tomodel.device
. Whenever the model is then moved viato(...)
,cuda()
,cpu()
, etc, this automatically works.Backwards compatibility
Although
_target_device
was a private property, it was somewhat commonly used, e.g. by SetFit, because of this bug. Applying the fix itself (i.e. replacing allself._target_device
withself.device
) would cause those third parties to break. As a result, I turned_target_device
into a property with a setter pointing toself.to
and a getter pointing toself.device
. The getter has alogger.warn
. In a later stage we can fully deprecate this, e.g. for a v3.0.0.I added a test to ensure that this all works as expected now, including the BC preservation.