diff --git a/.github/wordlist.txt b/.github/wordlist.txt index cd68dcab..36b8b386 100644 --- a/.github/wordlist.txt +++ b/.github/wordlist.txt @@ -159,3 +159,4 @@ valkeymodules virtualenv www md +yaml diff --git a/dev_requirements.txt b/dev_requirements.txt index 2b2dbf9d..c56d2483 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -12,7 +12,6 @@ pytest-asyncio pytest-cov pytest-timeout ujson>=4.2.0 -urllib3<2 uvloop vulture>=2.3.0 wheel>=0.30.0 diff --git a/docs/conf.py b/docs/conf.py index 7a7d306a..447e16fc 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -128,7 +128,6 @@ # further. For a list of options available for each theme, see the # documentation. html_theme_options = { - "display_version": True, "footer_icons": [ { "name": "GitHub", diff --git a/docs/connections.rst b/docs/connections.rst index 8f7758ed..ca7c8255 100644 --- a/docs/connections.rst +++ b/docs/connections.rst @@ -55,7 +55,7 @@ ClusterNode Async Client ************ -See complete example: `here `_ +See complete example: `here `__ This client is used for communicating with Valkey, asynchronously. @@ -88,7 +88,7 @@ ClusterPipeline (Async) Connection ********** -See complete example: `here `_ +See complete example: `here `__ Connection ========== @@ -104,7 +104,7 @@ Connection (Async) Connection Pools **************** -See complete example: `here `_ +See complete example: `here `__ ConnectionPool ============== diff --git a/docs/opentelemetry.rst b/docs/opentelemetry.rst index 05aff88f..790a98ef 100644 --- a/docs/opentelemetry.rst +++ b/docs/opentelemetry.rst @@ -4,7 +4,7 @@ Integrating OpenTelemetry What is OpenTelemetry? ---------------------- -`OpenTelemetry `_ is an open-source observability framework for traces, metrics, and logs. It is a merger of OpenCensus and OpenTracing projects hosted by Cloud Native Computing Foundation. +`OpenTelemetry `__ is an open-source observability framework for traces, metrics, and logs. It is a merger of OpenCensus and OpenTracing projects hosted by Cloud Native Computing Foundation. OpenTelemetry allows developers to collect and export telemetry data in a vendor agnostic way. With OpenTelemetry, you can instrument your application once and then add or change vendors without changing the instrumentation, for example, here is a list of `popular DataDog competitors `_ that support OpenTelemetry. @@ -61,7 +61,7 @@ Once the code is patched, you can use valkey-py as usually: OpenTelemetry API ----------------- -`OpenTelemetry `_ API is a programming interface that you can use to instrument code and collect telemetry data such as traces, metrics, and logs. +`OpenTelemetry API `__ is a programming interface that you can use to instrument code and collect telemetry data such as traces, metrics, and logs. You can use OpenTelemetry API to measure important operations: @@ -125,7 +125,7 @@ Alerting and notifications Uptrace also allows you to monitor `OpenTelemetry metrics `_ using alerting rules. For example, the following monitor uses the group by node expression to create an alert whenever an individual Valkey shard is down: -.. code-block:: python +.. code-block:: yaml monitors: - name: Valkey shard is down @@ -142,7 +142,7 @@ Uptrace also allows you to monitor `OpenTelemetry metrics =4.0.2 +async-timeout>=4.0.3 diff --git a/setup.py b/setup.py index e84a0b4f..bb48accf 100644 --- a/setup.py +++ b/setup.py @@ -57,6 +57,6 @@ ], extras_require={ "libvalkey": ["libvalkey>=4.0.0b1"], - "ocsp": ["cryptography>=36.0.1", "pyopenssl==20.0.1", "requests>=2.26.0"], + "ocsp": ["cryptography>=36.0.1", "pyopenssl==23.2.1", "requests>=2.31.0"], }, ) diff --git a/tests/conftest.py b/tests/conftest.py index 60d5242e..9c6e1015 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -233,7 +233,9 @@ def skip_ifmodversion_lt(min_version: str, module_name: str): for j in modules: if module_name == j.get("name"): version = j.get("ver") - mv = int(min_version.replace(".", "")) + mv = int( + "".join(["%02d" % int(segment) for segment in min_version.split(".")]) + ) check = version < mv return pytest.mark.skipif(check, reason="Valkey module version") diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index d9a2dfd1..4c3099d2 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -14,7 +14,11 @@ parse_url, ) from valkey.asyncio import ConnectionPool, Valkey -from valkey.asyncio.connection import Connection, UnixDomainSocketConnection +from valkey.asyncio.connection import ( + Connection, + SSLConnection, + UnixDomainSocketConnection, +) from valkey.asyncio.retry import Retry from valkey.backoff import NoBackoff from valkey.exceptions import ConnectionError, InvalidResponse, TimeoutError @@ -494,3 +498,53 @@ async def test_connection_garbage_collection(request): await client.aclose() await pool.aclose() + + +@pytest.mark.parametrize( + "conn, error, expected_message", + [ + (SSLConnection(), OSError(), "Error connecting to localhost:6379."), + (SSLConnection(), OSError(12), "Error 12 connecting to localhost:6379."), + ( + SSLConnection(), + OSError(12, "Some Error"), + "Error 12 connecting to localhost:6379. Some Error.", + ), + ( + UnixDomainSocketConnection(path="unix:///tmp/valkey.sock"), + OSError(), + "Error connecting to unix:///tmp/valkey.sock.", + ), + ( + UnixDomainSocketConnection(path="unix:///tmp/valkey.sock"), + OSError(12), + "Error 12 connecting to unix:///tmp/valkey.sock.", + ), + ( + UnixDomainSocketConnection(path="unix:///tmp/valkey.sock"), + OSError(12, "Some Error"), + "Error 12 connecting to unix:///tmp/valkey.sock. Some Error.", + ), + ], +) +async def test_format_error_message(conn, error, expected_message): + """Test that the _error_message function formats errors correctly""" + error_message = conn._error_message(error) + assert error_message == expected_message + + +async def test_network_connection_failure(): + with pytest.raises(ConnectionError) as e: + valkey = Valkey(host="127.0.0.1", port=9999) + await valkey.set("a", "b") + assert str(e.value).startswith("Error 111 connecting to 127.0.0.1:9999. Connect") + + +async def test_unix_socket_connection_failure(): + with pytest.raises(ConnectionError) as e: + valkey = Valkey(unix_socket_path="unix:///tmp/a.sock") + await valkey.set("a", "b") + assert ( + str(e.value) + == "Error 2 connecting to unix:///tmp/a.sock. No such file or directory." + ) diff --git a/tests/test_asyncio/test_lock.py b/tests/test_asyncio/test_lock.py index 49b593c4..264eeb47 100644 --- a/tests/test_asyncio/test_lock.py +++ b/tests/test_asyncio/test_lock.py @@ -104,16 +104,16 @@ async def test_blocking(self, r): lock_2 = self.get_lock(r, "foo") assert lock_2.blocking - async def test_blocking_timeout(self, r, event_loop): + async def test_blocking_timeout(self, r): lock1 = self.get_lock(r, "foo") assert await lock1.acquire(blocking=False) bt = 0.2 sleep = 0.05 lock2 = self.get_lock(r, "foo", sleep=sleep, blocking_timeout=bt) - start = event_loop.time() + start = asyncio.get_running_loop().time() assert not await lock2.acquire() # The elapsed duration should be less than the total blocking_timeout - assert bt >= (event_loop.time() - start) > bt - sleep + assert bt >= (asyncio.get_running_loop().time() - start) > bt - sleep await lock1.release() async def test_context_manager(self, r): diff --git a/tests/test_connect.py b/tests/test_connect.py index 57c0a2db..ac91f5a0 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -104,6 +104,23 @@ def test_tcp_ssl_tls12_custom_ciphers(tcp_address, ssl_ciphers): ) +""" +Addresses bug CAE-333 which uncovered that the init method of the base +class did override the initialization of the socket_timeout parameter. +""" + + +def test_unix_socket_with_timeout(): + conn = UnixDomainSocketConnection(socket_timeout=1000) + + # Check if the base class defaults were taken over. + assert conn.db == 0 + + # Verify if the timeout and the path is set correctly. + assert conn.socket_timeout == 1000 + assert conn.path == "" + + @pytest.mark.ssl @pytest.mark.skipif(not ssl.HAS_TLSv1_3, reason="requires TLSv1.3") def test_tcp_ssl_version_mismatch(tcp_address): diff --git a/tests/test_connection.py b/tests/test_connection.py index 4354cfd2..545a7d3f 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -291,3 +291,53 @@ def mock_disconnect(_): assert called == 1 pool.disconnect() + + +@pytest.mark.parametrize( + "conn, error, expected_message", + [ + (SSLConnection(), OSError(), "Error connecting to localhost:6379."), + (SSLConnection(), OSError(12), "Error 12 connecting to localhost:6379."), + ( + SSLConnection(), + OSError(12, "Some Error"), + "Error 12 connecting to localhost:6379. Some Error.", + ), + ( + UnixDomainSocketConnection(path="unix:///tmp/valkey.sock"), + OSError(), + "Error connecting to unix:///tmp/valkey.sock.", + ), + ( + UnixDomainSocketConnection(path="unix:///tmp/valkey.sock"), + OSError(12), + "Error 12 connecting to unix:///tmp/valkey.sock.", + ), + ( + UnixDomainSocketConnection(path="unix:///tmp/valkey.sock"), + OSError(12, "Some Error"), + "Error 12 connecting to unix:///tmp/valkey.sock. Some Error.", + ), + ], +) +def test_format_error_message(conn, error, expected_message): + """Test that the _error_message function formats errors correctly""" + error_message = conn._error_message(error) + assert error_message == expected_message + + +def test_network_connection_failure(): + with pytest.raises(ConnectionError) as e: + valkey = Valkey(port=9999) + valkey.set("a", "b") + assert str(e.value) == "Error 111 connecting to localhost:9999. Connection refused." + + +def test_unix_socket_connection_failure(): + with pytest.raises(ConnectionError) as e: + valkey = Valkey(unix_socket_path="unix:///tmp/a.sock") + valkey.set("a", "b") + assert ( + str(e.value) + == "Error 2 connecting to unix:///tmp/a.sock. No such file or directory." + ) diff --git a/tests/test_retry.py b/tests/test_retry.py index 3b757822..a6c04f84 100644 --- a/tests/test_retry.py +++ b/tests/test_retry.py @@ -1,7 +1,7 @@ from unittest.mock import patch import pytest -from valkey.backoff import ExponentialBackoff, NoBackoff +from valkey.backoff import AbstractBackoff, ExponentialBackoff, NoBackoff from valkey.client import Valkey from valkey.connection import Connection, UnixDomainSocketConnection from valkey.exceptions import ( @@ -15,7 +15,7 @@ from .conftest import _get_client -class BackoffMock: +class BackoffMock(AbstractBackoff): def __init__(self): self.reset_calls = 0 self.calls = 0 diff --git a/valkey/asyncio/cluster.py b/valkey/asyncio/cluster.py index 56de44f7..c496ae0c 100644 --- a/valkey/asyncio/cluster.py +++ b/valkey/asyncio/cluster.py @@ -1315,6 +1315,8 @@ async def initialize(self) -> None: port = int(primary_node[1]) host, port = self.remap_host_port(host, port) + nodes_for_slot = [] + target_node = tmp_nodes_cache.get(get_node_name(host, port)) if not target_node: target_node = ClusterNode( @@ -1322,30 +1324,26 @@ async def initialize(self) -> None: ) # add this node to the nodes cache tmp_nodes_cache[target_node.name] = target_node + nodes_for_slot.append(target_node) + + replica_nodes = slot[3:] + for replica_node in replica_nodes: + host = replica_node[0] + port = replica_node[1] + host, port = self.remap_host_port(host, port) + + target_replica_node = tmp_nodes_cache.get(get_node_name(host, port)) + if not target_replica_node: + target_replica_node = ClusterNode( + host, port, REPLICA, **self.connection_kwargs + ) + # add this node to the nodes cache + tmp_nodes_cache[target_replica_node.name] = target_replica_node + nodes_for_slot.append(target_replica_node) for i in range(int(slot[0]), int(slot[1]) + 1): if i not in tmp_slots: - tmp_slots[i] = [] - tmp_slots[i].append(target_node) - replica_nodes = [slot[j] for j in range(3, len(slot))] - - for replica_node in replica_nodes: - host = replica_node[0] - port = replica_node[1] - host, port = self.remap_host_port(host, port) - - target_replica_node = tmp_nodes_cache.get( - get_node_name(host, port) - ) - if not target_replica_node: - target_replica_node = ClusterNode( - host, port, REPLICA, **self.connection_kwargs - ) - tmp_slots[i].append(target_replica_node) - # add this node to the nodes cache - tmp_nodes_cache[target_replica_node.name] = ( - target_replica_node - ) + tmp_slots[i] = nodes_for_slot else: # Validate that 2 nodes want to use the same slot cache # setup diff --git a/valkey/asyncio/connection.py b/valkey/asyncio/connection.py index 3f3e2059..c7a18ad9 100644 --- a/valkey/asyncio/connection.py +++ b/valkey/asyncio/connection.py @@ -24,6 +24,8 @@ Union, ) +from ..utils import format_error_message + # the functionality is available in 3.11.x but has a major issue before # 3.11.3. See https://github.com/redis/redis-py/issues/2633 if sys.version_info >= (3, 11, 3): @@ -342,9 +344,8 @@ async def _connect(self): def _host_error(self) -> str: pass - @abstractmethod def _error_message(self, exception: BaseException) -> str: - pass + return format_error_message(self._host_error(), exception) async def on_connect(self) -> None: """Initialize the connection, authenticate and select a database""" @@ -796,27 +797,6 @@ async def _connect(self): def _host_error(self) -> str: return f"{self.host}:{self.port}" - def _error_message(self, exception: BaseException) -> str: - # args for socket.error can either be (errno, "message") - # or just "message" - - host_error = self._host_error() - - if not exception.args: - # asyncio has a bug where on Connection reset by peer, the - # exception is not instanciated, so args is empty. This is the - # workaround. - # See: https://github.com/redis/redis-py/issues/2237 - # See: https://github.com/python/cpython/issues/94061 - return f"Error connecting to {host_error}. Connection reset by peer" - elif len(exception.args) == 1: - return f"Error connecting to {host_error}. {exception.args[0]}." - else: - return ( - f"Error {exception.args[0]} connecting to {host_error}. " - f"{exception.args[0]}." - ) - class SSLConnection(Connection): """Manages SSL connections to and from the Valkey server(s). @@ -968,20 +948,6 @@ async def _connect(self): def _host_error(self) -> str: return self.path - def _error_message(self, exception: BaseException) -> str: - # args for socket.error can either be (errno, "message") - # or just "message" - host_error = self._host_error() - if len(exception.args) == 1: - return ( - f"Error connecting to unix socket: {host_error}. {exception.args[0]}." - ) - else: - return ( - f"Error {exception.args[0]} connecting to unix socket: " - f"{host_error}. {exception.args[1]}." - ) - class ConnectKwargs(TypedDict, total=False): username: str diff --git a/valkey/cluster.py b/valkey/cluster.py index 0496ccbd..453482b8 100644 --- a/valkey/cluster.py +++ b/valkey/cluster.py @@ -1519,6 +1519,8 @@ def _get_or_create_cluster_node(self, host, port, role, tmp_nodes_cache): target_node = ClusterNode(host, port, role) if target_node.server_type != role: target_node.server_type = role + # add this node to the nodes cache + tmp_nodes_cache[target_node.name] = target_node return target_node @@ -1582,31 +1584,26 @@ def initialize(self): port = int(primary_node[1]) host, port = self.remap_host_port(host, port) + nodes_for_slot = [] + target_node = self._get_or_create_cluster_node( host, port, PRIMARY, tmp_nodes_cache ) - # add this node to the nodes cache - tmp_nodes_cache[target_node.name] = target_node + nodes_for_slot.append(target_node) + + replica_nodes = slot[3:] + for replica_node in replica_nodes: + host = str_if_bytes(replica_node[0]) + port = int(replica_node[1]) + host, port = self.remap_host_port(host, port) + target_replica_node = self._get_or_create_cluster_node( + host, port, REPLICA, tmp_nodes_cache + ) + nodes_for_slot.append(target_replica_node) for i in range(int(slot[0]), int(slot[1]) + 1): if i not in tmp_slots: - tmp_slots[i] = [] - tmp_slots[i].append(target_node) - replica_nodes = [slot[j] for j in range(3, len(slot))] - - for replica_node in replica_nodes: - host = str_if_bytes(replica_node[0]) - port = replica_node[1] - host, port = self.remap_host_port(host, port) - - target_replica_node = self._get_or_create_cluster_node( - host, port, REPLICA, tmp_nodes_cache - ) - tmp_slots[i].append(target_replica_node) - # add this node to the nodes cache - tmp_nodes_cache[target_replica_node.name] = ( - target_replica_node - ) + tmp_slots[i] = nodes_for_slot else: # Validate that 2 nodes want to use the same slot cache # setup diff --git a/valkey/commands/core.py b/valkey/commands/core.py index 669d3148..57366a3d 100644 --- a/valkey/commands/core.py +++ b/valkey/commands/core.py @@ -79,7 +79,7 @@ def acl_dryrun(self, username, *args, **kwargs): def acl_deluser(self, *username: str, **kwargs) -> ResponseT: """ - Delete the ACL for the specified ``username``s + Delete the ACL for the specified ``username``\\s For more information see https://valkey.io/commands/acl-deluser """ @@ -227,9 +227,10 @@ def acl_setuser( must be prefixed with either a '+' to add the command permission or a '-' to remove the command permission. keys: A list of key patterns to grant the user access to. Key patterns allow - '*' to support wildcard matching. For example, '*' grants access to - all keys while 'cache:*' grants access to all keys that are prefixed - with 'cache:'. `keys` should not be prefixed with a '~'. + ``'*'`` to support wildcard matching. For example, ``'*'`` grants + access to all keys while ``'cache:*'`` grants access to all keys that + are prefixed with ``cache:``. + `keys` should not be prefixed with a ``'~'``. reset: Indicates whether the user should be fully reset prior to applying the new ACL. Setting this to `True` will remove all existing passwords, flags, and privileges from the user and then apply the @@ -3363,7 +3364,7 @@ def sintercard( self, numkeys: int, keys: List[str], limit: int = 0 ) -> Union[Awaitable[int], int]: """ - Return the cardinality of the intersect of multiple sets specified by ``keys`. + Return the cardinality of the intersect of multiple sets specified by ``keys``. When LIMIT provided (defaults to 0 and means unlimited), if the intersection cardinality reaches limit partway through the computation, the algorithm will @@ -3494,9 +3495,11 @@ class StreamCommands(CommandsProtocol): def xack(self, name: KeyT, groupname: GroupT, *ids: StreamIdT) -> ResponseT: """ Acknowledges the successful processing of one or more messages. - name: name of the stream. - groupname: name of the consumer group. - *ids: message ids to acknowledge. + + Args: + name: name of the stream. + groupname: name of the consumer group. + *ids: message ids to acknowledge. For more information see https://valkey.io/commands/xack """ @@ -3692,8 +3695,10 @@ def xclaim( def xdel(self, name: KeyT, *ids: StreamIdT) -> ResponseT: """ Deletes one or more messages from a stream. - name: name of the stream. - *ids: message ids to delete. + + Args: + name: name of the stream. + *ids: message ids to delete. For more information see https://valkey.io/commands/xdel """ @@ -4261,7 +4266,7 @@ def zintercard( ) -> Union[Awaitable[int], int]: """ Return the cardinality of the intersect of multiple sorted sets - specified by ``keys`. + specified by ``keys``. When LIMIT provided (defaults to 0 and means unlimited), if the intersection cardinality reaches limit partway through the computation, the algorithm will exit and yield limit as the cardinality diff --git a/valkey/commands/graph/commands.py b/valkey/commands/graph/commands.py index d50f5950..19ff3b6a 100644 --- a/valkey/commands/graph/commands.py +++ b/valkey/commands/graph/commands.py @@ -155,7 +155,7 @@ def slowlog(self): def config(self, name, value=None, set=False): """ Retrieve or update a RedisGraph configuration. - For more information see `https://valkey.io/commands/graph.config-get/>`_. # noqa + For more information see ``__. Args: diff --git a/valkey/commands/search/commands.py b/valkey/commands/search/commands.py index e16fc9d7..ab5719a7 100644 --- a/valkey/commands/search/commands.py +++ b/valkey/commands/search/commands.py @@ -335,30 +335,30 @@ def add_document( """ Add a single document to the index. - ### Parameters + Args: - - **doc_id**: the id of the saved document. - - **nosave**: if set to true, we just index the document, and don't + doc_id: the id of the saved document. + nosave: if set to true, we just index the document, and don't save a copy of it. This means that searches will just return ids. - - **score**: the document ranking, between 0.0 and 1.0 - - **payload**: optional inner-index payload we can save for fast - i access in scoring functions - - **replace**: if True, and the document already is in the index, - we perform an update and reindex the document - - **partial**: if True, the fields specified will be added to the + score: the document ranking, between 0.0 and 1.0 + payload: optional inner-index payload we can save for fast + access in scoring functions + replace: if True, and the document already is in the index, + we perform an update and reindex the document + partial: if True, the fields specified will be added to the existing document. This has the added benefit that any fields specified with `no_index` will not be reindexed again. Implies `replace` - - **language**: Specify the language used for document tokenization. - - **no_create**: if True, the document is only updated and reindexed + language: Specify the language used for document tokenization. + no_create: if True, the document is only updated and reindexed if it already exists. If the document does not exist, an error will be returned. Implies `replace` - - **fields** kwargs dictionary of the document fields to be saved - and/or indexed. - NOTE: Geo points shoule be encoded as strings of "lon,lat" + fields: kwargs dictionary of the document fields to be saved + and/or indexed. + NOTE: Geo points shoule be encoded as strings of "lon,lat" """ # noqa return self._add_document( doc_id, @@ -620,13 +620,13 @@ def spellcheck(self, query, distance=None, include=None, exclude=None): """ Issue a spellcheck query - ### Parameters + Args: - **query**: search query. - **distance***: the maximal Levenshtein distance for spelling + query: search query. + distance: the maximal Levenshtein distance for spelling suggestions (default: 1, max: 4). - **include**: specifies an inclusion custom dictionary. - **exclude**: specifies an exclusion custom dictionary. + include: specifies an inclusion custom dictionary. + exclude: specifies an exclusion custom dictionary. For more information see `FT.SPELLCHECK `_. """ # noqa diff --git a/valkey/connection.py b/valkey/connection.py index 07e004e1..a85b3db2 100644 --- a/valkey/connection.py +++ b/valkey/connection.py @@ -37,6 +37,7 @@ CRYPTOGRAPHY_AVAILABLE, LIBVALKEY_AVAILABLE, SSL_AVAILABLE, + format_error_message, get_lib_version, str_if_bytes, ) @@ -336,9 +337,8 @@ def _connect(self): def _host_error(self): pass - @abstractmethod def _error_message(self, exception): - pass + return format_error_message(self._host_error(), exception) def on_connect(self): "Initialize the connection, authenticate and select a database" @@ -731,27 +731,6 @@ def _connect(self): def _host_error(self): return f"{self.host}:{self.port}" - def _error_message(self, exception): - # args for socket.error can either be (errno, "message") - # or just "message" - - host_error = self._host_error() - - if len(exception.args) == 1: - try: - return f"Error connecting to {host_error}. \ - {exception.args[0]}." - except AttributeError: - return f"Connection Error: {exception.args[0]}" - else: - try: - return ( - f"Error {exception.args[0]} connecting to " - f"{host_error}. {exception.args[1]}." - ) - except AttributeError: - return f"Connection Error: {exception.args[0]}" - class SSLConnection(Connection): """Manages SSL connections to and from the Valkey server(s). @@ -832,8 +811,24 @@ def __init__( super().__init__(**kwargs) def _connect(self): - "Wrap the socket with SSL support" + "Wrap the socket with SSL support, handling potential errors." sock = super()._connect() + try: + return self._wrap_socket_with_ssl(sock) + except (OSError, ValkeyError): + sock.close() + raise + + def _wrap_socket_with_ssl(self, sock): + """ + Wraps the socket with SSL support. + + Args: + sock: The plain socket to wrap with SSL. + + Returns: + An SSL wrapped socket. + """ context = ssl.create_default_context() context.check_hostname = self.check_hostname context.verify_mode = self.cert_reqs @@ -851,11 +846,12 @@ def _connect(self): context.load_verify_locations( cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data ) - if self.ssl_min_version is not None: + if self.ssl_min_version is None: + context.minimum_version = ssl.TLSVersion.TLSv1_2 + else: context.minimum_version = self.ssl_min_version if self.ssl_ciphers: context.set_ciphers(self.ssl_ciphers) - sslsock = context.wrap_socket(sock, server_hostname=self.host) if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False: raise ValkeyError("cryptography is not installed.") @@ -865,6 +861,8 @@ def _connect(self): "- not both." ) + sslsock = context.wrap_socket(sock, server_hostname=self.host) + # validation for the stapled case if self.ssl_validate_ocsp_stapled: import OpenSSL @@ -907,9 +905,9 @@ class UnixDomainSocketConnection(AbstractConnection): "Manages UDS communication to and from a Valkey server" def __init__(self, path="", socket_timeout=None, **kwargs): + super().__init__(**kwargs) self.path = path self.socket_timeout = socket_timeout - super().__init__(**kwargs) def repr_pieces(self): pieces = [("path", self.path), ("db", self.db)] @@ -921,27 +919,18 @@ def _connect(self): "Create a Unix domain socket connection" sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) sock.settimeout(self.socket_connect_timeout) - sock.connect(self.path) + try: + sock.connect(self.path) + except OSError: + # Prevent ResourceWarnings for unclosed sockets. + sock.close() + raise sock.settimeout(self.socket_timeout) return sock def _host_error(self): return self.path - def _error_message(self, exception): - # args for socket.error can either be (errno, "message") - # or just "message" - host_error = self._host_error() - if len(exception.args) == 1: - return ( - f"Error connecting to unix socket: {host_error}. {exception.args[0]}." - ) - else: - return ( - f"Error {exception.args[0]} connecting to unix socket: " - f"{host_error}. {exception.args[1]}." - ) - FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") diff --git a/valkey/retry.py b/valkey/retry.py index 02962bd9..e40a8331 100644 --- a/valkey/retry.py +++ b/valkey/retry.py @@ -1,17 +1,27 @@ import socket from time import sleep +from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Type, TypeVar from valkey.exceptions import ConnectionError, TimeoutError +T = TypeVar("T") + +if TYPE_CHECKING: + from redis.backoff import AbstractBackoff + class Retry: """Retry a specific number of times after a failure""" def __init__( self, - backoff, - retries, - supported_errors=(ConnectionError, TimeoutError, socket.timeout), + backoff: "AbstractBackoff", + retries: int, + supported_errors: Tuple[Type[Exception], ...] = ( + ConnectionError, + TimeoutError, + socket.timeout, + ), ): """ Initialize a `Retry` object with a `Backoff` object @@ -24,7 +34,9 @@ def __init__( self._retries = retries self._supported_errors = supported_errors - def update_supported_errors(self, specified_errors: list): + def update_supported_errors( + self, specified_errors: Iterable[Type[Exception]] + ) -> None: """ Updates the supported errors with the specified error types """ @@ -32,7 +44,11 @@ def update_supported_errors(self, specified_errors: list): set(self._supported_errors + tuple(specified_errors)) ) - def call_with_retry(self, do, fail): + def call_with_retry( + self, + do: Callable[[], T], + fail: Callable[[Exception], Any], + ) -> T: """ Execute an operation that might fail and returns its result, or raise the exception that was thrown depending on the `Backoff` object. diff --git a/valkey/utils.py b/valkey/utils.py index adc40a8c..e6ce6213 100644 --- a/valkey/utils.py +++ b/valkey/utils.py @@ -139,3 +139,15 @@ def get_lib_version(): except metadata.PackageNotFoundError: libver = "99.99.99" return libver + + +def format_error_message(host_error: str, exception: BaseException) -> str: + if not exception.args: + return f"Error connecting to {host_error}." + elif len(exception.args) == 1: + return f"Error {exception.args[0]} connecting to {host_error}." + else: + return ( + f"Error {exception.args[0]} connecting to {host_error}. " + f"{exception.args[1]}." + )