From 99d8fc11fcc7687a91dcadb5d527df33278f22cb Mon Sep 17 00:00:00 2001 From: Dima Tisnek Date: Fri, 22 Nov 2024 12:23:39 +0900 Subject: [PATCH] chore: no need to `assert connection` --- juju/client/connection.py | 24 +++++++++--------------- juju/client/connector.py | 3 ++- juju/model.py | 17 +++++++---------- 3 files changed, 18 insertions(+), 26 deletions(-) diff --git a/juju/client/connection.py b/juju/client/connection.py index e79f2ea7..88c31c2a 100644 --- a/juju/client/connection.py +++ b/juju/client/connection.py @@ -27,7 +27,7 @@ from .facade_versions import client_facade_versions, known_unsupported_facades SpecifiedFacades: TypeAlias = "dict[str, dict[Literal['versions'], Sequence[int]]]" -_WebSocket: TypeAlias = "websockets.legacy.client.WebSocketClientProtocol" +_WebSocket: TypeAlias = websockets.WebSocketClientProtocol LEVELS = ["TRACE", "DEBUG", "INFO", "WARNING", "ERROR"] log = logging.getLogger("juju.client.connection") @@ -291,7 +291,7 @@ def is_using_old_client(self): def is_open(self): return self.monitor.status == Monitor.CONNECTED - def _get_ssl(self, cert=None): + def _get_ssl(self, cert: str | None = None) -> ssl.SSLContext: context = ssl.create_default_context( purpose=ssl.Purpose.SERVER_AUTH, cadata=cert ) @@ -305,7 +305,9 @@ def _get_ssl(self, cert=None): context.check_hostname = False return context - async def _open(self, endpoint, cacert) -> tuple[_WebSocket, str, str, str]: + async def _open( + self, endpoint: str, cacert: str + ) -> tuple[_WebSocket, str, str, str]: if self.is_debug_log_connection: assert self.uuid url = f"wss://user-{self.username}:{self.password}@{endpoint}/model/{self.uuid}/log" @@ -323,10 +325,6 @@ async def _open(self, endpoint, cacert) -> tuple[_WebSocket, str, str, str]: sock = self.proxy.socket() server_hostname = "juju-app" - def _exit_tasks(): - for task in jasyncio.all_tasks(): - task.cancel() - return ( ( await websockets.connect( @@ -342,7 +340,7 @@ def _exit_tasks(): cacert, ) - async def close(self, to_reconnect=False): + async def close(self, to_reconnect: bool = False): if not self._ws: return self.monitor.close_called.set() @@ -380,11 +378,7 @@ async def close(self, to_reconnect=False): async def _recv(self, request_id: int) -> dict[str, Any]: if not self.is_open: - raise websockets.exceptions.ConnectionClosed( - websockets.frames.Close( - websockets.frames.CloseCode.NORMAL_CLOSURE, "websocket closed" - ) - ) + raise websockets.exceptions.ConnectionClosedOK(None, None) try: return await self.messages.get(request_id) except GeneratorExit: @@ -626,7 +620,7 @@ async def rpc( return result - def _http_headers(self): + def _http_headers(self) -> dict[str, str]: """Return dictionary of http headers necessary for making an http connection to the endpoint of this Connection. @@ -640,7 +634,7 @@ def _http_headers(self): token = base64.b64encode(creds.encode()) return {"Authorization": f"Basic {token.decode()}"} - def https_connection(self): + def https_connection(self) -> tuple[HTTPSConnection, dict[str, str], str]: """Return an https connection to this Connection's endpoint. Returns a 3-tuple containing:: diff --git a/juju/client/connector.py b/juju/client/connector.py index c9be0cda..cc307902 100644 --- a/juju/client/connector.py +++ b/juju/client/connector.py @@ -50,7 +50,7 @@ def __init__( self.model_name = None self.jujudata = jujudata or FileJujuData() - def is_connected(self): + def is_connected(self) -> bool: """Report whether there is a currently connected controller or not""" return self._connection is not None @@ -60,6 +60,7 @@ def connection(self) -> Connection: """ if not self.is_connected(): raise NoConnectionException("not connected") + assert self._connection return self._connection async def connect(self, **kwargs): diff --git a/juju/model.py b/juju/model.py index 81648a89..ccd5cfbe 100644 --- a/juju/model.py +++ b/juju/model.py @@ -924,10 +924,7 @@ def add_local_charm(self, charm_file, series="", size=None): instead. """ - connection = self.connection() - assert connection - - conn, headers, path_prefix = connection.https_connection() + conn, headers, path_prefix = self.connection().https_connection() path = "%s/charms?series=%s" % (path_prefix, series) headers["Content-Type"] = "application/zip" if size: @@ -1320,14 +1317,14 @@ async def _all_watcher(): del allwatcher.Id continue except websockets.ConnectionClosed: - connection = self.connection() - assert connection - monitor = connection.monitor - if monitor.status == monitor.ERROR: + if self.connection().monitor.status == connection.Monitor.ERROR: # closed unexpectedly, try to reopen log.warning("Watcher: connection closed, reopening") - await connection.reconnect() - if monitor.status != monitor.CONNECTED: + await self.connection().reconnect() + if ( + self.connection().monitor.status + != connection.Monitor.CONNECTED + ): # reconnect failed; abort and shutdown log.error( "Watcher: automatic reconnect "