diff --git a/valkey/_parsers/__init__.py b/valkey/_parsers/__init__.py index 6cc32e3c..c601be37 100644 --- a/valkey/_parsers/__init__.py +++ b/valkey/_parsers/__init__.py @@ -4,7 +4,7 @@ from .hiredis import _AsyncHiredisParser, _HiredisParser from .resp2 import _AsyncRESP2Parser, _RESP2Parser from .resp3 import _AsyncRESP3Parser, _RESP3Parser - +from url_parser import parse_url __all__ = [ "AsyncCommandsParser", "_AsyncHiredisParser", @@ -17,4 +17,5 @@ "_HiredisParser", "_RESP2Parser", "_RESP3Parser", + "parse_url", ] diff --git a/valkey/_parsers/url_parser.py b/valkey/_parsers/url_parser.py new file mode 100644 index 00000000..a8138f1b --- /dev/null +++ b/valkey/_parsers/url_parser.py @@ -0,0 +1,85 @@ +from valkey.asyncio.connection import ConnectKwargs, UnixDomainSocketConnection, SSLConnection +from urllib.parse import ParseResult, parse_qs, unquote, urlparse +from types import MappingProxyType +from typing import ( + Callable, + Mapping, + Optional, +) + + +def to_bool(value) -> Optional[bool]: + if value is None or value == "": + return None + if isinstance(value, str) and value.upper() in FALSE_STRINGS: + return False + return bool(value) + + +FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") + +URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType( + { + "db": int, + "socket_timeout": float, + "socket_connect_timeout": float, + "socket_keepalive": to_bool, + "retry_on_timeout": to_bool, + "max_connections": int, + "health_check_interval": int, + "ssl_check_hostname": to_bool, + "timeout": float, + } +) + + +def parse_url(url: str) -> ConnectKwargs: + parsed: ParseResult = urlparse(url) + kwargs: ConnectKwargs = {} + + for name, value_list in parse_qs(parsed.query).items(): + if value_list and len(value_list) > 0: + value = unquote(value_list[0]) + parser = URL_QUERY_ARGUMENT_PARSERS.get(name) + if parser: + try: + kwargs[name] = parser(value) + except (TypeError, ValueError): + raise ValueError(f"Invalid value for `{name}` in connection URL.") + else: + kwargs[name] = value + + if parsed.username: + kwargs["username"] = unquote(parsed.username) + if parsed.password: + kwargs["password"] = unquote(parsed.password) + + # We only support valkey://, valkeys:// and unix:// schemes. + if parsed.scheme == "unix": + if parsed.path: + kwargs["path"] = unquote(parsed.path) + kwargs["connection_class"] = UnixDomainSocketConnection + + elif parsed.scheme in ("valkey", "valkeys"): + if parsed.hostname: + kwargs["host"] = unquote(parsed.hostname) + if parsed.port: + kwargs["port"] = int(parsed.port) + + # If there's a path argument, use it as the db argument if a + # querystring value wasn't specified + if parsed.path and "db" not in kwargs: + try: + kwargs["db"] = int(unquote(parsed.path).replace("/", "")) + except (AttributeError, ValueError): + pass + + if parsed.scheme == "valkeys": + kwargs["connection_class"] = SSLConnection + else: + valid_schemes = "valkey://, valkeys://, unix://" + raise ValueError( + f"Valkey URL must specify one of the following schemes ({valid_schemes})" + ) + + return kwargs diff --git a/valkey/asyncio/connection.py b/valkey/asyncio/connection.py index 77b329a8..7ad5777a 100644 --- a/valkey/asyncio/connection.py +++ b/valkey/asyncio/connection.py @@ -9,10 +9,9 @@ import weakref from abc import abstractmethod from itertools import chain -from types import MappingProxyType +from .._parsers.url_parser import parse_url from typing import ( Any, - Callable, Iterable, List, Mapping, @@ -25,7 +24,6 @@ TypeVar, Union, ) -from urllib.parse import ParseResult, parse_qs, unquote, urlparse # the functionality is available in 3.11.x but has a major issue before # 3.11.3. See https://github.com/redis/redis-py/issues/2633 @@ -986,31 +984,8 @@ def _error_message(self, exception: BaseException) -> str: ) -FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") -def to_bool(value) -> Optional[bool]: - if value is None or value == "": - return None - if isinstance(value, str) and value.upper() in FALSE_STRINGS: - return False - return bool(value) - - -URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyType( - { - "db": int, - "socket_timeout": float, - "socket_connect_timeout": float, - "socket_keepalive": to_bool, - "retry_on_timeout": to_bool, - "max_connections": int, - "health_check_interval": int, - "ssl_check_hostname": to_bool, - "timeout": float, - } -) - class ConnectKwargs(TypedDict, total=False): username: str @@ -1022,56 +997,6 @@ class ConnectKwargs(TypedDict, total=False): path: str -def parse_url(url: str) -> ConnectKwargs: - parsed: ParseResult = urlparse(url) - kwargs: ConnectKwargs = {} - - for name, value_list in parse_qs(parsed.query).items(): - if value_list and len(value_list) > 0: - value = unquote(value_list[0]) - parser = URL_QUERY_ARGUMENT_PARSERS.get(name) - if parser: - try: - kwargs[name] = parser(value) - except (TypeError, ValueError): - raise ValueError(f"Invalid value for `{name}` in connection URL.") - else: - kwargs[name] = value - - if parsed.username: - kwargs["username"] = unquote(parsed.username) - if parsed.password: - kwargs["password"] = unquote(parsed.password) - - # We only support valkey://, valkeys:// and unix:// schemes. - if parsed.scheme == "unix": - if parsed.path: - kwargs["path"] = unquote(parsed.path) - kwargs["connection_class"] = UnixDomainSocketConnection - - elif parsed.scheme in ("valkey", "valkeys"): - if parsed.hostname: - kwargs["host"] = unquote(parsed.hostname) - if parsed.port: - kwargs["port"] = int(parsed.port) - - # If there's a path argument, use it as the db argument if a - # querystring value wasn't specified - if parsed.path and "db" not in kwargs: - try: - kwargs["db"] = int(unquote(parsed.path).replace("/", "")) - except (AttributeError, ValueError): - pass - - if parsed.scheme == "valkeys": - kwargs["connection_class"] = SSLConnection - else: - valid_schemes = "valkey://, valkeys://, unix://" - raise ValueError( - f"Valkey URL must specify one of the following schemes ({valid_schemes})" - ) - - return kwargs _CP = TypeVar("_CP", bound="ConnectionPool") diff --git a/valkey/connection.py b/valkey/connection.py index 29d3fbb0..584b0f23 100644 --- a/valkey/connection.py +++ b/valkey/connection.py @@ -10,7 +10,7 @@ from queue import Empty, Full, LifoQueue from time import time from typing import Any, Callable, List, Optional, Sequence, Type, Union -from urllib.parse import parse_qs, unquote, urlparse +from _parsers.url_parser import parse_url from ._cache import ( DEFAULT_ALLOW_LIST, @@ -106,9 +106,9 @@ def pack(self, *args): # output list if we're sending large values or memoryviews arg_length = len(arg) if ( - len(buff) > buffer_cutoff - or arg_length > buffer_cutoff - or isinstance(arg, memoryview) + len(buff) > buffer_cutoff + or arg_length > buffer_cutoff + or isinstance(arg, memoryview) ): buff = SYM_EMPTY.join( (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) @@ -135,35 +135,35 @@ class AbstractConnection: "Manages communication to and from a Valkey server" def __init__( - self, - db: int = 0, - password: Optional[str] = None, - socket_timeout: Optional[float] = None, - socket_connect_timeout: Optional[float] = None, - retry_on_timeout: bool = False, - retry_on_error=SENTINEL, - encoding: str = "utf-8", - encoding_errors: str = "strict", - decode_responses: bool = False, - parser_class=DefaultParser, - socket_read_size: int = 65536, - health_check_interval: int = 0, - client_name: Optional[str] = None, - lib_name: Optional[str] = "valkey-py", - lib_version: Optional[str] = get_lib_version(), - username: Optional[str] = None, - retry: Union[Any, None] = None, - valkey_connect_func: Optional[Callable[[], None]] = None, - credential_provider: Optional[CredentialProvider] = None, - protocol: Optional[int] = 2, - command_packer: Optional[Callable[[], None]] = None, - cache_enabled: bool = False, - client_cache: Optional[AbstractCache] = None, - cache_max_size: int = 10000, - cache_ttl: int = 0, - cache_policy: str = DEFAULT_EVICTION_POLICY, - cache_deny_list: List[str] = DEFAULT_DENY_LIST, - cache_allow_list: List[str] = DEFAULT_ALLOW_LIST, + self, + db: int = 0, + password: Optional[str] = None, + socket_timeout: Optional[float] = None, + socket_connect_timeout: Optional[float] = None, + retry_on_timeout: bool = False, + retry_on_error=SENTINEL, + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, + parser_class=DefaultParser, + socket_read_size: int = 65536, + health_check_interval: int = 0, + client_name: Optional[str] = None, + lib_name: Optional[str] = "valkey-py", + lib_version: Optional[str] = get_lib_version(), + username: Optional[str] = None, + retry: Union[Any, None] = None, + valkey_connect_func: Optional[Callable[[], None]] = None, + credential_provider: Optional[CredentialProvider] = None, + protocol: Optional[int] = 2, + command_packer: Optional[Callable[[], None]] = None, + cache_enabled: bool = False, + client_cache: Optional[AbstractCache] = None, + cache_max_size: int = 10000, + cache_ttl: int = 0, + cache_policy: str = DEFAULT_EVICTION_POLICY, + cache_deny_list: List[str] = DEFAULT_DENY_LIST, + cache_allow_list: List[str] = DEFAULT_ALLOW_LIST, ): """ Initialize a new Connection. @@ -351,8 +351,8 @@ def on_connect(self): # if credential provider or username and/or password are set, authenticate if self.credential_provider or (self.username or self.password): cred_provider = ( - self.credential_provider - or UsernamePasswordCredentialProvider(self.username, self.password) + self.credential_provider + or UsernamePasswordCredentialProvider(self.username, self.password) ) auth_args = cred_provider.get_credentials() @@ -400,8 +400,8 @@ def on_connect(self): self.send_command("HELLO", self.protocol) response = self.read_response() if ( - response.get(b"proto") != self.protocol - and response.get("proto") != self.protocol + response.get(b"proto") != self.protocol + and response.get("proto") != self.protocol ): raise ConnectionError("Invalid RESP version") @@ -529,11 +529,11 @@ def can_read(self, timeout=0): raise ConnectionError(f"Error while reading from {host_error}: {e.args}") def read_response( - self, - disable_decoding=False, - *, - disconnect_on_error=True, - push_request=False, + self, + disable_decoding=False, + *, + disconnect_on_error=True, + push_request=False, ): """Read the response from a previously sent command""" @@ -589,9 +589,9 @@ def pack_commands(self, commands): for chunk in self._command_packer.pack(*cmd): chunklen = len(chunk) if ( - buffer_length > buffer_cutoff - or chunklen > buffer_cutoff - or isinstance(chunk, memoryview) + buffer_length > buffer_cutoff + or chunklen > buffer_cutoff + or isinstance(chunk, memoryview) ): if pieces: output.append(SYM_EMPTY.join(pieces)) @@ -609,7 +609,7 @@ def pack_commands(self, commands): return output def _cache_invalidation_process( - self, data: List[Union[str, Optional[List[str]]]] + self, data: List[Union[str, Optional[List[str]]]] ) -> None: """ Invalidate (delete) all valkey commands associated with a specific key. @@ -628,9 +628,9 @@ def _get_from_local_cache(self, command: Sequence[str]): If the command is in the local cache, return the response """ if ( - self.client_cache is None - or command[0] in self.cache_deny_list - or command[0] not in self.cache_allow_list + self.client_cache is None + or command[0] in self.cache_deny_list + or command[0] not in self.cache_allow_list ): return None while self.can_read(): @@ -638,16 +638,16 @@ def _get_from_local_cache(self, command: Sequence[str]): return self.client_cache.get(command) def _add_to_local_cache( - self, command: Sequence[str], response: ResponseT, keys: List[KeysT] + self, command: Sequence[str], response: ResponseT, keys: List[KeysT] ): """ Add the command and response to the local cache if the command is allowed to be cached """ if ( - self.client_cache is not None - and (self.cache_deny_list == [] or command[0] not in self.cache_deny_list) - and (self.cache_allow_list == [] or command[0] in self.cache_allow_list) + self.client_cache is not None + and (self.cache_deny_list == [] or command[0] not in self.cache_deny_list) + and (self.cache_allow_list == [] or command[0] in self.cache_allow_list) ): self.client_cache.set(command, response, keys) @@ -668,13 +668,13 @@ class Connection(AbstractConnection): "Manages TCP communication to and from a Valkey server" def __init__( - self, - host="localhost", - port=6379, - socket_keepalive=False, - socket_keepalive_options=None, - socket_type=0, - **kwargs, + self, + host="localhost", + port=6379, + socket_keepalive=False, + socket_keepalive_options=None, + socket_type=0, + **kwargs, ): self.host = host self.port = int(port) @@ -696,7 +696,7 @@ def _connect(self): # socket.connect() err = None for res in socket.getaddrinfo( - self.host, self.port, self.socket_type, socket.SOCK_STREAM + self.host, self.port, self.socket_type, socket.SOCK_STREAM ): family, socktype, proto, canonname, socket_address = res sock = None @@ -762,22 +762,22 @@ class SSLConnection(Connection): """ # noqa def __init__( - self, - ssl_keyfile=None, - ssl_certfile=None, - ssl_cert_reqs="required", - ssl_ca_certs=None, - ssl_ca_data=None, - ssl_check_hostname=False, - ssl_ca_path=None, - ssl_password=None, - ssl_validate_ocsp=False, - ssl_validate_ocsp_stapled=False, - ssl_ocsp_context=None, - ssl_ocsp_expected_cert=None, - ssl_min_version=None, - ssl_ciphers=None, - **kwargs, + self, + ssl_keyfile=None, + ssl_certfile=None, + ssl_cert_reqs="required", + ssl_ca_certs=None, + ssl_ca_data=None, + ssl_check_hostname=False, + ssl_ca_path=None, + ssl_password=None, + ssl_validate_ocsp=False, + ssl_validate_ocsp_stapled=False, + ssl_ocsp_context=None, + ssl_ocsp_expected_cert=None, + ssl_min_version=None, + ssl_ciphers=None, + **kwargs, ): """Constructor @@ -846,9 +846,9 @@ def _connect(self): password=self.certificate_password, ) if ( - self.ca_certs is not None - or self.ca_path is not None - or self.ca_data is not None + self.ca_certs is not None + or self.ca_path is not None + or self.ca_data is not None ): context.load_verify_locations( cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data @@ -970,63 +970,6 @@ def to_bool(value): } -def parse_url(url): - if not ( - url.startswith("valkey://") - or url.startswith("valkeys://") - or url.startswith("unix://") - ): - raise ValueError( - "Valkey URL must specify one of the following " - "schemes (valkey://, valkeys://, unix://)" - ) - - url = urlparse(url) - kwargs = {} - - for name, value in parse_qs(url.query).items(): - if value and len(value) > 0: - value = unquote(value[0]) - parser = URL_QUERY_ARGUMENT_PARSERS.get(name) - if parser: - try: - kwargs[name] = parser(value) - except (TypeError, ValueError): - raise ValueError(f"Invalid value for `{name}` in connection URL.") - else: - kwargs[name] = value - - if url.username: - kwargs["username"] = unquote(url.username) - if url.password: - kwargs["password"] = unquote(url.password) - - # We only support valkey://, valkeys:// and unix:// schemes. - if url.scheme == "unix": - if url.path: - kwargs["path"] = unquote(url.path) - kwargs["connection_class"] = UnixDomainSocketConnection - - else: # implied: url.scheme in ("valkey", "valkeys"): - if url.hostname: - kwargs["host"] = unquote(url.hostname) - if url.port: - kwargs["port"] = int(url.port) - - # If there's a path argument, use it as the db argument if a - # querystring value wasn't specified - if url.path and "db" not in kwargs: - try: - kwargs["db"] = int(unquote(url.path).replace("/", "")) - except (AttributeError, ValueError): - pass - - if url.scheme == "valkeys": - kwargs["connection_class"] = SSLConnection - - return kwargs - - class ConnectionPool: """ Create a connection pool. ``If max_connections`` is set, then this @@ -1089,12 +1032,12 @@ class initializer. In the case of conflicting arguments, querystring return cls(**kwargs) def __init__( - self, - connection_class=Connection, - max_connections: Optional[int] = None, - **connection_kwargs, + self, + connection_class=Connection, + max_connections: Optional[int] = None, + **connection_kwargs, ): - max_connections = max_connections or 2**31 + max_connections = max_connections or 2 ** 31 if not isinstance(max_connections, int) or max_connections < 0: raise ValueError('"max_connections" must be a positive integer') @@ -1347,12 +1290,12 @@ class BlockingConnectionPool(ConnectionPool): """ def __init__( - self, - max_connections=50, - timeout=20, - connection_class=Connection, - queue_class=LifoQueue, - **connection_kwargs, + self, + max_connections=50, + timeout=20, + connection_class=Connection, + queue_class=LifoQueue, + **connection_kwargs, ): self.queue_class = queue_class self.timeout = timeout