Skip to content

Commit

Permalink
Rename use_multiple_gpus to use_multi_process
Browse files Browse the repository at this point in the history
Because this also supports multiple CPUs & multi_process is what the underlying code is called.
Also stop the pools afterwards again
And remove code duplication
  • Loading branch information
tomaarsen committed Nov 26, 2024
1 parent 6e024bb commit b161e5e
Showing 1 changed file with 29 additions and 45 deletions.
74 changes: 29 additions & 45 deletions sentence_transformers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,8 +535,8 @@ def mine_hard_negatives(
batch_size: int = 32,
faiss_batch_size: int = 16384,
use_faiss: bool = False,
use_multi_process: list[str] | bool = False,
verbose: bool = True,
use_multiple_gpus=False,
) -> Dataset:
"""
Add hard negatives to a dataset of (anchor, positive) pairs to create (anchor, positive, negative) triplets or
Expand Down Expand Up @@ -644,6 +644,9 @@ def mine_hard_negatives(
batch_size (int): Batch size for encoding the dataset. Defaults to 32.
faiss_batch_size (int): Batch size for FAISS top-k search. Defaults to 16384.
use_faiss (bool): Whether to use FAISS for similarity search. May be recommended for large datasets. Defaults to False.
use_multi_process (bool | List[str], optional): Whether to use multi-GPU/CPU processing. If True, uses all GPUs if CUDA
is available, and 4 CPU processes if it's not available. You can also pass a list of PyTorch devices like
["cuda:0", "cuda:1", ...] or ["cpu", "cpu", "cpu", "cpu"].
verbose (bool): Whether to print statistics and logging. Defaults to True.
Returns:
Expand Down Expand Up @@ -718,6 +721,30 @@ def mine_hard_negatives(
avg_positives_per_query = np.mean(positives_per_query)
print(f"Found an average of {avg_positives_per_query:.3f} positives per query.")

# Embed the corpus and the queries
if use_multi_process:
pool = model.start_multi_process_pool(
target_devices=None if isinstance(use_multi_process, bool) else use_multi_process
)
corpus_embeddings = model.encode_multi_process(
corpus, pool, batch_size=batch_size, normalize_embeddings=True, show_progress_bar=True
)
query_embeddings = model.encode_multi_process(
queries, pool, batch_size=batch_size, normalize_embeddings=True, show_progress_bar=True
)
model.stop_multi_process_pool(pool)
else:
corpus_embeddings = model.encode(
corpus, batch_size=batch_size, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True
)
query_embeddings = model.encode(
queries,
batch_size=batch_size,
normalize_embeddings=True,
convert_to_numpy=True,
show_progress_bar=True,
)

if use_faiss:
import faiss

Expand All @@ -731,27 +758,6 @@ def mine_hard_negatives(
except Exception:
pass

if use_multiple_gpus:
pool = model.start_multi_process_pool()

corpus_embeddings = model.encode_multi_process(
corpus, pool, batch_size=batch_size, normalize_embeddings=True, show_progress_bar=True
)

query_embeddings = model.encode_multi_process(
queries, pool, batch_size=batch_size, normalize_embeddings=True, show_progress_bar=True
)
else:
corpus_embeddings = model.encode(
corpus, batch_size=batch_size, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True
)
query_embeddings = model.encode(
queries,
batch_size=batch_size,
normalize_embeddings=True,
convert_to_numpy=True,
show_progress_bar=True,
)
index.add(corpus_embeddings)

scores_list = []
Expand All @@ -766,29 +772,7 @@ def mine_hard_negatives(
indices = torch.from_numpy(np.concatenate(indices_list, axis=0)).to(device)

else:
# Embed the corpus and the queries

if use_multiple_gpus:
pool = model.start_multi_process_pool()

corpus_embeddings = model.encode_multi_process(
corpus, pool, batch_size=batch_size, normalize_embeddings=True, show_progress_bar=True
)

query_embeddings = model.encode_multi_process(
queries, pool, batch_size=batch_size, normalize_embeddings=True, show_progress_bar=True
)
else:
corpus_embeddings = model.encode(
corpus, batch_size=batch_size, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True
)
query_embeddings = model.encode(
queries,
batch_size=batch_size,
normalize_embeddings=True,
convert_to_numpy=True,
show_progress_bar=True,
)
# Compute the similarity scores between the queries and the corpus
scores = model.similarity(query_embeddings, corpus_embeddings).to(device)

# Keep only the range_max + max_positives highest scores. We offset by 1 to potentially include the positive pair
Expand Down

0 comments on commit b161e5e

Please sign in to comment.