Skip to content

Commit

Permalink
feature: Add retry logic in LimitedConcurrencyClient in case of `Bu…
Browse files Browse the repository at this point in the history
…syError`

TASK: IL-431
  • Loading branch information
FlorianSchepersAA committed Apr 25, 2024
1 parent 1c4f153 commit 4702354
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 71 deletions.
20 changes: 17 additions & 3 deletions src/intelligence_layer/connectors/limited_concurrency_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from functools import lru_cache
from os import getenv
from threading import Semaphore
from time import sleep
import time
from typing import Any, Mapping, Optional, Protocol, Sequence

from aleph_alpha_client import (
Expand Down Expand Up @@ -106,10 +108,10 @@ class LimitedConcurrencyClient:
"""

def __init__(
self, client: AlephAlphaClientProtocol, max_concurrency: int = 20
) -> None:
self, client: AlephAlphaClientProtocol, max_concurrency: int = 20, max_retry_time: int = 600 ) -> None:
self._client = client
self._concurrency_limit_semaphore = Semaphore(max_concurrency)
self._max_retry_time = max_retry_time

@classmethod
@lru_cache(maxsize=1)
Expand Down Expand Up @@ -143,7 +145,19 @@ def complete(
model: str,
) -> CompletionResponse:
with self._concurrency_limit_semaphore:
return self._client.complete(request, model)
retries = 0
start_time = time.time()
while time.time() - start_time < self._max_retry_time:
try:
return self._client.complete(request, model)
except Exception as e:
if e.args[0] == 503:
sleep(min(2**retries,self._max_retry_time - (time.time() - start_time)))
retries += 1
continue
else:
raise e
raise e

def get_version(self) -> str:
with self._concurrency_limit_semaphore:
Expand Down
185 changes: 117 additions & 68 deletions tests/connectors/test_limited_concurrency_client.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,117 @@
from concurrent.futures import ThreadPoolExecutor
from threading import Lock
from time import sleep
from typing import cast

from aleph_alpha_client import CompletionRequest, CompletionResponse, Prompt
from pytest import fixture

from intelligence_layer.connectors.limited_concurrency_client import (
AlephAlphaClientProtocol,
LimitedConcurrencyClient,
)


class ConcurrencyCountingClient:
max_concurrency_counter: int = 0
concurrency_counter: int = 0

def __init__(self) -> None:
self.lock = Lock()

def complete(self, request: CompletionRequest, model: str) -> CompletionResponse:
with self.lock:
self.concurrency_counter += 1
self.max_concurrency_counter = max(
self.max_concurrency_counter, self.concurrency_counter
)
sleep(0.01)
with self.lock:
self.concurrency_counter -= 1
return CompletionResponse(
model_version="model-version",
completions=[],
optimized_prompt=None,
num_tokens_generated=0,
num_tokens_prompt_total=0,
)


TEST_MAX_CONCURRENCY = 3


@fixture
def concurrency_counting_client() -> ConcurrencyCountingClient:
return ConcurrencyCountingClient()


@fixture
def limited_concurrency_client(
concurrency_counting_client: ConcurrencyCountingClient,
) -> LimitedConcurrencyClient:
return LimitedConcurrencyClient(
cast(AlephAlphaClientProtocol, concurrency_counting_client),
TEST_MAX_CONCURRENCY,
)


def test_methods_concurrency_is_limited(
limited_concurrency_client: LimitedConcurrencyClient,
concurrency_counting_client: ConcurrencyCountingClient,
) -> None:
with ThreadPoolExecutor(max_workers=TEST_MAX_CONCURRENCY * 10) as executor:
executor.map(
limited_concurrency_client.complete,
[CompletionRequest(prompt=Prompt(""))] * TEST_MAX_CONCURRENCY * 10,
["model"] * TEST_MAX_CONCURRENCY * 10,
)
assert concurrency_counting_client.max_concurrency_counter == TEST_MAX_CONCURRENCY
from concurrent.futures import ThreadPoolExecutor
from threading import Lock
from time import sleep
from typing import cast

from aleph_alpha_client import CompletionRequest, CompletionResponse, Prompt
from pytest import fixture
import pytest

from intelligence_layer.connectors.limited_concurrency_client import (
AlephAlphaClientProtocol,
LimitedConcurrencyClient,
)


class ConcurrencyCountingClient:
max_concurrency_counter: int = 0
concurrency_counter: int = 0

def __init__(self) -> None:
self.lock = Lock()

def complete(self, request: CompletionRequest, model: str) -> CompletionResponse:
with self.lock:
self.concurrency_counter += 1
self.max_concurrency_counter = max(
self.max_concurrency_counter, self.concurrency_counter
)
sleep(0.01)
with self.lock:
self.concurrency_counter -= 1
return CompletionResponse(
model_version="model-version",
completions=[],
optimized_prompt=None,
num_tokens_generated=0,
num_tokens_prompt_total=0,
)


class BusyClient:
def __init__(
self,
return_value: CompletionResponse | Exception,
) -> None:
self.number_of_retries: int = 0
self.return_value = return_value

def complete(self, request: CompletionRequest, model: str) -> CompletionResponse:
self.number_of_retries += 1
if self.number_of_retries < 2:
raise Exception(503)
else:
if isinstance(self.return_value, Exception):
raise self.return_value
else:
return self.return_value


TEST_MAX_CONCURRENCY = 3


@fixture
def concurrency_counting_client() -> ConcurrencyCountingClient:
return ConcurrencyCountingClient()


@fixture
def limited_concurrency_client(
concurrency_counting_client: ConcurrencyCountingClient,
) -> LimitedConcurrencyClient:
return LimitedConcurrencyClient(
cast(AlephAlphaClientProtocol, concurrency_counting_client),
TEST_MAX_CONCURRENCY,
)


def test_methods_concurrency_is_limited(
limited_concurrency_client: LimitedConcurrencyClient,
concurrency_counting_client: ConcurrencyCountingClient,
) -> None:
with ThreadPoolExecutor(max_workers=TEST_MAX_CONCURRENCY * 10) as executor:
executor.map(
limited_concurrency_client.complete,
[CompletionRequest(prompt=Prompt(""))] * TEST_MAX_CONCURRENCY * 10,
["model"] * TEST_MAX_CONCURRENCY * 10,
)
assert concurrency_counting_client.max_concurrency_counter == TEST_MAX_CONCURRENCY


def test_limited_concurrency_client_retries() -> None:
expected_completion = CompletionResponse(
model_version="model-version",
completions=[],
optimized_prompt=None,
num_tokens_generated=0,
num_tokens_prompt_total=0,
)
busy_client = BusyClient(
return_value=expected_completion
)
limited_concurrency_client = LimitedConcurrencyClient(busy_client)
completion = limited_concurrency_client.complete(
[CompletionRequest(prompt=Prompt(""))], ["model"]
)
assert completion == expected_completion


def test_limited_concurrency_client_throws_exception() -> None:
expected_exception = Exception(404)
busy_client = BusyClient(return_value=expected_exception)
limited_concurrency_client = LimitedConcurrencyClient(busy_client)
with pytest.raises(Exception) as e:
limited_concurrency_client.complete(
[CompletionRequest(prompt=Prompt(""))], ["model"]
)
assert e.value == expected_exception

0 comments on commit 4702354

Please sign in to comment.