From 135753bccb9dc305efd00d3d3d3dcb31f86432c0 Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Tue, 12 Dec 2023 13:48:25 +0100 Subject: [PATCH] [`bug`] Prevent `to` from being ignored (#2351) * Update _target_device on `to` call + test * Fully replace `_target_device` with `device` But try to preserve backwards compatibility * Update test phrasing --- sentence_transformers/SentenceTransformer.py | 23 +++++++++++++++----- tests/test_sentence_transformer.py | 19 ++++++++++++++++ 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index 361e3f434..637babf1b 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -122,7 +122,7 @@ def __init__(self, model_name_or_path: Optional[str] = None, device = get_device_name() logger.info("Use pytorch device_name: {}".format(device)) - self._target_device = torch.device(device) + self.to(device) def encode(self, sentences: Union[str, List[str]], batch_size: int = 32, @@ -167,7 +167,7 @@ def encode(self, sentences: Union[str, List[str]], input_was_string = True if device is None: - device = self._target_device + device = self.device self.to(device) @@ -658,7 +658,7 @@ def fit(self, from torch.cuda.amp import autocast scaler = torch.cuda.amp.GradScaler() - self.to(self._target_device) + self.to(self.device) dataloaders = [dataloader for dataloader, _ in train_objectives] @@ -668,7 +668,7 @@ def fit(self, loss_models = [loss for _, loss in train_objectives] for loss_model in loss_models: - loss_model.to(self._target_device) + loss_model.to(self.device) self.best_score = -9999999 @@ -724,8 +724,8 @@ def fit(self, data = next(data_iterator) features, labels = data - labels = labels.to(self._target_device) - features = list(map(lambda batch: batch_to_device(batch, self._target_device), features)) + labels = labels.to(self.device) + features = list(map(lambda batch: batch_to_device(batch, self.device), features)) if use_amp: with autocast(): @@ -949,3 +949,14 @@ def max_seq_length(self, value): Property to set the maximal input sequence length for the model. Longer inputs will be truncated. """ self._first_module().max_seq_length = value + + @property + def _target_device(self) -> torch.device: + logger.warning( + "`SentenceTransformer._target_device` has been removed, please use `SentenceTransformer.device` instead.", + ) + return self.device + + @_target_device.setter + def _target_device(self, device: Optional[Union[int, str, torch.device]] = None) -> None: + self.to(device) \ No newline at end of file diff --git a/tests/test_sentence_transformer.py b/tests/test_sentence_transformer.py index c8913fcb7..e0d3acf7a 100644 --- a/tests/test_sentence_transformer.py +++ b/tests/test_sentence_transformer.py @@ -2,6 +2,7 @@ Tests general behaviour of the SentenceTransformer class """ + from pathlib import Path import tempfile @@ -45,3 +46,21 @@ def test_load_with_safetensors(self): torch.equal(safetensors_model.encode(sentences, convert_to_tensor=True), pytorch_model.encode(sentences, convert_to_tensor=True)), msg="Ensure that Safetensors and PyTorch loaded models result in identical embeddings", ) + + @unittest.skipUnless(torch.cuda.is_available(), reason="CUDA must be available to test moving devices effectively.") + def test_to(self): + model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors", device="cpu") + + test_device = torch.device("cuda") + self.assertEqual(model.device.type, "cpu") + self.assertEqual(test_device.type, "cuda") + + model.to(test_device) + self.assertEqual(model.device.type, "cuda", msg="The model device should have updated") + + model.encode("Test sentence") + self.assertEqual(model.device.type, "cuda", msg="Encoding shouldn't change the device") + + self.assertEqual(model._target_device, model.device, msg="Prevent backwards compatibility failure for _target_device") + model._target_device = "cpu" + self.assertEqual(model.device.type, "cpu", msg="Ensure that setting `_target_device` doesn't crash.") \ No newline at end of file