Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multi-GPU support for mine_hard_negatives #2967

Merged
merged 9 commits into from
Dec 4, 2024
42 changes: 29 additions & 13 deletions sentence_transformers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ def mine_hard_negatives(
batch_size: int = 32,
faiss_batch_size: int = 16384,
use_faiss: bool = False,
use_multi_process: list[str] | bool = False,
verbose: bool = True,
) -> Dataset:
"""
Expand Down Expand Up @@ -643,6 +644,9 @@ def mine_hard_negatives(
batch_size (int): Batch size for encoding the dataset. Defaults to 32.
faiss_batch_size (int): Batch size for FAISS top-k search. Defaults to 16384.
use_faiss (bool): Whether to use FAISS for similarity search. May be recommended for large datasets. Defaults to False.
use_multi_process (bool | List[str], optional): Whether to use multi-GPU/CPU processing. If True, uses all GPUs if CUDA
is available, and 4 CPU processes if it's not available. You can also pass a list of PyTorch devices like
["cuda:0", "cuda:1", ...] or ["cpu", "cpu", "cpu", "cpu"].
verbose (bool): Whether to print statistics and logging. Defaults to True.

Returns:
Expand Down Expand Up @@ -717,6 +721,30 @@ def mine_hard_negatives(
avg_positives_per_query = np.mean(positives_per_query)
print(f"Found an average of {avg_positives_per_query:.3f} positives per query.")

# Embed the corpus and the queries
if use_multi_process:
pool = model.start_multi_process_pool(
target_devices=None if isinstance(use_multi_process, bool) else use_multi_process
)
corpus_embeddings = model.encode_multi_process(
corpus, pool, batch_size=batch_size, normalize_embeddings=True, show_progress_bar=True
)
query_embeddings = model.encode_multi_process(
queries, pool, batch_size=batch_size, normalize_embeddings=True, show_progress_bar=True
)
model.stop_multi_process_pool(pool)
else:
corpus_embeddings = model.encode(
corpus, batch_size=batch_size, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True
)
query_embeddings = model.encode(
queries,
batch_size=batch_size,
normalize_embeddings=True,
convert_to_numpy=True,
show_progress_bar=True,
)

if use_faiss:
import faiss

Expand All @@ -730,12 +758,6 @@ def mine_hard_negatives(
except Exception:
pass

corpus_embeddings = model.encode(
corpus, batch_size=batch_size, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True
)
query_embeddings = model.encode(
queries, batch_size=batch_size, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True
)
index.add(corpus_embeddings)

scores_list = []
Expand All @@ -750,13 +772,7 @@ def mine_hard_negatives(
indices = torch.from_numpy(np.concatenate(indices_list, axis=0)).to(device)

else:
# Embed the corpus and the queries
corpus_embeddings = model.encode(
corpus, batch_size=batch_size, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True
)
query_embeddings = model.encode(
queries, batch_size=batch_size, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True
)
# Compute the similarity scores between the queries and the corpus
scores = model.similarity(query_embeddings, corpus_embeddings).to(device)

# Keep only the range_max + max_positives highest scores. We offset by 1 to potentially include the positive pair
Expand Down