-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
test: replace unittest with pytest (#2407)
* test: replace unittest with pytest * test: add back hf hub workaround of pull request * style: fix formatting in tests * Re-add missing tests * Fix wrong slow mark * Remove dead code (inaccessible return) * Move fixtures to conftest.py, add type hints --------- Co-authored-by: Tom Aarsen <[email protected]>
- Loading branch information
Showing
9 changed files
with
619 additions
and
494 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,32 @@ | ||
from sentence_transformers import SentenceTransformer | ||
import pytest | ||
|
||
from sentence_transformers import SentenceTransformer, CrossEncoder | ||
from sentence_transformers.models import Transformer, Pooling | ||
|
||
|
||
@pytest.fixture() | ||
def model() -> SentenceTransformer: | ||
def stsb_bert_tiny_model() -> SentenceTransformer: | ||
return SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors") | ||
|
||
|
||
@pytest.fixture() | ||
def paraphrase_distilroberta_base_v1_model() -> SentenceTransformer: | ||
return SentenceTransformer("paraphrase-distilroberta-base-v1") | ||
|
||
|
||
@pytest.fixture() | ||
def distilroberta_base_ce_model() -> CrossEncoder: | ||
return CrossEncoder("distilroberta-base", num_labels=1) | ||
|
||
|
||
@pytest.fixture() | ||
def clip_vit_b_32_model() -> SentenceTransformer: | ||
return SentenceTransformer("clip-ViT-B-32") | ||
|
||
|
||
@pytest.fixture() | ||
def distilbert_base_uncased_model() -> SentenceTransformer: | ||
word_embedding_model = Transformer("distilbert-base-uncased") | ||
pooling_model = Pooling(word_embedding_model.get_word_embedding_dimension()) | ||
model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.