diff --git a/src/aiovantage/connection.py b/src/aiovantage/connection.py index 8c84395..c71d3c2 100644 --- a/src/aiovantage/connection.py +++ b/src/aiovantage/connection.py @@ -1,14 +1,28 @@ """Wrapper for an asyncio connection to a Vantage controller.""" import asyncio +from collections.abc import Callable from ssl import CERT_NONE, SSLContext, create_default_context +from typing import ClassVar from .errors import ClientConnectionError, ClientTimeoutError +def _get_default_context() -> SSLContext: + """Create a default SSL context.""" + # We don't have a local issuer certificate to check against, and we'll be + # connecting to an IP address so we can't check the hostname + ctx = create_default_context() + ctx.check_hostname = False + ctx.verify_mode = CERT_NONE + return ctx + + class BaseConnection: """Wrapper for an asyncio connection to a Vantage controller.""" + ssl_context_factory: ClassVar[Callable[[], SSLContext]] = _get_default_context + default_port: int default_ssl_port: int buffer_limit: int = 2**16 @@ -30,11 +44,7 @@ def __init__( # Set up the SSL context self._ssl: SSLContext | None if ssl is True: - # We don't have a local issuer certificate to check against, and we'll be - # connecting to an IP address so we can't check the hostname - self._ssl = create_default_context() - self._ssl.check_hostname = False - self._ssl.verify_mode = CERT_NONE + self._ssl = self.ssl_context_factory() elif isinstance(ssl, SSLContext): self._ssl = ssl else: