diff --git a/cohere/compass/clients/parser.py b/cohere/compass/clients/parser.py index 9fa3bc9..ab55bb9 100644 --- a/cohere/compass/clients/parser.py +++ b/cohere/compass/clients/parser.py @@ -54,6 +54,7 @@ def __init__( username: Optional[str] = None, password: Optional[str] = None, num_workers: int = 4, + retries: int = 3, ): """ Initialize the CompassParserClient. @@ -79,8 +80,9 @@ def __init__( self.username = username or os.getenv("COHERE_COMPASS_USERNAME") self.password = password or os.getenv("COHERE_COMPASS_PASSWORD") self.session = requests.Session() - self.thread_pool = ThreadPoolExecutor(num_workers) + self.thread_pool = ThreadPoolExecutor(num_workers * 2) self.num_workers = num_workers + self.retries = retries self.metadata_config = metadata_config logger.info( @@ -181,6 +183,7 @@ def process_file(i: int) -> list[CompassDocument]: process_file, range(len(filenames)), max_queued=self.num_workers, + retries=self.retries, ): yield from results diff --git a/cohere/compass/utils.py b/cohere/compass/utils.py index 371b258..3398896 100644 --- a/cohere/compass/utils.py +++ b/cohere/compass/utils.py @@ -1,6 +1,7 @@ # Python imports import base64 import glob +import logging import os import uuid from collections.abc import Iterable, Iterator @@ -23,9 +24,15 @@ T = TypeVar("T") U = TypeVar("U") +logger = logging.getLogger(__name__) + def imap_queued( - executor: Executor, f: Callable[[T], U], it: Iterable[T], max_queued: int + executor: Executor, + f: Callable[[T], U], + it: Iterable[T], + max_queued: int, + retries: int = 3, ) -> Iterator[U]: """ Similar to Python's `map`, but uses an executor to parallelize the calls. @@ -34,20 +41,43 @@ def imap_queued( :param f: the function to call. :param it: the iterable to map over. :param max_queued: the maximum number of futures to keep in flight. + :param retries: maximum number of retries to make in case of failure :returns: an iterator over the results. """ + + def _execute_with_retry(f: Callable[[T], U], x: T): + """ + Execute a function with retries on failure. + + :param f: the function to call. + :param x: the parameter to pass to function. + """ + for attempt in range(retries): + future = executor.submit(f, x) + try: + return future.result() # Attempt to get the result + except Exception as e: + logger.info(f"Attempt {attempt + 1} failed for input {x}: {e}") + if ( + attempt == retries - 1 + ): # If it's the last attempt, re-raise the exception + logger.error(f"Cannot process file {x} after {retries} attempts") + assert max_queued >= 1 futures_set: set[futures.Future[U]] = set() for x in it: - futures_set.add(executor.submit(f, x)) + futures_set.add(executor.submit(_execute_with_retry, f, x)) while len(futures_set) > max_queued: done, futures_set = futures.wait( futures_set, return_when=futures.FIRST_COMPLETED ) for future in done: - yield future.result() + try: + yield future.result() + except Exception: + logger.error(f"Cannot process file {x} after {retries} attempts") for future in futures.as_completed(futures_set): yield future.result()