diff --git a/sentence_transformers/util.py b/sentence_transformers/util.py index 0186c53e1..0cee7f0dc 100644 --- a/sentence_transformers/util.py +++ b/sentence_transformers/util.py @@ -535,8 +535,8 @@ def mine_hard_negatives( batch_size: int = 32, faiss_batch_size: int = 16384, use_faiss: bool = False, + use_multi_process: list[str] | bool = False, verbose: bool = True, - use_multiple_gpus=False, ) -> Dataset: """ Add hard negatives to a dataset of (anchor, positive) pairs to create (anchor, positive, negative) triplets or @@ -644,6 +644,9 @@ def mine_hard_negatives( batch_size (int): Batch size for encoding the dataset. Defaults to 32. faiss_batch_size (int): Batch size for FAISS top-k search. Defaults to 16384. use_faiss (bool): Whether to use FAISS for similarity search. May be recommended for large datasets. Defaults to False. + use_multi_process (bool | List[str], optional): Whether to use multi-GPU/CPU processing. If True, uses all GPUs if CUDA + is available, and 4 CPU processes if it's not available. You can also pass a list of PyTorch devices like + ["cuda:0", "cuda:1", ...] or ["cpu", "cpu", "cpu", "cpu"]. verbose (bool): Whether to print statistics and logging. Defaults to True. Returns: @@ -718,6 +721,30 @@ def mine_hard_negatives( avg_positives_per_query = np.mean(positives_per_query) print(f"Found an average of {avg_positives_per_query:.3f} positives per query.") + # Embed the corpus and the queries + if use_multi_process: + pool = model.start_multi_process_pool( + target_devices=None if isinstance(use_multi_process, bool) else use_multi_process + ) + corpus_embeddings = model.encode_multi_process( + corpus, pool, batch_size=batch_size, normalize_embeddings=True, show_progress_bar=True + ) + query_embeddings = model.encode_multi_process( + queries, pool, batch_size=batch_size, normalize_embeddings=True, show_progress_bar=True + ) + model.stop_multi_process_pool(pool) + else: + corpus_embeddings = model.encode( + corpus, batch_size=batch_size, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True + ) + query_embeddings = model.encode( + queries, + batch_size=batch_size, + normalize_embeddings=True, + convert_to_numpy=True, + show_progress_bar=True, + ) + if use_faiss: import faiss @@ -731,27 +758,6 @@ def mine_hard_negatives( except Exception: pass - if use_multiple_gpus: - pool = model.start_multi_process_pool() - - corpus_embeddings = model.encode_multi_process( - corpus, pool, batch_size=batch_size, normalize_embeddings=True, show_progress_bar=True - ) - - query_embeddings = model.encode_multi_process( - queries, pool, batch_size=batch_size, normalize_embeddings=True, show_progress_bar=True - ) - else: - corpus_embeddings = model.encode( - corpus, batch_size=batch_size, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True - ) - query_embeddings = model.encode( - queries, - batch_size=batch_size, - normalize_embeddings=True, - convert_to_numpy=True, - show_progress_bar=True, - ) index.add(corpus_embeddings) scores_list = [] @@ -766,29 +772,7 @@ def mine_hard_negatives( indices = torch.from_numpy(np.concatenate(indices_list, axis=0)).to(device) else: - # Embed the corpus and the queries - - if use_multiple_gpus: - pool = model.start_multi_process_pool() - - corpus_embeddings = model.encode_multi_process( - corpus, pool, batch_size=batch_size, normalize_embeddings=True, show_progress_bar=True - ) - - query_embeddings = model.encode_multi_process( - queries, pool, batch_size=batch_size, normalize_embeddings=True, show_progress_bar=True - ) - else: - corpus_embeddings = model.encode( - corpus, batch_size=batch_size, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True - ) - query_embeddings = model.encode( - queries, - batch_size=batch_size, - normalize_embeddings=True, - convert_to_numpy=True, - show_progress_bar=True, - ) + # Compute the similarity scores between the queries and the corpus scores = model.similarity(query_embeddings, corpus_embeddings).to(device) # Keep only the range_max + max_positives highest scores. We offset by 1 to potentially include the positive pair