Skip to content

Commit

Permalink
Clarify that target_devices can be a list of CPUs too
Browse files Browse the repository at this point in the history
  • Loading branch information
tomaarsen committed Dec 18, 2023
1 parent aae5d40 commit 94a71d9
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,18 +232,21 @@ def encode(self, sentences: Union[str, List[str]],
def start_multi_process_pool(self, target_devices: List[str] = None):
"""
Starts multi process to process the encoding with several, independent processes.
This method is recommended if you want to encode on multiple GPUs. It is advised
This method is recommended if you want to encode on multiple GPUs or CPUs. It is advised
to start only one process per GPU. This method works together with encode_multi_process
and stop_multi_process_pool.
:param target_devices: PyTorch target devices, e.g. cuda:0, cuda:1... If None, all available CUDA devices will be used
:param target_devices: PyTorch target devices, e.g. ["cuda:0", "cuda:1", ...] or ["cpu", "cpu", "cpu", "cpu"].
If target_devices is None and CUDA is available, then all available CUDA devices will be used. If
target_devices is None and CUDA is not available, then 4 CPU devices will be used.
:return: Returns a dict with the target processes, an input queue and and output queue.
"""
if target_devices is None:
if torch.cuda.is_available():
target_devices = ['cuda:{}'.format(i) for i in range(torch.cuda.device_count())]
else:
logger.info("CUDA is not available. Starting 4 CPU workers")
target_devices = ['cpu']*4
target_devices = ['cpu'] * 4

logger.info("Start multi-process pool on devices: {}".format(', '.join(map(str, target_devices))))

Expand Down

0 comments on commit 94a71d9

Please sign in to comment.