Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sync devel with main #647

Merged
merged 1 commit into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 75 additions & 64 deletions ansible_rulebook/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from ansible_rulebook import rules_parser as rules_parser
from ansible_rulebook.common import StartupArgs
from ansible_rulebook.conf import settings
from ansible_rulebook.exception import ShutdownException
from ansible_rulebook.token import renew_token

logger = logging.getLogger(__name__)
Expand All @@ -42,9 +41,41 @@
BACKOFF_INITIAL = 5


async def _wait_before_retry(backoff_delay: float) -> float:
# Sleep and retry implemention duplicated from
# websockets.lagacy.client.Client

# Add a random initial delay between 0 and 5 seconds.
# See 7.2.3. Recovering from Abnormal Closure in RFC 6544.
if backoff_delay == BACKOFF_MIN:
initial_delay = random.random() * BACKOFF_INITIAL
logger.info(
"! websocket connect failed; reconnecting in %.1f seconds",
initial_delay,
exc_info=False,
)
await asyncio.sleep(initial_delay)
else:
logger.info(
"! websocket connect failed again; retrying in %d seconds",
int(backoff_delay),
exc_info=False,
)
await asyncio.sleep(int(backoff_delay))
# Increase delay with truncated exponential backoff.
backoff_delay = backoff_delay * BACKOFF_FACTOR
backoff_delay = min(backoff_delay, BACKOFF_MAX)
return backoff_delay


async def _update_authorization_header(headers: dict) -> None:
new_token = await renew_token()
headers["Authorization"] = f"Bearer {new_token}"


async def _connect_websocket(
handler: tp.Callable[[WebSocketClientProtocol], tp.Awaitable],
retry: bool,
retry_on_close: bool,
**kwargs: list,
) -> tp.Any:
logger.info("websocket %s connecting", settings.websocket_url)
Expand All @@ -55,79 +86,57 @@ async def _connect_websocket(
else:
extra_headers = {}

result = None
refresh = True
refresh_token = True

backoff_delay = BACKOFF_MIN
while True:
backoff_delay = BACKOFF_MIN
try:
async with websockets.connect(
settings.websocket_url,
ssl=_sslcontext(),
extra_headers=extra_headers,
) as websocket:
result = await handler(websocket, **kwargs)
if not retry:
break
# Connection succeeded - reset backoff delay and refresh_token
backoff_delay = BACKOFF_MIN
refresh_token = True
return await handler(websocket, **kwargs)
except asyncio.CancelledError: # pragma: no cover
raise
except ShutdownException:
break
except Exception as e:
status403_legacy = (
isinstance(e, websockets.exceptions.InvalidStatusCode)
and e.status_code == 403
)
status403 = (
isinstance(e, websockets.exceptions.InvalidStatus)
and e.response.status_code == 403
)
if status403_legacy or status403:
if refresh and settings.websocket_refresh_token:
new_token = await renew_token()
extra_headers["Authorization"] = f"Bearer {new_token}"
# Only attempt to refresh token once. If a new token cannot
# establish the connection, something else must cause 403
refresh = False
else:
raise
elif isinstance(e, OSError) and "[Errno 61]" in str(e):
# Sleep and retry implemention duplicated from
# websockets.lagacy.client.Client

# Add a random initial delay between 0 and 5 seconds.
# See 7.2.3. Recovering from Abnormal Closure in RFC 6544.
if backoff_delay == BACKOFF_MIN:
initial_delay = random.random() * BACKOFF_INITIAL
logger.info(
"! connect failed; reconnecting in %.1f seconds",
initial_delay,
exc_info=True,
)
await asyncio.sleep(initial_delay)
else:
logger.info(
"! connect failed again; retrying in %d seconds",
int(backoff_delay),
exc_info=True,
)
await asyncio.sleep(int(backoff_delay))
# Increase delay with truncated exponential backoff.
backoff_delay = backoff_delay * BACKOFF_FACTOR
backoff_delay = min(backoff_delay, BACKOFF_MAX)
continue
else:
# Connection succeeded - reset backoff delay
backoff_delay = BACKOFF_MIN
refresh = True

return result
except websockets.exceptions.InvalidStatusCode as e:
if refresh_token and e.status_code == 403:
await _update_authorization_header(extra_headers)
# Only attempt to refresh token once. If a new token cannot
# establish the connection, something else must have caused 403
refresh_token = False
else:
raise # abort
except websockets.exceptions.InvalidStatus as e:
if refresh_token and e.response.status_code == 403:
await _update_authorization_header(extra_headers)
refresh_token = False
else:
raise # abort
except OSError as e:
if "[Errno 61]" in str(e):
# if connection cannot be established, retry later
backoff_delay = await _wait_before_retry(backoff_delay)
else:
raise # abort
except websockets.exceptions.ConnectionClosedError as e:
if e.code == 1011:
# unexpected error raised from server
raise # abort
if retry_on_close:
backoff_delay = await _wait_before_retry(backoff_delay)
except websockets.exceptions.ConnectionClosedOK:
if retry_on_close:
backoff_delay = await _wait_before_retry(backoff_delay)


async def request_workload(activation_instance_id: str) -> StartupArgs:
return await _connect_websocket(
handler=_handle_request_workload,
retry=False,
retry_on_close=False,
activation_instance_id=activation_instance_id,
)

Expand Down Expand Up @@ -193,7 +202,7 @@ async def send_event_log_to_websocket(event_log: asyncio.Queue):

return await _connect_websocket(
handler=_handle_send_event_log,
retry=True,
retry_on_close=True,
logs=logs,
)

Expand All @@ -206,7 +215,8 @@ async def _handle_send_event_log(

if logs.event:
logger.info("Resending last event...")
await websocket.send(json.dumps(logs.event))
json_str = json.dumps(logs.event)
await websocket.send(json_str)
logs.event = None

while True:
Expand All @@ -215,10 +225,11 @@ async def _handle_send_event_log(

if event == dict(type="Exit"):
logger.info("Exiting feedback websocket task")
raise ShutdownException(shutdown=None)
break

logs.event = event
await websocket.send(json.dumps(event))
json_str = json.dumps(event)
await websocket.send(json_str)
logs.event = None


Expand Down
14 changes: 10 additions & 4 deletions tests/e2e/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

@pytest.mark.e2e
@pytest.mark.asyncio
async def test_websocket_messages():
@pytest.mark.parametrize("expect_failure", [True, False])
async def test_websocket_messages(expect_failure):
"""
Verify that ansible-rulebook can correctly
send event messages to a websocket server
Expand All @@ -25,7 +26,7 @@ async def test_websocket_messages():
host = "127.0.0.1"
endpoint = "/api/ws2"
proc_id = "42"
port = 31415
port = 31415 if expect_failure else 31414
rulebook = (
utils.BASE_DATA_PATH / "rulebooks/websockets/test_websocket_range.yml"
)
Expand All @@ -39,7 +40,7 @@ async def test_websocket_messages():

# run server and ansible-rulebook
queue = asyncio.Queue()
handler = partial(utils.msg_handler, queue=queue)
handler = partial(utils.msg_handler, queue=queue, failed=expect_failure)
async with ws_server.serve(handler, host, port):
LOGGER.info(f"Running command: {cmd}")
proc = await asyncio.create_subprocess_shell(
Expand All @@ -50,7 +51,12 @@ async def test_websocket_messages():
)

await asyncio.wait_for(proc.wait(), timeout=DEFAULT_TIMEOUT)
assert proc.returncode == 0
if expect_failure:
assert proc.returncode == 1
assert queue.qsize() == 2
return
else:
assert proc.returncode == 0

# Verify data
assert not queue.empty()
Expand Down
11 changes: 10 additions & 1 deletion tests/e2e/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,22 @@ def assert_playbook_output(result: CompletedProcess) -> List[dict]:


async def msg_handler(
websocket: ws_server.WebSocketServerProtocol, queue: asyncio.Queue
websocket: ws_server.WebSocketServerProtocol,
queue: asyncio.Queue,
failed: bool = False,
):
"""
Handler for a websocket server that passes json messages
from ansible-rulebook in the given queue
"""
i = 0
async for message in websocket:
payload = json.loads(message)
data = {"path": websocket.path, "payload": payload}
await queue.put(data)
if i == 1:
if failed:
print(data["bad"]) # force a coding error
else:
await websocket.close() # should be auto reconnected
i += 1
38 changes: 36 additions & 2 deletions tests/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,16 @@ def my_func(data):


@pytest.mark.asyncio
@pytest.mark.parametrize(
"exception_class",
[
websockets.exceptions.ConnectionClosedError,
websockets.exceptions.ConnectionClosedError,
],
)
@mock.patch("ansible_rulebook.websocket.websockets.connect")
async def test_send_event_log_to_websocket_with_exception(
socket_mock: AsyncMock,
socket_mock: AsyncMock, exception_class
):
prepare_settings()
queue = asyncio.Queue()
Expand All @@ -154,10 +161,37 @@ async def test_send_event_log_to_websocket_with_exception(
socket_mock.return_value.__aiter__.side_effect = [mock_object]

socket_mock.return_value.send.side_effect = [
websockets.exceptions.ConnectionClosed(rcvd=None, sent=None),
exception_class(rcvd=None, sent=None),
data_sent.append({"a": 1}),
data_sent.append({"b": 2}),
]

await send_event_log_to_websocket(queue)
assert data_sent == [{"a": 1}, {"b": 2}]


@pytest.mark.asyncio
@mock.patch("ansible_rulebook.websocket.websockets.connect")
async def test_send_event_log_to_websocket_with_non_recoverable_exception(
socket_mock: AsyncMock,
):
prepare_settings()
queue = asyncio.Queue()
queue.put_nowait({"a": 1})
queue.put_nowait({"b": 2})
queue.put_nowait(dict(type="Exit"))

mock_object = AsyncMock()
socket_mock.return_value = mock_object
socket_mock.return_value.__aenter__.return_value = mock_object
socket_mock.return_value.__anext__.return_value = mock_object
socket_mock.return_value.__aiter__.side_effect = [mock_object]

rcvd = mock.Mock()
rcvd.code = 1011
socket_mock.return_value.send.side_effect = (
websockets.exceptions.ConnectionClosedError(rcvd=rcvd, sent=None)
)

with pytest.raises(websockets.exceptions.ConnectionClosedError):
await send_event_log_to_websocket(queue)