diff --git a/src/intelligence_layer/connectors/limited_concurrency_client.py b/src/intelligence_layer/connectors/limited_concurrency_client.py index 0495fb073..9498e5a7f 100644 --- a/src/intelligence_layer/connectors/limited_concurrency_client.py +++ b/src/intelligence_layer/connectors/limited_concurrency_client.py @@ -1,6 +1,8 @@ from functools import lru_cache from os import getenv from threading import Semaphore +from time import sleep +import time from typing import Any, Mapping, Optional, Protocol, Sequence from aleph_alpha_client import ( @@ -106,10 +108,10 @@ class LimitedConcurrencyClient: """ def __init__( - self, client: AlephAlphaClientProtocol, max_concurrency: int = 20 - ) -> None: + self, client: AlephAlphaClientProtocol, max_concurrency: int = 20, max_retry_time: int = 600 ) -> None: self._client = client self._concurrency_limit_semaphore = Semaphore(max_concurrency) + self._max_retry_time = max_retry_time @classmethod @lru_cache(maxsize=1) @@ -143,7 +145,19 @@ def complete( model: str, ) -> CompletionResponse: with self._concurrency_limit_semaphore: - return self._client.complete(request, model) + retries = 0 + start_time = time.time() + while time.time() - start_time < self._max_retry_time: + try: + return self._client.complete(request, model) + except Exception as e: + if e.args[0] == 503: + sleep(min(2**retries,self._max_retry_time - (time.time() - start_time))) + retries += 1 + continue + else: + raise e + raise e def get_version(self) -> str: with self._concurrency_limit_semaphore: diff --git a/tests/connectors/test_limited_concurrency_client.py b/tests/connectors/test_limited_concurrency_client.py index a89e0ef53..d6f0a5c8b 100644 --- a/tests/connectors/test_limited_concurrency_client.py +++ b/tests/connectors/test_limited_concurrency_client.py @@ -1,68 +1,116 @@ -from concurrent.futures import ThreadPoolExecutor -from threading import Lock -from time import sleep -from typing import cast - -from aleph_alpha_client import CompletionRequest, CompletionResponse, Prompt -from pytest import fixture - -from intelligence_layer.connectors.limited_concurrency_client import ( - AlephAlphaClientProtocol, - LimitedConcurrencyClient, -) - - -class ConcurrencyCountingClient: - max_concurrency_counter: int = 0 - concurrency_counter: int = 0 - - def __init__(self) -> None: - self.lock = Lock() - - def complete(self, request: CompletionRequest, model: str) -> CompletionResponse: - with self.lock: - self.concurrency_counter += 1 - self.max_concurrency_counter = max( - self.max_concurrency_counter, self.concurrency_counter - ) - sleep(0.01) - with self.lock: - self.concurrency_counter -= 1 - return CompletionResponse( - model_version="model-version", - completions=[], - optimized_prompt=None, - num_tokens_generated=0, - num_tokens_prompt_total=0, - ) - - -TEST_MAX_CONCURRENCY = 3 - - -@fixture -def concurrency_counting_client() -> ConcurrencyCountingClient: - return ConcurrencyCountingClient() - - -@fixture -def limited_concurrency_client( - concurrency_counting_client: ConcurrencyCountingClient, -) -> LimitedConcurrencyClient: - return LimitedConcurrencyClient( - cast(AlephAlphaClientProtocol, concurrency_counting_client), - TEST_MAX_CONCURRENCY, - ) - - -def test_methods_concurrency_is_limited( - limited_concurrency_client: LimitedConcurrencyClient, - concurrency_counting_client: ConcurrencyCountingClient, -) -> None: - with ThreadPoolExecutor(max_workers=TEST_MAX_CONCURRENCY * 10) as executor: - executor.map( - limited_concurrency_client.complete, - [CompletionRequest(prompt=Prompt(""))] * TEST_MAX_CONCURRENCY * 10, - ["model"] * TEST_MAX_CONCURRENCY * 10, - ) - assert concurrency_counting_client.max_concurrency_counter == TEST_MAX_CONCURRENCY +from concurrent.futures import ThreadPoolExecutor +from threading import Lock +from time import sleep +from typing import cast + +from aleph_alpha_client import CompletionRequest, CompletionResponse, Prompt +from pytest import fixture +import pytest + +from intelligence_layer.connectors.limited_concurrency_client import ( + AlephAlphaClientProtocol, + LimitedConcurrencyClient, +) + + +class ConcurrencyCountingClient: + max_concurrency_counter: int = 0 + concurrency_counter: int = 0 + + def __init__(self) -> None: + self.lock = Lock() + + def complete(self, request: CompletionRequest, model: str) -> CompletionResponse: + with self.lock: + self.concurrency_counter += 1 + self.max_concurrency_counter = max( + self.max_concurrency_counter, self.concurrency_counter + ) + sleep(0.01) + with self.lock: + self.concurrency_counter -= 1 + return CompletionResponse( + model_version="model-version", + completions=[], + optimized_prompt=None, + num_tokens_generated=0, + num_tokens_prompt_total=0, + ) + + +class BusyClient: + def __init__(self, raise_exception_after_retries: bool) -> None: + self.number_of_retries: int = 0 + self.raise_exception = raise_exception_after_retries + + def complete(self, request: CompletionRequest, model: str) -> CompletionResponse: + self.number_of_retries += 1 + if self.number_of_retries < 2: + raise Exception(503) + else: + if self.raise_exception: + raise Exception(404) + else: + return CompletionResponse( + model_version="model-version", + completions=[], + optimized_prompt=None, + num_tokens_generated=0, + num_tokens_prompt_total=0, + ) + + +TEST_MAX_CONCURRENCY = 3 + + +@fixture +def concurrency_counting_client() -> ConcurrencyCountingClient: + return ConcurrencyCountingClient() + + +@fixture +def limited_concurrency_client( + concurrency_counting_client: ConcurrencyCountingClient, +) -> LimitedConcurrencyClient: + return LimitedConcurrencyClient( + cast(AlephAlphaClientProtocol, concurrency_counting_client), + TEST_MAX_CONCURRENCY, + ) + + +def test_methods_concurrency_is_limited( + limited_concurrency_client: LimitedConcurrencyClient, + concurrency_counting_client: ConcurrencyCountingClient, +) -> None: + with ThreadPoolExecutor(max_workers=TEST_MAX_CONCURRENCY * 10) as executor: + executor.map( + limited_concurrency_client.complete, + [CompletionRequest(prompt=Prompt(""))] * TEST_MAX_CONCURRENCY * 10, + ["model"] * TEST_MAX_CONCURRENCY * 10, + ) + assert concurrency_counting_client.max_concurrency_counter == TEST_MAX_CONCURRENCY + + +def test_limited_concurrency_client_retries() -> None: + busy_client = BusyClient(raise_exception_after_retries=False) + limited_concurrency_client = LimitedConcurrencyClient(busy_client) + completion = limited_concurrency_client.complete( + [CompletionRequest(prompt=Prompt(""))], ["model"] + ) + assert completion == CompletionResponse( + model_version="model-version", + completions=[], + optimized_prompt=None, + num_tokens_generated=0, + num_tokens_prompt_total=0, + ) + + +def test_limited_concurrency_client_throws_exception() -> None: + busy_client = BusyClient(raise_exception_after_retries=True) + limited_concurrency_client = LimitedConcurrencyClient(busy_client) + with pytest.raises(Exception) as e: + limited_concurrency_client.complete( + [CompletionRequest(prompt=Prompt(""))], ["model"] + ) + assert e.value.args[0] == 404