Skip to content

Commit

Permalink
Merge pull request #1477 from neptune-ai/as/support-retry-after
Browse files Browse the repository at this point in the history
Add support for retry-after mechanism
  • Loading branch information
asledz authored Oct 5, 2023
2 parents 0cb299b + 8c7e289 commit 385a08d
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 7 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
- Programmatically delete trashed neptune objects ([#1475](https://github.com/neptune-ai/neptune-client/pull/1475))
- Added support for callbacks that stop the synchronization if the lag or lack of progress exceeds a certain threshold ([#1478](https://github.com/neptune-ai/neptune-client/pull/1478))

### Changes
- - Add support for `retry-after` header in HTTPTooManyRequests ([#1477](https://github.com/neptune-ai/neptune-client/pull/1477))

### Fixes
- Add newline at the end of generated `.patch` while tracking uncommitted changes ([#1473](https://github.com/neptune-ai/neptune-client/pull/1473))
- Clarify `NeptuneLimitExceedException` error message ([#1480](https://github.com/neptune-ai/neptune-client/pull/1480))
Expand Down
28 changes: 23 additions & 5 deletions src/neptune/common/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = ["with_api_exceptions_handler"]
__all__ = ["with_api_exceptions_handler", "get_retry_from_headers_or_default"]

import itertools
import logging
Expand Down Expand Up @@ -52,9 +52,19 @@
_logger = logging.getLogger(__name__)

MAX_RETRY_TIME = 30
MAX_RETRY_MULTIPLIER = 10
retries_timeout = int(os.getenv(NEPTUNE_RETRIES_TIMEOUT_ENV, "60"))


def get_retry_from_headers_or_default(headers, retry_count):
try:
return (
int(headers["retry-after"][0]) if "retry-after" in headers else 2 ** min(MAX_RETRY_MULTIPLIER, retry_count)
)
except Exception:
return min(2 ** min(MAX_RETRY_MULTIPLIER, retry_count), MAX_RETRY_TIME)


def with_api_exceptions_handler(func):
def wrapper(*args, **kwargs):
ssl_error_occurred = False
Expand Down Expand Up @@ -95,12 +105,16 @@ def wrapper(*args, **kwargs):
HTTPServiceUnavailable,
HTTPGatewayTimeout,
HTTPBadGateway,
HTTPTooManyRequests,
HTTPInternalServerError,
NewConnectionError,
ChunkedEncodingError,
) as e:
time.sleep(min(2 ** min(10, retry), MAX_RETRY_TIME))
time.sleep(min(2 ** min(MAX_RETRY_MULTIPLIER, retry), MAX_RETRY_TIME))
last_exception = e
continue
except HTTPTooManyRequests as e:
wait_time = get_retry_from_headers_or_default(e.response.headers, retry)
time.sleep(wait_time)
last_exception = e
continue
except NeptuneAuthTokenExpired:
Expand All @@ -120,10 +134,14 @@ def wrapper(*args, **kwargs):
HTTPBadGateway.status_code,
HTTPServiceUnavailable.status_code,
HTTPGatewayTimeout.status_code,
HTTPTooManyRequests.status_code,
HTTPInternalServerError.status_code,
):
time.sleep(min(2 ** min(10, retry), MAX_RETRY_TIME))
time.sleep(min(2 ** min(MAX_RETRY_MULTIPLIER, retry), MAX_RETRY_TIME))
last_exception = e
continue
elif status_code == HTTPTooManyRequests.status_code:
wait_time = get_retry_from_headers_or_default(e.response.headers, retry)
time.sleep(wait_time)
last_exception = e
continue
elif status_code == HTTPUnauthorized.status_code:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from urllib3.exceptions import NewConnectionError

from neptune.common.backends.utils import get_retry_from_headers_or_default
from neptune.legacy.api_exceptions import (
ConnectionLost,
Forbidden,
Expand Down Expand Up @@ -66,7 +67,6 @@ def wrapper(*args, **kwargs):
HTTPRequestTimeout,
HTTPGatewayTimeout,
HTTPBadGateway,
HTTPTooManyRequests,
HTTPInternalServerError,
NewConnectionError,
):
Expand All @@ -75,6 +75,11 @@ def wrapper(*args, **kwargs):
time.sleep(2**retry)
retry += 1
continue
except HTTPTooManyRequests as e:
wait_time = get_retry_from_headers_or_default(e.response.headers, retry)
time.sleep(wait_time)
retry += 1
continue
except HTTPUnauthorized:
raise Unauthorized()
except HTTPForbidden:
Expand All @@ -88,7 +93,6 @@ def wrapper(*args, **kwargs):
HTTPBadGateway.status_code,
HTTPServiceUnavailable.status_code,
HTTPGatewayTimeout.status_code,
HTTPTooManyRequests.status_code,
HTTPInternalServerError.status_code,
):
if retry >= 6:
Expand All @@ -98,6 +102,11 @@ def wrapper(*args, **kwargs):
time.sleep(2**retry)
retry += 1
continue
elif status_code == HTTPTooManyRequests.status_code:
wait_time = get_retry_from_headers_or_default(e.response.headers, retry)
time.sleep(wait_time)
retry += 1
continue
elif status_code >= HTTPInternalServerError.status_code:
raise ServerError()
elif status_code == HTTPUnauthorized.status_code:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
#
import socket
import time
import unittest
import uuid
from pathlib import Path
Expand Down Expand Up @@ -226,6 +227,43 @@ def test_execute_operations(self, upload_mock, swagger_client_factory):
result,
)

@pytest.mark.asyncio
@patch("socket.gethostbyname", MagicMock(return_value="1.1.1.1"))
async def test_too_many_requests(self, swagger_client_factory):
# given
swagger_client = self._get_swagger_client_mock(swagger_client_factory)
backend = HostedNeptuneBackend(credentials)
container_uuid = str(uuid.uuid4())

container_type = ContainerType.RUN

response = MagicMock()
response.response().return_value = []

retry_after_seconds = 5 # Przykładowy czas oczekiwania
too_many_requests_response = HTTPTooManyRequests(MagicMock())
too_many_requests_response.headers = {"retry-after": str(retry_after_seconds)}

swagger_client.api.executeOperations.side_effect = Mock(
side_effect=[too_many_requests_response, response_mock()]
)

# when
result_start_time = time.time()
result = await backend.execute_async( # Użyj await, aby poczekać na wykonanie coroutine
container_id=container_uuid,
container_type=container_type,
operations=[
LogFloats(["images", "img1"], [LogFloats.ValueType(1, 2, 3)]),
],
operation_storage=self.dummy_operation_storage,
)
result_end_time = time.time()

# then
self.assertEqual(result, (1, []))
assert retry_after_seconds <= (result_end_time - result_start_time) <= (retry_after_seconds * 2)

@patch("socket.gethostbyname", MagicMock(return_value="1.1.1.1"))
def test_execute_operations_retry_request(self, swagger_client_factory):
# given
Expand Down

0 comments on commit 385a08d

Please sign in to comment.