Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cross encoder device issue #3104

Merged
merged 6 commits into from
Dec 2, 2024

Conversation

susnato
Copy link
Contributor

@susnato susnato commented Nov 29, 2024

Fixes #3078

  • fixed the previous issue and now the model is pushed to the specified device when the object is created.
  • added new method to which directly pushes the model to the specified device, removing the need to use model.model.to
  • added 2 tests to make sure it works as expected

@susnato
Copy link
Contributor Author

susnato commented Nov 29, 2024

cc: @tomaarsen

@susnato
Copy link
Contributor Author

susnato commented Nov 29, 2024

Tested with a cuda enabled device and the tests are passing -

=================================================================================================================== test session starts ====================================================================================================================
platform linux -- Python 3.10.14, pytest-8.3.2, pluggy-1.5.0
rootdir: /home/susnato/workspace/sentence-transformers
configfile: pyproject.toml
plugins: cov-5.0.0, anyio-4.4.0, httpbin-2.0.0
collected 541 items / 332 deselected / 1 skipped / 209 selected                                                                                                                                                                                            

tests/evaluation/test_binary_classification_evaluator.py ..                                                                                                                                                                                          [  0%]
tests/evaluation/test_information_retrieval_evaluator.py ..                                                                                                                                                                                          [  1%]
tests/evaluation/test_label_accuracy_evaluator.py .                                                                                                                                                                                                  [  2%]
tests/evaluation/test_paraphrase_mining_evaluator.py .                                                                                                                                                                                               [  2%]
tests/evaluation/test_triplet_evaluator.py .                                                                                                                                                                                                         [  3%]
tests/models/test_static_embedding.py .....ss                                                                                                                                                                                                        [  6%]
tests/samplers/test_group_by_label_batch_sampler.py ...                                                                                                                                                                                              [  8%]
tests/samplers/test_no_duplicates_batch_sampler.py ...                                                                                                                                                                                               [  9%]
tests/samplers/test_round_robin_batch_sampler.py ..                                                                                                                                                                                                  [ 10%]
tests/test_cmnrl.py .....                                                                                                                                                                                                                            [ 12%]
tests/test_compute_embeddings.py ....                                                                                                                                                                                                                [ 14%]
tests/test_cross_encoder.py .............                                                                                                                                                                                                            [ 21%]
tests/test_image_embeddings.py .                                                                                                                                                                                                                     [ 21%]
tests/test_model_card.py .....                                                                                                                                                                                                                       [ 23%]
tests/test_model_card_data.py .....                                                                                                                                                                                                                  [ 26%]
tests/test_multi_process.py ssss                                                                                                                                                                                                                     [ 28%]
tests/test_pretrained_stsb.py .......................                                                                                                                                                                                                [ 39%]
tests/test_sentence_transformer.py ...........................................................................................                                                                                                                       [ 82%]
tests/test_train_stsb.py ..                                                                                                                                                                                                                          [ 83%]
tests/test_trainer.py .......................                                                                                                                                                                                                        [ 94%]
tests/test_util.py ...........                                                                                                                                                                                                                       [100%]

===================================================================================================================== warnings summary =====================================================================================================================
../../.local/lib/python3.10/site-packages/accelerate/utils/other.py:220
  /home/susnato/.local/lib/python3.10/site-packages/accelerate/utils/other.py:220: DeprecationWarning: numpy.core is deprecated and has been renamed to numpy._core. The numpy._core namespace contains private NumPy internals and its use is discouraged, as NumPy internals can change without warning in any release. In practice, most real-world usage of numpy.core is to access functionality in the public NumPy API. If that is the case, use the public NumPy API. If not, you are using NumPy internals. If you would still like to access an internal attribute, use numpy._core.multiarray.
    np.core.multiarray._reconstruct,

tests/test_sentence_transformer.py::test_load_checkpoint_with_peft_and_lora
  /home/susnato/.local/lib/python3.10/site-packages/transformers/integrations/peft.py:418: FutureWarning: The `active_adapter` method is deprecated and will be removed in a future version.
    warnings.warn(

tests/test_sentence_transformer.py::test_similarity_score_save
  /home/susnato/workspace/sentence-transformers/tests/test_sentence_transformer.py:630: DeprecationWarning: __array_wrap__ must accept context and return_scalar arguments (positionally) in the future. (Deprecated NumPy 2.0)
    assert np.not_equal(cosine_scores, dot_scores).all()

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================================================================================== 203 passed, 7 skipped, 332 deselected, 3 warnings in 262.22s (0:04:22) ==========================================================================================

@tomaarsen
Copy link
Collaborator

tomaarsen commented Dec 2, 2024

Hello!

Well done here @susnato! I made some tiny adjustments, primarily surrounding the removal of _target_device. Despite it being a "private attribute", I imagine that some users updated model._target_device to move the model from devices, so I've added a property + setter to ensure that it's still possible (although it's not recommended).
We have something similar for the SentenceTransformer class.

What do you think @susnato? Oh, I'm also going to try the Copilot reviewer on this.

  • Tom Aarsen

@tomaarsen tomaarsen requested a review from Copilot December 2, 2024 10:06

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot reviewed 2 out of 2 changed files in this pull request and generated no suggestions.

Comments skipped due to low confidence (1)

sentence_transformers/cross_encoder/CrossEncoder.py:605

  • The method to should explicitly return None to match its signature.
def to(self, device: int | str | torch.device | None = None) -> None:
@susnato
Copy link
Contributor Author

susnato commented Dec 2, 2024

hi @tomaarsen , seems good to me to have the _target_device.

I'm also going to try the Copilot reviewer on this.

lol, I didn't knew this even existed!

@tomaarsen
Copy link
Collaborator

lol, I didn't knew this even existed!

It's very new! I only got access a few days ago, and this is the second time that I've used it. Seems like it didn't help much this time, haha.

I'll merge this now as I think we're good to go. Thanks for the help!

  • Tom Aarsen

@tomaarsen tomaarsen merged commit a49ffc5 into UKPLab:master Dec 2, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

CrossEncoder is not pushed to cuda until predict is called, even if cuda is specified as device.
2 participants