Skip to content

Commit

Permalink
[feat + fix] Add normalize_embeddings support to multi-process …
Browse files Browse the repository at this point in the history
…encoding; fix multi-process encoding on CUDA devices (#2377)

* added normalize option to multiprocess encode

* Rename some variables

* Add missing docstring

* Update logger text

* Moving to CPU is required, otherwise all weights become 0.0

* Update test_encode_multi_process tests

---------

Co-authored-by: teisnp <[email protected]>
  • Loading branch information
tomaarsen and TeisNP authored Dec 13, 2023
1 parent 135753b commit 8af4744
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 28 deletions.
38 changes: 28 additions & 10 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,18 +235,20 @@ def start_multi_process_pool(self, target_devices: List[str] = None):
if torch.cuda.is_available():
target_devices = ['cuda:{}'.format(i) for i in range(torch.cuda.device_count())]
else:
logger.info("CUDA is not available. Start 4 CPU worker")
logger.info("CUDA is not available. Starting 4 CPU workers")
target_devices = ['cpu']*4

logger.info("Start multi-process pool on devices: {}".format(', '.join(map(str, target_devices))))

self.to("cpu")
self.share_memory()
ctx = mp.get_context('spawn')
input_queue = ctx.Queue()
output_queue = ctx.Queue()
processes = []

for cuda_id in target_devices:
p = ctx.Process(target=SentenceTransformer._encode_multi_process_worker, args=(cuda_id, self, input_queue, output_queue), daemon=True)
for device_id in target_devices:
p = ctx.Process(target=SentenceTransformer._encode_multi_process_worker, args=(device_id, self, input_queue, output_queue), daemon=True)
p.start()
processes.append(p)

Expand All @@ -269,7 +271,13 @@ def stop_multi_process_pool(pool):
pool['output'].close()


def encode_multi_process(self, sentences: List[str], pool: Dict[str, object], batch_size: int = 32, chunk_size: int = None):
def encode_multi_process(
self,
sentences: List[str],
pool: Dict[str, object],
batch_size: int = 32,
chunk_size: int = None,
normalize_embeddings: bool = False):
"""
This method allows to run encode() on multiple GPUs. The sentences are chunked into smaller packages
and sent to individual processes, which encode these on the different GPUs. This method is only suitable
Expand All @@ -279,6 +287,8 @@ def encode_multi_process(self, sentences: List[str], pool: Dict[str, object], ba
:param pool: A pool of workers started with SentenceTransformer.start_multi_process_pool
:param batch_size: Encode sentences with batch size
:param chunk_size: Sentences are chunked and sent to the individual processes. If none, it determine a sensible size.
:param normalize_embeddings: Whether to normalize returned vectors to have length 1. In that case,
the faster dot-product (util.dot_score) instead of cosine similarity can be used.
:return: Numpy matrix with all embeddings
"""

Expand All @@ -294,12 +304,12 @@ def encode_multi_process(self, sentences: List[str], pool: Dict[str, object], ba
for sentence in sentences:
chunk.append(sentence)
if len(chunk) >= chunk_size:
input_queue.put([last_chunk_id, batch_size, chunk])
input_queue.put([last_chunk_id, batch_size, chunk, normalize_embeddings])
last_chunk_id += 1
chunk = []

if len(chunk) > 0:
input_queue.put([last_chunk_id, batch_size, chunk])
input_queue.put([last_chunk_id, batch_size, chunk, normalize_embeddings])
last_chunk_id += 1

output_queue = pool['output']
Expand All @@ -314,9 +324,17 @@ def _encode_multi_process_worker(target_device: str, model, input_queue, results
"""
while True:
try:
id, batch_size, sentences = input_queue.get()
embeddings = model.encode(sentences, device=target_device, show_progress_bar=False, convert_to_numpy=True, batch_size=batch_size)
results_queue.put([id, embeddings])
chunk_id, batch_size, sentences, normalize_embeddings = input_queue.get()
embeddings = model.encode(
sentences,
device=target_device,
show_progress_bar=False,
convert_to_numpy=True,
batch_size=batch_size,
normalize_embeddings=normalize_embeddings,
)

results_queue.put([chunk_id, embeddings])
except queue.Empty:
break

Expand Down Expand Up @@ -959,4 +977,4 @@ def _target_device(self) -> torch.device:

@_target_device.setter
def _target_device(self, device: Optional[Union[int, str, torch.device]] = None) -> None:
self.to(device)
self.to(device)
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

from sentence_transformers import SentenceTransformer
import pytest


@pytest.fixture()
def model() -> SentenceTransformer:
return SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors")
35 changes: 17 additions & 18 deletions tests/test_multi_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,30 @@
"""


import unittest
from sentence_transformers import SentenceTransformer
import numpy as np
import pytest

class ComputeMultiProcessTest(unittest.TestCase):
def setUp(self):
self.model = SentenceTransformer('paraphrase-distilroberta-base-v1')
@pytest.mark.parametrize("normalize_embeddings", (False, True))
def test_encode_multi_process(model: SentenceTransformer, normalize_embeddings: bool) -> None:
sentences = ["This is sentence {}".format(i) for i in range(40)]

def test_multi_gpu_encode(self):
# Start the multi-process pool on all available CUDA devices
pool = self.model.start_multi_process_pool(['cpu', 'cpu'])
# Start the multi-process pool on e.g. two CPU devices & compute the embeddings using the pool
pool = model.start_multi_process_pool(['cpu', 'cpu'])
emb = model.encode_multi_process(sentences, pool, chunk_size=10, normalize_embeddings=normalize_embeddings)
model.stop_multi_process_pool(pool)
assert emb.shape == (len(sentences), 128)

sentences = ["This is sentence {}".format(i) for i in range(1000)]
# Make sure the embeddings aren't just all 0
assert emb.sum() != 0.0

# Compute the embeddings using the multi-process pool
emb = self.model.encode_multi_process(sentences, pool, chunk_size=50)
self.model.stop_multi_process_pool(pool)
assert emb.shape == (len(sentences), 768)
# Compare against normal embeddings
emb_normal = model.encode(sentences, normalize_embeddings=normalize_embeddings)
diff = np.max(np.abs(emb - emb_normal))
assert diff < 0.001

emb_normal = self.model.encode(sentences)


diff = np.max(np.abs(emb - emb_normal))
print("Max multi proc diff", diff)
assert diff < 0.001
# Ensure that after normalizing, the means are all almost 0, and otherwise not
assert np.all(np.abs(emb.mean(1)) < 0.01) == normalize_embeddings



Expand Down

0 comments on commit 8af4744

Please sign in to comment.