Skip to content

Commit

Permalink
fix: handle unexpected error raised from websocket server (#644)
Browse files Browse the repository at this point in the history
Enhance error handling for websocket communication. Differenciate error
at connection or at sending or receiving. Will not retry if an
unexpected error (code 1011) is raised by the server.

fix AAP-19535: Ansible-rulebook keeps on retrying requesting workload
through websocket
  • Loading branch information
bzwei authored Jan 18, 2024
1 parent 927e589 commit 37344da
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 71 deletions.
139 changes: 75 additions & 64 deletions ansible_rulebook/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from ansible_rulebook import rules_parser as rules_parser
from ansible_rulebook.common import StartupArgs
from ansible_rulebook.conf import settings
from ansible_rulebook.exception import ShutdownException
from ansible_rulebook.token import renew_token

logger = logging.getLogger(__name__)
Expand All @@ -42,9 +41,41 @@
BACKOFF_INITIAL = 5


async def _wait_before_retry(backoff_delay: float) -> float:
# Sleep and retry implemention duplicated from
# websockets.lagacy.client.Client

# Add a random initial delay between 0 and 5 seconds.
# See 7.2.3. Recovering from Abnormal Closure in RFC 6544.
if backoff_delay == BACKOFF_MIN:
initial_delay = random.random() * BACKOFF_INITIAL
logger.info(
"! websocket connect failed; reconnecting in %.1f seconds",
initial_delay,
exc_info=False,
)
await asyncio.sleep(initial_delay)
else:
logger.info(
"! websocket connect failed again; retrying in %d seconds",
int(backoff_delay),
exc_info=False,
)
await asyncio.sleep(int(backoff_delay))
# Increase delay with truncated exponential backoff.
backoff_delay = backoff_delay * BACKOFF_FACTOR
backoff_delay = min(backoff_delay, BACKOFF_MAX)
return backoff_delay


async def _update_authorization_header(headers: dict) -> None:
new_token = await renew_token()
headers["Authorization"] = f"Bearer {new_token}"


async def _connect_websocket(
handler: tp.Callable[[WebSocketClientProtocol], tp.Awaitable],
retry: bool,
retry_on_close: bool,
**kwargs: list,
) -> tp.Any:
logger.info("websocket %s connecting", settings.websocket_url)
Expand All @@ -55,79 +86,57 @@ async def _connect_websocket(
else:
extra_headers = {}

result = None
refresh = True
refresh_token = True

backoff_delay = BACKOFF_MIN
while True:
backoff_delay = BACKOFF_MIN
try:
async with websockets.connect(
settings.websocket_url,
ssl=_sslcontext(),
extra_headers=extra_headers,
) as websocket:
result = await handler(websocket, **kwargs)
if not retry:
break
# Connection succeeded - reset backoff delay and refresh_token
backoff_delay = BACKOFF_MIN
refresh_token = True
return await handler(websocket, **kwargs)
except asyncio.CancelledError: # pragma: no cover
raise
except ShutdownException:
break
except Exception as e:
status403_legacy = (
isinstance(e, websockets.exceptions.InvalidStatusCode)
and e.status_code == 403
)
status403 = (
isinstance(e, websockets.exceptions.InvalidStatus)
and e.response.status_code == 403
)
if status403_legacy or status403:
if refresh and settings.websocket_refresh_token:
new_token = await renew_token()
extra_headers["Authorization"] = f"Bearer {new_token}"
# Only attempt to refresh token once. If a new token cannot
# establish the connection, something else must cause 403
refresh = False
else:
raise
elif isinstance(e, OSError) and "[Errno 61]" in str(e):
# Sleep and retry implemention duplicated from
# websockets.lagacy.client.Client

# Add a random initial delay between 0 and 5 seconds.
# See 7.2.3. Recovering from Abnormal Closure in RFC 6544.
if backoff_delay == BACKOFF_MIN:
initial_delay = random.random() * BACKOFF_INITIAL
logger.info(
"! connect failed; reconnecting in %.1f seconds",
initial_delay,
exc_info=True,
)
await asyncio.sleep(initial_delay)
else:
logger.info(
"! connect failed again; retrying in %d seconds",
int(backoff_delay),
exc_info=True,
)
await asyncio.sleep(int(backoff_delay))
# Increase delay with truncated exponential backoff.
backoff_delay = backoff_delay * BACKOFF_FACTOR
backoff_delay = min(backoff_delay, BACKOFF_MAX)
continue
else:
# Connection succeeded - reset backoff delay
backoff_delay = BACKOFF_MIN
refresh = True

return result
except websockets.exceptions.InvalidStatusCode as e:
if refresh_token and e.status_code == 403:
await _update_authorization_header(extra_headers)
# Only attempt to refresh token once. If a new token cannot
# establish the connection, something else must have caused 403
refresh_token = False
else:
raise # abort
except websockets.exceptions.InvalidStatus as e:
if refresh_token and e.response.status_code == 403:
await _update_authorization_header(extra_headers)
refresh_token = False
else:
raise # abort
except OSError as e:
if "[Errno 61]" in str(e):
# if connection cannot be established, retry later
backoff_delay = await _wait_before_retry(backoff_delay)
else:
raise # abort
except websockets.exceptions.ConnectionClosedError as e:
if e.code == 1011:
# unexpected error raised from server
raise # abort
if retry_on_close:
backoff_delay = await _wait_before_retry(backoff_delay)
except websockets.exceptions.ConnectionClosedOK:
if retry_on_close:
backoff_delay = await _wait_before_retry(backoff_delay)


async def request_workload(activation_instance_id: str) -> StartupArgs:
return await _connect_websocket(
handler=_handle_request_workload,
retry=False,
retry_on_close=False,
activation_instance_id=activation_instance_id,
)

Expand Down Expand Up @@ -193,7 +202,7 @@ async def send_event_log_to_websocket(event_log: asyncio.Queue):

return await _connect_websocket(
handler=_handle_send_event_log,
retry=True,
retry_on_close=True,
logs=logs,
)

Expand All @@ -206,7 +215,8 @@ async def _handle_send_event_log(

if logs.event:
logger.info("Resending last event...")
await websocket.send(json.dumps(logs.event))
json_str = json.dumps(logs.event)
await websocket.send(json_str)
logs.event = None

while True:
Expand All @@ -215,10 +225,11 @@ async def _handle_send_event_log(

if event == dict(type="Exit"):
logger.info("Exiting feedback websocket task")
raise ShutdownException(shutdown=None)
break

logs.event = event
await websocket.send(json.dumps(event))
json_str = json.dumps(event)
await websocket.send(json_str)
logs.event = None


Expand Down
14 changes: 10 additions & 4 deletions tests/e2e/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

@pytest.mark.e2e
@pytest.mark.asyncio
async def test_websocket_messages():
@pytest.mark.parametrize("expect_failure", [True, False])
async def test_websocket_messages(expect_failure):
"""
Verify that ansible-rulebook can correctly
send event messages to a websocket server
Expand All @@ -25,7 +26,7 @@ async def test_websocket_messages():
host = "127.0.0.1"
endpoint = "/api/ws2"
proc_id = "42"
port = 31415
port = 31415 if expect_failure else 31414
rulebook = (
utils.BASE_DATA_PATH / "rulebooks/websockets/test_websocket_range.yml"
)
Expand All @@ -39,7 +40,7 @@ async def test_websocket_messages():

# run server and ansible-rulebook
queue = asyncio.Queue()
handler = partial(utils.msg_handler, queue=queue)
handler = partial(utils.msg_handler, queue=queue, failed=expect_failure)
async with ws_server.serve(handler, host, port):
LOGGER.info(f"Running command: {cmd}")
proc = await asyncio.create_subprocess_shell(
Expand All @@ -50,7 +51,12 @@ async def test_websocket_messages():
)

await asyncio.wait_for(proc.wait(), timeout=DEFAULT_TIMEOUT)
assert proc.returncode == 0
if expect_failure:
assert proc.returncode == 1
assert queue.qsize() == 2
return
else:
assert proc.returncode == 0

# Verify data
assert not queue.empty()
Expand Down
11 changes: 10 additions & 1 deletion tests/e2e/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,22 @@ def assert_playbook_output(result: CompletedProcess) -> List[dict]:


async def msg_handler(
websocket: ws_server.WebSocketServerProtocol, queue: asyncio.Queue
websocket: ws_server.WebSocketServerProtocol,
queue: asyncio.Queue,
failed: bool = False,
):
"""
Handler for a websocket server that passes json messages
from ansible-rulebook in the given queue
"""
i = 0
async for message in websocket:
payload = json.loads(message)
data = {"path": websocket.path, "payload": payload}
await queue.put(data)
if i == 1:
if failed:
print(data["bad"]) # force a coding error
else:
await websocket.close() # should be auto reconnected
i += 1
38 changes: 36 additions & 2 deletions tests/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,16 @@ def my_func(data):


@pytest.mark.asyncio
@pytest.mark.parametrize(
"exception_class",
[
websockets.exceptions.ConnectionClosedError,
websockets.exceptions.ConnectionClosedError,
],
)
@mock.patch("ansible_rulebook.websocket.websockets.connect")
async def test_send_event_log_to_websocket_with_exception(
socket_mock: AsyncMock,
socket_mock: AsyncMock, exception_class
):
prepare_settings()
queue = asyncio.Queue()
Expand All @@ -154,10 +161,37 @@ async def test_send_event_log_to_websocket_with_exception(
socket_mock.return_value.__aiter__.side_effect = [mock_object]

socket_mock.return_value.send.side_effect = [
websockets.exceptions.ConnectionClosed(rcvd=None, sent=None),
exception_class(rcvd=None, sent=None),
data_sent.append({"a": 1}),
data_sent.append({"b": 2}),
]

await send_event_log_to_websocket(queue)
assert data_sent == [{"a": 1}, {"b": 2}]


@pytest.mark.asyncio
@mock.patch("ansible_rulebook.websocket.websockets.connect")
async def test_send_event_log_to_websocket_with_non_recoverable_exception(
socket_mock: AsyncMock,
):
prepare_settings()
queue = asyncio.Queue()
queue.put_nowait({"a": 1})
queue.put_nowait({"b": 2})
queue.put_nowait(dict(type="Exit"))

mock_object = AsyncMock()
socket_mock.return_value = mock_object
socket_mock.return_value.__aenter__.return_value = mock_object
socket_mock.return_value.__anext__.return_value = mock_object
socket_mock.return_value.__aiter__.side_effect = [mock_object]

rcvd = mock.Mock()
rcvd.code = 1011
socket_mock.return_value.send.side_effect = (
websockets.exceptions.ConnectionClosedError(rcvd=rcvd, sent=None)
)

with pytest.raises(websockets.exceptions.ConnectionClosedError):
await send_event_log_to_websocket(queue)

0 comments on commit 37344da

Please sign in to comment.