Skip to content

Commit

Permalink
chore: no need to assert connection
Browse files Browse the repository at this point in the history
  • Loading branch information
dimaqq committed Nov 22, 2024
1 parent ee6414e commit 99d8fc1
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 26 deletions.
24 changes: 9 additions & 15 deletions juju/client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .facade_versions import client_facade_versions, known_unsupported_facades

SpecifiedFacades: TypeAlias = "dict[str, dict[Literal['versions'], Sequence[int]]]"
_WebSocket: TypeAlias = "websockets.legacy.client.WebSocketClientProtocol"
_WebSocket: TypeAlias = websockets.WebSocketClientProtocol

LEVELS = ["TRACE", "DEBUG", "INFO", "WARNING", "ERROR"]
log = logging.getLogger("juju.client.connection")
Expand Down Expand Up @@ -291,7 +291,7 @@ def is_using_old_client(self):
def is_open(self):
return self.monitor.status == Monitor.CONNECTED

def _get_ssl(self, cert=None):
def _get_ssl(self, cert: str | None = None) -> ssl.SSLContext:
context = ssl.create_default_context(
purpose=ssl.Purpose.SERVER_AUTH, cadata=cert
)
Expand All @@ -305,7 +305,9 @@ def _get_ssl(self, cert=None):
context.check_hostname = False
return context

async def _open(self, endpoint, cacert) -> tuple[_WebSocket, str, str, str]:
async def _open(
self, endpoint: str, cacert: str
) -> tuple[_WebSocket, str, str, str]:
if self.is_debug_log_connection:
assert self.uuid
url = f"wss://user-{self.username}:{self.password}@{endpoint}/model/{self.uuid}/log"
Expand All @@ -323,10 +325,6 @@ async def _open(self, endpoint, cacert) -> tuple[_WebSocket, str, str, str]:
sock = self.proxy.socket()
server_hostname = "juju-app"

def _exit_tasks():
for task in jasyncio.all_tasks():
task.cancel()

return (
(
await websockets.connect(
Expand All @@ -342,7 +340,7 @@ def _exit_tasks():
cacert,
)

async def close(self, to_reconnect=False):
async def close(self, to_reconnect: bool = False):
if not self._ws:
return
self.monitor.close_called.set()
Expand Down Expand Up @@ -380,11 +378,7 @@ async def close(self, to_reconnect=False):

async def _recv(self, request_id: int) -> dict[str, Any]:
if not self.is_open:
raise websockets.exceptions.ConnectionClosed(
websockets.frames.Close(
websockets.frames.CloseCode.NORMAL_CLOSURE, "websocket closed"
)
)
raise websockets.exceptions.ConnectionClosedOK(None, None)
try:
return await self.messages.get(request_id)
except GeneratorExit:
Expand Down Expand Up @@ -626,7 +620,7 @@ async def rpc(

return result

def _http_headers(self):
def _http_headers(self) -> dict[str, str]:
"""Return dictionary of http headers necessary for making an http
connection to the endpoint of this Connection.
Expand All @@ -640,7 +634,7 @@ def _http_headers(self):
token = base64.b64encode(creds.encode())
return {"Authorization": f"Basic {token.decode()}"}

def https_connection(self):
def https_connection(self) -> tuple[HTTPSConnection, dict[str, str], str]:
"""Return an https connection to this Connection's endpoint.
Returns a 3-tuple containing::
Expand Down
3 changes: 2 additions & 1 deletion juju/client/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
self.model_name = None
self.jujudata = jujudata or FileJujuData()

def is_connected(self):
def is_connected(self) -> bool:
"""Report whether there is a currently connected controller or not"""
return self._connection is not None

Expand All @@ -60,6 +60,7 @@ def connection(self) -> Connection:
"""
if not self.is_connected():
raise NoConnectionException("not connected")
assert self._connection
return self._connection

async def connect(self, **kwargs):
Expand Down
17 changes: 7 additions & 10 deletions juju/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,10 +924,7 @@ def add_local_charm(self, charm_file, series="", size=None):
instead.
"""
connection = self.connection()
assert connection

conn, headers, path_prefix = connection.https_connection()
conn, headers, path_prefix = self.connection().https_connection()
path = "%s/charms?series=%s" % (path_prefix, series)
headers["Content-Type"] = "application/zip"
if size:
Expand Down Expand Up @@ -1320,14 +1317,14 @@ async def _all_watcher():
del allwatcher.Id
continue
except websockets.ConnectionClosed:
connection = self.connection()
assert connection
monitor = connection.monitor
if monitor.status == monitor.ERROR:
if self.connection().monitor.status == connection.Monitor.ERROR:
# closed unexpectedly, try to reopen
log.warning("Watcher: connection closed, reopening")
await connection.reconnect()
if monitor.status != monitor.CONNECTED:
await self.connection().reconnect()
if (
self.connection().monitor.status
!= connection.Monitor.CONNECTED
):
# reconnect failed; abort and shutdown
log.error(
"Watcher: automatic reconnect "
Expand Down

0 comments on commit 99d8fc1

Please sign in to comment.