diff --git a/juju/client/connection.py b/juju/client/connection.py index 88c31c2ab..03eb50961 100644 --- a/juju/client/connection.py +++ b/juju/client/connection.py @@ -92,7 +92,7 @@ def status(self): and connection._receiver_task.cancelled() ) - if stopped or not connection._ws.open: + if stopped or connection._ws.state is not websockets.protocol.State.OPEN: return self.ERROR # everything is fine! @@ -357,8 +357,7 @@ async def close(self, to_reconnect: bool = False): tasks_need_to_be_gathered.append(self._debug_log_task) self._debug_log_task.cancel() - if self._ws and not self._ws.closed: - await self._ws.close() + await self._ws.close() if not to_reconnect: try: diff --git a/pyproject.toml b/pyproject.toml index f3c1118ec..eeb170742 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dependencies = [ "macaroonbakery>=1.1,<2.0", "pyRFC3339>=1.0,<2.0", "pyyaml>=5.1.2", - "websockets>=13.0.1,<14.0", + "websockets>=13.0.1", "paramiko>=2.4.0", "pyasn1>=0.4.4", "toposort>=1.5,<2", diff --git a/setup.py b/setup.py index 5ee1fe8c1..93825375b 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ "macaroonbakery>=1.1,<2.0", "pyRFC3339>=1.0,<2.0", "pyyaml>=5.1.2", - "websockets>=13.0.1,<14.0", + "websockets>=13.0.1", "paramiko>=2.4.0", "pyasn1>=0.4.4", "toposort>=1.5,<2", diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index a7cc808cb..9ed876466 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -7,6 +7,7 @@ from unittest import mock import pytest +import websockets from websockets.exceptions import ConnectionClosed from juju.client.connection import Connection @@ -17,8 +18,7 @@ class WebsocketMock: def __init__(self, responses): super().__init__() self.responses = deque(responses) - self.open = True - self.closed = False + self.state = websockets.protocol.State.OPEN async def send(self, message): pass @@ -30,8 +30,7 @@ async def recv(self): return json.dumps(self.responses.popleft()) async def close(self): - self.open = False - self.closed = True + pass async def test_out_of_order():